diff --git a/src/sparseml/pytorch/image_classification/export.py b/src/sparseml/pytorch/image_classification/export.py index 7403aec9f95..7781ca24b45 100644 --- a/src/sparseml/pytorch/image_classification/export.py +++ b/src/sparseml/pytorch/image_classification/export.py @@ -387,18 +387,12 @@ def export( """ exporter = ModuleExporter(model, save_dir) - # export PyTorch state dict - LOGGER.info(f"exporting pytorch in {save_dir}") - - exporter.export_pytorch( - use_zipfile_serialization_if_available=(use_zipfile_serialization_if_available) - ) - onnx_exported = False - if not val_loader: # create fake data for export val_loader = [[torch.randn(1, 3, image_size, image_size)]] + onnx_exported = False + for batch, data in tqdm( enumerate(val_loader), desc="Exporting samples",