Skip to content

Commit

Permalink
Semi-structured 2:4 sparsity via SparseSemiStructuredTensor #4
Browse files Browse the repository at this point in the history
magic_wand semi_structured_sparse_tensor_linear branch integrates 2:4 semi-structured sparsity into SparseTensor. This PR adds a new sparsity config for 2:4 sparsity to neuralmagic-vllm, using the SparseTensor 2:4 support.

This PR also refactors the sparse linear method into a separate file, vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py, which supports all sparsity formats.
  • Loading branch information
afeldman-nm authored and robertgshaw2-redhat committed Feb 22, 2024
1 parent 5344a01 commit 81dba47
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 66 deletions.
12 changes: 12 additions & 0 deletions examples/offline_inference_semi_structured_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from vllm import LLM, SamplingParams

model = LLM("nm-testing/zephyr-50sparse-24",
sparsity="semi_structured_sparse_w16a16",
enforce_eager=True,
dtype="float16",
tensor_parallel_size=1,
max_model_len=1024)

sampling_params = SamplingParams(max_tokens=100, temperature=0)
outputs = model.generate("Hello my name is", sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16"]
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]

if self.quantization is not None:
raise ValueError("Both sparsity and quantization detected. Only "
Expand Down
34 changes: 27 additions & 7 deletions vllm/model_executor/layers/parameters/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import torch

from magic_wand import SparseTensor, SparseBitmaskStorageFormat
from typing import Type
from magic_wand import (SparseTensor, CompressedStorageFormat,
SparseBitmaskStorageFormat)


class SparseParameter(SparseTensor):

@staticmethod
def __new__(
cls,
shape: torch.Size,
dtype: torch.dtype,
):
def __new__(cls,
shape: torch.Size,
dtype: torch.dtype,
storage_format_cls: Type[
CompressedStorageFormat] = SparseBitmaskStorageFormat):
assert torch.__version__ > (1,
10), "SparseTensor requires PyTorch 1.11+"

self = torch.Tensor._make_wrapper_subclass(cls,
size=shape,
dtype=dtype,
requires_grad=False)
self.storage_format_cls = SparseBitmaskStorageFormat
self.storage_format_cls = storage_format_cls
self.compressed_data = None
self.dense_data = None
self._is_param = True

return self

def has_compressed_data(self) -> bool:
return (self.compressed_data is not None)

def get_dense_data(self) -> torch.Tensor:
if self.dense_data is not None:
raise ValueError(
Expand All @@ -39,6 +45,20 @@ def _unpack(self) -> torch.Tensor:
dtype=self.dtype,
device="cuda")

@classmethod
def _copy(cls, arg0, arg1):
assert arg0.shape == arg1.shape

if arg0.has_compressed_data():
arg0.compressed_data.copy_(arg1)
else:
arg0.compressed_data = arg0.storage_format_cls.compress(arg1)

return arg0

def copy_(self, src, non_blocking=False):
return SparseParameter._copy(self, src)

def pack(self) -> None:
if self.dense_data is None:
raise ValueError("Called pack() but dense_data does not exist.")
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config
from vllm.model_executor.layers.sparsity.semi_structured_sparse_w16a16 import SemiStructuredSparseW16A16Config

_SPARSITY_CONFIG_REGISTRY = {
"sparse_w16a16": SparseW16A16Config,
"semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config,
}


Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/sparsity/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
from typing import Any, Dict, List

import torch
from typing import Type

from vllm.model_executor.layers.linear import LinearMethodBase
from magic_wand import CompressedStorageFormat


class SparsityConfig(ABC):
"""Base class for sparsity configs."""

@abstractmethod
def get_storage_format_cls(self) -> Type[CompressedStorageFormat]:
"""Sparse representation format"""
raise NotImplementedError

@abstractmethod
def get_name(self) -> str:
"""Name of the sparse method."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch

from typing import Any, Dict, List, Type
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)


class SemiStructuredSparseW16A16Config(SparsityConfig):
"""Config class for SemiStructuredSparseW16A16."""

def __init__(self) -> None:
pass

def __repr__(self) -> str:
return "SemiStructuredSparseW16A16Config()"

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseSemiStructuredStorageFormat

@classmethod
def get_name(cls) -> str:
return "semi_structured_sparse_w16a16"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
# TODO: Update after checks on more GPUs
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["sparsity_config.json"]

@classmethod
def from_config(
cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
67 changes: 9 additions & 58 deletions vllm/model_executor/layers/sparsity/sparse_w16a16.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Type

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter

from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)


class SparseW16A16Config(SparsityConfig):
Expand All @@ -21,6 +21,10 @@ def __init__(self) -> None:
def __repr__(self) -> str:
return "SparseW16A16Config()"

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseBitmaskStorageFormat

@classmethod
def get_name(cls) -> str:
return "sparse_w16a16"
Expand All @@ -43,57 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self)


class SparseW16A16LinearMethod(LinearMethodBase):
"""Linear method for Sparse W16A16.
Args:
sparsity_config: The sparse config.
"""

def __init__(self, sparsity_config: SparseW16A16Config):
self.sparsity_config = sparsity_config

def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
weight = SparseParameter(
shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

return {"weight": weight}

def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]

# Uncompress to dense
dense_weight = sparse_weight.to_dense()

# # Uncomment to verify sparsity
# density = torch.count_nonzero(
# dense_weight).item() / dense_weight.numel()
# print(f"sparsity = {1.0 - density}")

# Standard matrix multiply
if bias is not None:
output = F.linear(x, dense_weight, bias)
else:
output = F.linear(x, dense_weight)

return output
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
55 changes: 55 additions & 0 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Dict, Optional, Type

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)


class SparseW16A16LinearMethod(LinearMethodBase):
"""Linear method for Sparse W16A16.
Args:
sparsity_config: The sparse config.
"""
storage_format_cls: Type[CompressedStorageFormat] = None

def __init__(self, sparsity_config: SparsityConfig,
storage_format_cls: Type[CompressedStorageFormat]):
self.sparsity_config = sparsity_config
self.storage_format_cls = storage_format_cls

def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = SparseParameter(shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

return {"weight": weight}

def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]

if self.storage_format_cls == SparseSemiStructuredStorageFormat:
output = F.linear(x, sparse_weight, bias)
return output
else:

# Standard matrix multiply
# Uncompress to dense
output = F.linear(x, sparse_weight.to_dense(), bias)
return output

0 comments on commit 81dba47

Please sign in to comment.