-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
QAT Support for new Framework with QuantizationModifier Testing (#1763)
* filling out quantization modifer for training * unit tests for quantization modifier oneshot * pytorch tests * deleting debug scripts * add pytorch flag * fix post quant calib * move e2e example * file path issue fix * fix imports * quality
- Loading branch information
Sara Adkins
authored
Oct 16, 2023
1 parent
cb4e02b
commit f889bb8
Showing
9 changed files
with
580 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
integrations/torchvision/modifiers_refactor_example/e2e_recipe.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
test_stage: | ||
quantization_modifiers: | ||
QuantizationModifier: | ||
start: eval(start_quant_epoch) | ||
scheme: | ||
input_activations: | ||
num_bits: 8 | ||
symmetric: False | ||
weights: | ||
num_bits: 4 | ||
symmetric: True | ||
strategy: "channel" | ||
ignore: ['classifier'] | ||
pruning_modifiers: | ||
MagnitudePruningModifier: | ||
init_sparsity: 0.0 | ||
final_sparsity: 0.5 | ||
start: eval(warm_up_epochs) | ||
end: eval(warm_up_epochs + pruning_epochs) | ||
update_frequency: 0.5 | ||
targets: | ||
- features.0.0.weight | ||
- features.1.conv.0.0.weight | ||
- features.1.conv.1.weight | ||
- features.2.conv.0.0.weight | ||
- features.2.conv.1.0.weight | ||
- features.2.conv.2.weight | ||
- features.3.conv.0.0.weight | ||
- features.3.conv.1.0.weight | ||
- features.3.conv.2.weight | ||
- features.4.conv.0.0.weight | ||
- features.4.conv.1.0.weight | ||
- features.4.conv.2.weight | ||
- features.5.conv.0.0.weight | ||
- features.5.conv.1.0.weight | ||
- features.5.conv.2.weight | ||
- features.6.conv.0.0.weight | ||
- features.6.conv.1.0.weight | ||
- features.6.conv.2.weight | ||
- features.7.conv.0.0.weight | ||
- features.7.conv.1.0.weight | ||
- features.7.conv.2.weight | ||
- features.8.conv.0.0.weight | ||
- features.8.conv.1.0.weight | ||
- features.8.conv.2.weight | ||
- features.9.conv.0.0.weight | ||
- features.9.conv.1.0.weight | ||
- features.9.conv.2.weight | ||
- features.10.conv.0.0.weight | ||
- features.10.conv.1.0.weight | ||
- features.10.conv.2.weight | ||
- features.11.conv.0.0.weight | ||
- features.11.conv.1.0.weight | ||
- features.11.conv.2.weight | ||
- features.12.conv.0.0.weight | ||
- features.12.conv.1.0.weight | ||
- features.12.conv.2.weight | ||
- features.13.conv.0.0.weight | ||
- features.13.conv.1.0.weight | ||
- features.13.conv.2.weight | ||
- features.14.conv.0.0.weight | ||
- features.14.conv.1.0.weight | ||
- features.14.conv.2.weight | ||
- features.15.conv.0.0.weight | ||
- features.15.conv.1.0.weight | ||
- features.15.conv.2.weight | ||
- features.16.conv.0.0.weight | ||
- features.16.conv.1.0.weight | ||
- features.16.conv.2.weight | ||
- features.17.conv.0.0.weight | ||
- features.17.conv.1.0.weight | ||
- features.17.conv.2.weight | ||
- features.18.0.weight | ||
- classifier.1.weight | ||
leave_enabled: True |
155 changes: 155 additions & 0 deletions
155
integrations/torchvision/modifiers_refactor_example/e2e_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def main(): | ||
import os | ||
|
||
import datasets | ||
import torch | ||
import torchvision | ||
from torch.nn import CrossEntropyLoss | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
import sparseml.core.session as sml | ||
from sparseml.core.event import EventType | ||
from sparseml.core.framework import Framework | ||
from sparseml.pytorch.utils import ( | ||
ModuleExporter, | ||
get_prunable_layers, | ||
tensor_sparsity, | ||
) | ||
|
||
NUM_LABELS = 3 | ||
BATCH_SIZE = 32 | ||
NUM_EPOCHS = 12 | ||
recipe = "e2e_recipe.yaml" | ||
device = "cuda:0" | ||
|
||
# set up SparseML session | ||
sml.create_session() | ||
session = sml.active_session() | ||
|
||
# download model | ||
model = torchvision.models.mobilenet_v2( | ||
weights=torchvision.models.MobileNet_V2_Weights.DEFAULT | ||
) | ||
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS) | ||
model.to(device) | ||
|
||
# download data | ||
beans_dataset = datasets.load_dataset("beans") | ||
train_folder, _ = os.path.split(beans_dataset["train"][0]["image_file_path"]) | ||
train_path, _ = os.path.split(train_folder) | ||
val_folder, _ = os.path.split(beans_dataset["validation"][0]["image_file_path"]) | ||
val_path, _ = os.path.split(train_folder) | ||
|
||
# dataloaders | ||
imagenet_transform = transforms.Compose( | ||
[ | ||
transforms.Resize( | ||
size=256, | ||
interpolation=transforms.InterpolationMode.BILINEAR, | ||
max_size=None, | ||
antialias=None, | ||
), | ||
transforms.CenterCrop(size=(224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
] | ||
) | ||
|
||
train_dataset = torchvision.datasets.ImageFolder( | ||
root=train_path, transform=imagenet_transform | ||
) | ||
train_loader = DataLoader( | ||
train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16 | ||
) | ||
|
||
val_dataset = torchvision.datasets.ImageFolder( | ||
root=val_path, transform=imagenet_transform | ||
) | ||
val_loader = DataLoader( | ||
val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16 | ||
) | ||
|
||
# loss and optimizer | ||
criterion = CrossEntropyLoss() | ||
optimizer = Adam(model.parameters(), lr=8e-3) | ||
|
||
# initialize session | ||
recipe_args = {"warm_up_epochs": 5, "start_quant_epoch": 3, "pruning_epochs": 5} | ||
_ = session.initialize( | ||
framework=Framework.pytorch, | ||
recipe=recipe, | ||
recipe_args=recipe_args, | ||
model=model, | ||
teacher_model=None, | ||
optimizer=optimizer, | ||
train_data=train_loader, | ||
val_data=val_loader, | ||
start=0.0, | ||
steps_per_epoch=len(train_loader), | ||
) | ||
|
||
# loop through batches | ||
for epoch in range(NUM_EPOCHS): | ||
running_loss = 0.0 | ||
total_correct = 0 | ||
total_predictions = 0 | ||
for step, (inputs, labels) in enumerate(session.state.data.train): | ||
inputs = inputs.to(device) | ||
labels = labels.to(device) | ||
session.state.optimizer.optimizer.zero_grad() | ||
session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) | ||
|
||
outputs = session.state.model.model(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
session.event(event_type=EventType.LOSS_CALCULATED, loss=loss) | ||
|
||
session.event(event_type=EventType.OPTIM_PRE_STEP) | ||
session.state.optimizer.optimizer.step() | ||
session.event(event_type=EventType.OPTIM_POST_STEP) | ||
|
||
running_loss += loss.item() | ||
|
||
predictions = outputs.argmax(dim=1) | ||
total_correct += torch.sum(predictions == labels).item() | ||
total_predictions += inputs.size(0) | ||
|
||
session.event(event_type=EventType.BATCH_END) | ||
|
||
loss = running_loss / (step + 1.0) | ||
accuracy = total_correct / total_predictions | ||
print("Epoch: {} Loss: {} Accuracy: {}".format(epoch + 1, loss, accuracy)) | ||
|
||
# finalize session | ||
session.finalize() | ||
|
||
# view sparsities | ||
for (name, layer) in get_prunable_layers(session.state.model.model): | ||
print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}") | ||
|
||
# save sparsified model | ||
save_dir = "e2e_experiment" | ||
exporter = ModuleExporter(model, output_dir=save_dir) | ||
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth") | ||
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.