diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 4619d581..66f8cd39 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Optional, Union import torch +from compressed_tensors.utils import Aliasable from pydantic import BaseModel, Field, field_validator, model_validator @@ -53,17 +54,30 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" -class ActivationOrdering(str, Enum): +class ActivationOrdering(Aliasable, str, Enum): """ Enum storing strategies for activation ordering Group: reorder groups and weight\n - Weight: only reorder weight, not groups. Slightly lower latency and - accuracy compared to group actorder\n + Weight: only reorder weight, not groups. Slightly lower accuracy but also lower + latency when compared to group actorder\n + Dynamic: alias for Group\n + Static: alias for Weight\n """ GROUP = "group" WEIGHT = "weight" + # aliases + DYNAMIC = "dynamic" + STATIC = "static" + + @property + @staticmethod + def aliases(self) -> Dict[str, str]: + return { + "dynamic": "group", + "static": "weight", + } class QuantizationArgs(BaseModel, use_enum_values=True): diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 3a8152da..36b88604 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model + """ Pre-Set Quantization Scheme Args """ diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index e1587ada..9fb9f971 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from abc import abstractmethod +from typing import Any, Dict, Optional import torch from transformers import AutoConfig @@ -24,6 +25,7 @@ "tensor_follows_mask_structure", "replace_module", "is_compressed_tensors_config", + "Aliasable", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -119,3 +121,36 @@ def is_compressed_tensors_config(compression_config: Any) -> bool: return isinstance(compression_config, CompressedTensorsConfig) except ImportError: return False + + +class Aliasable: + """ + A mixin for enums to allow aliasing of enum members + + Example: + >>> class MyClass(Aliasable, int, Enum): + >>> ... + """ + + @property + @staticmethod + @abstractmethod + def aliases(self) -> Dict[str, str]: + raise NotImplementedError() + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value or ( + self.aliases.get(self.value, self.value) + == self.aliases.get(other.value, other.value) + ) + else: + self_value = self.aliases.get(self.value, self.value) + other_value = self.aliases.get(other, other) + return self_value == other_value + + return False + + def __hash__(self): + canonical_value = self.aliases.get(self.value, self.value) + return hash(canonical_value) diff --git a/tests/test_quantization/test_quant_args.py b/tests/test_quantization/test_quant_args.py index b0972125..3a42a626 100644 --- a/tests/test_quantization/test_quant_args.py +++ b/tests/test_quantization/test_quant_args.py @@ -83,14 +83,28 @@ def test_actorder(): # test group inference with actorder args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.GROUP) assert args.strategy == QuantizationStrategy.GROUP + args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.DYNAMIC) + assert args.strategy == QuantizationStrategy.GROUP # test invalid pairings + with pytest.raises(ValueError): + QuantizationArgs(group_size=None, actorder="group") with pytest.raises(ValueError): QuantizationArgs(group_size=None, actorder="weight") + with pytest.raises(ValueError): + QuantizationArgs(group_size=None, actorder="static") + with pytest.raises(ValueError): + QuantizationArgs(group_size=-1, actorder="group") with pytest.raises(ValueError): QuantizationArgs(group_size=-1, actorder="weight") + with pytest.raises(ValueError): + QuantizationArgs(group_size=-1, actorder="static") + with pytest.raises(ValueError): + QuantizationArgs(strategy="tensor", actorder="group") with pytest.raises(ValueError): QuantizationArgs(strategy="tensor", actorder="weight") + with pytest.raises(ValueError): + QuantizationArgs(strategy="tensor", actorder="static") # test boolean and none defaulting assert ( @@ -101,6 +115,38 @@ def test_actorder(): assert QuantizationArgs(group_size=1, actorder=None).actorder is None +def test_actorder_aliases(): + assert ( + ActivationOrdering.GROUP + == ActivationOrdering.DYNAMIC + == ActivationOrdering.GROUP + ) + assert ( + ActivationOrdering.WEIGHT + == ActivationOrdering.STATIC + == ActivationOrdering.WEIGHT + ) + + assert ActivationOrdering.GROUP == "dynamic" == ActivationOrdering.GROUP + assert ActivationOrdering.DYNAMIC == "dynamic" == ActivationOrdering.DYNAMIC + assert ActivationOrdering.GROUP == "group" == ActivationOrdering.GROUP + assert ActivationOrdering.DYNAMIC == "group" == ActivationOrdering.DYNAMIC + + assert ActivationOrdering.WEIGHT == "static" == ActivationOrdering.WEIGHT + assert ActivationOrdering.STATIC == "static" == ActivationOrdering.STATIC + assert ActivationOrdering.WEIGHT == "weight" == ActivationOrdering.WEIGHT + assert ActivationOrdering.STATIC == "weight" == ActivationOrdering.STATIC + + assert ActivationOrdering.WEIGHT != "dynamic" != ActivationOrdering.WEIGHT + assert ActivationOrdering.STATIC != "dynamic" != ActivationOrdering.STATIC + assert ActivationOrdering.WEIGHT != "group" != ActivationOrdering.WEIGHT + assert ActivationOrdering.STATIC != "group" != ActivationOrdering.STATIC + assert ActivationOrdering.GROUP != "static" != ActivationOrdering.GROUP + assert ActivationOrdering.DYNAMIC != "static" != ActivationOrdering.DYNAMIC + assert ActivationOrdering.GROUP != "weight" != ActivationOrdering.GROUP + assert ActivationOrdering.DYNAMIC != "weight" != ActivationOrdering.DYNAMIC + + def test_invalid(): with pytest.raises(ValidationError): QuantizationArgs(type="invalid")