Skip to content

Commit

Permalink
test initalize and restructure folders
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Apr 12, 2024
1 parent e1be3be commit 557d119
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:

_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)

if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
Expand Down
75 changes: 75 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import pytest
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Linear


@pytest.fixture(scope="module")
def create_quantization_scheme():
def quantization_scheme(
targets: List[str],
weights: Optional[QuantizationArgs] = None,
input_activations: Optional[QuantizationArgs] = None,
output_activations: Optional[QuantizationArgs] = None,
):
return QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)

return quantization_scheme


def test_initialize_module_for_quantization(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
layer = Linear(4, 4)

assert not hasattr(layer, "quantization_scheme")
assert not hasattr(layer, "quantization_status")

# add attributes, zero_points and scale
initialize_module_for_quantization(layer, quantization_scheme)

expected_registered_params = {
"input_scale",
"input_zero_point",
"weight_scale",
"weight_zero_point",
}
actual_registered_params = set()
for key in layer.state_dict().keys():
actual_registered_params.add(key)

expected_registered_params.issubset(actual_registered_params)

assert hasattr(layer, "quantization_scheme")
assert hasattr(layer, "quantization_status")

assert layer.quantization_status == QuantizationStatus.INITIALIZED
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 557d119

Please sign in to comment.