diff --git a/Makefile b/Makefile index c8600e281c9..bcfcf77d686 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx) PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx endif ifneq ($(findstring pytorch,$(TARGETS)),pytorch) - PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch + PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers endif ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models) PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models diff --git a/integrations/torchvision/modifiers_refactor_example/e2e_recipe.yaml b/integrations/torchvision/modifiers_refactor_example/e2e_recipe.yaml new file mode 100644 index 00000000000..94b7f289f3a --- /dev/null +++ b/integrations/torchvision/modifiers_refactor_example/e2e_recipe.yaml @@ -0,0 +1,75 @@ +test_stage: + quantization_modifiers: + QuantizationModifier: + start: eval(start_quant_epoch) + scheme: + input_activations: + num_bits: 8 + symmetric: False + weights: + num_bits: 4 + symmetric: True + strategy: "channel" + ignore: ['classifier'] + pruning_modifiers: + MagnitudePruningModifier: + init_sparsity: 0.0 + final_sparsity: 0.5 + start: eval(warm_up_epochs) + end: eval(warm_up_epochs + pruning_epochs) + update_frequency: 0.5 + targets: + - features.0.0.weight + - features.1.conv.0.0.weight + - features.1.conv.1.weight + - features.2.conv.0.0.weight + - features.2.conv.1.0.weight + - features.2.conv.2.weight + - features.3.conv.0.0.weight + - features.3.conv.1.0.weight + - features.3.conv.2.weight + - features.4.conv.0.0.weight + - features.4.conv.1.0.weight + - features.4.conv.2.weight + - features.5.conv.0.0.weight + - features.5.conv.1.0.weight + - features.5.conv.2.weight + - features.6.conv.0.0.weight + - features.6.conv.1.0.weight + - features.6.conv.2.weight + - features.7.conv.0.0.weight + - features.7.conv.1.0.weight + - features.7.conv.2.weight + - features.8.conv.0.0.weight + - features.8.conv.1.0.weight + - features.8.conv.2.weight + - features.9.conv.0.0.weight + - features.9.conv.1.0.weight + - features.9.conv.2.weight + - features.10.conv.0.0.weight + - features.10.conv.1.0.weight + - features.10.conv.2.weight + - features.11.conv.0.0.weight + - features.11.conv.1.0.weight + - features.11.conv.2.weight + - features.12.conv.0.0.weight + - features.12.conv.1.0.weight + - features.12.conv.2.weight + - features.13.conv.0.0.weight + - features.13.conv.1.0.weight + - features.13.conv.2.weight + - features.14.conv.0.0.weight + - features.14.conv.1.0.weight + - features.14.conv.2.weight + - features.15.conv.0.0.weight + - features.15.conv.1.0.weight + - features.15.conv.2.weight + - features.16.conv.0.0.weight + - features.16.conv.1.0.weight + - features.16.conv.2.weight + - features.17.conv.0.0.weight + - features.17.conv.1.0.weight + - features.17.conv.2.weight + - features.18.0.weight + - classifier.1.weight + leave_enabled: True diff --git a/integrations/torchvision/modifiers_refactor_example/e2e_test.py b/integrations/torchvision/modifiers_refactor_example/e2e_test.py new file mode 100644 index 00000000000..27fd697b50c --- /dev/null +++ b/integrations/torchvision/modifiers_refactor_example/e2e_test.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def main(): + import os + + import datasets + import torch + import torchvision + from torch.nn import CrossEntropyLoss + from torch.optim import Adam + from torch.utils.data import DataLoader + from torchvision import transforms + + import sparseml.core.session as sml + from sparseml.core.event import EventType + from sparseml.core.framework import Framework + from sparseml.pytorch.utils import ( + ModuleExporter, + get_prunable_layers, + tensor_sparsity, + ) + + NUM_LABELS = 3 + BATCH_SIZE = 32 + NUM_EPOCHS = 12 + recipe = "e2e_recipe.yaml" + device = "cuda:0" + + # set up SparseML session + sml.create_session() + session = sml.active_session() + + # download model + model = torchvision.models.mobilenet_v2( + weights=torchvision.models.MobileNet_V2_Weights.DEFAULT + ) + model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS) + model.to(device) + + # download data + beans_dataset = datasets.load_dataset("beans") + train_folder, _ = os.path.split(beans_dataset["train"][0]["image_file_path"]) + train_path, _ = os.path.split(train_folder) + val_folder, _ = os.path.split(beans_dataset["validation"][0]["image_file_path"]) + val_path, _ = os.path.split(train_folder) + + # dataloaders + imagenet_transform = transforms.Compose( + [ + transforms.Resize( + size=256, + interpolation=transforms.InterpolationMode.BILINEAR, + max_size=None, + antialias=None, + ), + transforms.CenterCrop(size=(224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + train_dataset = torchvision.datasets.ImageFolder( + root=train_path, transform=imagenet_transform + ) + train_loader = DataLoader( + train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16 + ) + + val_dataset = torchvision.datasets.ImageFolder( + root=val_path, transform=imagenet_transform + ) + val_loader = DataLoader( + val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16 + ) + + # loss and optimizer + criterion = CrossEntropyLoss() + optimizer = Adam(model.parameters(), lr=8e-3) + + # initialize session + recipe_args = {"warm_up_epochs": 5, "start_quant_epoch": 3, "pruning_epochs": 5} + _ = session.initialize( + framework=Framework.pytorch, + recipe=recipe, + recipe_args=recipe_args, + model=model, + teacher_model=None, + optimizer=optimizer, + train_data=train_loader, + val_data=val_loader, + start=0.0, + steps_per_epoch=len(train_loader), + ) + + # loop through batches + for epoch in range(NUM_EPOCHS): + running_loss = 0.0 + total_correct = 0 + total_predictions = 0 + for step, (inputs, labels) in enumerate(session.state.data.train): + inputs = inputs.to(device) + labels = labels.to(device) + session.state.optimizer.optimizer.zero_grad() + session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) + + outputs = session.state.model.model(inputs) + loss = criterion(outputs, labels) + loss.backward() + session.event(event_type=EventType.LOSS_CALCULATED, loss=loss) + + session.event(event_type=EventType.OPTIM_PRE_STEP) + session.state.optimizer.optimizer.step() + session.event(event_type=EventType.OPTIM_POST_STEP) + + running_loss += loss.item() + + predictions = outputs.argmax(dim=1) + total_correct += torch.sum(predictions == labels).item() + total_predictions += inputs.size(0) + + session.event(event_type=EventType.BATCH_END) + + loss = running_loss / (step + 1.0) + accuracy = total_correct / total_predictions + print("Epoch: {} Loss: {} Accuracy: {}".format(epoch + 1, loss, accuracy)) + + # finalize session + session.finalize() + + # view sparsities + for (name, layer) in get_prunable_layers(session.state.model.model): + print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}") + + # save sparsified model + save_dir = "e2e_experiment" + exporter = ModuleExporter(model, output_dir=save_dir) + exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth") + exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx") + + +if __name__ == "__main__": + main() diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index bdf76131592..c46d5a0febb 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -184,7 +184,7 @@ def simplify_combine_recipes( ) combined.version = simplified.version combined.stages.extend(simplified.stages) - combined.args.combine(simplified.args) + combined.args.update(simplified.args) return combined diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/base.py index 7a229c43a65..479b843815e 100644 --- a/src/sparseml/modifiers/quantization/base.py +++ b/src/sparseml/modifiers/quantization/base.py @@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional -from sparseml.core import Modifier, State +from sparseml.core import Event, Modifier, State from sparseml.modifiers.quantization.utils.quantization_scheme import ( QuantizationScheme, QuantizationSchemeLoadable, @@ -104,6 +104,57 @@ def __init__(self, **kwargs): ) if self.model_fuse_fn_kwargs is None: self.model_fuse_fn_kwargs = {} + if self.ignore is None: + self.ignore = [] + + def calculate_freeze_bn_stats_epoch(self) -> float: + """ + Get the epoch at which we want to stop updating batch normalization stats + + :return: freeze_bn_stats_epoch if set, else -1 + """ + return ( + self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1 + ) + + def check_should_freeze_bn_stats(self, event: Event) -> bool: + """ + Given the current index, determine if we should freeze batch normalization stats + + :param event: Event to get index from + :return: True if stats should be frozen, False otherwise + """ + freeze_epoch = self.calculate_freeze_bn_stats_epoch() + if freeze_epoch == -1: + return False + if event.current_index >= freeze_epoch: + return True + return False + + def calculate_disable_observer_epoch(self) -> float: + """ + Get the epoch at which we want to disable to quantization observer + :return epoch to disable at, or -1 if it is not set + """ + return ( + self.disable_quantization_observer_epoch + if self.disable_quantization_observer_epoch is not None + else -1 + ) + + def check_should_disable_observer(self, event: Event) -> bool: + """ + Given the current index, determine if we should disable the observer + + :param event: Event to get index from + :return: True if observer should be disabled, False otherwise + """ + disable_epoch = self.calculate_disable_observer_epoch() + if disable_epoch == -1: + return False + if event.current_index >= disable_epoch: + return True + return False def on_initialize_structure(self, state: State, **kwargs): pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py index 30e731c4415..71b7d08d6fe 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -19,10 +19,11 @@ import torch from torch.nn import Module -from sparseml.core import Event, State +from sparseml.core import Event, EventType, State from sparseml.modifiers.quantization.base import QuantizationModifier from sparseml.modifiers.quantization.utils.helpers import ( configure_module_bn_wrappers, + freeze_bn_stats, fuse_module_conv_bn_relus, ) from sparseml.modifiers.quantization.utils.quantize import ( @@ -40,6 +41,8 @@ class QuantizationModifierPyTorch(QuantizationModifier): calibration_dataloader_: Any = None calibration_function_: Any = None qat_enabled_: bool = False + quantization_observer_disabled_: bool = False + bn_stats_frozen_: bool = False def on_initialize(self, state: State, **kwargs) -> bool: raise_if_torch_quantization_not_available() @@ -51,10 +54,10 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.calibration_dataloader_ = state.data.calib module = state.model.model - device = state.hardware.device - state.model.model.to(device) - module = state.model.model - self._enable_module_qat(module) + + if self.calculate_start() == -1: # one-shot + self._enable_module_qat(module) + self._disable_quantization_observer(module) return True @@ -63,21 +66,34 @@ def on_finalize(self, state: State, **kwargs) -> bool: state.model.model.to(state.hardware.device) state.model.model.apply(torch.quantization.enable_observer) self._calibrate_if_possible(state.model.model) - state.model.model.apply(torch.quantization.disable_observer) + self._disable_quantization_observer(state.model.model) return True def on_start(self, state: State, event: Event, **kwargs): - pass + if not self.qat_enabled_: + self._enable_module_qat(state.model.model) def on_update(self, state: State, event: Event, **kwargs): - pass + if event.type_ == EventType.BATCH_START: + if self.check_should_freeze_bn_stats(event): + self._freeze_bn_stats(state.model.model) + if self.check_should_disable_observer(event): + self._disable_quantization_observer(state.model.model) def on_end(self, state: State, event: Event, **kwargs): - pass + self._disable_quantization_observer(state.model.model) def on_event(self, state: State, event: Event, **kwargs): pass + def _freeze_bn_stats(self, model: Module): + model.apply(freeze_bn_stats) + self.bn_stats_frozen_ = True + + def _disable_quantization_observer(self, model: Module): + model.apply(torch.quantization.disable_observer) + self.quantization_observer_disabled_ = True + def _enable_module_qat(self, module: Module): # fuse conv-bn-relu blocks prior to quantization emulation self._fuse(module) @@ -164,4 +180,4 @@ def _calibrate(self, module: Module): if module_training: module.train() else: - module.apply(torch.quantization.disable_observer) + self._disable_quantization_observer(module) diff --git a/tests/sparseml/modifiers/conf.py b/tests/sparseml/modifiers/conf.py new file mode 100644 index 00000000000..90ab0c34be6 --- /dev/null +++ b/tests/sparseml/modifiers/conf.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.core import State +from sparseml.core.event import EventType +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.core.lifecycle import CallbacksEventLifecycle + + +def setup_modifier_factory(): + ModifierFactory.refresh() + assert ModifierFactory._loaded, "ModifierFactory not loaded" + + +class LifecyleTestingHarness: + def __init__(self, model=None, optimizer=None, framework=Framework.pytorch): + self.state = State(framework=framework) + self.state.update(model=model, optimizer=optimizer, start=0, steps_per_epoch=1) + + self.event_lifecycle = CallbacksEventLifecycle( + type_first=EventType.BATCH_START, start=self.state.start_event + ) + + def update_modifier(self, modifier, event_type): + events = self.event_lifecycle.events_from_type(event_type) + for event in events: + modifier.update_event(self.state, event=event) + + def get_state(self): + return self.state + + def trigger_modifier_for_epochs(self, modifier, num_epochs): + for _ in range(num_epochs): + self.update_modifier(modifier, EventType.BATCH_START) + self.update_modifier(modifier, EventType.LOSS_CALCULATED) + self.update_modifier(modifier, EventType.OPTIM_PRE_STEP) + self.update_modifier(modifier, EventType.OPTIM_POST_STEP) + self.update_modifier(modifier, EventType.BATCH_END) diff --git a/tests/sparseml/modifiers/quantization/test_base.py b/tests/sparseml/modifiers/quantization/test_base.py new file mode 100644 index 00000000000..cd5fab0e755 --- /dev/null +++ b/tests/sparseml/modifiers/quantization/test_base.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.core.event import Event +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.quantization import QuantizationModifier +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +def test_quantization_registered(): + setup_modifier_factory() + + kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **kwargs, + ) + + assert isinstance(quant_obj, QuantizationModifier) + + +def test_end_epochs(): + start = 0.0 + scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=6, symmetric=False), + ) + + disable_quant_epoch, freeze_bn_epoch = None, None + obj_modifier = QuantizationModifier( + start=start, + scheme=scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) + + assert obj_modifier.calculate_disable_observer_epoch() == -1 + assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1 + + for epoch in range(3): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) + + disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 + obj_modifier = QuantizationModifier( + start=start, + scheme=scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) + + assert obj_modifier.calculate_disable_observer_epoch() == disable_quant_epoch + assert obj_modifier.calculate_freeze_bn_stats_epoch() == freeze_bn_epoch + + for epoch in range(4): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) + + event = Event(steps_per_epoch=1, global_step=4) + assert obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) + + for epoch in range(5, 8): + event = Event(steps_per_epoch=1, global_step=epoch) + assert obj_modifier.check_should_disable_observer(event) + assert obj_modifier.check_should_freeze_bn_stats(event) diff --git a/tests/sparseml/modifiers/quantization/test_pytorch.py b/tests/sparseml/modifiers/quantization/test_pytorch.py new file mode 100644 index 00000000000..b982c54cc44 --- /dev/null +++ b/tests/sparseml/modifiers/quantization/test_pytorch.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from sparseml.core import State +from sparseml.core.event import Event, EventType +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.quantization import QuantizationModifierPyTorch +from sparseml.pytorch.sparsification.quantization.quantize import ( + is_qat_helper_module, + is_quantizable_module, +) +from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory +from tests.sparseml.pytorch.helpers import ConvNet, LinearNet +from tests.sparseml.pytorch.sparsification.quantization.test_modifier_quantization import ( # noqa E501 + _test_qat_wrapped_module, + _test_quantized_module, +) + + +def _test_qat_applied(modifier, model): + assert modifier.qat_enabled_ + + for name, module in model.named_modules(): + if is_qat_helper_module(module): + # skip helper modules + continue + + is_target_submodule = not any( + name.startswith(submodule_name) for submodule_name in modifier.ignore + ) + is_included_module_type = any( + module_type_name == module.__class__.__name__ + for module_type_name in modifier.scheme_overrides + ) + is_quantizable = is_included_module_type or is_quantizable_module( + module, + exclude_module_types=modifier.ignore, + ) + + if is_target_submodule and is_quantizable: + if getattr(module, "wrap_qat", False): + _test_qat_wrapped_module(model, name) + elif is_quantizable: + # check each target module is quantized + _test_quantized_module(model, modifier, module, name) + else: + # check all non-target modules are not quantized + assert not hasattr(module, "quantization_scheme") + assert not hasattr(module, "qconfig") + + +def test_quantization_registered(): + setup_modifier_factory() + + kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **kwargs, + ) + + assert isinstance(quant_obj, QuantizationModifierPyTorch) + + +@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) +def test_quantization_oneshot(model_class): + model = model_class() + state = State(framework=Framework.pytorch, start_event=Event()) + state.update(model=model) + + scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False, strategy="channel"), + ) + kwargs = dict(scheme=scheme) + + modifier = QuantizationModifierPyTorch(**kwargs) + + modifier.initialize(state) + + # for one-shot, we set up quantization on initialization + _test_qat_applied(modifier, model) + + # we shouldn't keep updating stats after one-shot + assert modifier.quantization_observer_disabled_ + + test_start_event = Event(type_=EventType.BATCH_START) + test_end_event = Event(type_=EventType.BATCH_END) + assert not modifier.should_start(test_start_event) + assert not modifier.should_end(test_end_event) + + modifier.finalize(state) + assert modifier.finalized + + +@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) +def test_quantization_training(model_class): + start_epoch = 2 + + model = model_class() + kwargs = dict( + start=start_epoch, + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + modifier = QuantizationModifierPyTorch(**kwargs) + + testing_harness = LifecyleTestingHarness(model=model) + modifier.initialize(testing_harness.get_state()) + assert not modifier.qat_enabled_ + + testing_harness.trigger_modifier_for_epochs(modifier, start_epoch) + assert not modifier.qat_enabled_ + testing_harness.trigger_modifier_for_epochs(modifier, start_epoch + 1) + _test_qat_applied(modifier, model) + + modifier.finalize(testing_harness.get_state()) + assert modifier.quantization_observer_disabled_