diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index afa2404a4bc..a15a1692b4f 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -84,6 +84,7 @@ from transformers.tokenization_utils_base import PaddingStrategy from sparseml.optim import parse_recipe_variables +from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import export_onnx from sparseml.transformers.sparsification import Trainer @@ -245,6 +246,7 @@ def export_transformer_to_onnx( trust_remote_code: bool = False, data_args: Optional[Union[Dict[str, Any], str]] = None, one_shot: Optional[str] = None, + opset: int = TORCH_DEFAULT_ONNX_OPSET, ) -> str: """ Exports the saved transformers file to ONNX at batch size 1 using @@ -266,6 +268,7 @@ def export_transformer_to_onnx( :param data_args: additional args to instantiate a `DataTrainingArguments` instance for exporting samples :param one_shot: one shot recipe to be applied before exporting model + :param opset: ONNX opset to export with :return: path to the exported ONNX file """ task = task.replace("_", "-").replace(" ", "-") @@ -398,6 +401,7 @@ def export_transformer_to_onnx( inputs, onnx_file_path, convert_qat=convert_qat, + opset=opset, **kwargs, ) _LOGGER.info(f"ONNX exported to {onnx_file_path}") @@ -562,6 +566,12 @@ def _parse_args() -> argparse.Namespace: action="store_true", help=("Set flag to allow custom models in HF-transformers"), ) + parser.add_argument( + "--opset", + type=int, + default=TORCH_DEFAULT_ONNX_OPSET, + help=f"ONNX opset to export with, default: {TORCH_DEFAULT_ONNX_OPSET}", + ) return parser.parse_args() @@ -577,6 +587,7 @@ def export( trust_remote_code: bool = False, data_args: Optional[str] = None, one_shot: Optional[str] = None, + opset: int = TORCH_DEFAULT_ONNX_OPSET, ): if os.path.exists(model_path): # expand to absolute path to support downstream logic @@ -592,6 +603,7 @@ def export( trust_remote_code=trust_remote_code, data_args=data_args, one_shot=one_shot, + opset=opset, ) deployment_folder_dir = create_deployment_folder( @@ -616,6 +628,7 @@ def main(): trust_remote_code=args.trust_remote_code, data_args=args.data_args, one_shot=args.one_shot, + opset=args.opset, )