Skip to content

Commit

Permalink
Add SparsityStructure Enum
Browse files Browse the repository at this point in the history
Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
rahul-tuli committed Oct 22, 2024
1 parent 4baf16d commit 43cb1d7
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
iter_named_leaf_modules,
)

from llmcompressor.transformers.compression.sparsity_config import SparsityStructure

__all__ = ["infer_quantization_format"]


Expand Down Expand Up @@ -35,7 +37,10 @@ def infer_quantization_format(

if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = sparsity_structure is not None and sparsity_structure == "2:4"
is_24_structure = (
SparsityStructure(sparsity_structure).value
== SparsityStructure.TWO_FOUR.value
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

if is_weight_only: # w4a16 and w8a16
Expand Down
61 changes: 58 additions & 3 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum, unique
from typing import Dict, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
Expand All @@ -13,13 +14,67 @@
)


@unique
class SparsityStructure(Enum):
"""
An enumeration to represent different sparsity structures.
Attributes
----------
TWO_FOUR : str
Represents a 2:4 sparsity structure.
UNSTRUCTURED : str
Represents an unstructured sparsity structure.
Examples
--------
>>> SparsityStructure('2:4')
<SparsityStructure.TWO_FOUR: '2:4'>
>>> SparsityStructure('unstructured')
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
True
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure('invalid')
Traceback (most recent call last):
...
ValueError: invalid is not a valid SparsityStructure
"""

TWO_FOUR = "2:4"
UNSTRUCTURED = "unstructured"

def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value.lower() if value is not None else value
return obj

@classmethod
def _missing_(cls, value):
# Handle None and case-insensitive values
if value is None:
return cls.UNSTRUCTURED
for member in cls:
if member.value == value.lower():
return member
raise ValueError(f"{value} is not a valid {cls.__name__}")


class SparsityConfigMetadata:
"""
Class of helper functions for filling out a SparsityCompressionConfig with readable
metadata from the model
"""

SPARSITY_THRESHOLD: float = 0.4
SPARSITY_THRESHOLD: float = 0.5

@staticmethod
def infer_global_sparsity(
Expand Down Expand Up @@ -66,7 +121,7 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure or "unstructured"
return SparsityStructure(sparsity_structure).value

@staticmethod
def from_pretrained(
Expand Down Expand Up @@ -101,7 +156,7 @@ def from_pretrained(
# compression
format = CompressionFormat.dense.value
if compress:
if sparsity_structure == "2:4":
if sparsity_structure == SparsityStructure.TWO_FOUR.value:
format = CompressionFormat.sparse_24.value
else:
format = CompressionFormat.sparse_bitmask.value
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from llmcompressor.transformers.compression.sparsity_config import SparsityStructure


def test_sparsity_structure_valid_cases():
assert (
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
), "Failed to match '2:4' with TWO_FOUR"
assert (
SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED
), "Failed to match 'unstructured' with UNSTRUCTURED"
assert (
SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED
), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED"
assert (
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
), "Failed to match None with UNSTRUCTURED"


def test_sparsity_structure_invalid_case():
with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"):
SparsityStructure("invalid")


def test_sparsity_structure_case_insensitivity():
assert (
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
), "Failed to match '2:4' with TWO_FOUR"
assert (
SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR
), "Failed to match '2:4'.upper() with TWO_FOUR"
assert (
SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED
), "Failed to match 'unstructured'.upper() with UNSTRUCTURED"
assert (
SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED
), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED"


def test_sparsity_structure_default_case():
assert (
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
), "Failed to match None with UNSTRUCTURED"

0 comments on commit 43cb1d7

Please sign in to comment.