Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lifecyle tests #6

Closed
wants to merge 17 commits into from
28 changes: 15 additions & 13 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from torch.nn import Module


__all__ = ["wrap_module_forward_quantized"]
__all__ = [
"wrap_module_forward_quantized",
"quantize",
"dequantize",
"fake_quantize",
"maybe_calibrate_or_quantize",
]


def quantize(
Expand Down Expand Up @@ -67,16 +73,15 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
input_ = args[0]

if scheme.input_activations is not None:
# calibrate and (fake) quantize input activations when applicable
input_ = _maybe_calibrate_or_quantize(
input_ = maybe_calibrate_or_quantize(
module, input_, "input", scheme.input_activations
)

if scheme.weights is not None:
# calibrate and (fake) quantize weights when applicable
self.weight.data = _maybe_calibrate_or_quantize(
self.weight.data = maybe_calibrate_or_quantize(
module, self.weight, "weight", scheme.weights
)

Expand All @@ -87,10 +92,9 @@ def wrapped_forward(self, *args, **kwargs):

if scheme.output_activations is not None:
# calibrate and (fake) quantize output activations when applicable
output = _maybe_calibrate_or_quantize(
output = maybe_calibrate_or_quantize(
module, output, "output", scheme.output_activations
)

return output

# bind wrapped forward to module class so reference to `self` is correct
Expand All @@ -99,22 +103,17 @@ def wrapped_forward(self, *args, **kwargs):
setattr(module, "forward", bound_wrapped_forward)


def _maybe_calibrate_or_quantize(
def maybe_calibrate_or_quantize(
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# only run quantized for the included stages
if module.quantization_status not in {
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
}:
if module.quantization_status == QuantizationStatus.INITIALIZED:
return value

scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

print(scale, zero_point)

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
Expand All @@ -124,4 +123,7 @@ def _maybe_calibrate_or_quantize(
scale.data = updated_scale
zero_point.data = updated_zero_point

if scale.data.numel() < 1 and zero_point.data.numel() < 1:
raise ValueError("scale and zero_points are empty.")

return fake_quantize(value, scale, zero_point, args)
6 changes: 5 additions & 1 deletion src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module
submodule_name_do_delete = set()
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)
submodule_name_do_delete.add(submodule_name)

for submodule_name in submodule_name_do_delete:
delattr(module, submodule_name)

module.quantization_status = QuantizationStatus.FROZEN
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:

_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)

if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
Expand Down
37 changes: 37 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import pytest
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme


@pytest.fixture
def create_quantization_scheme():
def quantization_scheme(
targets: List[str],
weights: Optional[QuantizationArgs] = None,
input_activations: Optional[QuantizationArgs] = None,
output_activations: Optional[QuantizationArgs] = None,
):
return QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)

return quantization_scheme
41 changes: 41 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch.nn import Linear


@pytest.mark.parametrize("quantization_status", ["INITIALIZED", "CALIBRATION"])
def test_set_module_for_calibration(create_quantization_scheme, quantization_status):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)

layer = Linear(4, 4)
layer.quantization_status = QuantizationStatus(quantization_status)
layer.quantization_scheme = quantization_scheme

if layer.quantization_status == QuantizationStatus.INITIALIZED:
set_module_for_calibration(layer)
assert layer.quantization_status == QuantizationStatus.CALIBRATION
else:
with pytest.raises(TypeError):
set_module_for_calibration(layer)
80 changes: 80 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch
from sparsetensors.quantization.lifecycle.forward import (
maybe_calibrate_or_quantize,
wrap_module_forward_quantized,
)
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch.nn import Linear


def test_wrap_module_forward_quantized(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
layer = Linear(4, 4)

func_forward = layer.forward.__func__

# check that the forward call is overwritten
wrap_module_forward_quantized(layer, quantization_scheme)

assert not func_forward == layer.forward.__func__


@pytest.mark.parametrize(
"quantization_status", ["INITIALIZED", "CALIBRATION", "FROZEN"]
)
def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=False)
layer = Linear(4, 4)
layer.weight.data *= 100

initialize_module_for_quantization(layer, quantization_scheme)
layer.quantization_status = QuantizationStatus(quantization_status)

if layer.quantization_status == QuantizationStatus.INITIALIZED:
out = maybe_calibrate_or_quantize(
layer, layer.weight.data, "input", quantization_args
)
assert torch.allclose(out, layer.weight.data)
elif layer.quantization_status == QuantizationStatus.CALIBRATION:
out = maybe_calibrate_or_quantize(
layer, layer.weight.data, "input", quantization_args
)
assert not torch.allclose(out, layer.weight.data)

elif layer.quantization_status == QuantizationStatus.FROZEN:
# scale and zero points are empty -- cannot quantize
with pytest.raises(ValueError):
out = maybe_calibrate_or_quantize(
layer, layer.weight.data, "input", quantization_args
)
47 changes: 47 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_frozen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch.nn import Linear


def test_set_module_for_calibration(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)

layer = Linear(4, 4)

initialize_module_for_quantization(layer, quantization_scheme)
layer.quantization_status = QuantizationStatus("CALIBRATION")

# should have both input and weight observer after initalizing
assert hasattr(layer, "input_observer")
assert hasattr(layer, "weight_observer")

# observers should get deleted after freezing
freeze_module_quantization(layer)
assert not hasattr(layer, "input_observer")
assert not hasattr(layer, "weight_observer")

assert layer.quantization_status == QuantizationStatus("FROZEN")
54 changes: 54 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from torch.nn import Linear


def test_initialize_module_for_quantization(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
layer = Linear(4, 4)

assert not hasattr(layer, "quantization_scheme")
assert not hasattr(layer, "quantization_status")

# add attributes, zero_points and scale
initialize_module_for_quantization(layer, quantization_scheme)

expected_registered_params = {
"input_scale",
"input_zero_point",
"weight_scale",
"weight_zero_point",
}
actual_registered_params = set()
for key in layer.state_dict().keys():
actual_registered_params.add(key)

expected_registered_params.issubset(actual_registered_params)

assert hasattr(layer, "quantization_scheme")
assert hasattr(layer, "quantization_status")

assert layer.quantization_status == QuantizationStatus.INITIALIZED
Loading