Skip to content

Commit

Permalink
Merge branch 'main' into kylesayrs/model_compressor-typechecking-import
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Jan 14, 2025
2 parents bed83b6 + e05d8f4 commit 4e725ff
Show file tree
Hide file tree
Showing 37 changed files with 2,253 additions and 291 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,9 @@ jobs:
with:
ref: ${{ inputs.gitref }}

- name: install testmo
uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0

- name: create testmo run
id: create_testmo_run
uses: neuralmagic/nm-actions/actions/testmo-run-create@v1.2.0
if: success()
uses: neuralmagic/nm-actions/actions/testmo-run-create@v1.11.0
with:
testmo_url: https://neuralmagic.testmo.net
testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }}
Expand Down Expand Up @@ -142,8 +138,8 @@ jobs:

- name: report build status to testmo
id: report_build
uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.2.0
if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }}
uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.11.0
if: success() || failure()
with:
testmo_url: https://neuralmagic.testmo.net
testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }}
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ jobs:
with:
venv: TEST

- name: install testmo
uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0

- name: download whl
id: download
uses: actions/download-artifact@v4
Expand All @@ -108,8 +105,8 @@ jobs:

- name: report test results
id: report_test
uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.2.0
if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }}
uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.11.0
if: (success() || failure()) && inputs.testmo_run_id != ''
with:
testmo_url: https://neuralmagic.testmo.net
testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }}
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ jobs:
with:
python-version: 3.10.12

- name: install testmo
uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0

- name: complete testmo run
uses: neuralmagic/nm-actions/actions/testmo-run-complete@v1.2.0
if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }}
uses: neuralmagic/nm-actions/actions/testmo-run-complete@v1.11.0
if: (success() || failure()) && inputs.testmo_run_id != ''
with:
testmo_url: https://neuralmagic.testmo.net
testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }}
Expand Down
38 changes: 32 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -15,7 +15,33 @@
import os
from setuptools import setup, find_packages
from typing import List, Dict, Tuple
from utils.artifacts import get_release_and_version


def get_release_and_version(package_path: str) -> Tuple[bool, bool, str, str, str, str]:
"""
Load version and release info from compressed-tensors package
"""
# compressed-tensors/src/compressed_tensors/version.py always exists, default source of truth
version_path = os.path.join(package_path, "version.py")

# exec() cannot set local variables so need to manually
locals_dict = {}
exec(open(version_path).read(), globals(), locals_dict)
is_release = locals_dict.get("is_release", False)
version = locals_dict.get("version", "unknown")
version_major = locals_dict.get("version_major", "unknown")
version_minor = locals_dict.get("version_minor", "unknown")
version_bug = locals_dict.get("version_bug", "unknown")

print(f"Loaded version {version} from {version_path}")

return (
is_release,
version,
version_major,
version_minor,
version_bug,
)


package_path = os.path.join(
Expand All @@ -35,7 +61,7 @@
_PACKAGE_NAME = "compressed-tensors"
else:
_PACKAGE_NAME = "compressed-tensors-nightly"


def _setup_long_description() -> Tuple[str, str]:
return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
Expand All @@ -44,7 +70,7 @@ def _setup_packages() -> List:
return find_packages(
"src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"]
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -104,7 +106,6 @@ def from_pretrained(
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

@classmethod
Expand Down Expand Up @@ -137,7 +138,7 @@ def from_compression_config(
format, **sparsity_config
)
if quantization_config is not None:
quantization_config = QuantizationConfig.parse_obj(quantization_config)
quantization_config = QuantizationConfig.model_validate(quantization_config)

return cls(
sparsity_config=sparsity_config, quantization_config=quantization_config
Expand Down Expand Up @@ -193,7 +194,7 @@ def parse_sparsity_config(

if is_compressed_tensors_config(compression_config):
s_config = compression_config.sparsity_config
return s_config.dict() if s_config is not None else None
return s_config.model_dump() if s_config is not None else None

return compression_config.get(SPARSITY_CONFIG_NAME, None)

Expand All @@ -214,7 +215,7 @@ def parse_quantization_config(

if is_compressed_tensors_config(compression_config):
q_config = compression_config.quantization_config
return q_config.dict() if q_config is not None else None
return q_config.model_dump() if q_config is not None else None

quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
Expand Down Expand Up @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_sparse_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand All @@ -301,23 +308,44 @@ def decompress(self, model_path: str, model: Module):
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
if self.sparsity_compressor is not None:
sparse_decompressed = False

if (
self.sparsity_compressor is not None
and self.sparsity_config.format != CompressionFormat.dense.value
):
# Sparse decompression is applied on the model_path
dense_gen = self.sparsity_compressor.decompress(model_path)
self._replace_weights(dense_gen, model)
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
load_pretrained_quantization(model, model_path)

model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
model_path, names_to_scheme=names_to_scheme
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
self._replace_weights(dense_gen, model)

def update_status(module):
def freeze_quantization_status(module):
module.quantization_status = QuantizationStatus.FROZEN

model.apply(update_status)
model.apply(freeze_quantization_status)
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)

def update_config(self, save_directory: str):
Expand Down Expand Up @@ -367,12 +395,26 @@ def update_config(self, save_directory: str):
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

def _replace_weights(self, dense_weight_generator, model):
def _replace_weights(self, dense_weight_generator, model: Module):
"""
Replace the weights of the model with the
provided dense weights.
This method iterates over the dense_weight_generator and
updates the corresponding weights in the model. If a parameter
name does not exist in the model, it will be skipped.
:param dense_weight_generator (generator): A generator that yields
tuples of (name, data), where 'name' is the parameter name and
'data' is the updated param data
:param model: The model whose weights are to be updated.
"""
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
split_name = name.split(".")
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
module = operator.attrgetter(prefix)(model)
update_parameter_data(module, data, param_name)
if hasattr(module, param_name):
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
Expand Down Expand Up @@ -402,3 +444,23 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.
:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
40 changes: 35 additions & 5 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
)
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -113,30 +118,55 @@ def compress(

def decompress(
self,
path_to_model_or_tensors: str,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, (str, Path)):
yield from self._decompress_from_path(
path_to_model_or_tensors, names_to_scheme, device
)

else:
yield from self._decompress_from_state_dict(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(self, path_to_model, names_to_scheme, device):
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
path_to_model, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def _decompress_from_state_dict(self, state_dict, names_to_scheme):
weight_mappings = get_nested_mappings_from_state_dict(
state_dict, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, param_value in weight_mappings[weight_name].items():
weight_data[param_name] = param_value

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
Expand Down
Loading

0 comments on commit 4e725ff

Please sign in to comment.