diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 6f7c2f53..ccc3e649 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -13,7 +13,7 @@ # limitations under the License. from enum import Enum -from typing import Optional +from typing import List, Optional from compressed_tensors.registry import RegistryMixin from pydantic import BaseModel @@ -37,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel): Base data class for storing sparsity compression parameters :param format: name of compression format + :param targets: List of layer names or layer types that aren't sparse and should + be ignored during compression. By default, assume all layers are targeted + :param ignore: List of layer names (unique) to ignore from targets. Defaults to None :param global_sparsity: average sparsity of the entire model :param sparsity_structure: structure of the sparsity, such as "unstructured", "2:4", "8:16" etc """ format: str + targets: Optional[List[str]] = None + ignore: Optional[List[str]] = None global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured"