From eb1c705cd0eb3a644f20fa0ab8c821aa89e0fbec Mon Sep 17 00:00:00 2001
From: Kyle Sayers <kylesayrs@gmail.com>
Date: Fri, 4 Oct 2024 13:07:58 -0400
Subject: [PATCH 01/12] remove function (#156)

Co-authored-by: Kyle Sayers <kyle@neuralmagic.com>
---
 .../quantization/lifecycle/helpers.py         | 47 -------------------
 1 file changed, 47 deletions(-)

diff --git a/src/compressed_tensors/quantization/lifecycle/helpers.py b/src/compressed_tensors/quantization/lifecycle/helpers.py
index 497a9921..9d755328 100644
--- a/src/compressed_tensors/quantization/lifecycle/helpers.py
+++ b/src/compressed_tensors/quantization/lifecycle/helpers.py
@@ -16,62 +16,15 @@
 Miscelaneous helpers for the quantization lifecycle
 """
 
-from typing import Optional
-
-import torch
 from torch.nn import Module
 
 
 __all__ = [
-    "update_layer_weight_quant_params",
     "enable_quantization",
     "disable_quantization",
 ]
 
 
-def update_layer_weight_quant_params(
-    layer: Module,
-    weight: Optional[torch.Tensor] = None,
-    g_idx: Optional[torch.Tensor] = None,
-    reset_obs: bool = False,
-):
-    """
-    Update quantization parameters on layer
-
-    :param layer: input layer
-    :param weight: weight to update quant params with, defaults to layer weight
-    :param g_idx: optional mapping from column index to group index
-    :param reset_obs: reset the observer before calculating quant params,
-        defaults to False
-    """
-    attached_weight = getattr(layer, "weight", None)
-
-    if weight is None:
-        weight = attached_weight
-    scale = getattr(layer, "weight_scale", None)
-    zero_point = getattr(layer, "weight_zero_point", None)
-    if g_idx is None:
-        g_idx = getattr(layer, "weight_g_idx", None)
-    observer = getattr(layer, "weight_observer", None)
-
-    if weight is None or observer is None or scale is None or zero_point is None:
-        # scale, zp, or observer not calibratable or weight not available
-        return
-
-    if reset_obs:
-        observer.reset()
-
-    if attached_weight is not None:
-        weight = weight.to(attached_weight.dtype)
-
-    updated_scale, updated_zero_point = observer(weight)
-
-    # update scale and zero point
-    device = next(layer.parameters()).device
-    scale.data = updated_scale.to(device)
-    zero_point.data = updated_zero_point.to(device)
-
-
 def enable_quantization(module: Module):
     module.quantization_enabled = True
 

From c2455b7c5040adc0e7fc6a197144971b2597d607 Mon Sep 17 00:00:00 2001
From: Rahul Tuli <rahul@neuralmagic.com>
Date: Fri, 4 Oct 2024 16:07:11 -0400
Subject: [PATCH 02/12] Revert "Ignore Dense sparsity config (#169)" (#181)

---
 .../compressors/model_compressors/model_compressor.py         | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py
index ac15fdaa..6473554d 100644
--- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py
+++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py
@@ -242,10 +242,6 @@ def __init__(
         self.sparsity_compressor = None
         self.quantization_compressor = None
 
-        if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
-            # ignore dense sparsity config
-            self.sparsity_config = None
-
         if sparsity_config is not None:
             self.sparsity_compressor = BaseCompressor.load_from_registry(
                 sparsity_config.format, config=sparsity_config

From d6d823cec0d5bde2151b98ba15b5771c92fb390e Mon Sep 17 00:00:00 2001
From: Kyle Sayers <kylesayrs@gmail.com>
Date: Mon, 7 Oct 2024 12:55:10 -0400
Subject: [PATCH 03/12] Workaround HF Quantizer `apply_quantization_config`
 misuse (#180)

* workaround hf quantizer apply none

* Add usage comment
---
 src/compressed_tensors/quantization/lifecycle/apply.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py
index a66dba92..09281528 100644
--- a/src/compressed_tensors/quantization/lifecycle/apply.py
+++ b/src/compressed_tensors/quantization/lifecycle/apply.py
@@ -107,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
 
 
 def apply_quantization_config(
-    model: Module, config: QuantizationConfig, run_compressed: bool = False
-) -> Dict:
+    model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
+) -> OrderedDict:
     """
     Initializes the model for quantization in-place based on the given config
 
@@ -117,6 +117,10 @@ def apply_quantization_config(
     :param run_compressed: Whether the model will be run in compressed mode or
         decompressed fully on load
     """
+    # Workaround for when HF Quantizer passes None, see PR #180
+    if config is None:
+        return OrderedDict()
+
     # remove reference to the original `config`
     # argument. This function can mutate it, and we'd
     # like to keep the original `config` as it is.

From b876a600cf2264e38e505cb8853b89d54110d027 Mon Sep 17 00:00:00 2001
From: dhuangnm <74931910+dhuangnm@users.noreply.github.com>
Date: Tue, 8 Oct 2024 10:41:24 -0400
Subject: [PATCH 04/12] bump up version to 0.7.0 (#186)

Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local>
---
 src/compressed_tensors/version.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py
index 17e7b592..1d241270 100644
--- a/src/compressed_tensors/version.py
+++ b/src/compressed_tensors/version.py
@@ -17,7 +17,7 @@
 """
 
 
-version_base = "0.6.0"
+version_base = "0.7.0"
 is_release = True  # change to True to set the generated version as a release version
 
 

From b2abe724a5b97821b90fdaa602eebdb1fa8a5eec Mon Sep 17 00:00:00 2001
From: Dipika Sikka <dipikasikka1@gmail.com>
Date: Fri, 11 Oct 2024 14:14:53 -0400
Subject: [PATCH 05/12] [Observer Restructure]: Remove MemoryLess Observer; use
 helper function for dynamic quantization (#187)

* remove memoryless observer; use helper function for dynamic quantization

* update init

* clean-up

* update test case

* fix arg

* validation + update name

* update preset schemes; swap condition check
---
 .../quantization/lifecycle/forward.py         | 10 ++--
 .../quantization/lifecycle/initialize.py      | 13 +++--
 .../quantization/observers/__init__.py        |  1 -
 .../quantization/observers/helpers.py         | 42 +++++++++++++-
 .../quantization/observers/memoryless.py      | 56 -------------------
 .../quantization/quant_args.py                | 32 +++++++++--
 .../quantization/quant_scheme.py              |  3 +
 .../lifecycle/test_dynamic_lifecycle.py       |  2 +-
 8 files changed, 85 insertions(+), 74 deletions(-)
 delete mode 100644 src/compressed_tensors/quantization/observers/memoryless.py

diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py
index 4dbe4a85..d685c0c0 100644
--- a/src/compressed_tensors/quantization/lifecycle/forward.py
+++ b/src/compressed_tensors/quantization/lifecycle/forward.py
@@ -18,7 +18,10 @@
 
 import torch
 from compressed_tensors.quantization.cache import QuantizedKVParameterCache
-from compressed_tensors.quantization.observers.helpers import calculate_range
+from compressed_tensors.quantization.observers.helpers import (
+    calculate_range,
+    compute_dynamic_scales_and_zp,
+)
 from compressed_tensors.quantization.quant_args import (
     QuantizationArgs,
     QuantizationStrategy,
@@ -376,9 +379,8 @@ def maybe_calibrate_or_quantize(
     g_idx = getattr(module, "weight_g_idx", None)
 
     if args.dynamic:
-        # dynamic quantization - get scale and zero point directly from observer
-        observer = getattr(module, f"{base_name}_observer")
-        scale, zero_point = observer(value, g_idx=g_idx)
+        # dynamic quantization - no need to invoke observer
+        scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
     else:
         # static quantization - get previous scale and zero point from layer
         scale = getattr(module, f"{base_name}_scale")
diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py
index 49e7b1a9..9b98da33 100644
--- a/src/compressed_tensors/quantization/lifecycle/initialize.py
+++ b/src/compressed_tensors/quantization/lifecycle/initialize.py
@@ -153,12 +153,16 @@ def _initialize_scale_zero_point_observer(
     weight_shape: Optional[torch.Size] = None,
     force_zero_point: bool = True,
 ):
+
     # initialize observer module and attach as submodule
     observer = quantization_args.get_observer()
-    module.register_module(f"{base_name}_observer", observer)
+    # no need to register an observer for dynamic quantization
+    if observer:
+        module.register_module(f"{base_name}_observer", observer)
 
+    # no need to register a scale and zero point for a dynamic quantization
     if quantization_args.dynamic:
-        return  # no need to register a scale and zero point for a dynamic observer
+        return
 
     device = next(module.parameters()).device
     if is_module_offloaded(module):
@@ -173,10 +177,7 @@ def _initialize_scale_zero_point_observer(
             expected_shape = (weight_shape[0], 1)
         elif quantization_args.strategy == QuantizationStrategy.GROUP:
             num_groups = weight_shape[1] // quantization_args.group_size
-            expected_shape = (
-                weight_shape[0],
-                max(num_groups, 1)
-            )
+            expected_shape = (weight_shape[0], max(num_groups, 1))
 
     scale_dtype = module.weight.dtype
     if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
diff --git a/src/compressed_tensors/quantization/observers/__init__.py b/src/compressed_tensors/quantization/observers/__init__.py
index 2cb15c96..05b6b367 100644
--- a/src/compressed_tensors/quantization/observers/__init__.py
+++ b/src/compressed_tensors/quantization/observers/__init__.py
@@ -17,6 +17,5 @@
 
 from .helpers import *
 from .base import *
-from .memoryless import *
 from .min_max import *
 from .mse import *
diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py
index 13c05991..875a05b3 100644
--- a/src/compressed_tensors/quantization/observers/helpers.py
+++ b/src/compressed_tensors/quantization/observers/helpers.py
@@ -13,18 +13,56 @@
 # limitations under the License.
 
 from collections import Counter
-from typing import Tuple
+from typing import Optional, Tuple
 
 import torch
 from compressed_tensors.quantization.quant_args import (
     FP8_DTYPE,
     QuantizationArgs,
+    QuantizationStrategy,
     QuantizationType,
 )
 from torch import FloatTensor, IntTensor, Tensor
 
 
-__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
+__all__ = [
+    "calculate_qparams",
+    "get_observer_token_count",
+    "calculate_range",
+    "compute_dynamic_scales_and_zp",
+]
+
+
+def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
+    """
+    Returns the computed scales and zero points for dynamic activation
+    qunatization.
+
+    :param value: tensor to calculate quantization parameters for
+    :param args: quantization args
+    :param reduce_dims: optional tuple of dimensions to reduce along,
+        returned scale and zero point will be shaped (1,) along the
+        reduced dimensions
+    :return: tuple of scale and zero point derived from the observed tensor
+    """
+    if args.strategy == QuantizationStrategy.TOKEN:
+        dim = {1, 2}
+        reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
+    elif args.strategy == QuantizationStrategy.TENSOR:
+        reduce_dims = None
+    else:
+        raise ValueError(
+            f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
+            "must be used for dynamic quantization",
+        )
+
+    if not reduce_dims:
+        min_val, max_val = torch.aminmax(value)
+    else:
+        min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
+        max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
+
+    return calculate_qparams(min_val, max_val, args)
 
 
 def get_observer_token_count(module: torch.nn.Module) -> Counter:
diff --git a/src/compressed_tensors/quantization/observers/memoryless.py b/src/compressed_tensors/quantization/observers/memoryless.py
deleted file mode 100644
index ea8a7a01..00000000
--- a/src/compressed_tensors/quantization/observers/memoryless.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Optional, Tuple
-
-import torch
-from compressed_tensors.quantization.observers.base import Observer
-from compressed_tensors.quantization.observers.helpers import calculate_qparams
-from torch import FloatTensor, IntTensor, Tensor
-
-
-__all__ = ["MemorylessObserver"]
-
-
-@Observer.register("memoryless", alias=["dynamic"])
-class MemorylessObserver(Observer):
-    """
-    Implements a quantization observer that sets the scale and
-    zero point based on the latest observed value without tracking state
-    """
-
-    def calculate_qparams(
-        self,
-        observed: Tensor,
-        tensor_id: Optional[Any] = None,
-        reduce_dims: Optional[Tuple[int]] = None,
-    ) -> Tuple[FloatTensor, IntTensor]:
-        """
-        Returns the min and max values of observed tensor
-
-        :param observed: observed tensor to calculate quantization parameters for
-        :param tensor_id: optional id for tensor; not used for memoryless
-        :param reduce_dims: optional tuple of dimensions to reduce along,
-            returned scale and zero point will be shaped (1,) along the
-            reduced dimensions
-        :return: tuple of scale and zero point derived from the observed tensor
-        """
-
-        if not reduce_dims:
-            min_val, max_val = torch.aminmax(observed)
-        else:
-            min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
-            max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
-
-        return calculate_qparams(min_val, max_val, self.quantization_args)
diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py
index 54805c58..c2fc0b6a 100644
--- a/src/compressed_tensors/quantization/quant_args.py
+++ b/src/compressed_tensors/quantization/quant_args.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import warnings
 from enum import Enum
 from typing import Any, Dict, Optional, Union
 
@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
     block_structure: Optional[str] = None
     dynamic: bool = False
     actorder: Union[ActivationOrdering, bool, None] = None
-    observer: str = Field(
+    observer: Optional[str] = Field(
         default="minmax",
         description=(
             "The class to use to compute the quantization param - "
@@ -115,10 +116,10 @@ def get_observer(self):
         """
         from compressed_tensors.quantization.observers.base import Observer
 
+        # No observer required for the dynamic case
         if self.dynamic:
-            # override defualt observer for dynamic, you never want minmax which
-            # keeps state across samples for dynamic
-            self.observer = "memoryless"
+            self.observer = None
+            return self.observer
 
         return Observer.load_from_registry(self.observer, quantization_args=self)
 
@@ -171,6 +172,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
         strategy = model.strategy
         group_size = model.group_size
         actorder = model.actorder
+        dynamic = model.dynamic
+        observer = model.observer
 
         # infer strategy
         if strategy is None:
@@ -207,6 +210,27 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
                 "activation ordering"
             )
 
+        if dynamic:
+            if strategy not in (
+                QuantizationStrategy.TOKEN,
+                QuantizationStrategy.TENSOR,
+            ):
+                raise ValueError(
+                    f"One of {QuantizationStrategy.TOKEN} or "
+                    f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
+                    "quantization",
+                )
+            if observer is not None:
+                warnings.warn(
+                    "No observer is used for dynamic quantization, setting to None"
+                )
+                model.observer = None
+
+        # if we have not set an observer and we
+        # are running static quantization, use minmax
+        if not observer and not dynamic:
+            model.observer = "minmax"
+
         # write back modified values
         model.strategy = strategy
         return model
diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py
index b41eaafb..180d0f26 100644
--- a/src/compressed_tensors/quantization/quant_scheme.py
+++ b/src/compressed_tensors/quantization/quant_scheme.py
@@ -122,6 +122,7 @@ def is_preset_scheme(name: str) -> bool:
         strategy=QuantizationStrategy.TOKEN,
         symmetric=True,
         dynamic=True,
+        observer=None,
     ),
 )
 
@@ -164,6 +165,7 @@ def is_preset_scheme(name: str) -> bool:
         strategy=QuantizationStrategy.TOKEN,
         symmetric=True,
         dynamic=True,
+        observer=None,
     ),
 )
 
@@ -200,6 +202,7 @@ def is_preset_scheme(name: str) -> bool:
         strategy=QuantizationStrategy.TOKEN,
         symmetric=True,
         dynamic=True,
+        observer=None,
     ),
 )
 
diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
index be228451..45d49370 100644
--- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
+++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
@@ -73,7 +73,7 @@ def _test_layer_dynamic_quantization_status(
     # check inputs always have an observer if quantized but never scale/zp
     assert not hasattr(module, "input_scale")
     assert not hasattr(module, "input_zero_point")
-    assert hasattr(module, "input_observer") == inputs
+    assert not hasattr(module, "input_observer")
 
     # check weights always have scale/zp and observer only if not frozen
     assert hasattr(module, "weight_scale") == weights

From 506cd36829f6fd386a0a2c75ed8a51cc6dcab619 Mon Sep 17 00:00:00 2001
From: dhuangnm <74931910+dhuangnm@users.noreply.github.com>
Date: Wed, 16 Oct 2024 15:44:43 -0400
Subject: [PATCH 06/12] bump up to 0.7.1 for patch release (#192)

Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local>
---
 src/compressed_tensors/version.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py
index 1d241270..d0ba0363 100644
--- a/src/compressed_tensors/version.py
+++ b/src/compressed_tensors/version.py
@@ -17,7 +17,7 @@
 """
 
 
-version_base = "0.7.0"
+version_base = "0.7.1"
 is_release = True  # change to True to set the generated version as a release version
 
 

From 955e9068d4bd4cd632cde212117d1bf8c880e2ed Mon Sep 17 00:00:00 2001
From: Dipika Sikka <dipikasikka1@gmail.com>
Date: Fri, 18 Oct 2024 11:15:34 -0400
Subject: [PATCH 07/12] [Observer Restructure]: Separate out scale/zp and
 observer init; separate  out calibration from forward pass (#188)

* separate out scale/zp and observer init

* temporary workaround

* separate out calibration from forward pass

* Fix typo

* fix missing import

* fix tests

* update all other tests

* Fix an incorrectly working test

* clean

* update

* clean-up

* fix test case

* remove commented code

* remove TODOs

* remove TODO
---
 src/compressed_tensors/quantization/cache.py  |  1 -
 .../quantization/lifecycle/calibration.py     | 12 +++
 .../quantization/lifecycle/forward.py         | 83 +++++++++++++------
 .../quantization/lifecycle/frozen.py          | 10 +--
 .../quantization/lifecycle/initialize.py      | 28 +++----
 .../lifecycle/test_dynamic_lifecycle.py       |  1 +
 .../lifecycle/test_forward.py                 | 50 +++++++----
 .../lifecycle/test_frozen.py                  |  5 +-
 .../lifecycle/test_lifecycle.py               |  5 +-
 9 files changed, 127 insertions(+), 68 deletions(-)

diff --git a/src/compressed_tensors/quantization/cache.py b/src/compressed_tensors/quantization/cache.py
index cc33a48a..312f1c9d 100644
--- a/src/compressed_tensors/quantization/cache.py
+++ b/src/compressed_tensors/quantization/cache.py
@@ -28,7 +28,6 @@ class KVCacheScaleType(Enum):
 
 
 class QuantizedKVParameterCache(HFDyanmicCache):
-
     """
     Quantized KV cache used in the forward call based on HF's dynamic cache.
     Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py
index d444694d..66dc35a0 100644
--- a/src/compressed_tensors/quantization/lifecycle/calibration.py
+++ b/src/compressed_tensors/quantization/lifecycle/calibration.py
@@ -53,7 +53,19 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
 
     if quantize_weights_upfront and module.quantization_scheme.weights is not None:
         # set weight scale and zero_point up front, calibration data doesn't affect it
+        if not hasattr(module, "weight_observer"):
+            from compressed_tensors.quantization.lifecycle.initialize import (
+                initialize_observers,
+            )
+
+            initialize_observers(
+                module=module,
+                base_name="weight",
+                quantization_args=module.quantization_scheme.weights,
+            )
+
         observer = module.weight_observer
+
         g_idx = getattr(module, "weight_g_idx", None)
 
         offloaded = is_module_offloaded(module)
diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py
index d685c0c0..eae641f8 100644
--- a/src/compressed_tensors/quantization/lifecycle/forward.py
+++ b/src/compressed_tensors/quantization/lifecycle/forward.py
@@ -38,7 +38,8 @@
     "dequantize",
     "fake_quantize",
     "wrap_module_forward_quantized",
-    "maybe_calibrate_or_quantize",
+    "forward_quantize",
+    "calibrate_activations",
 ]
 
 
@@ -276,14 +277,24 @@ def wrapped_forward(self, *args, **kwargs):
 
         if scheme.input_activations is not None:
             # calibrate and (fake) quantize input activations when applicable
-            input_ = maybe_calibrate_or_quantize(
-                module, input_, "input", scheme.input_activations
-            )
+            # NOTE: will be moved out of compressed-tensors
+            if (
+                module.quantization_status == QuantizationStatus.CALIBRATION
+                and not scheme.input_activations.dynamic
+            ):
+                calibrate_activations(
+                    module=module,
+                    value=input_,
+                    base_name="input",
+                    quantization_args=scheme.input_activations,
+                )
+
+            input_ = forward_quantize(module, input_, "input", scheme.input_activations)
 
         if scheme.weights is not None and not compressed:
             # calibrate and (fake) quantize weights when applicable
             unquantized_weight = self.weight.data.clone()
-            self.weight.data = maybe_calibrate_or_quantize(
+            self.weight.data = forward_quantize(
                 module, self.weight, "weight", scheme.weights
             )
 
@@ -296,7 +307,19 @@ def wrapped_forward(self, *args, **kwargs):
             # calibrate and (fake) quantize output activations when applicable
             # kv_cache scales updated on model self_attn forward call in
             # wrap_module_forward_quantized_attn
-            output = maybe_calibrate_or_quantize(
+
+            if (
+                module.quantization_status == QuantizationStatus.CALIBRATION
+                and not scheme.output_activations.dynamic
+            ):
+                calibrate_activations(
+                    module=module,
+                    value=output,
+                    base_name="output",
+                    quantization_args=scheme.ouput_activations,
+                )
+
+            output = forward_quantize(
                 module, output, "output", scheme.output_activations
             )
 
@@ -356,12 +379,36 @@ def wrapped_forward(self, *args, **kwargs):
     setattr(module, "forward", bound_wrapped_forward)
 
 
-def maybe_calibrate_or_quantize(
+def calibrate_activations(
+    module: Module,
+    value: torch.Tensor,
+    base_name: str,
+    quantization_args: QuantizationArgs,
+):
+    # If empty tensor, can't update zp/scale
+    # Case for MoEs
+    if value.numel() == 0:
+        return
+    # calibration mode - get new quant params from observer
+    if not hasattr(module, f"{base_name}_observer"):
+        from compressed_tensors.quantization.lifecycle import initialize_observers
+
+        initialize_observers(
+            module=module, base_name=base_name, quantization_args=quantization_args
+        )
+
+    observer = getattr(module, f"{base_name}_observer")
+
+    updated_scale, updated_zero_point = observer(value)
+
+    # update scale and zero point
+    update_parameter_data(module, updated_scale, f"{base_name}_scale")
+    update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
+
+
+def forward_quantize(
     module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
 ) -> torch.Tensor:
-    # don't run quantization if we haven't entered calibration mode
-    if module.quantization_status == QuantizationStatus.INITIALIZED:
-        return value
 
     # in compressed mode, the weight is already compressed and quantized so we don't
     # need to run fake quantization
@@ -386,22 +433,6 @@ def maybe_calibrate_or_quantize(
         scale = getattr(module, f"{base_name}_scale")
         zero_point = getattr(module, f"{base_name}_zero_point", None)
 
-        if (
-            module.quantization_status == QuantizationStatus.CALIBRATION
-            and base_name != "weight"
-        ):
-            # calibration mode - get new quant params from observer
-            observer = getattr(module, f"{base_name}_observer")
-
-            updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
-
-            # update scale and zero point
-            update_parameter_data(module, updated_scale, f"{base_name}_scale")
-            update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
-
-            scale = updated_scale
-            zero_point = updated_zero_point
-
     return fake_quantize(
         x=value,
         scale=scale,
diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py
index 66356cb7..e4723431 100644
--- a/src/compressed_tensors/quantization/lifecycle/frozen.py
+++ b/src/compressed_tensors/quantization/lifecycle/frozen.py
@@ -41,15 +41,11 @@ def freeze_module_quantization(module: Module):
         return
 
     # delete observers from module if not dynamic
-    if scheme.input_activations and not scheme.input_activations.dynamic:
+    if hasattr(module, "input_observer") and not scheme.input_activations.dynamic:
         delattr(module, "input_observer")
-    if scheme.weights and not scheme.weights.dynamic:
+    if hasattr(module, "weight_observer") and not scheme.weights.dynamic:
         delattr(module, "weight_observer")
-    if (
-        scheme.output_activations
-        and not is_kv_cache_quant_scheme(scheme)
-        and not scheme.output_activations.dynamic
-    ):
+    if hasattr(module, "output_observer") and not scheme.output_activations.dynamic:
         delattr(module, "output_observer")
 
     module.quantization_status = QuantizationStatus.FROZEN
diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py
index 9b98da33..68157cb1 100644
--- a/src/compressed_tensors/quantization/lifecycle/initialize.py
+++ b/src/compressed_tensors/quantization/lifecycle/initialize.py
@@ -34,9 +34,7 @@
 from torch.nn import Module, Parameter
 
 
-__all__ = [
-    "initialize_module_for_quantization",
-]
+__all__ = ["initialize_module_for_quantization", "initialize_observers"]
 
 
 _LOGGER = logging.getLogger(__name__)
@@ -74,7 +72,7 @@ def initialize_module_for_quantization(
     else:
 
         if scheme.input_activations is not None:
-            _initialize_scale_zero_point_observer(
+            _initialize_scale_zero_point(
                 module,
                 "input",
                 scheme.input_activations,
@@ -85,7 +83,7 @@ def initialize_module_for_quantization(
                 weight_shape = None
                 if isinstance(module, torch.nn.Linear):
                     weight_shape = module.weight.shape
-                _initialize_scale_zero_point_observer(
+                _initialize_scale_zero_point(
                     module,
                     "weight",
                     scheme.weights,
@@ -101,7 +99,7 @@ def initialize_module_for_quantization(
 
         if scheme.output_activations is not None:
             if not is_kv_cache_quant_scheme(scheme):
-                _initialize_scale_zero_point_observer(
+                _initialize_scale_zero_point(
                     module, "output", scheme.output_activations
                 )
 
@@ -146,21 +144,23 @@ def initialize_module_for_quantization(
                 module._hf_hook.weights_map = new_prefix_dict
 
 
-def _initialize_scale_zero_point_observer(
+def initialize_observers(
     module: Module,
     base_name: str,
     quantization_args: QuantizationArgs,
-    weight_shape: Optional[torch.Size] = None,
-    force_zero_point: bool = True,
 ):
-
     # initialize observer module and attach as submodule
     observer = quantization_args.get_observer()
-    # no need to register an observer for dynamic quantization
-    if observer:
-        module.register_module(f"{base_name}_observer", observer)
+    module.register_module(f"{base_name}_observer", observer)
 
-    # no need to register a scale and zero point for a dynamic quantization
+
+def _initialize_scale_zero_point(
+    module: Module,
+    base_name: str,
+    quantization_args: QuantizationArgs,
+    weight_shape: Optional[torch.Size] = None,
+    force_zero_point: bool = True,
+):
     if quantization_args.dynamic:
         return
 
diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
index 45d49370..1f88626e 100644
--- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
+++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py
@@ -23,6 +23,7 @@
 
 
 def test_apply_tinyllama_dynamic_activations():
+    # NOTE: should not calibrate dynamic quant
     quant_config = get_sample_dynamic_tinyllama_quant_config()
     model = get_tinyllama_model()
 
diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py
index b9ee67ff..0730c991 100644
--- a/tests/test_quantization/lifecycle/test_forward.py
+++ b/tests/test_quantization/lifecycle/test_forward.py
@@ -15,13 +15,18 @@
 
 import pytest
 import torch
+from compressed_tensors.quantization.lifecycle.calibration import (
+    set_module_for_calibration,
+)
 from compressed_tensors.quantization.lifecycle.forward import (
+    calibrate_activations,
     dequantize,
-    maybe_calibrate_or_quantize,
+    forward_quantize,
     quantize,
     wrap_module_forward_quantized,
     wrap_module_forward_quantized_attn,
 )
+from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
 from compressed_tensors.quantization.lifecycle.initialize import (
     initialize_module_for_quantization,
 )
@@ -58,7 +63,7 @@ def test_wrap_module_forward_quantized(create_quantization_scheme):
 @pytest.mark.parametrize(
     "quantization_status", ["initialized", "calibration", "frozen"]
 )
-def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status):
+def test_forward_quantize(create_quantization_scheme, quantization_status):
     num_bits = 8
     quantization_scheme = create_quantization_scheme(
         targets=["*"],
@@ -72,26 +77,41 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st
     dummy_tensor = torch.randn(8, 4)  # (num_tokens, num_features)
     layer.quantization_status = QuantizationStatus(quantization_status)
 
-    initialize_module_for_quantization(layer, quantization_scheme)
-
     # only calibration updates the scale and zero-point
     if layer.quantization_status == QuantizationStatus.INITIALIZED:
-        out = maybe_calibrate_or_quantize(
-            layer, dummy_tensor, "input", quantization_args
-        )
-        assert torch.allclose(out, dummy_tensor)
+        # Init zp and scales
+        initialize_module_for_quantization(layer, quantization_scheme)
+        # init weight observers; update weight scales/zp
+        set_module_for_calibration(layer)
+        # call quant/dequant on weights
+        out = forward_quantize(layer, layer.weight, "weight", quantization_args)
+        assert torch.allclose(out, layer.weight.data, atol=0.2)
     elif layer.quantization_status == QuantizationStatus.CALIBRATION:
-        out = maybe_calibrate_or_quantize(
-            layer, dummy_tensor, "input", quantization_args
+        # init zp/scales
+        initialize_module_for_quantization(layer, quantization_scheme)
+        #  init weight observers; update weight scales/zp
+        set_module_for_calibration(layer)
+        #  init input observers, update input scales/zp
+        calibrate_activations(
+            module=layer,
+            value=dummy_tensor,
+            base_name="input",
+            quantization_args=quantization_args,
         )
+        # call quant/dequant on inputs
+        out = forward_quantize(layer, dummy_tensor, "input", quantization_args)
         assert torch.allclose(out, dummy_tensor, atol=0.2)
         assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0]
     elif layer.quantization_status == QuantizationStatus.FROZEN:
-        # scale and zero points are empty -- cannot quantize
-        with pytest.raises(Exception):
-            out = maybe_calibrate_or_quantize(
-                layer, layer.weight.data, "input", quantization_args
-            )
+        # init weight observers
+        initialize_module_for_quantization(layer, quantization_scheme)
+        # init weight observers; update weight scales/zp
+        set_module_for_calibration(layer)
+        # remove weight observers and any input observers
+        freeze_module_quantization(layer)
+        # call quant/dequant on weights
+        out = forward_quantize(layer, layer.weight.data, "weight", quantization_args)
+        assert torch.allclose(out, layer.weight.data, atol=0.2)
 
 
 @pytest.mark.parametrize(
diff --git a/tests/test_quantization/lifecycle/test_frozen.py b/tests/test_quantization/lifecycle/test_frozen.py
index 056c6089..dddff117 100644
--- a/tests/test_quantization/lifecycle/test_frozen.py
+++ b/tests/test_quantization/lifecycle/test_frozen.py
@@ -13,6 +13,9 @@
 # limitations under the License.
 
 
+from compressed_tensors.quantization.lifecycle.calibration import (
+    set_module_for_calibration,
+)
 from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
 from compressed_tensors.quantization.lifecycle.initialize import (
     initialize_module_for_quantization,
@@ -34,9 +37,9 @@ def test_set_module_for_calibration(create_quantization_scheme):
 
     initialize_module_for_quantization(layer, quantization_scheme)
     layer.quantization_status = QuantizationStatus("calibration")
+    set_module_for_calibration(layer)
 
     # should have both input and weight observer after initalizing
-    assert hasattr(layer, "input_observer")
     assert hasattr(layer, "weight_observer")
 
     # observers should get deleted after freezing
diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py
index 970e4de9..41c56789 100644
--- a/tests/test_quantization/lifecycle/test_lifecycle.py
+++ b/tests/test_quantization/lifecycle/test_lifecycle.py
@@ -59,15 +59,12 @@ def test_lifecyle(create_quantization_scheme):
         expected_layer_keys.remove(key)
     assert len(expected_layer_keys) == 0
 
-    # should have both input and weight observer after initalizing
-    assert hasattr(layer, "input_observer")
-    assert hasattr(layer, "weight_observer")
-
     assert hasattr(layer, "quantization_scheme")
     assert hasattr(layer, "quantization_status")
     assert layer.quantization_status == QuantizationStatus.INITIALIZED
 
     set_module_for_calibration(layer)
+    assert hasattr(layer, "weight_observer")
     assert layer.quantization_status == QuantizationStatus.CALIBRATION
 
     # do a calibration step

From 232e4944b84798bd05fddc18a7752ae2b5d460da Mon Sep 17 00:00:00 2001
From: Alexandre Marques <alexandre@neuralmagic.com>
Date: Fri, 18 Oct 2024 17:00:11 -0400
Subject: [PATCH 08/12] Fix device allocation for MSE observer (#190)

* make sure all tensors are in the same device

* fix initial assignment

* Fix testing
---
 src/compressed_tensors/quantization/observers/mse.py | 6 +++---
 tests/test_quantization/test_observers/test_mse.py   | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py
index ff122a78..24d80584 100644
--- a/src/compressed_tensors/quantization/observers/mse.py
+++ b/src/compressed_tensors/quantization/observers/mse.py
@@ -70,9 +70,9 @@ def calculate_mse_min_max(
             absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
             absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
 
-        best = torch.full(absolute_min_val.shape, float("inf"))
-        min_val = torch.ones(absolute_min_val.shape)
-        max_val = torch.zeros(absolute_max_val.shape)
+        best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max)
+        min_val = torch.ones_like(absolute_min_val)
+        max_val = torch.zeros_like(absolute_max_val)
         for i in range(int(self.maxshrink * self.grid)):
             p = 1 - i / self.grid
             shrinked_min_val = p * absolute_min_val
diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py
index 76a317f4..659fc573 100644
--- a/tests/test_quantization/test_observers/test_mse.py
+++ b/tests/test_quantization/test_observers/test_mse.py
@@ -27,7 +27,7 @@
     ],
 )
 def test_mse_observer(symmetric, expected_scale, expected_zero_point):
-    tensor = torch.tensor([1, 1, 1, 1, 1])
+    tensor = torch.tensor([1., 1., 1., 1., 1.])
     num_bits = 8
     weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")
 

From d3216bc8d27e90da19b9c05913895fd83c0ac4a3 Mon Sep 17 00:00:00 2001
From: dhuangnm <74931910+dhuangnm@users.noreply.github.com>
Date: Wed, 23 Oct 2024 10:38:28 -0400
Subject: [PATCH 09/12] drop 3.8 and add 3.12 to testing (#196)

Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local>
---
 .github/workflows/trigger-all.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.github/workflows/trigger-all.yml b/.github/workflows/trigger-all.yml
index 02427999..1e53ff30 100644
--- a/.github/workflows/trigger-all.yml
+++ b/.github/workflows/trigger-all.yml
@@ -35,6 +35,6 @@ jobs:
             test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"},
                             {"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"},
                             {"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"},
-                            {"python":"3.8.17","label":"k8s-a100-duo","timeout":"40"}]'
+                            {"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'
 
         secrets: inherit

From d3dea3ffd3f81d9840e99dddf242bdde6f22d6af Mon Sep 17 00:00:00 2001
From: Kyle Sayers <kylesayrs@gmail.com>
Date: Wed, 23 Oct 2024 15:18:32 -0400
Subject: [PATCH 10/12] fix test which required accelerate, apply style (#194)

---
 src/compressed_tensors/quantization/lifecycle/frozen.py  | 1 -
 src/compressed_tensors/quantization/observers/helpers.py | 2 +-
 src/compressed_tensors/quantization/observers/mse.py     | 4 +++-
 tests/test_quantization/lifecycle/test_apply.py          | 8 +-------
 tests/test_quantization/test_observers/test_mse.py       | 2 +-
 5 files changed, 6 insertions(+), 11 deletions(-)

diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py
index e4723431..4a65482c 100644
--- a/src/compressed_tensors/quantization/lifecycle/frozen.py
+++ b/src/compressed_tensors/quantization/lifecycle/frozen.py
@@ -14,7 +14,6 @@
 
 
 from compressed_tensors.quantization.quant_config import QuantizationStatus
-from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
 from torch.nn import Module
 
 
diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py
index 875a05b3..ec474303 100644
--- a/src/compressed_tensors/quantization/observers/helpers.py
+++ b/src/compressed_tensors/quantization/observers/helpers.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from collections import Counter
-from typing import Optional, Tuple
+from typing import Tuple
 
 import torch
 from compressed_tensors.quantization.quant_args import (
diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py
index 24d80584..739e921f 100644
--- a/src/compressed_tensors/quantization/observers/mse.py
+++ b/src/compressed_tensors/quantization/observers/mse.py
@@ -70,7 +70,9 @@ def calculate_mse_min_max(
             absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
             absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
 
-        best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max)
+        best = torch.full_like(
+            absolute_min_val, torch.finfo(absolute_min_val.dtype).max
+        )
         min_val = torch.ones_like(absolute_min_val)
         max_val = torch.zeros_like(absolute_max_val)
         for i in range(int(self.maxshrink * self.grid)):
diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py
index 5f0bd093..dcb980f2 100644
--- a/tests/test_quantization/lifecycle/test_apply.py
+++ b/tests/test_quantization/lifecycle/test_apply.py
@@ -236,14 +236,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
 def test_apply_quantization_status(caplog, ignore, should_raise_warning):
     import logging
 
-    from transformers import AutoModelForCausalLM
-
     # load a dense, unquantized tiny llama model
-    model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
-    model = AutoModelForCausalLM.from_pretrained(
-        model_name, device_map="cpu", torch_dtype="auto"
-    )
-
+    model = get_tinyllama_model()
     quantization_config_dict = {
         "quant_method": "sparseml",
         "format": "pack-quantized",
diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py
index 659fc573..098551a2 100644
--- a/tests/test_quantization/test_observers/test_mse.py
+++ b/tests/test_quantization/test_observers/test_mse.py
@@ -27,7 +27,7 @@
     ],
 )
 def test_mse_observer(symmetric, expected_scale, expected_zero_point):
-    tensor = torch.tensor([1., 1., 1., 1., 1.])
+    tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])
     num_bits = 8
     weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")
 

From 07abbf3b9d9a7ec1f497a192241030181c480d66 Mon Sep 17 00:00:00 2001
From: Kyle Sayers <kylesayrs@gmail.com>
Date: Wed, 23 Oct 2024 15:25:05 -0400
Subject: [PATCH 11/12] [Bugfix] Move observer and g_idx until after module in
 onloaded (#195)

---
 .../quantization/lifecycle/calibration.py                   | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py
index 66dc35a0..c9e51813 100644
--- a/src/compressed_tensors/quantization/lifecycle/calibration.py
+++ b/src/compressed_tensors/quantization/lifecycle/calibration.py
@@ -64,14 +64,12 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
                 quantization_args=module.quantization_scheme.weights,
             )
 
-        observer = module.weight_observer
-
-        g_idx = getattr(module, "weight_g_idx", None)
-
         offloaded = is_module_offloaded(module)
         if offloaded:
             module._hf_hook.pre_forward(module)
 
+        observer = module.weight_observer
+        g_idx = getattr(module, "weight_g_idx", None)
         scale, zero_point = observer(module.weight, g_idx=g_idx)
         update_parameter_data(module, scale, "weight_scale")
         update_parameter_data(module, zero_point, "weight_zero_point")

From 13b5c0ba049163d6f7d310867c4ea283a116bcce Mon Sep 17 00:00:00 2001
From: Rahul Tuli <rahul@neuralmagic.com>
Date: Thu, 24 Oct 2024 08:29:24 -0400
Subject: [PATCH 12/12] Add sparsity structure enum (#197)

---
 src/compressed_tensors/config/base.py | 62 ++++++++++++++++++++++++++-
 tests/test_configs/__init__.py        | 13 ++++++
 tests/test_configs/test_base.py       | 57 ++++++++++++++++++++++++
 3 files changed, 130 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_configs/__init__.py
 create mode 100644 tests/test_configs/test_base.py

diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py
index ccc3e649..79a4fcdd 100644
--- a/src/compressed_tensors/config/base.py
+++ b/src/compressed_tensors/config/base.py
@@ -12,16 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from enum import Enum
+from enum import Enum, unique
 from typing import List, Optional
 
 from compressed_tensors.registry import RegistryMixin
 from pydantic import BaseModel
 
 
-__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
+__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
 
 
+@unique
 class CompressionFormat(Enum):
     dense = "dense"
     sparse_bitmask = "sparse-bitmask"
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
     marlin_24 = "marlin-24"
 
 
+@unique
+class SparsityStructure(Enum):
+    """
+    An enumeration to represent different sparsity structures.
+
+    Attributes
+    ----------
+    TWO_FOUR : str
+        Represents a 2:4 sparsity structure.
+    ZERO_ZERO : str
+        Represents a 0:0 sparsity structure.
+    UNSTRUCTURED : str
+        Represents an unstructured sparsity structure.
+
+    Examples
+    --------
+    >>> SparsityStructure('2:4')
+    <SparsityStructure.TWO_FOUR: '2:4'>
+
+    >>> SparsityStructure('unstructured')
+    <SparsityStructure.UNSTRUCTURED: 'unstructured'>
+
+    >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
+    True
+
+    >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
+    True
+
+    >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
+    True
+
+    >>> SparsityStructure('invalid')
+    Traceback (most recent call last):
+        ...
+    ValueError: invalid is not a valid SparsityStructure
+    """
+
+    TWO_FOUR = "2:4"
+    UNSTRUCTURED = "unstructured"
+    ZERO_ZERO = "0:0"
+
+    def __new__(cls, value):
+        obj = object.__new__(cls)
+        obj._value_ = value.lower() if value is not None else value
+        return obj
+
+    @classmethod
+    def _missing_(cls, value):
+        # Handle None and case-insensitive values
+        if value is None:
+            return cls.UNSTRUCTURED
+        for member in cls:
+            if member.value == value.lower():
+                return member
+        raise ValueError(f"{value} is not a valid {cls.__name__}")
+
+
 class SparsityCompressionConfig(RegistryMixin, BaseModel):
     """
     Base data class for storing sparsity compression parameters
diff --git a/tests/test_configs/__init__.py b/tests/test_configs/__init__.py
new file mode 100644
index 00000000..0c44f887
--- /dev/null
+++ b/tests/test_configs/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/test_configs/test_base.py b/tests/test_configs/test_base.py
new file mode 100644
index 00000000..5334ef38
--- /dev/null
+++ b/tests/test_configs/test_base.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from compressed_tensors.config import SparsityStructure
+
+
+def test_sparsity_structure_valid_cases():
+    assert (
+        SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
+    ), "Failed to match '2:4' with TWO_FOUR"
+    assert (
+        SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match 'unstructured' with UNSTRUCTURED"
+    assert (
+        SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED"
+    assert (
+        SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match None with UNSTRUCTURED"
+
+
+def test_sparsity_structure_invalid_case():
+    with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"):
+        SparsityStructure("invalid")
+
+
+def test_sparsity_structure_case_insensitivity():
+    assert (
+        SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
+    ), "Failed to match '2:4' with TWO_FOUR"
+    assert (
+        SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR
+    ), "Failed to match '2:4'.upper() with TWO_FOUR"
+    assert (
+        SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match 'unstructured'.upper() with UNSTRUCTURED"
+    assert (
+        SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED"
+
+
+def test_sparsity_structure_default_case():
+    assert (
+        SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
+    ), "Failed to match None with UNSTRUCTURED"