Skip to content

Commit

Permalink
Add sparsity structure enum (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Oct 24, 2024
1 parent 07abbf3 commit 13b5c0b
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 2 deletions.
62 changes: 60 additions & 2 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from enum import Enum, unique
from typing import List, Optional

from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]


@unique
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
Expand All @@ -32,6 +33,63 @@ class CompressionFormat(Enum):
marlin_24 = "marlin-24"


@unique
class SparsityStructure(Enum):
"""
An enumeration to represent different sparsity structures.
Attributes
----------
TWO_FOUR : str
Represents a 2:4 sparsity structure.
ZERO_ZERO : str
Represents a 0:0 sparsity structure.
UNSTRUCTURED : str
Represents an unstructured sparsity structure.
Examples
--------
>>> SparsityStructure('2:4')
<SparsityStructure.TWO_FOUR: '2:4'>
>>> SparsityStructure('unstructured')
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
True
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure('invalid')
Traceback (most recent call last):
...
ValueError: invalid is not a valid SparsityStructure
"""

TWO_FOUR = "2:4"
UNSTRUCTURED = "unstructured"
ZERO_ZERO = "0:0"

def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value.lower() if value is not None else value
return obj

@classmethod
def _missing_(cls, value):
# Handle None and case-insensitive values
if value is None:
return cls.UNSTRUCTURED
for member in cls:
if member.value == value.lower():
return member
raise ValueError(f"{value} is not a valid {cls.__name__}")


class SparsityCompressionConfig(RegistryMixin, BaseModel):
"""
Base data class for storing sparsity compression parameters
Expand Down
13 changes: 13 additions & 0 deletions tests/test_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
57 changes: 57 additions & 0 deletions tests/test_configs/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from compressed_tensors.config import SparsityStructure


def test_sparsity_structure_valid_cases():
assert (
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
), "Failed to match '2:4' with TWO_FOUR"
assert (
SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED
), "Failed to match 'unstructured' with UNSTRUCTURED"
assert (
SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED
), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED"
assert (
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
), "Failed to match None with UNSTRUCTURED"


def test_sparsity_structure_invalid_case():
with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"):
SparsityStructure("invalid")


def test_sparsity_structure_case_insensitivity():
assert (
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
), "Failed to match '2:4' with TWO_FOUR"
assert (
SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR
), "Failed to match '2:4'.upper() with TWO_FOUR"
assert (
SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED
), "Failed to match 'unstructured'.upper() with UNSTRUCTURED"
assert (
SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED
), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED"


def test_sparsity_structure_default_case():
assert (
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
), "Failed to match None with UNSTRUCTURED"

0 comments on commit 13b5c0b

Please sign in to comment.