diff --git a/mypy.ini b/mypy.ini index 5ab02361d615c..b66464d21f54a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -165,6 +165,9 @@ ignore_missing_imports = True [mypy-tensorboard.*] ignore_missing_imports = True +[mypy-onnx.*] +ignore_missing_imports = True + [mypy-matplotlib.*] ignore_missing_imports = True @@ -298,14 +301,5 @@ ignore_missing_imports = True # Third party dependencies that are optional. # -[mypy-onnx.*] -ignore_missing_imports = True - -[mypy-onnxruntime.*] -ignore_missing_imports = True - -[mypy-onnxscript.*] -ignore_missing_imports = True - [mypy-redis] ignore_missing_imports = True \ No newline at end of file diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index 18b1a05a70046..e8e564d68040d 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -203,5 +203,222 @@ def test_serialize_succeeds_when_model_greater_than_2gb(self): serializer.serialize(onnx_program, io.BytesIO()) +class TestONNXExportWithDynamo(common_utils.TestCase): + def test_args_normalization_with_no_kwargs(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), + ( + torch.randn(1, 1, 2), + torch.randn(1, 1, 2), + ), + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_is_tensor_not_tuple(self): + exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),)) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModel(), torch.randn(1, 1, 2), dynamo=True + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_kwargs(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_empty_dict_at_the_tail(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": { + 0: torch.export.Dim("customx_dim_0"), + 1: torch.export.Dim("customx_dim_1"), + 2: torch.export.Dim("customx_dim_2"), + }, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, + "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, + }, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": { + 0: torch.export.Dim("customx_dim_0"), + 1: torch.export.Dim("customx_dim_1"), + 2: torch.export.Dim("customx_dim_2"), + }, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "x": [0, 1, 2], + "b": [0, 1, 2], + }, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_supports_partial_dynamic_shapes(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": None, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "b": [0, 1, 2], + }, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_shapes_hit_constraints_in_dynamo(self): + # SampleModelTwoInputs has constraints becuse of add of two inputs, + # so the two input shapes are related. + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "Constraints violated", + ): + _ = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(2, 2, 3), torch.randn(2, 2, 3)), + dynamic_axes={ + "x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"}, + "b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"}, + }, + dynamo=True, + ) + + def test_saved_f_exists_after_export(self): + with common_utils.TemporaryFileName(suffix=".onnx") as path: + _ = torch.onnx.export( + SampleModel(), torch.randn(1, 1, 2), path, dynamo=True + ) + self.assertTrue(os.path.exists(path)) + + def test_raises_error_when_input_is_script_module(self): + class ScriptModule(torch.jit.ScriptModule): + def forward(self, x): + return x + + with self.assertRaisesRegex( + TypeError, + "Dynamo export does not support ScriptModule or ScriptFunction.", + ): + _ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/exporter/README.md b/test/onnx/exporter/README.md deleted file mode 100644 index 7ad65ca338b1d..0000000000000 --- a/test/onnx/exporter/README.md +++ /dev/null @@ -1 +0,0 @@ -Directory for all ExportedProgram exporter logic. diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py deleted file mode 100644 index 157ea1197b634..0000000000000 --- a/test/onnx/exporter/test_api.py +++ /dev/null @@ -1,120 +0,0 @@ -# Owner(s): ["module: onnx"] -"""Simple API tests for the ONNX exporter.""" - -from __future__ import annotations - -import os - -import torch -from torch.onnx._internal import exporter -from torch.testing._internal import common_utils - - -class SampleModel(torch.nn.Module): - def forward(self, x): - y = x + 1 - z = y.relu() - return (y, z) - - -class SampleModelTwoInputs(torch.nn.Module): - def forward(self, x, b): - y = x + b - z = y.relu() - return (y, z) - - -class SampleModelForDynamicShapes(torch.nn.Module): - def forward(self, x, b): - return x.relu(), b.sigmoid() - - -class TestExportAPIDynamo(common_utils.TestCase): - """Tests for the ONNX exporter API when dynamo=True.""" - - def test_args_normalization_with_no_kwargs(self): - onnx_program = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_args_normalization_with_kwargs(self): - onnx_program = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_args_normalization_with_empty_dict_at_the_tail(self): - onnx_program = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): - onnx_program = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, - "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, - }, - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): - onnx_program = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "x": [0, 1, 2], - "b": [0, 1, 2], - }, - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_dynamic_axes_supports_partial_dynamic_shapes(self): - onnx_program = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "b": [0, 1, 2], - }, - dynamo=True, - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - def test_saved_f_exists_after_export(self): - with common_utils.TemporaryFileName(suffix=".onnx") as path: - _ = torch.onnx.export( - SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True - ) - self.assertTrue(os.path.exists(path)) - - def test_export_supports_script_module(self): - class ScriptModule(torch.nn.Module): - def forward(self, x): - return x - - onnx_program = torch.onnx.export( - torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),), dynamo=True - ) - assert onnx_program - exporter.verify_onnx_program(onnx_program) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 5433540aeb2b6..0f5c2dc5a358a 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -286,25 +286,6 @@ def onerror(modname): # do not get imported by public code. private_allowlist = { "torch._inductor.codegen.cuda.cuda_kernel", - # TODO(#133647): Remove the onnx._internal entries after - # onnx and onnxscript are installed in CI. - "torch.onnx._internal.exporter", - "torch.onnx._internal.exporter._analysis", - "torch.onnx._internal.exporter._building", - "torch.onnx._internal.exporter._capture_strategies", - "torch.onnx._internal.exporter._compat", - "torch.onnx._internal.exporter._core", - "torch.onnx._internal.exporter._decomp", - "torch.onnx._internal.exporter._dispatching", - "torch.onnx._internal.exporter._fx_passes", - "torch.onnx._internal.exporter._ir_passes", - "torch.onnx._internal.exporter._isolated", - "torch.onnx._internal.exporter._onnx_program", - "torch.onnx._internal.exporter._registration", - "torch.onnx._internal.exporter._reporting", - "torch.onnx._internal.exporter._schemas", - "torch.onnx._internal.exporter._tensors", - "torch.onnx._internal.exporter._verification", "torch.onnx._internal.fx._pass", "torch.onnx._internal.fx.analysis", "torch.onnx._internal.fx.analysis.unsupported_nodes", diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b1bf988f1976e..0cb563cd5cc08 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,63 +1,4 @@ # mypy: allow-untyped-defs -from __future__ import annotations - - -__all__ = [ - # Modules - "symbolic_helper", - "utils", - "errors", - # All opsets - "symbolic_caffe2", - "symbolic_opset7", - "symbolic_opset8", - "symbolic_opset9", - "symbolic_opset10", - "symbolic_opset11", - "symbolic_opset12", - "symbolic_opset13", - "symbolic_opset14", - "symbolic_opset15", - "symbolic_opset16", - "symbolic_opset17", - "symbolic_opset18", - "symbolic_opset19", - "symbolic_opset20", - # Enums - "ExportTypes", - "OperatorExportTypes", - "TrainingMode", - "TensorProtoDataType", - "JitScalarType", - # Public functions - "export", - "export_to_pretty_string", - "is_in_onnx_export", - "select_model_mode_for_export", - "register_custom_op_symbolic", - "unregister_custom_op_symbolic", - "disable_log", - "enable_log", - # Errors - "CheckerError", # Backwards compatibility - # Dynamo Exporter - "DiagnosticOptions", - "ExportOptions", - "ONNXProgram", - "ONNXProgramSerializer", - "ONNXRuntimeOptions", - "InvalidExportOptionsError", - "OnnxExporterError", - "OnnxRegistry", - "dynamo_export", - "enable_fake_mode", - # DORT / torch.compile - "is_onnxrt_backend_supported", -] - -from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING - -import torch from torch import _C from torch._C import _onnx as _C_onnx from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode @@ -75,6 +16,7 @@ _optimize_graph, _run_symbolic_function, _run_symbolic_method, + export, export_to_pretty_string, is_in_onnx_export, register_custom_op_symbolic, @@ -120,8 +62,58 @@ ) -if TYPE_CHECKING: - import os +__all__ = [ + # Modules + "symbolic_helper", + "utils", + "errors", + # All opsets + "symbolic_caffe2", + "symbolic_opset7", + "symbolic_opset8", + "symbolic_opset9", + "symbolic_opset10", + "symbolic_opset11", + "symbolic_opset12", + "symbolic_opset13", + "symbolic_opset14", + "symbolic_opset15", + "symbolic_opset16", + "symbolic_opset17", + "symbolic_opset18", + "symbolic_opset19", + "symbolic_opset20", + # Enums + "ExportTypes", + "OperatorExportTypes", + "TrainingMode", + "TensorProtoDataType", + "JitScalarType", + # Public functions + "export", + "export_to_pretty_string", + "is_in_onnx_export", + "select_model_mode_for_export", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", + "disable_log", + "enable_log", + # Errors + "CheckerError", # Backwards compatibility + # Dynamo Exporter + "DiagnosticOptions", + "ExportOptions", + "ONNXProgram", + "ONNXProgramSerializer", + "ONNXRuntimeOptions", + "InvalidExportOptionsError", + "OnnxExporterError", + "OnnxRegistry", + "dynamo_export", + "enable_fake_mode", + # DORT / torch.compile + "is_onnxrt_backend_supported", +] # Set namespace for exposed private names ExportTypes.__module__ = "torch.onnx" @@ -145,257 +137,6 @@ producer_version = _C_onnx.PRODUCER_VERSION -def export( - model: torch.nn.Module - | torch.export.ExportedProgram - | torch.jit.ScriptModule - | torch.jit.ScriptFunction, - args: tuple[Any, ...], - f: str | os.PathLike | None = None, - *, - kwargs: dict[str, Any] | None = None, - export_params: bool = True, - verbose: bool | None = None, - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - opset_version: int | None = None, - dynamic_axes: Mapping[str, Mapping[int, str]] - | Mapping[str, Sequence[int]] - | None = None, - keep_initializers_as_inputs: bool = False, - dynamo: bool = False, - # Dynamo only options - external_data: bool = True, - dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, - report: bool = False, - verify: bool = False, - profile: bool = False, - dump_exported_program: bool = False, - artifacts_dir: str | os.PathLike = ".", - fallback: bool = False, - # Deprecated options - training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, - operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, - do_constant_folding: bool = True, - custom_opsets: Mapping[str, int] | None = None, - export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, - autograd_inlining: bool = True, - **_: Any, # ignored options -) -> Any | None: - r"""Exports a model into ONNX format. - - Args: - model: The model to be exported. - args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the - exported model; any Tensor arguments will become inputs of the exported model, - in the order they occur in the tuple. - f: Path to the output ONNX model file. E.g. "model.onnx". - kwargs: Optional example keyword inputs. - export_params: If false, parameters (weights) will not be exported. - verbose: Whether to enable verbose logging. - input_names: names to assign to the input nodes of the graph, in order. - output_names: names to assign to the output nodes of the graph, in order. - opset_version: The version of the - `default (ai.onnx) opset `_ - to target. Must be >= 7. - dynamic_axes: - - By default the exported model will have the shapes of all input and output tensors - set to exactly match those given in ``args``. To specify axes of tensors as - dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: - - * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or - ``output_names``. - * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a - list, each element is an axis index. - - For example:: - - class SumModule(torch.nn.Module): - def forward(self, x): - return torch.sum(x, dim=1) - - torch.onnx.export( - SumModule(), - (torch.ones(2, 2),), - "onnx.pb", - input_names=["x"], - output_names=["sum"] - ) - - Produces:: - - input { - name: "x" - ... - shape { - dim { - dim_value: 2 # axis 0 - } - dim { - dim_value: 2 # axis 1 - ... - output { - name: "sum" - ... - shape { - dim { - dim_value: 2 # axis 0 - ... - - While:: - - torch.onnx.export( - SumModule(), - (torch.ones(2, 2),), - "onnx.pb", - input_names=["x"], - output_names=["sum"], - dynamic_axes={ - # dict value: manually named axes - "x": {0: "my_custom_axis_name"}, - # list value: automatic names - "sum": [0], - } - ) - - Produces:: - - input { - name: "x" - ... - shape { - dim { - dim_param: "my_custom_axis_name" # axis 0 - } - dim { - dim_value: 2 # axis 1 - ... - output { - name: "sum" - ... - shape { - dim { - dim_param: "sum_dynamic_axes_1" # axis 0 - ... - - keep_initializers_as_inputs: If True, all the - initializers (typically corresponding to model weights) in the - exported graph will also be added as inputs to the graph. If False, - then initializers are not added as inputs to the graph, and only - the user inputs are added as inputs. - - Set this to True if you intend to supply model weights at runtime. - Set it to False if the weights are static to allow for better optimizations - (e.g. constant folding) by backends/runtimes. - - dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. - external_data: Whether to save the model weights as an external data file. - This is required for models with large weights that exceed the ONNX file size limit (2GB). - When False, the weights are saved in the ONNX file with the model architecture. - dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to - :func:`torch.export.export` for more details. - report: Whether to generate a markdown report for the export process. - verify: Whether to verify the exported model using ONNX Runtime. - profile: Whether to profile the export process. - dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. - This is useful for debugging the exporter. - artifacts_dir: The directory to save the debugging artifacts like the report and the serialized - exported program. - fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. - - training: Deprecated option. Instead, set the training mode of the model before exporting. - operator_export_type: Deprecated option. Only ONNX is supported. - do_constant_folding: Deprecated option. The exported graph is always optimized. - custom_opsets: Deprecated. - A dictionary: - - * KEY (str): opset domain name - * VALUE (int): opset version - - If a custom opset is referenced by ``model`` but not mentioned in this dictionary, - the opset version is set to 1. Only custom opset domain name and version should be - indicated through this argument. - export_modules_as_functions: Deprecated option. - - Flag to enable - exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the - particular types of modules to export as local functions in ONNX. - This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because - ``opset_version`` < 15 implies IR version < 8, which means no local function support. - Module variables will be exported as function attributes. There are two categories of function - attributes. - - 1. Annotated attributes: class variables that have type annotations via - `PEP 526-style `_ - will be exported as attributes. - Annotated attributes are not used inside the subgraph of ONNX local function because - they are not created by PyTorch JIT tracing, but they may be used by consumers - to determine whether or not to replace the function with a particular fused kernel. - - 2. Inferred attributes: variables that are used by operators inside the module. Attribute names - will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from - python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. - - * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. - * ``True``: export all ``nn.Module`` forward calls as local function nodes. - * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, - only if the type of the ``nn.Module`` is found in the set. - autograd_inlining: Deprecated. - Flag used to control whether to inline autograd functions. - Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. - """ - if dynamo is True or isinstance(model, torch.export.ExportedProgram): - from torch.onnx._internal import exporter - - if isinstance(args, torch.Tensor): - args = (args,) - return exporter.export_compat( - model, - args, - f, - kwargs=kwargs, - export_params=export_params, - verbose=verbose, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - external_data=external_data, - dynamic_shapes=dynamic_shapes, - report=report, - verify=verify, - profile=profile, - dump_exported_program=dump_exported_program, - artifacts_dir=artifacts_dir, - fallback=fallback, - ) - else: - from torch.onnx.utils import export - - export( - model, - args, - f, # type: ignore[arg-type] - kwargs=kwargs, - export_params=export_params, - verbose=verbose is True, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - training=training, - operator_export_type=operator_export_type, - do_constant_folding=do_constant_folding, - custom_opsets=custom_opsets, - export_modules_as_functions=export_modules_as_functions, - autograd_inlining=autograd_inlining, - ) - return None - - @_deprecation.deprecated( since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead" ) diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py deleted file mode 100644 index d08d16012213c..0000000000000 --- a/torch/onnx/_internal/_lazy_import.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Utility to lazily import modules.""" -# mypy: allow-untyped-defs -from __future__ import annotations - -import importlib -from typing import Any, TYPE_CHECKING - - -class _LazyModule: - """Lazily import a module.""" - - def __init__(self, module_name: str) -> None: - self._name = module_name - self._module: Any = None - - def __repr__(self) -> str: - return f"" - - def __getattr__(self, attr): - if self._module is None: - self._module = importlib.import_module(".", self._name) - return getattr(self._module, attr) - - -# Import the following modules during type checking to enable code intelligence features, -# such as auto-completion in tools like pylance, even when these modules are not explicitly -# imported in user code. -# NOTE: Add additional used imports here. -if TYPE_CHECKING: - import onnx - import onnxscript - - onnxscript_ir = onnxscript.ir - -else: - onnx = _LazyModule("onnx") - onnxscript = _LazyModule("onnxscript") - onnxscript_ir = _LazyModule("onnxscript.ir") diff --git a/torch/onnx/_internal/exporter/__init__.py b/torch/onnx/_internal/exporter/__init__.py deleted file mode 100644 index 3bf21aa01dd41..0000000000000 --- a/torch/onnx/_internal/exporter/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -__all__ = [ - "ONNXRegistry", - "ONNXProgram", - "analyze", - "export", - "exported_program_to_ir", - "verify_onnx_program", - "export_compat", -] - -from ._analysis import analyze -from ._compat import export_compat -from ._core import export, exported_program_to_ir -from ._onnx_program import ONNXProgram -from ._registration import ONNXRegistry -from ._verification import verify_onnx_program diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py deleted file mode 100644 index 43a6519482663..0000000000000 --- a/torch/onnx/_internal/exporter/_analysis.py +++ /dev/null @@ -1,250 +0,0 @@ -"""Compatibility analyzer for PyTorch models.""" - -# mypy: allow-untyped-defs -# flake8: noqa: B950 We do not need flake8 as it complains line length -from __future__ import annotations - -import dataclasses -import textwrap -import traceback -from collections import defaultdict -from typing import TYPE_CHECKING - -import onnxscript - -import torch -import torch._export.serde.schema -from torch.export import graph_signature -from torch.onnx._internal.exporter import _dispatching, _registration - - -if TYPE_CHECKING: - import torch.fx - - -@dataclasses.dataclass -class ModelInfo: - """Information about the model.""" - - parameter_count: defaultdict[torch.dtype, int] = dataclasses.field( - default_factory=lambda: defaultdict(int) - ) - buffer_count: defaultdict[torch.dtype, int] = dataclasses.field( - default_factory=lambda: defaultdict(int) - ) - fx_node_count: int = 0 - fx_node_op_count: defaultdict[str, int] = dataclasses.field( - default_factory=lambda: defaultdict(int) - ) - fx_node_target_count: defaultdict[str, int] = dataclasses.field( - default_factory=lambda: defaultdict(int) - ) - dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field( - default_factory=list - ) - inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( - default_factory=dict - ) - outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( - default_factory=dict - ) - - -def _count_weights( - exported_program: torch.export.ExportedProgram, -) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]: - """Count the size of the parameters in the exported program.""" - - parameter_count: defaultdict[torch.dtype, int] = defaultdict(int) - buffer_count: defaultdict[torch.dtype, int] = defaultdict(int) - for parameter in exported_program.parameters(): - dtype = parameter.dtype - parameter_count[dtype] += parameter.numel() - - for buffer in exported_program.buffers(): - dtype = buffer.dtype - buffer_count[dtype] += buffer.numel() - - return parameter_count, buffer_count - - -def _format_model_info(model_info: ModelInfo) -> str: - """Format the information about the model.""" - lines = [ - textwrap.dedent( - f"""\ - PyTorch ONNX Conversion Analysis - - ## Model Information - - The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters). - Number of parameters per dtype: - ```python - {model_info.parameter_count} - ``` - Number of buffers per dtype: - ```python - {model_info.buffer_count} - ``` - """ - ), - "Inputs:", - *[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()], - "", - "Outputs:", - *[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()], - "", - f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:", - ] - for op, count in model_info.fx_node_op_count.items(): - lines.append(f"- `{op}`: {count}") - lines.append("\n") - lines.append("Of the call_function nodes, the counts of operators used are:\n") - sorted_targets = sorted( - model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True - ) - for target, count in sorted_targets: - lines.append(f"- `{target}`: {count}") - - lines.append("") - lines.append("## ONNX Conversion Information") - lines.append("") - - if model_info.dispatch_failures: - lines.append( - "The model contains operators the dispatcher could not find registered ONNX decompositions for. " - "This may be due to missing implementations, decompositions not registered " - "correctly, or a bug in the dispatcher." - ) - lines.append("") - lines.append("Errors grouped by operator:\n") - - target_to_nodes = defaultdict(list) - for node, _ in model_info.dispatch_failures: - target_to_nodes[str(node.target)].append(node) - - target_to_messages = {} - for node, message in model_info.dispatch_failures: - if str(node.target) not in target_to_messages: - target_to_messages[str(node.target)] = message - - for target, nodes in sorted( - target_to_nodes.items(), key=lambda x: x[0], reverse=True - ): - message = textwrap.indent( - f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`", - " ", - ) - lines.append(f"- `{target}`: {message}") - else: - lines.append("All operators in the model have registered ONNX decompositions.") - - return "\n".join(lines) - - -def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]: - """Get the input and output specs of the exported program.""" - - nodes: dict[str, torch.fx.Node] = { - node.name: node for node in exported_program.graph.nodes - } - user_inputs = [ - spec - for spec in exported_program.graph_signature.input_specs - if spec.kind == graph_signature.InputKind.USER_INPUT - ] - user_outputs = [ - spec - for spec in exported_program.graph_signature.output_specs - if spec.kind == graph_signature.OutputKind.USER_OUTPUT - ] - inputs: dict[str, torch._export.serde.schema.TensorMeta] = {} - outputs: dict[str, torch._export.serde.schema.TensorMeta] = {} - for spec in user_inputs: - if isinstance(spec.arg, graph_signature.ConstantArgument): - continue - name = spec.arg.name - # FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type - inputs[name] = nodes[name].meta["tensor_meta"] - for spec in user_outputs: - if isinstance(spec.arg, graph_signature.ConstantArgument): - continue - name = spec.arg.name - outputs[name] = nodes[name].meta["tensor_meta"] - return inputs, outputs - - -def _count_fx_targets( - exported_program: torch.export.ExportedProgram, -) -> defaultdict[str, int]: - """Count the number of targets for each node in the exported program.""" - fx_node_target_count: defaultdict[str, int] = defaultdict(int) - for node in exported_program.graph.nodes: - if node.op == "call_function": - fx_node_target_count[str(node.target)] += 1 - return fx_node_target_count - - -def analyze( - exported_program: torch.export.ExportedProgram, - registry: _registration.ONNXRegistry | None = None, - file=None, -) -> None: - """Analyze the compatibility of the exported program.""" - # Get basic information about the model - model_info = ModelInfo() - model_info.parameter_count, model_info.buffer_count = _count_weights( - exported_program - ) - model_info.fx_node_count = len(exported_program.graph.nodes) - model_info.fx_node_target_count = _count_fx_targets(exported_program) - inputs, outputs = _get_io_specs(exported_program) - model_info.inputs = inputs - model_info.outputs = outputs - - if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) - - # Try to find ops for every node in the graph - for node in exported_program.graph.nodes: - model_info.fx_node_op_count[node.op] += 1 - if node.op == "call_function": - try: - onnx_function, message = _dispatching.dispatch(node, registry) - except Exception as e: - message = "Critical Error in dispatcher:\n" - formatted_exception = "\n".join( - traceback.format_exception(type(e), e, e.__traceback__) - ) - message += f"```pytb\n{formatted_exception}\n```\n" - onnx_function = None - if onnx_function is None: - model_info.dispatch_failures.append((node, message)) - - # Print the results - report = _format_model_info(model_info) - print(report, file=file, flush=True) - - -def compare_ops( - program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram -) -> tuple[set[str], set[str]]: - """Compare and get unique ops in two exported programs. - - Args: - program_a: The first exported program. - program_b: The second exported program. - - Returns: - A tuple of two sets, where the first set contains the unique ops in the first program - and the second set contains the unique ops in the second program. - """ - program_a_ops = set(_count_fx_targets(program_a)) - program_b_ops = set(_count_fx_targets(program_b)) - return program_a_ops - program_b_ops, program_b_ops - program_a_ops diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py deleted file mode 100644 index ddda83c3718ab..0000000000000 --- a/torch/onnx/_internal/exporter/_building.py +++ /dev/null @@ -1,516 +0,0 @@ -"""NOTES: - -We need a typing module that will handling Python to ONNX type promotion for use. -For example, if we have torch.ops.aten.add(Tensor, 1.0), we need to promote 1.0 -to the same type as Tensor. The same thing needs to work for -torch.ops.aten.add(1.0, Tensor) as well, which means we need a mechanism to` -""" - -# mypy: allow-untyped-defs -# mypy: disable-error-code=union-attr -from __future__ import annotations - -import copy -import inspect -import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union - -import onnxscript -from onnxscript import evaluator, ir -from onnxscript.ir import convenience as ir_convenience - -import torch -from torch.onnx._internal.exporter import _schemas, _tensors, errors - - -if TYPE_CHECKING: - import onnx - - -logger = logging.getLogger(__name__) - -# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes -ValidAttributeType = Union[ - ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None -] - -AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType] - - -# Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value -def _construct_named_inputs_and_attrs( - signature: _schemas.OpSignature, - args: Sequence[AllowedArgType], - kwargs: Mapping[str, AllowedArgType], -) -> tuple[dict[str, AllowedArgType], dict[str, ValidAttributeType]]: - """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. - - This function uses the OpSignature to determine which argument in args and kwargs corresponds to - which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are - stored in named_attrs. If an _optional input_ is not provided, it is filled with None. - - Args: - signature: The OpSignature for the node. - args: The positional arguments for the node. - kwargs: The keyword arguments for the node. - - Returns: - A tuple of two mappings: named_inputs and named_attrs. - - Raises: - ValueError: If a required parameter is not provided. - """ - # 1. Construct the (named_inputs, named_attrs) mapping based on (args, kwargs) and the signature. - # a. Loop over all parameters in the signature and args together - # b. Depending on param.is_input, Record named_inputs[param.name] = arg or named_attrs[param.name] = arg - # c. Handle kwargs as well - # d. Fill in None if the input is not provided - named_inputs = {} - named_attrs = {} - reversed_args_stack = list(reversed(args)) - for param in signature.params: - if isinstance(param, _schemas.Parameter): - # Handle inputs - if reversed_args_stack: - # First exhaust the positional arguments - if param.variadic: - # Handle variadic arguments - named_inputs[param.name] = tuple(args) - reversed_args_stack.clear() - else: - named_inputs[param.name] = reversed_args_stack.pop() # type: ignore[assignment] - elif param.name in kwargs: - named_inputs[param.name] = kwargs[param.name] # type: ignore[assignment] - elif param.required: - raise ValueError( - f"Required parameter '{param.name}' is not provided. " - f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." - ) - else: - logger.debug( - "Optional parameter '%s' is not provided. Added as None. Signature: %s", - param.name, - signature, - ) - named_inputs[param.name] = None # type: ignore[assignment] - else: - # Handle attributes - attribute: ValidAttributeType | ir.Attr - assert isinstance( - param, _schemas.AttributeParameter - ), f"Expected AttributeParameter, got {type(param)}" - if reversed_args_stack: - # First exhaust the positional arguments - attribute = reversed_args_stack.pop() # type: ignore[assignment] - elif param.name in kwargs: - attribute = kwargs[param.name] # type: ignore[assignment] - elif param.default is not None: - attribute = param.default - else: - attribute = None - - if attribute is None: - if param.required: - raise ValueError( - f"Required attribute '{param.name}' is not provided. " - f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." - ) - else: - logger.debug( - "Optional attribute '%s' is None. Dropped. Signature: %s", - param.name, - signature, - ) - continue - - if isinstance(attribute, ir.Attr): - # Turn the attribute from an default value into an actual parameter for the node - attr_copied = copy.copy(attribute) - # Make sure the name is the same as the parameter name and not the name of the default parameter - attr_copied.name = param.name - attribute = attr_copied - - if isinstance(attribute, int) and param.type == ir.AttributeType.FLOAT: - # Convert the attribute to float if needed. This happens in PyTorch - # where an attribute marked as float can be passed as an int. - attribute = float(attribute) - named_attrs[param.name] = attribute - return named_inputs, named_attrs # type: ignore[return-value] - - -def _resolve_parameter_dtypes( - signature: _schemas.OpSignature, named_inputs: Mapping[str, AllowedArgType] -) -> Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol]: - """Determine which parameter takes which type. - - Handle non-tensor input corner cases and type promotion. - - Requires: - All ir.Value in name_inputs should have type set. Their type should be - compatible with the type_constraint of the corresponding parameter in the signature. - - Args: - signature: The OpSignature for the node. - named_inputs: The mapping of parameter names to their arguments. - - Returns: - A mapping of Constraint names to ir.TypeProtocol. - """ - # a. Create type_binding: dict[str, ir.TypeProtocol] - # b. Iterate over all named_inputs - # b0. Find the corresponding parameter in the signature - # b1. If the argument is a Python constant, skip. - # b2. If the argument is a ir.Value, Bind {constraint: arg.type}. - type_binding = {} - for name, arg in named_inputs.items(): - param = signature.params_map[name] - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" - if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)): - # Skip the Python constants because we do not know what dtype they should take yet - continue - elif isinstance(arg, ir.Value): - if arg.type is None: - # Skip the ir.Value if the type is not set - continue - # NOTE: We assume arg.type is compatible with the type_constraint - assert arg.type is not None, f"Expected type to be set for {arg}" - # TODO(justinchuby): Implement type promotion logic here. - type_binding[param.type_constraint] = arg.type - return type_binding - - -def _process_python_constants_and_sequences( - signature: _schemas.OpSignature, - named_inputs: dict[str, AllowedArgType], - type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], - constant_farm: dict[ - tuple[ - bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], - ir.DataType, - ], - ir.Value, - ], - opset: onnxscript.values.Opset, -) -> dict[str, ir.Value | None]: - """Convert Python constants to Constant nodes and list to Sequence nodes based on the dtype information. - - The added constants will be replacing values in named_inputs in place. - - Args: - signature: The OpSignature for the node. - named_inputs: The mapping of parameter names to their arguments. - type_binding: A mapping of Constraint names to ir.DataType. - constant_farm: A dictionary of {(py_value, ir.DataType): ir.Value} to store the deduplicated constants. - opset: The Opset to use for creating Constant nodes. - - Returns: - None - """ - # 3. Convert Python constants to Constant nodes based on the dtype information; - # construct sequences - # a. Iterate over all parameters in the signature the second time - # b. If the parameter is in to_resolve_type: - # - If param.constraint in type_binding, - # Get the constant from constant_farm (deduplicated); - # otherwise set named_inputs[param.name] = Constant(value, dtype=type_binding[param.constraint]) - # - Otherwise, set named_inputs[param.name] = Constant(value) - for name, arg in named_inputs.items(): - param = signature.params_map[name] - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" - - if isinstance(arg, ir.Value): - # TODO(justinchuby): Cast the ir.Value here if needed - continue - if ( - isinstance(arg, Sequence) - and len(arg) > 0 - and all(isinstance(val, ir.Value) for val in arg) - ): - # Skip the sequence of ir.Value. This is a variadic input or a Sequence input - # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants - # like `Max(0, ir.Value())` - # We need to convert the Python constants to Constant nodes - # NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float] - continue - # if param.variadic: - # # FXIME: Handle variadic inputs and sequence inputs differently - # raise NotImplementedError - # TODO: Find a way to recursively build constants. Maybe extract the logic out. - # FIXME: I am here - - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" - - if param.type_constraint in type_binding: - # A known dtype is available - dtype = type_binding[param.type_constraint].dtype - elif len(param.type_constraint.allowed_types) == 1: - # Only one type is allowed - dtype = next(iter(param.type_constraint.allowed_types)).dtype - else: - # No dtype information available. Infer from the Python constant - if isinstance(arg, bool): - dtype = ir.DataType.BOOL - elif isinstance(arg, float): - dtype = ir.DataType.FLOAT - elif isinstance(arg, int): - dtype = ir.DataType.INT64 - elif isinstance(arg, str): - dtype = ir.DataType.STRING - elif isinstance(arg, (tuple, list)) and all( - isinstance(val, int) for val in arg - ): - dtype = ir.DataType.INT64 - elif isinstance(arg, (tuple, list)) and any( - isinstance(val, float) for val in arg - ): - # NOTE: if any float is present, the dtype is float - dtype = ir.DataType.FLOAT - elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - dtype = arg.dtype - elif arg is None: - dtype = ir.DataType.UNDEFINED - else: - raise TypeError( - f"Constant input '{arg}' of type '{type(arg)}' is not supported" - ) - - if arg is None: - constant_value = None - elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - # Deduplicate the constants - if isinstance(arg, (tuple, list)): - # Make the arg hashable - arg = tuple(arg) # noqa: PLW2901 - constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] - if constant_value is None: - constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] - constant_value = opset.Constant(value=constant_tensor) - constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] - else: - constant_value = opset.Constant(value=arg) - - named_inputs[param.name] = constant_value - return named_inputs # type: ignore[return-value] - - -def _construct_node( - signature: _schemas.OpSignature, - named_inputs: Mapping[str, ir.Value | None], - named_attrs: Mapping[str, ValidAttributeType], - opset: onnxscript.values.Opset, -) -> ir.Node: - """Construct the node with the inputs and attributes. - - Variadic inputs are flattened. - - Args: - signature: The OpSignature for the node. - named_inputs: The mapping of parameter names to their arguments. When we - do not have the schema of an operator, we do not know the names of - the inputs, in which case the names can be anything because they - are not used in this function. The data structure is passed in for - consistency with the other functions. - named_attrs: The mapping of attribute names to their values. - """ - inputs: list[Any] = [] - # Flatten variadic inputs - for value in named_inputs.values(): - if isinstance(value, Sequence): - inputs.extend(value) - else: - inputs.append(value) - - # Construct and filter out None attributes - attributes = [ - attr - for attr in ir_convenience.convert_attributes(named_attrs) - if attr.value is not None - ] - outputs = [_tensors.SymbolicTensor(opset) for _ in signature.outputs] - return ir.Node( - signature.domain, - signature.name, - inputs=inputs, - attributes=attributes, - outputs=outputs, - ) - - -class OpRecorder(evaluator.Evaluator): - """An onnxscript Evaluator that captures the graph into torchscript.""" - - def __init__( - self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value] - ): - self.nodes: list[ir.Node] = [] - self.opset = opset - self.functions: dict[ir.OperatorIdentifier, onnxscript.OnnxFunction] = {} - self.constant_farm = constant_farm - - def _call_op( - self, - op_signature: _schemas.OpSignature, - named_inputs: dict[str, AllowedArgType], - named_attrs: dict[str, ValidAttributeType], - ) -> Sequence[_tensors.SymbolicTensor]: - """Record nodes for the given opschema and arguments. - - Args: - op_signature: The OpSchema containing the node signature. - named_inputs: The mapping of parameter names to their arguments. - named_attrs: The mapping of attribute names to their values. - """ - type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) - try: - converted_named_inputs = _process_python_constants_and_sequences( - op_signature, named_inputs, type_binding, self.constant_farm, self.opset - ) - except Exception as e: - raise errors.GraphConstructionError( - f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " - f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." - ) from e - - try: - self.nodes.append( - node := _construct_node( - op_signature, converted_named_inputs, named_attrs, self.opset - ) - ) - except Exception as e: - raise errors.GraphConstructionError( - f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. " - f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, " - f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." - ) from e - return node.outputs # type: ignore[return-value] - - def eval( - self, - schema: onnx.defs.OpSchema, - args: Sequence[AllowedArgType], # type: ignore[override] - kwargs: Mapping[str, AllowedArgType], - ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor]: - try: - op_signature = _schemas.OpSignature.from_opschema(schema) - named_inputs, named_attrs = _construct_named_inputs_and_attrs( - op_signature, args, kwargs - ) - # TODO(justinchuby): Handle cast - if schema.name == "CastLike": - assert len(named_inputs) == 2 - # Skip CastLike if the input and output types are the same - src_input = named_inputs["input"] - target_type = named_inputs["target_type"] - - if ( - isinstance(src_input, ir.Value) - and isinstance(target_type, ir.Value) - and src_input.dtype is not None - and target_type.dtype is not None - ): - # dtypes are available - if src_input.dtype == target_type.dtype: - # Same type. No cast needed - return src_input # type: ignore[return-value] - else: - # Create a Cast node - return self.opset.Cast(src_input, to=target_type.dtype) # type: ignore[union-attr,return-value] - - outputs = self._call_op(op_signature, named_inputs, named_attrs) - if len(outputs) == 1: - return outputs[0] - return outputs - except Exception as e: - raise errors.GraphConstructionError( - f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}." - ) from e - - def eval_function( # type: ignore[override] - self, - function: onnxscript.OnnxFunction, - args: Sequence[AllowedArgType], - kwargs: Mapping[str, AllowedArgType], - ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int: - try: - # Special cases for handling IsScalar and Rank - if function.name == "IsScalar": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], _tensors.SymbolicTensor): - if args[0].rank is not None: - return args[0].rank == 0 - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - return False - else: - # Python constants are scalars - return True - if function.name == "Rank": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], _tensors.SymbolicTensor): - if args[0].rank is not None: - return args[0].rank - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - if all(isinstance(arg, (int, float)) for arg in args[0]): - return 1 - else: - # Fall to call add_function_call - pass - else: - # Python constants are scalars - return 0 - - # NOTE: signature is written to function in the registration process - # TODO: Upstream signature to ONNX Function - if hasattr(function, "signature"): - op_signature = function.signature - else: - op_signature = _schemas.OpSignature.from_function( - function, function.function_ir.domain, function.name - ) - - named_inputs, named_attrs = _construct_named_inputs_and_attrs( - op_signature, args, kwargs - ) - - # NOTE: We need to call traceable functions after the _construct_named_inputs_and_attrs - # call because it will filter out the unexpected kwargs for us. - if function.traceable: - # Trace the function call instead of adding the function as a node - return function.function(**named_inputs, **named_attrs) - - outputs = self._call_op(op_signature, named_inputs, named_attrs) - - self.functions[(function.function_ir.domain, function.name, "")] = function - if len(outputs) == 1: - return outputs[0] - return outputs - except Exception as e: - try: - source_file = inspect.getsourcefile(function.function) - _, lineno = inspect.getsourcelines(function.function) - except Exception: - source_file = lineno = None - raise errors.GraphConstructionError( - f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}." - + f" The function is defined at '{source_file}:{lineno}'." - if source_file - else "" - ) from e diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py deleted file mode 100644 index dc511491d6b41..0000000000000 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Strategies for capturing ExportedPrograms.""" - -# mypy: allow-untyped-defs -from __future__ import annotations - -import abc -import dataclasses -import datetime -import pathlib -from typing import Any, Callable, TYPE_CHECKING - -import torch -from torch._export import converter as _torchscript_converter -from torch.utils import _pytree - - -if TYPE_CHECKING: - import os - - -def _verbose_printer(verbose: bool | None) -> Callable[..., None]: - """Prints messages based on `verbose`.""" - if verbose is False: - return lambda *_, **__: None - return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) - - -def _take_first_line(text: str) -> str: - """Take the first line of a text.""" - lines = text.split("\n", maxsplit=1) - first_line = lines[0] - if len(lines) > 1: - first_line += "[...]" - return first_line - - -@dataclasses.dataclass -class Result: - exported_program: torch.export.ExportedProgram | None - strategy: str - exception: Exception | None = None - - @property - def success(self) -> bool: - return self.exported_program is not None - - -class CaptureStrategy(abc.ABC): - """Strategy for capturing a module as ExportedProgram. - - To use a strategy, create an instance and call it with the model, args, kwargs, and dynamic_shapes. - Example:: - - strategy = TorchExportStrategy(verbose=True) - result = strategy(model, args, kwargs, dynamic_shapes) - """ - - def __init__( - self, - *, - verbose: bool = False, - dump: bool = False, - artifacts_dir: str | os.PathLike = ".", - timestamp: str | None = None, - ): - """Initialize the strategy. - - Args: - verbose: Whether to print verbose messages. - dump: Whether to dump the intermediate artifacts to a file. - """ - self._verbose_print = _verbose_printer(verbose) - self._dump = dump - self._artifacts_dir = pathlib.Path(artifacts_dir) - self._timestamp = timestamp or datetime.datetime.now().strftime( - "%Y-%m-%d_%H-%M-%S-%f" - ) - - def __call__( - self, - model: torch.nn.Module | torch.jit.ScriptFunction, - args: tuple[Any, ...], - kwargs: dict[str, Any] | None, - dynamic_shapes, - ) -> Result: - self._enter(model) - if kwargs is None: - kwargs = {} - try: - exported_program = self._capture(model, args, kwargs, dynamic_shapes) - except Exception as e: - self._failure(model, e) - return Result( - exported_program=None, - strategy=self.__class__.__name__, - exception=e, - ) - self._success(model) - return Result(exported_program, strategy=self.__call__.__name__) - - @abc.abstractmethod - def _capture( - self, model, args, kwargs, dynamic_shapes - ) -> torch.export.ExportedProgram: - raise NotImplementedError - - def _enter(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: - return - - def _success(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: - return - - def _failure( - self, model: torch.nn.Module | torch.jit.ScriptFunction, e: Exception - ) -> None: - return - - -class TorchExportStrategy(CaptureStrategy): - def _capture( - self, model, args, kwargs, dynamic_shapes - ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) - - def _enter(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export`..." - ) - - def _success(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export`... ✅" - ) - - def _failure(self, model, e) -> None: - del e # Unused - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export`... ❌" - ) - - -class TorchExportNonStrictStrategy(CaptureStrategy): - def _capture( - self, model, args, kwargs, dynamic_shapes - ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False - ) - - def _enter(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`..." - ) - - def _success(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ✅" - ) - - def _failure(self, model, e) -> None: - del e # Unused - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ❌" - ) - - -class JitTraceConvertStrategy(CaptureStrategy): - def _capture( - self, model, args, kwargs, dynamic_shapes - ) -> torch.export.ExportedProgram: - del dynamic_shapes # Unused - - flattened_args, spec = _pytree.tree_flatten((args, kwargs)) - flattened_args = tuple(flattened_args) - - # Since torch.jit.trace only accepts Tensors as inputs, we filter - # out non-Tensor arguments and reconstruct the arguments after entering - # the WrappedModel. - tensor_placeholder = object() - non_tensor_args = [ - arg if not isinstance(arg, torch.Tensor) else tensor_placeholder - for arg in flattened_args - ] - tensor_args = tuple( - arg for arg in flattened_args if isinstance(arg, torch.Tensor) - ) - - class WrappedModel(torch.nn.Module): - """Wrap the model so that it takes flattened arguments.""" - - def __init__(self, m): - super().__init__() - self.model = m - - def forward(self, *_args): - # Take the non-Tensor arguments list as a starting point and - # replace the tensor_placeholder with the actual tensor arguments - # from _args. - reconstructed_flattened_args = non_tensor_args.copy() - _args_iter = iter(_args) - for i, arg in enumerate(reconstructed_flattened_args): - if arg is tensor_placeholder: - reconstructed_flattened_args[i] = next(_args_iter) - # Unflatten the arguments and kwargs to pass to the model. - unflattened_args, unflattened_kwargs = _pytree.tree_unflatten( - reconstructed_flattened_args, spec - ) - results = self.model(*unflattened_args, **unflattened_kwargs) - if not isinstance(results, tuple): - results = (results,) - flattened_results, _ = _pytree.tree_flatten(results) - if len(flattened_results) == 1: - return flattened_results[0] - return tuple(flattened_results) - - jit_model = torch.jit.trace( - WrappedModel(model), - example_inputs=tensor_args, - check_trace=False, - strict=False, - ) - if self._dump: - program_path = self._artifacts_dir / f"onnx_export_{self._timestamp}.pt" - try: - torch.jit.save(jit_model, program_path) - except Exception as e: - self._verbose_print( - f"Failed to save Torch Script model due to an error: {e}" - ) - else: - self._verbose_print( - f"Torch Script model has been saved to '{program_path}'." - ) - return _torchscript_converter.TS2EPConverter( - jit_model, flattened_args - ).convert() - - def _enter(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with Torch Script..." - ) - - def _success(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with Torch Script... ✅" - ) - - def _failure(self, model, e) -> None: - del e # Unused - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with Torch Script... ❌" - ) - - -class LegacyDynamoStrategy(CaptureStrategy): - """Strategy implemented by the ONNX team using internal dynamo APIs and custom fx passes.""" - - def _capture( - self, model, args, kwargs, dynamic_shapes - ) -> torch.export.ExportedProgram: - # NOTE: Import here to prevent circular dependency - from torch.onnx._internal.fx import diagnostics, passes - - graph_module, _ = torch._dynamo.export( - model, - tracing_mode="symbolic", - dynamic_shapes=dynamic_shapes, - )( - *args, - **kwargs, - ) - torch._dynamo.reset() - - diagnostic_context = diagnostics.DiagnosticContext( - "torch.onnx.export", - torch.__version__, - ) - - flattened_args, _ = _pytree.tree_flatten((args, kwargs)) - flattened_args = tuple(flattened_args) - - # ONNX does not support views and mutations. - # Functionalize to get a semantically equivalent graph without mutations. - graph_module = passes.Functionalize( - diagnostic_context, - graph_module, - enable_dynamic_axes=bool(dynamic_shapes), - ).run(*flattened_args) - - # Input mutations are detected and distilled after `Functionalize` pass. - # Remove them since ONNX inference does not need them. - graph_module = passes.RemoveInputMutation(diagnostic_context, graph_module).run( - *flattened_args - ) - - # Use torch.export to recapture the GraphModule into an ExportedProgram. - return torch.export.export(graph_module, flattened_args) - - def _enter(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with internal Dynamo apis..." - ) - - def _success(self, model) -> None: - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ✅" - ) - - def _failure(self, model, e) -> None: - del e # Unused - model_repr = _take_first_line(repr(model)) - self._verbose_print( - f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ❌" - ) - - -CAPTURE_STRATEGIES = ( - TorchExportStrategy, - TorchExportNonStrictStrategy, - JitTraceConvertStrategy, - LegacyDynamoStrategy, -) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py deleted file mode 100644 index 642f768d7285c..0000000000000 --- a/torch/onnx/_internal/exporter/_compat.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Compatibility functions for the torch.onnx.export API.""" - -# mypy: allow-untyped-defs -# mypy: disable-error-code=attr-defined -from __future__ import annotations - -import inspect -import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING - -import onnx - -import torch -import torch.export -from torch.onnx._internal.exporter import _core, _onnx_program - - -if TYPE_CHECKING: - import os - -logger = logging.getLogger(__name__) - - -def _signature(model) -> inspect.Signature: - should_be_callable = getattr(model, "forward", model) - if callable(should_be_callable): - return inspect.signature(should_be_callable) - raise ValueError("model has no forward method and is not callable") - - -def _from_dynamic_axes_to_dynamic_shapes( - model, - dynamic_axes=None, - input_names: Sequence[str] | None = None, -) -> dict[str, Any] | None: - """ - - dynamic_axes examples: - (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} - (2) dynamic_axes = {"x": [0], "y": [1]} - - these will be converted to dynamic_shapes respectively: - (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} - (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names - - """ - # https://github.com/pytorch/pytorch/pull/128371 - # 1. The function does not need to provide dynamic_shapes to torch.export.export - if dynamic_axes is None: - return None - - if input_names is None: - input_names = [] - - sig = _signature(model) - if len(input_names) > len(sig.parameters): - raise ValueError( - f"Number of input names ({len(input_names)}) should not be greater than " - f"the number of model inputs ({len(sig.parameters)})" - ) - input_names_to_model_inputs = {} - for idx, param_name in enumerate(sig.parameters): - if idx < len(input_names): - input_names_to_model_inputs[input_names[idx]] = param_name - else: - input_names_to_model_inputs[param_name] = param_name - - # NOTE: torch.export.export does not support input names assignment, - # so we need to map input names to model inputs to create dynamic_shapes - # for the exported program - dynamic_shapes_to_exported_program = {} - for input_name, axes in dynamic_axes.items(): - # input_name can be either from inptu_names or from the model inputs - if input_name not in input_names_to_model_inputs: - raise ValueError( - f"dynamix axis: {input_name} is not found in the input names: {input_names}" - ) - model_input_name = input_names_to_model_inputs[input_name] - if isinstance(axes, dict): - dynamic_shapes_to_exported_program[model_input_name] = { - k: torch.export.Dim(v) for k, v in axes.items() - } - elif isinstance(axes, list): - dynamic_shapes_to_exported_program[model_input_name] = { - k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes - } - else: - raise TypeError( - f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" - ) - # torch.export.export needs static dim to present in dynamic_shapes - # for all input tensors, so we need to add them with None - for input_name in sig.parameters: - if input_name not in dynamic_shapes_to_exported_program: - dynamic_shapes_to_exported_program[input_name] = None # type: ignore[assignment] - - return dynamic_shapes_to_exported_program - - -def _get_torch_export_args( - args: tuple[Any, ...], - kwargs: dict[str, Any] | None, -) -> tuple[tuple[Any, ...], dict[str, Any] | None]: - """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" - if not kwargs and args and isinstance(args[-1], dict): - kwargs = args[-1] - args = args[:-1] - return args, kwargs - - -def _convert_version(path: str | os.PathLike, opset_version: int) -> None: - """Convert the ONNX file to a specific version.""" - model = onnx.load(path, load_external_data=False) - model = onnx.version_converter.convert_version(model, opset_version) - onnx.save(model, path) - - -def export_compat( - model: torch.nn.Module - | torch.export.ExportedProgram - | torch.jit.ScriptModule - | torch.jit.ScriptFunction, - args: tuple[Any, ...], - f: str | os.PathLike | None = None, - *, - kwargs: dict[str, Any] | None = None, - export_params: bool = True, - verbose: bool | None = None, - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - opset_version: int | None = None, - dynamic_axes: Mapping[str, Mapping[int, str]] - | Mapping[str, Sequence[int]] - | None = None, - dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, - keep_initializers_as_inputs: bool = False, - external_data: bool = True, - report: bool = False, - verify: bool = False, - profile: bool = False, - dump_exported_program: bool = False, - artifacts_dir: str | os.PathLike = ".", - fallback: bool = False, - **_, -) -> _onnx_program.ONNXProgram | None: - if isinstance(model, torch.export.ExportedProgram): - # We the model is already exported program, so the args, kwargs, and dynamic_shapes - # are not used - dynamic_shapes = dynamic_shapes or {} - else: - args, kwargs = _get_torch_export_args(args, kwargs) - if dynamic_shapes is None and dynamic_axes is not None: - dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( - model, dynamic_axes, input_names - ) - - should_convert_version = False - - try: - onnx_program = _core.export( - model, - args, - kwargs, - registry=None, - dynamic_shapes=dynamic_shapes, - input_names=input_names, - output_names=output_names, - profile=profile, - report=report, - verify=verify, - dump_exported_program=dump_exported_program, - artifacts_dir=artifacts_dir, - verbose=verbose, - ) - - if f is not None: - # Always save the initializers as external data to reduce the size of the ONNX file - onnx_program.save( - f, - include_initializers=export_params, - keep_initializers_as_inputs=keep_initializers_as_inputs, - external_data=external_data, - ) - if ( - opset_version is not None - and opset_version != onnx_program.model.opset_imports.get("") - ): - should_convert_version = True - - except Exception as e: - if fallback: - if verbose is not False: - print( - "[torch.onnx] Falling back to legacy torch.onnx.export due " - f"to the following error: {e}", - ) - torch.onnx.utils.export( - model, # type: ignore[arg-type] - args, - f, # type: ignore[arg-type] - kwargs=kwargs, - export_params=export_params, - input_names=input_names, - output_names=output_names, - opset_version=17, # TODO(justinchuby): Hard coded to 17 for now - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - ) - onnx_program = None - if opset_version is None: - opset_version = 18 - if opset_version != 17: - should_convert_version = True - else: - raise - - if f is not None and should_convert_version: - assert opset_version is not None - if verbose is not False: - print( - f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..." - ) - _convert_version(f, opset_version) - - return onnx_program diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py deleted file mode 100644 index 3d28a0544a80e..0000000000000 --- a/torch/onnx/_internal/exporter/_core.py +++ /dev/null @@ -1,1344 +0,0 @@ -# mypy: allow-untyped-defs -# flake8: noqa: B950 We do not need flake8 as it complains line length -from __future__ import annotations - -import ctypes -import datetime -import inspect -import itertools -import logging -import operator -import pathlib -import textwrap -import traceback -import typing -from typing import Any, Callable, Literal, Sequence - -import onnx - -import onnxscript -import onnxscript.evaluator -import onnxscript.function_libs -import onnxscript.function_libs.torch_lib -import onnxscript.function_libs.torch_lib.registration -from onnxscript import ir -from onnxscript.ir import convenience as ir_convenience - -import torch -import torch.fx -from torch.export import graph_signature -from torch.onnx._internal.exporter import ( - _analysis, - _building, - _capture_strategies, - _dispatching, - _fx_passes, - _ir_passes, - _isolated, - _onnx_program, - _registration, - _reporting, - _tensors, - _verification, - errors, -) - - -if typing.TYPE_CHECKING: - import os - - import numpy as np - - -# Define utilities to convert PyTorch data types so users do not need to specify manually -_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, -} -_BLUE = "\033[96m" -_END = "\033[0m" - -_STEP_ONE_ERROR_MESSAGE = textwrap.dedent( - f"""\ - Failed to export the model with torch.export. {_BLUE}This is step 1/2{_END} of exporting the model to ONNX. Next steps: - - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information. - - Debug `torch.export.export` and summit a PR to PyTorch. - - Create an issue in the PyTorch GitHub repository against the {_BLUE}*torch.export*{_END} component and attach the full error stack as well as reproduction scripts.""" -) - -_STEP_TWO_ERROR_MESSAGE = textwrap.dedent( - f"""\ - Failed to convert the exported program to an ONNX model. {_BLUE}This is step 2/2{_END} of exporting the model to ONNX. Next steps: - - If there is a missing ONNX function, implement it and register it to the registry. - - If there is an internal error during ONNX conversion, debug the error and summit a PR to PyTorch. - - Save the ExportedProgram as a pt2 file and create an error report with `export(..., report=True)`. Create an issue in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component. Attach the pt2 model and the error report.""" -) - -logger = logging.getLogger(__name__) - - -def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: - return _TORCH_DTYPE_TO_ONNX[dtype] - - -class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor, name: str | None = None): - # Pass the tensor as the raw data to ir.Tensor's constructor - super().__init__( - tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name - ) - - def __array__(self, dtype: Any = None) -> np.ndarray: - # numpy() calls __array__ in ir.Tensor - if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).__array__(dtype) - if self.dtype in { - ir.DataType.FLOAT8E4M3FN, - ir.DataType.FLOAT8E4M3FNUZ, - ir.DataType.FLOAT8E5M2, - ir.DataType.FLOAT8E5M2FNUZ, - }: - # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).__array__(dtype) - return self.raw.__array__(dtype) - - def tobytes(self) -> bytes: - # Implement tobytes to support native PyTorch types so we can use types like bloat16 - # Reading from memory directly is also more efficient because - # it avoids copying to a NumPy array - tensor = self.raw.detach().cpu().contiguous() - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) - ) - - -# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L56C1-L62C19 -# class InputKind(Enum): -# USER_INPUT = auto() -# PARAMETER = auto() -# BUFFER = auto() -# CONSTANT_TENSOR = auto() -# CUSTOM_OBJ = auto() -# TOKEN = auto() - -# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L89C1-L96C19 -# class OutputKind(Enum): -# USER_OUTPUT = auto() -# LOSS_OUTPUT = auto() -# BUFFER_MUTATION = auto() -# GRADIENT_TO_PARAMETER = auto() -# GRADIENT_TO_USER_INPUT = auto() -# USER_INPUT_MUTATION = auto() -# TOKEN = auto() - - -def _set_shape_types( - values: Sequence[ir.Value], - meta_vals: Sequence[torch.Tensor], - complex_to_float: bool = True, -) -> None: - if not isinstance(meta_vals, Sequence): - logger.warning( - "Expected meta_vals to be a sequence, but got %s. There may be an internal error.", - meta_vals, - ) - meta_vals = (meta_vals,) - for value, meta_val in zip(values, meta_vals): - _set_shape_type(value, meta_val, complex_to_float=complex_to_float) - - -def _set_shape_type( - value: ir.Value, - meta_val: torch.Tensor | tuple[torch.Tensor], - complex_to_float: bool, -) -> None: - # TODO: Consider using meta["tensor_meta"] for this? Would it be faster? - if isinstance(meta_val, tuple): - logger.warning("Setting shape and type of tensors is not supported yet") - if isinstance(meta_val, torch.Tensor): - # FIXME: Consider shape for complex values - dims = [] - for dim in meta_val.shape: - if isinstance(dim, int): - dims.append(dim) - else: - dims.append(str(dim.node)) - value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype) - if complex_to_float: - if meta_val.dtype == torch.complex64: - value.dtype = ir.DataType.FLOAT - # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts - dims.append(2) - elif meta_val.dtype == torch.complex128: - value.dtype = ir.DataType.DOUBLE - # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts - dims.append(2) - - value.shape = ir.Shape(dims) - elif isinstance(meta_val, (int, torch.SymInt)): - # aten::sym_size output is a int, not a tensor, which stands - # for the size of one dim. We treat it as a scalar. - value.dtype = ir.DataType.INT64 - value.shape = ir.Shape([]) - elif isinstance(meta_val, (bool, torch.SymBool)): - value.dtype = ir.DataType.BOOL - value.shape = ir.Shape([]) - elif isinstance(meta_val, (float, torch.SymFloat)): - value.dtype = ir.DataType.FLOAT - value.shape = ir.Shape([]) - else: - pass - - -def _get_qualified_module_name(cls: Any) -> str: - if isinstance(cls, str): - return cls - module = cls.__module__ - if module is None or module == str.__class__.__module__: - return cls.__name__ - return module + "." + cls.__name__ - - -def _get_node_namespace(node: torch.fx.Node) -> tuple[str, list[str], list[str]]: - """Get the namespace and scope of the node. - - Example:: - - { - 'L__self__': ('', ), - 'L__self___avgpool': ('avgpool', ) - } - - Will yield - - namespace: ": torchvision.models.resnet.ResNet/avgpool: torch.nn.modules.pooling.AdaptiveAvgPool2d/node_name: node_target" - class_hierarchy: ["torchvision.models.resnet.ResNet", "torch.nn.modules.pooling.AdaptiveAvgPool2d", ] - name_scopes: ["", "avgpool", ] - - Args: - node: The node to get the namespace and scope of. - - Returns: - (namespace, class_hierarchy, name_scope) - """ - nn_module_stack = node.meta.get("nn_module_stack") - logger.debug("%s", nn_module_stack) - if nn_module_stack is None: - logger.warning( - "nn_module_stack not found for node '%s'. Skip adding metadata...", - node.name, - ) - return f"{node.name}: {node.target}", [str(node.target)], [node.name] - namespaces = [] - class_hierarchy = [] - name_scopes = [] - for name, nn_module in nn_module_stack.values(): - name_scopes.append(name) - nn_module_name = _get_qualified_module_name(nn_module) - class_hierarchy.append(nn_module_name) - namespaces.append(f"{name}: {_get_qualified_module_name(nn_module)}") - namespaces.append(f"{node.name}: {node.target}") - class_hierarchy.append(str(node.target)) - name_scopes.append(node.name) - - return "/".join(namespaces), class_hierarchy, name_scopes - - -def _set_node_metadata(fx_node: torch.fx.Node, ir_node: ir.Node) -> None: - """Adds namespace and other node metadata to the ONNX node.""" - namespace, class_hierarchy, name_scopes = _get_node_namespace(fx_node) - ir_node.metadata_props["namespace"] = namespace - ir_node.metadata_props["pkg.torch.onnx.class_hierarchy"] = repr(class_hierarchy) - ir_node.metadata_props["pkg.torch.onnx.name_scopes"] = repr(name_scopes) - ir_node.metadata_props["pkg.torch.onnx.fx_node"] = str(fx_node.format_node()) - ir_node.metadata_props["pkg.torch.onnx.stack_trace"] = fx_node.meta.get( - "stack_trace", "" - ) - - -def _handle_getitem_node( - node: torch.fx.Node, node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] -) -> ir.Value: - """Handle a getitem node. - - Add the input value it is getting to the mapping, then return the value. - - There are two cases for this node: - 1. The output is a Sequence (traced), we can simply get the value from the sequence - 2. The output is produced by a SplitToSequence node, we need to get the value from the sequence value - This function only handles the first case - """ - assert len(node.all_input_nodes) == 1 - source = node.all_input_nodes[0] - source_outputs = node_name_to_values[source.name] - assert isinstance( - source_outputs, Sequence - ), f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}" - index = typing.cast(int, node.args[1]) - value = source_outputs[index] - # Save the getitem value to the values mapping to in case - # it is one of the graph outputs - node_name_to_values[node.name] = value - # Rename the name of value with the getitem name. - value.name = node.name - return value - - -def _handle_call_function_node( - graph: ir.Graph, - node: torch.fx.Node, - node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], -) -> None: - """Handle a call_function node. - - Args: - graph: The ONNX graph at construction. - node: The FX node to translate. - node_name_to_values: A mapping of FX node names to their produced ir.Value. - """ - if node.target == operator.getitem: - _handle_getitem_node(node, node_name_to_values) - # Add op to the graph - op = str(node.target) - fx_inputs, attributes, input_names, output_names = _get_inputs_and_attributes(node) - inputs: list[ir.Value | None] = [] - for i, input_ in enumerate(fx_inputs): - if input_ is None: - inputs.append(None) - elif hasattr(input_, "name"): - if isinstance(input_, torch.fx.Node) and input_.target == operator.getitem: - actual_input = _handle_getitem_node(input_, node_name_to_values) - inputs.append(actual_input) - else: - value = node_name_to_values[input_.name] - assert not isinstance(value, Sequence) - inputs.append(value) - else: - attributes[f"arg_{i}"] = input_ - - outputs = [ir.Value(name=name) for name in output_names] - if len(outputs) > 1: - _set_shape_types(outputs, node.meta["val"], complex_to_float=False) - node_name_to_values[node.name] = outputs - else: - _set_shape_type(outputs[0], node.meta["val"], complex_to_float=False) - node_name_to_values[node.name] = outputs[0] - ir_node = ir.Node( - "pkg.torch.ops", - op, - inputs, - attributes=ir_convenience.convert_attributes(attributes), - outputs=outputs, - name=node.name, - ) - ir_node.meta["node"] = node - ir_node.metadata_props["pkg.torch.onnx.input_names"] = repr(input_names) - # Record the nn.Module stack for the node - _set_node_metadata(node, ir_node) - - graph.append(ir_node) - - -def _convert_fx_arg_to_onnx_arg( - arg, node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] -) -> Any: - """Convert an FX argument to an ONNX compatible argument. - - This function - - Converts a torch dtype to an integer - - Converts a torch device/memory_format/layout to a string - - Converts a torch.fx.Node to an ir.Value - - Converts a sequence of torch.fx.Node to a sequence of ir.Value - """ - if arg is None: - # None arguments are not modified because when the arg is an ONNX input - # we need to preserve the None value; when the arg is an ONNX attribute, - # we want to drop the value. - # The actual dropping of a None attribute value is done by OpRecorder - return None - if hasattr(arg, "name"): - if isinstance(arg, torch.fx.Node) and arg.target == operator.getitem: - source = arg.all_input_nodes[0] - source_outputs = node_name_to_values[source.name] - if isinstance(source_outputs, Sequence): - # If the node is getting an input from another node, get the actual value the node is retrieving - return _handle_getitem_node(arg, node_name_to_values) - else: - # `source_outputs` is a sequence(tensor()) value and we need to - # use SequenceAt to get the value. This is handled by torchlib - pass - # If the input is a node, get the value from the mapping - return node_name_to_values[arg.name] - if isinstance(arg, (list, tuple)): - return [_convert_fx_arg_to_onnx_arg(elem, node_name_to_values) for elem in arg] - if isinstance(arg, (torch.device, torch.memory_format, torch.layout)): - return str(arg) - if isinstance(arg, torch.dtype): - return _torch_dtype_to_onnx_dtype(arg) - # Maybe a Python value - return arg - - -def _get_onnxscript_opset(opset_version: int) -> onnxscript.values.Opset: - return onnxscript.values.Opset("", opset_version) - - -def _handle_call_function_node_with_lowering( - model: ir.Model, - node: torch.fx.Node, - node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], - constant_farm: dict[Any, ir.Value], - registry: _registration.ONNXRegistry, - opset: onnxscript.values.Opset, -) -> None: - if node.target == operator.getitem: - source = node.all_input_nodes[0] - source_outputs = node_name_to_values[source.name] - if isinstance(source_outputs, Sequence): - _handle_getitem_node(node, node_name_to_values) - return - else: - # `source_outputs` is a sequence(tensor()) value and we need to - # use SequenceAt to get the value. This is handled by torchlib - pass - - # Find the matching ONNX overload for the node - # NOTE: Create different registries for different ONNX opset versions - # TODO: Log the message here to expose false positives - onnx_function, message = _dispatching.dispatch(node, registry) - - if onnx_function is None: - # TODO(justinchuby): Fall back to ATen op or do something else? - raise errors.DispatchError( - f"No ONNX function found for {node.target!r}. Failure message: {message}" - ) - - # Map FX inputs to ONNX inputs and fill optional inputs. - # torch_args and torch_kwargs are for op-level validation - fx_args = node.args - fx_kwargs = node.kwargs - - # Replace the input FX nodes with ONNX values - onnx_args = [ - _convert_fx_arg_to_onnx_arg(input_, node_name_to_values) for input_ in fx_args - ] - - onnx_kwargs = {} - for key, value in fx_kwargs.items(): - onnx_kwargs[key] = _convert_fx_arg_to_onnx_arg(value, node_name_to_values) - if key == "dtype" and onnx_kwargs[key] is None: - # Set dtype to -1 if it is None - onnx_kwargs[key] = -1 - - with onnxscript.evaluator.default_as( - tracer := _building.OpRecorder(opset, constant_farm) - ): - try: - outputs = onnx_function(*onnx_args, **onnx_kwargs) - except Exception as e: - raise errors.GraphConstructionError( - f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" - ) from e - - # NOTE: Instead of using the output names from node.target._schema, - # we always use the index if there are more than one outputs so the - # names can be programmatically reconstructed. This is useful for - # comparing values from the ONNX graph with those from the FX graph. - # - # When there are multiple outputs, the output names will be - # node_name__0, node_name__1, etc. - if isinstance(outputs, Sequence): - _set_shape_types(outputs, node.meta["val"], complex_to_float=True) - node_name_to_values[node.name] = outputs - for i, output in enumerate(outputs): - output.name = f"{node.name}__{i}" - else: - _set_shape_type(outputs, node.meta["val"], complex_to_float=True) - node_name_to_values[node.name] = outputs - outputs.name = node.name - - for ir_node in tracer.nodes: - ir_node.meta["node"] = node - # Record the nn.Module stack for the node - _set_node_metadata(node, ir_node) - - # Add the traced nodes to the graph - model.graph.extend(tracer.nodes) - # Add the defined functions to the model - for identifier, onnxscript_function in tracer.functions.items(): - if identifier in model.functions: - continue - # TODO: Get IR function directly when onnxscript is updated - proto = onnxscript_function.to_function_proto() - ir_function = ir.serde.deserialize_function(proto) - model.functions[identifier] = ir_function - if ir_function.domain not in model.opset_imports: - # FIXME: Record the correct opset version of the function - model.opset_imports[ir_function.domain] = 1 - - -def _handle_placeholder_node( - node: torch.fx.Node, - node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], - *, - lower: str, - opset: onnxscript.values.Opset, -) -> None: - # Placeholder nodes are user inputs - # We need to create a new tensor for each user input - # and add it to the graph's inputs - name = node.name - input_ = _tensors.SymbolicTensor(opset, name=name) - input_.meta["node"] = node - _set_shape_type(input_, node.meta["val"], complex_to_float=lower != "none") - node_name_to_values[name] = input_ - # The inputs will be added to the graph later - - -def _add_nodes( - exported_program: torch.export.ExportedProgram, - model: ir.Model, - lower: Literal["at_conversion", "post_conversion", "none"], - registry: _registration.ONNXRegistry, -) -> dict[str, ir.Value | Sequence[ir.Value]]: - node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] = {} - constant_farm: dict[Any, ir.Value] = {} - opset = _get_onnxscript_opset(registry.opset_version) - for node in exported_program.graph.nodes: - logger.debug( - "%s", (node.name, node.args, node.target, node.op, node.type, node.kwargs) - ) - try: - if node.op == "placeholder": - _handle_placeholder_node( - node, - node_name_to_values, - lower=lower, - opset=opset, - ) - elif node.op == "call_function": - if lower == "at_conversion": - _handle_call_function_node_with_lowering( - model, - node, - node_name_to_values, - constant_farm, - registry=registry, - opset=opset, - ) - else: - # No lowering - _handle_call_function_node(model.graph, node, node_name_to_values) - except Exception as e: - raise errors.OnnxConversionError( - f"Error when translating node {node.format_node()}. See the stack trace for more information." - ) from e - return node_name_to_values - - -def _torch_version_integer() -> int: - return int(torch.__version__.replace(".", "").split("dev")[0]) - - -def _get_inputs_and_attributes( - node: torch.fx.Node, -) -> tuple[list[torch.fx.Node | None], dict[str, Any], list[str], list[str]]: - """Find and Fill in the not provided kwargs with default values. - - Returns: - (inputs, attributes, input_names, output_names) - """ - if inspect.isbuiltin(node.target) or isinstance(node.target, str): - inputs = list(node.args) - return inputs, {}, [], [node.name] # type: ignore[return-value] - - # The target should be an ATen operator now - assert hasattr( - node.target, "_schema" - ), f"The target should be an ATen operator now, but node target {node.target} has no schema" - node_schema: torch.FunctionSchema = node.target._schema - - # This function assumes the order of arguments in FX op is the - # same as the order of arguments in TorchScript op. - inputs: list[Any] = [] # type: ignore[no-redef] - input_names: list[str] = [] - attributes: dict[str, Any] = {} - - if inspect.isbuiltin(node.target): - inputs = list(node.args) - else: - for arg, schema_arg in zip(node.args, node_schema.arguments): - if arg is None or isinstance(arg, torch.fx.Node): - inputs.append(arg) - input_names.append(schema_arg.name) - elif isinstance(arg, Sequence) and all( - elem is None or isinstance(elem, torch.fx.Node) for elem in arg - ): - inputs.extend(arg) - input_names.extend([schema_arg.name] * len(arg)) - elif isinstance(arg, torch.device): - attributes[schema_arg.name] = str(arg) - elif isinstance(arg, torch.dtype): - attributes[schema_arg.name] = _torch_dtype_to_onnx_dtype(arg) - else: - attributes[schema_arg.name] = arg - for schema_arg in node_schema.arguments: - if schema_arg.name not in node.kwargs: - continue - kwarg = node.kwargs[schema_arg.name] - if schema_arg.name in { - "layout", - "device", - "requires_grad", - "memory_format", - "implicit", - } or isinstance(kwarg, torch.device): - attr = str(kwarg) - elif isinstance(kwarg, torch.dtype): - attr = _torch_dtype_to_onnx_dtype(kwarg) # type: ignore[assignment] - else: - attr = kwarg # type: ignore[assignment] - - attributes[schema_arg.name] = attr - - output_names = [f"{node.name}_{output.name}" for output in node_schema.returns] - - return inputs, attributes, input_names, output_names # type: ignore[return-value] - - -def _maybe_start_profiler(should_profile: bool) -> Any: - if should_profile: - import pyinstrument # type: ignore[import-not-found] - - profiler = pyinstrument.Profiler(async_mode="disabled") - profiler.start() - return profiler - return None - - -def _maybe_stop_profiler_and_get_result(profiler) -> str | None: - if profiler is None: - return None - profiler.stop() - return profiler.output_text(unicode=True) - - -def _format_exception(e: Exception) -> str: - """Format the full traceback as Python would show it.""" - return "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) - - -def _summarize_exception_stack(e: BaseException) -> str: - """Format the exception stack by showing the text of each exception.""" - causes = [e] - while e.__cause__ is not None: - causes.append(e.__cause__) - e = e.__cause__ - return ( - "\n\n## Exception summary\n\n" - + "⬆️\n".join([f"{type(e)}: {e}\n" for e in reversed(causes)]) - + "\n(Refer to the full stack trace above for more information.)" - ) - - -def _format_exceptions_for_all_strategies( - results: list[_capture_strategies.Result], -) -> str: - """Format all the exceptions from the capture strategies.""" - return "\n".join( - [ - f"# ⚠️ Errors from strategy '{result.strategy}': -----------------------\n\n" - f"{_format_exception(result.exception)}\n" - for result in results - if result.exception is not None - ] - ) - - -def exported_program_to_ir( - exported_program: torch.export.ExportedProgram, - *, - registry: _registration.ONNXRegistry | None = None, - lower: Literal["at_conversion", "post_conversion", "none"] = "at_conversion", -) -> ir.Model: - """Convert an exported program to an ONNX IR model. - - Reference: - - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html - - Args: - exported_program: The exported program to convert. - lower: Whether to lower the graph to core ONNX operators. - at_conversion: Lower whe translating the FX graph to ONNX IR. - post_conversion: Use an IR pass to lower the graph. - none: Do not lower the graph. - registry: The registry of all ONNX Script decomposition. - """ - if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) - if lower != "none": - exported_program = _prepare_exported_program_for_export( - exported_program, registry=registry - ) - return _exported_program_to_onnx_program( - exported_program, registry=registry, lower=lower - ).model - - -def _prepare_exported_program_for_export( - exported_program: torch.export.ExportedProgram, - *, - registry: _registration.ONNXRegistry, -) -> torch.export.ExportedProgram: - """Decompose and apply pre-export transformations to the exported program.""" - # Decompose the graph given the implemented torch ops in ONNX - exported_program = _fx_passes.decompose_with_registry(exported_program, registry) - - graph_module = exported_program.graph_module - # Include explicit type promotion nodes - graph_module = _fx_passes.insert_type_promotion_nodes(graph_module) - graph_module = _fx_passes.remove_assertion_nodes(graph_module) - # TODO(justinchuby): Reassigning the graph module to save some runtime. - # If this does not work, we need to retrace the module with torch.export - exported_program._graph_module = graph_module - return exported_program - - -def _exported_program_to_onnx_program( - exported_program: torch.export.ExportedProgram, - *, - registry: _registration.ONNXRegistry, - lower: Literal["at_conversion", "post_conversion", "none"] = "at_conversion", -) -> _onnx_program.ONNXProgram: - """Convert an exported program to an ONNX Program. - - The exported_program field in the returned ONNXProgram is one that is after - decompositions have been applied. - - Reference: - - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html - - Args: - exported_program: The exported program to convert. The exported program - should be the one that is after decompositions have been applied. - lower: Whether to lower the graph to core ONNX operators. - at_conversion: Lower whe translating the FX graph to ONNX IR. - post_conversion: Use an IR pass to lower the graph. - none: Do not lower the graph. - registry: The registry of all ONNX Script decomposition. - """ - model = ir.Model( - graph=ir.Graph( - [], - [], - nodes=[], - opset_imports={ - "": registry.opset_version, - }, - name="main_graph", - metadata_props={ - "pkg.torch.export.ExportedProgram.graph_signature": str( - exported_program.graph_signature - ), - "pkg.torch.export.ExportedProgram.range_constraints": str( - exported_program.range_constraints - ), - }, - ), - ir_version=9, - producer_name="torch", - producer_version=torch.__version__, - ) - - if lower == "none": - # Add the opset import for the torch ops - model.opset_imports["pkg.torch.ops"] = _torch_version_integer() - # NOTE: Function domains are added when translating nodes when lower="at_conversion" - - # 1. Add all nodes to the graph and create a dictionary of values - values = _add_nodes(exported_program, model, lower=lower, registry=registry) - - # 2. Add user inputs and all parameters/buffers to the graph. - # Since the node names and the tensor names are different, we need to rename - # the nodes to match the tensor names later. For now we will just use the node names. - user_inputs = [ - spec - for spec in exported_program.graph_signature.input_specs - if spec.kind == graph_signature.InputKind.USER_INPUT - ] - non_user_inputs = [ - spec - for spec in exported_program.graph_signature.input_specs - if spec.kind != graph_signature.InputKind.USER_INPUT - ] - - for spec in itertools.chain(user_inputs, non_user_inputs): - # Put the user inputs first and then the parameters/buffers - if isinstance(spec.arg, graph_signature.ConstantArgument): - logger.debug("Skipping constant argument %s", spec.arg) - continue - value_name = spec.arg.name - input_kind = spec.kind - persistent = spec.persistent - value = values[value_name] - - assert not isinstance( - value, Sequence - ), f"Input '{value_name}' should not be a sequence. This is unexpected." - - value.metadata_props[ - "pkg.torch.export.graph_signature.InputSpec.kind" - ] = input_kind.name - value.metadata_props[ - "pkg.torch.export.graph_signature.InputSpec.persistent" - ] = str(persistent) - - if input_kind == graph_signature.InputKind.USER_INPUT: - # Add only user inputs to the graph - # Subsequent passes can decide if they want to add initializers as inputs - model.graph.inputs.append(value) - else: - model.graph.initializers[value_name] = value - - # 3. Add user outputs to the graph and assign metadata to all outputs - user_outputs = [ - spec - for spec in exported_program.graph_signature.output_specs - if spec.kind == graph_signature.OutputKind.USER_OUTPUT - ] - non_user_outputs = [ - spec - for spec in exported_program.graph_signature.output_specs - if spec.kind != graph_signature.OutputKind.USER_OUTPUT - ] - for spec in itertools.chain(user_outputs, non_user_outputs): - if isinstance(spec.arg, graph_signature.ConstantArgument): - logger.warning("Skipping constant argument %s", spec.arg) - continue - value_name = spec.arg.name - output_kind = spec.kind - value = values[value_name] - - if not isinstance(value, (ir.Value, Sequence)): - raise TypeError( - f"Output '{value_name}' should be an ir.Value. Actual type is '{type(value)}': {value!r}. " - "This may be due to an incorrect implementation of the ONNX function that produced this output." - ) - - # The output value may be a sequence, meaning the operator has multiple outputs - _values = (value,) if not isinstance(value, Sequence) else value - - if len(_values) > 1: - logger.warning( - "Model output '%s' has multiple values: %s (output spec: %s). Please make sure this is expected.", - value_name, - _values, - spec, - ) - - for value in _values: - value.metadata_props[ - "pkg.torch.export.graph_signature.OutputSpec.kind" - ] = output_kind.name - if output_kind == graph_signature.OutputKind.USER_OUTPUT: - model.graph.outputs.append(value) - - # 4. Rename the initializers to match the tensor names - for name, param_name in itertools.chain( - exported_program.graph_signature.inputs_to_parameters.items(), - exported_program.graph_signature.inputs_to_buffers.items(), - exported_program.graph_signature.inputs_to_lifted_tensor_constants.items(), - ): - initializer = model.graph.initializers.pop(name) - initializer.name = param_name - # Record the original name so users can search the metadata and correspond - # with the FX graph - initializer.metadata_props["pkg.torch.onnx.original_node_name"] = name - model.graph.initializers[param_name] = initializer - - # 5. Add initializers to the graph - # ExportedProgram stores parameters and buffers in state_dict, - # but non_persistent_buffers and lifted_tensor_constants are not there - # so we need to get them from the name_* apis. - for name, torch_tensor in itertools.chain( - exported_program.named_parameters(), - exported_program.named_buffers(), - exported_program.constants.items(), - ): - initializer = model.graph.initializers.get(name) # type: ignore[assignment] - if initializer is None: - logger.warning("Tensor '%s' is not one of the initializers", name) - continue - if not isinstance(torch_tensor, torch.Tensor): - raise NotImplementedError( - f"Tensor '{name}' should be a torch.Tensor. Actual type is '{type(torch_tensor)}': {torch_tensor!r}. " - "This is unexpected and not yet supported." - ) - ir_tensor = TorchTensor(torch_tensor, name=name) - initializer.const_value = ir_tensor - _set_shape_type( - initializer, - torch_tensor, - complex_to_float=lower != "none", - ) - - # TODO: Decide if we should keep mutated buffers as inputs/outputs - - return _onnx_program.ONNXProgram(model, exported_program) - - -def _verbose_printer(verbose: bool | None) -> Callable[..., None]: - """Prints messages based on `verbose`.""" - if verbose is False: - return lambda *_, **__: None - return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) - - -def export( - model: torch.nn.Module - | torch.export.ExportedProgram - | torch.fx.GraphModule - | torch.jit.ScriptModule - | torch.jit.ScriptFunction, - args: tuple[Any, ...], - kwargs: dict[str, Any] | None = None, - *, - registry: _registration.ONNXRegistry | None = None, - dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - report: bool = False, - verify: bool = False, - profile: bool = False, - dump_exported_program: bool = False, - artifacts_dir: str | os.PathLike = ".", - verbose: bool | None = None, -) -> _onnx_program.ONNXProgram: - """Export a PyTorch model to ONNXProgram. - - Args: - model: The model to export. This can be a PyTorch nn.Module or an ExportedProgram. - args: The arguments to pass to the model. - kwargs: The keyword arguments to pass to the model. - registry: The registry of all ONNX decompositions. - dynamic_shapes: Dynamic shapes in the graph. - input_names: If provided, rename the inputs. - output_names: If provided, rename the outputs. - report: Whether to generate an error report if the export fails. - verify: Whether to verify the ONNX model after exporting. - profile: Whether to profile the export process. When report is True, - the profile result will be saved in the report. Otherwise, the profile - result will be printed. - dump_exported_program: Whether to save the exported program to a file. - artifacts_dir: The directory to save the exported program and error reports. - verbose: Whether to print verbose messages. If None (default), some messages will be printed. - - Returns: - The ONNXProgram with the exported IR graph. - - Raises: - TorchExportError: If the export process fails with torch.export. - OnnxConversionError: If the ExportedProgram to ONNX translation fails. - """ - # Set up the error reporting facilities - timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") - profiler = _maybe_start_profiler(profile) - - # Create the artifacts directory if it does not exist - artifacts_dir = pathlib.Path(artifacts_dir) - if report or profile or dump_exported_program: - artifacts_dir.mkdir(parents=True, exist_ok=True) - - verbose_print = _verbose_printer(verbose) - export_status = _reporting.ExportStatus() - failed_results: list[_capture_strategies.Result] = [] - - program: torch.export.ExportedProgram | None = None - # Step 1: Export the model with torch.export.export if the model is not already an ExportedProgram - if isinstance(model, torch.export.ExportedProgram): - program = model - export_status.torch_export = True - else: - # Convert an nn.Module to an ExportedProgram - # Try everything 🐰 (all paths for getting an ExportedProgram) - # When input is a JIT module, the last strategy will succeed so it is handled - result: _capture_strategies.Result | None = None - for strategy_class in _capture_strategies.CAPTURE_STRATEGIES: - strategy = strategy_class( # type: ignore[abstract] - verbose=verbose is not False, # Treat None as verbose - dump=dump_exported_program, - artifacts_dir=artifacts_dir, - timestamp=timestamp, - ) - result = strategy(model, args, kwargs, dynamic_shapes=dynamic_shapes) - - # Record the status - if strategy_class is _capture_strategies.TorchExportStrategy: - export_status.torch_export = result.success - elif strategy_class is _capture_strategies.TorchExportNonStrictStrategy: - export_status.torch_export_non_strict = result.success - elif strategy_class is _capture_strategies.JitTraceConvertStrategy: - export_status.torch_jit = result.success - - if result.exported_program is not None: - program = result.exported_program - break - else: - failed_results.append(result) - - assert result is not None - if result.exported_program is None: - # If all strategies fail, produce an error report and raise the first error - profile_result = _maybe_stop_profiler_and_get_result(profiler) - - if report: - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - - try: - _reporting.create_torch_export_error_report( - report_path, - _format_exceptions_for_all_strategies(failed_results), - export_status=export_status, - profile_result=profile_result, - ) - except Exception as e_report: - verbose_print( - f"Failed to save error report due to an error: {e_report}" - ) - else: - report_path = None - - first_error = failed_results[0].exception - assert first_error is not None - - # NOTE: We only throw the torch.export (first) exception because we want to - # focus on the torch.export.export error. Errors from other strategies like - # torch.jit.trace is due to the fallback and can be confusing to users. - # We save all errors in the error report. - raise errors.TorchExportError( - _STEP_ONE_ERROR_MESSAGE - + ( - f"\nError report has been saved to '{report_path}'." - if report - else "" - ) - + _summarize_exception_stack(first_error) - ) from first_error - - assert program is not None - - if dump_exported_program: - verbose_print("Dumping ExportedProgram because `dump_exported_program=True`...") - program_path = artifacts_dir / f"onnx_export_{timestamp}.pt2" - try: - torch.export.save(program, program_path) - except Exception as e: - verbose_print(f"Failed to save ExportedProgram due to an error: {e}") - else: - verbose_print(f"ExportedProgram has been saved to '{program_path}'.") - - # Step 2: Convert the exported program to an ONNX model - verbose_print("Translate the graph into ONNX...") - - # Step 2a: Decompose the exported program and insert type promotion nodes - try: - # Build the ONNX function registry - if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) - - # Process the exported program to run decompositions and type promotions etc. - decomposed_program = _prepare_exported_program_for_export( - program, registry=registry - ) - except Exception as e: - export_status.onnx_translation = False - verbose_print("Translate the graph into ONNX... ❌") - profile_result = _maybe_stop_profiler_and_get_result(profiler) - - if report: - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - - # Run the analysis to get the error report - try: - _reporting.create_onnx_export_report( - report_path, - f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", - program, - export_status=export_status, - profile_result=profile_result, - registry=registry, - ) - except Exception: - logger.exception("Failed to save report due to an error.") - else: - report_path = None - - raise errors.OnnxConversionError( - _STEP_TWO_ERROR_MESSAGE - + (f"\nError report has been saved to '{report_path}'." if report else "") - + _summarize_exception_stack(e) - ) from e - - # Step 2b: Translate the decomposed program to ONNX and produce ONNXProgram - if report or profile: - pre_decomp_unique_ops, post_decomp_unique_ops = _analysis.compare_ops( - program, decomposed_program - ) - else: - pre_decomp_unique_ops = None - post_decomp_unique_ops = None - - try: - # Convert the exported program to an ONNX model - onnx_program = _exported_program_to_onnx_program( - decomposed_program, registry=registry - ) - - # Run the ONNX passes - if input_names: - _ir_passes.rename_inputs(onnx_program.model, input_names) - if output_names: - _ir_passes.rename_outputs(onnx_program.model, output_names) - - # TODO(justinchuby): Remove the hack - _ir_passes.add_torchlib_common_imports(onnx_program.model) - - export_status.onnx_translation = True - verbose_print("Translate the graph into ONNX... ✅") - except Exception as e: - export_status.onnx_translation = False - verbose_print("Translate the graph into ONNX... ❌") - profile_result = _maybe_stop_profiler_and_get_result(profiler) - - if report: - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - - try: - assert pre_decomp_unique_ops is not None - assert post_decomp_unique_ops is not None - - # Run the analysis to get the error report - _reporting.create_onnx_export_report( - report_path, - f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", - program, - decomp_comparison=_reporting.format_decomp_comparison( - pre_decomp_unique_ops, post_decomp_unique_ops - ), - export_status=export_status, - profile_result=profile_result, - registry=registry, - ) - verbose_print(f"Export report has been saved to '{report_path}'.") - except Exception: - logger.exception("Failed to save report due to an error.") - else: - report_path = None - - raise errors.OnnxConversionError( - _STEP_TWO_ERROR_MESSAGE - + (f"\nError report has been saved to '{report_path}'." if report else "") - + _summarize_exception_stack(e) - ) from e - - profile_result = _maybe_stop_profiler_and_get_result(profiler) - - if not verify: - # Return if verification is not requested - if report: - try: - assert pre_decomp_unique_ops is not None - assert post_decomp_unique_ops is not None - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - _reporting.create_onnx_export_report( - report_path, - "No errors" - if not failed_results - else _format_exceptions_for_all_strategies(failed_results), - onnx_program.exported_program, - profile_result=profile_result, - export_status=export_status, - decomp_comparison=_reporting.format_decomp_comparison( - pre_decomp_unique_ops, post_decomp_unique_ops - ), - registry=registry, - ) - verbose_print(f"Export report has been saved to '{report_path}'.") - except Exception: - logger.exception("Failed to save report due to an error.") - elif profile and profile_result is not None: - verbose_print("Profile result:") - verbose_print(profile_result) - return onnx_program - - # Step 3: (verify=True) Check the ONNX model with ONNX checker - try: - verbose_print("Run `onnx.checker` on the ONNX model...") - - # TODO: Handle when model is >2GB - - model_proto = onnx_program.model_proto - byte_size = model_proto.ByteSize() - if byte_size < 2 * 1024 * 1024 * 1024: - # The checker may segfault so we need to run it in a separate process - _isolated.safe_call( - onnx.checker.check_model, onnx_program.model_proto, full_check=True # type: ignore[attr-defined] - ) - export_status.onnx_checker = True - verbose_print("Run `onnx.checker` on the ONNX model... ✅") - else: - verbose_print( - f"Run `onnx.checker` on the ONNX model... ⚠️ Skipped because model is too large ({byte_size})." - ) - except Exception as e: - export_status.onnx_checker = False - verbose_print("Run `onnx.checker` on the ONNX model... ❌") - if report: - try: - assert pre_decomp_unique_ops is not None - assert post_decomp_unique_ops is not None - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - _reporting.create_onnx_export_report( - report_path, - f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", - onnx_program.exported_program, - decomp_comparison=_reporting.format_decomp_comparison( - pre_decomp_unique_ops, post_decomp_unique_ops - ), - export_status=export_status, - profile_result=profile_result, - model=onnx_program.model, - registry=registry, - ) - verbose_print(f"Export report has been saved to '{report_path}'.") - except Exception: - logger.exception("Failed to save report due to an error.") - logger.warning( - "Conversion successful but the ONNX model fails ONNX checker. " # noqa: G004 - "Please create an issue " - f"in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component and " - "attach the full error stack as well as reproduction scripts. ", - exc_info=e, - ) - return onnx_program - - # Step 4: (verify=True) Execute the model with ONNX Runtime - try: - verbose_print("Execute the model with ONNX Runtime...") - verification_results = _verification.verify_onnx_program(onnx_program) - verbose_print("Execute the model with ONNX Runtime... ✅") - export_status.onnx_runtime = True - onnx_runtime_error_message = None - except Exception as e: - verbose_print("Execute the model with ONNX Runtime... ❌") - export_status.onnx_runtime = False - onnx_runtime_error_message = _format_exception(e) - verification_message = None - - else: - # Step 5: (verify=True) Validate the output values - verbose_print("Verify output accuracy...") - export_status.output_accuracy = True - for verification_result in verification_results: - # TODO(justinchuby): The threshold is arbitrary right now - if verification_result.absolute_difference >= 5e-3: - logger.warning( - "Output '%s' has a large absolute difference of %f. ", - verification_result.name, - verification_result.absolute_difference, - ) - export_status.output_accuracy = False - if verification_result.relative_difference >= 1e-1: - logger.warning( - "Output '%s' has a large relative difference of %f. ", - verification_result.name, - verification_result.relative_difference, - ) - export_status.output_accuracy = False - if export_status.output_accuracy: - verbose_print("Verify output accuracy... ✅") - else: - verbose_print("Verify output accuracy... ❌") - verification_message = _reporting.format_verification_infos( - verification_results - ) - - if report: - try: - assert pre_decomp_unique_ops is not None - assert post_decomp_unique_ops is not None - - traceback_lines = [] - if failed_results: - traceback_lines.append( - _format_exceptions_for_all_strategies(failed_results) - ) - if onnx_runtime_error_message: - traceback_lines.append( - "# ⚠️ ONNX Runtime error -----------------------" - ) - traceback_lines.append(onnx_runtime_error_message) - if not traceback_lines: - traceback_lines.append("No errors") - - report_path = artifacts_dir / _reporting.construct_report_file_name( - timestamp, export_status - ) - _reporting.create_onnx_export_report( - report_path, - "\n\n".join(traceback_lines), - onnx_program.exported_program, - profile_result=profile_result, - export_status=export_status, - decomp_comparison=_reporting.format_decomp_comparison( - pre_decomp_unique_ops, post_decomp_unique_ops - ), - model=onnx_program.model, - registry=registry, - verification_result=verification_message, - ) - verbose_print(f"Export report has been saved to '{report_path}'.") - except Exception: - logger.exception("Failed to save report due to an error.") - - # Release the inference session created during verification - onnx_program.release() - return onnx_program diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py deleted file mode 100644 index 3797a6d1fbd8e..0000000000000 --- a/torch/onnx/_internal/exporter/_decomp.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Build decomp table from PyTorch.""" - -# mypy: allow-untyped-defs -from __future__ import annotations - -from typing import Callable, TYPE_CHECKING - -import torch -import torch._ops - - -if TYPE_CHECKING: - from torch.onnx._internal.exporter import _registration - - -def get_onnx_implemented_overloads( - registry: _registration.ONNXRegistry, -) -> list[torch._ops.OperatorBase]: - """ - Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. - - Args: - registry: The ONNX registry for PyTorch. - - Returns: - A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. - """ - registered_ops: list[torch._ops.OperatorBase] = [] - for op_namespace in (torch.ops.aten, torch.ops.prims): - op_names = dir(op_namespace) - for op_name in op_names: - op_overload_packet = getattr(op_namespace, op_name) - if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): - continue - - for overload_name in op_overload_packet.overloads(): - op_overload = getattr(op_overload_packet, overload_name) - if registry.is_registered(op_overload): - registered_ops.append(op_overload) - return registered_ops - - -def create_onnx_friendly_decomposition_table( - registry, -) -> dict[torch._ops.OperatorBase, Callable]: - """ - This function creates a dictionary of op overloads and their decomposition functions - for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, - its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's - built-in aten-to-aten decomposition. - - Args: - registry: The ONNX registry for PyTorch. - - Returns: - Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding - decomposition functions. - """ - decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} - onnx_registered_ops = set(get_onnx_implemented_overloads(registry)) - - # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single - # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your - # definitions in a single TORCH_LIBRARY block. - for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined] - # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX - # symbolic function. - # NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is - # not exportable anyways. - if op_overload in onnx_registered_ops: - continue - decomposition_table[op_overload] = decomp_fn - - return decomposition_table diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py deleted file mode 100644 index b8aecfaa93793..0000000000000 --- a/torch/onnx/_internal/exporter/_dispatching.py +++ /dev/null @@ -1,345 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import logging -from typing import Sequence - -import onnxscript -from onnxscript import ir - -import torch -import torch.fx -from torch.onnx._internal.exporter import _registration, _schemas - - -logger = logging.getLogger(__name__) - -# Define utilities to convert PyTorch data types so users do not need to specify manually -_TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.DOUBLE, - torch.complex64: ir.DataType.FLOAT, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, -} - - -def _torch_dtype_to_onnx_compatible_dtype(dtype: torch.dtype) -> ir.DataType: - return _TORCH_DTYPE_TO_ONNX_COMPATIBLE[dtype] - - -def _attribute_type_compatible_with_arg( - attr: _schemas.AttributeParameter, - value: ir.Value | int | float | bool | Sequence[int] | Sequence[float] | None, -) -> bool: - """Check if the attribute type is compatible with the argument.""" - if isinstance(value, bool): - return attr.type is ir.AttributeType.INT - if isinstance(value, str): - return attr.type is ir.AttributeType.STRING - if isinstance(value, int): - return attr.type in {ir.AttributeType.INT, ir.AttributeType.FLOAT} - if isinstance(value, float): - return attr.type is ir.AttributeType.FLOAT - if isinstance(value, complex): - return False - if isinstance(value, Sequence): - if attr.type is ir.AttributeType.INTS: - return all(isinstance(i, int) for i in value) - if attr.type is ir.AttributeType.FLOATS: - return all(isinstance(i, (int, float)) for i in value) - if isinstance(value, torch.dtype): - return attr.type is ir.AttributeType.INT - if isinstance(value, (torch.device, torch.memory_format, torch.layout)): - return attr.type is ir.AttributeType.STRING - if value is None and not attr.required: - # An optional attribute is not supplied - return True - return False - - -def _param_type_compatible_with_arg( - param: _schemas.Parameter, - value: ir.TypeProtocol - | str - | int - | float - | complex - | Sequence[int] - | Sequence[float] - | None, - assigned_types: dict[str, ir.TypeProtocol], -) -> bool: - # Handle Python types first - if isinstance(value, bool): # noqa: SIM102 - if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}: - return True - if isinstance(value, int) and param.type_constraint.allowed_types & { - ir.TensorType(ir.DataType.INT4), - ir.TensorType(ir.DataType.INT8), - ir.TensorType(ir.DataType.INT16), - ir.TensorType(ir.DataType.INT32), - ir.TensorType(ir.DataType.INT64), - # Int inputs can be casted to a float too - ir.TensorType(ir.DataType.FLOAT8E4M3FN), - ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), - ir.TensorType(ir.DataType.FLOAT8E5M2), - ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), - ir.TensorType(ir.DataType.FLOAT16), - ir.TensorType(ir.DataType.FLOAT), - ir.TensorType(ir.DataType.DOUBLE), - }: - return True - if isinstance(value, float) and param.type_constraint.allowed_types & { - ir.TensorType(ir.DataType.FLOAT8E4M3FN), - ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), - ir.TensorType(ir.DataType.FLOAT8E5M2), - ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), - ir.TensorType(ir.DataType.FLOAT16), - ir.TensorType(ir.DataType.FLOAT), - ir.TensorType(ir.DataType.DOUBLE), - }: - return True - if isinstance(value, complex) and param.type_constraint.allowed_types & { - ir.TensorType(ir.DataType.FLOAT), - ir.TensorType(ir.DataType.DOUBLE), - ir.TensorType(ir.DataType.COMPLEX64), - ir.TensorType(ir.DataType.COMPLEX128), - }: - return True - if isinstance(value, str): # noqa: SIM102 - if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}: - return True - if isinstance(value, (list, tuple)): - if param.type_constraint.allowed_types & { - ir.TensorType(ir.DataType.INT32), - ir.TensorType(ir.DataType.INT64), - ir.TensorType(ir.DataType.FLOAT), - ir.TensorType(ir.DataType.DOUBLE), - ir.SequenceType(ir.TensorType(ir.DataType.INT32)), - ir.SequenceType(ir.TensorType(ir.DataType.INT64)), - ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), - ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), - } and all(isinstance(i, (int)) for i in value): - # We will just allow any fx node and trust that the overload handles it - return True - if param.type_constraint.allowed_types & { - ir.TensorType(ir.DataType.FLOAT), - ir.TensorType(ir.DataType.DOUBLE), - ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), - ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), - } and all(isinstance(i, (int, float)) for i in value): - # We will just allow any fx node and trust that the overload handles it - return True - if value is None and not param.required: - # An optional parameter is not supplied - return True - - if not isinstance(value, ir.TypeProtocol): - return False - - # Then check tensor types - if param.type_constraint.name in assigned_types: - # If a typevar is already bound, check if the value has the same type - assigned_type = assigned_types[param.type_constraint.name] - return assigned_type == value - # If the typevar is not bound, bind it to the value type - if value in param.type_constraint.allowed_types: - # TODO: Maybe just check dtype? Being more strict here for now - assigned_types[param.type_constraint.name] = value - return True - return False - - -def _get_type_from_tensor( - tensor: torch.Tensor | Sequence[torch.Tensor], -) -> ir.TypeProtocol: - if isinstance(tensor, torch.Tensor): - return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype)) - first_tensor = next((item for item in tensor if item is not None), None) - if first_tensor is None: - return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED)) - return ir.SequenceType( - ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(first_tensor.dtype)) - ) - - -def _get_first_tensor_in_node_list( - nodes: Sequence[torch.fx.Node | None], -) -> torch.Tensor | None: - for node in nodes: - if ( - node is not None - and "val" in node.meta - and isinstance(node.meta["val"], torch.Tensor) - ): - return node.meta["val"] - return None - - -def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]: - # FIXME: node.target may not have a schema - torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr] - node_args = {} - for arg, schema_arg in zip(node.args, torch_schema.arguments): - node_args[schema_arg.name] = arg - - node_args.update(node.kwargs) - return node_args - - -def get_matching_overload( - node: torch.fx.Node, - overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction], -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: - """Get the overload that matches the node's arguments. - - Args: - node: The node to match. - overloads: The overloads to match against. - - Returns: - A tuple containing the matched overload and a string describing the reason for failure or success. - """ - named_args = _get_named_fx_node_args(node) - # FIXME: node.target may and builtin and not have a schema - # FIXME: Handle when we don't know the names of the arguments - schema_args: dict[str, torch.Argument] = { - arg.name: arg - for arg in node.target._schema.arguments # type: ignore[union-attr] - } - failure_messages: list[str] = [] - for overload in overloads: - assigned_types: dict[str, ir.TypeProtocol] = {} - fail_reason = "" - if not hasattr(overload, "signature"): - # When an overload does not have a signature, we assume it is a custom op and should be matched - return ( - overload, - "The overload does not have a signature. Assuming it is a custom op and matching it.", - ) - for param in overload.signature: - if param.name not in schema_args and param.required: - # We don't need to handle variadic inputs as there is none. - # A required parameter is not supplied. - fail_reason = "Required parameter not supplied" - break - - # Get the argument - if param.name in named_args: - # Provided in Node args - arg = named_args[param.name] - elif ( - param.name in schema_args - and schema_args[param.name].has_default_value() - ): - # Provided in schema args - arg = schema_args[param.name].default_value - elif param.has_default(): - # Provided in the ONNX op definition - arg = param.default - else: - fail_reason = "Parameter not provided" - break - - if isinstance(param, _schemas.Parameter): - if isinstance(arg, torch.Tensor): - arg = _get_type_from_tensor(arg) # type: ignore[assignment] - if isinstance(arg, (list, tuple)) and any( - isinstance(t, torch.fx.Node) for t in arg - ): - first_tensor = _get_first_tensor_in_node_list(arg) - assert first_tensor is not None - # FIXME: Handle symfloat here - arg = ir.SequenceType(_get_type_from_tensor(first_tensor)) # type: ignore[assignment] - elif isinstance(arg, torch.fx.Node): - meta_val = arg.meta["val"] - arg = _get_type_from_tensor(meta_val) # type: ignore[assignment] - # TODO: Handle None attributes - # FIXME: Handle symfloat etc. - # Handle tensors and Python values - if not _param_type_compatible_with_arg(param, arg, assigned_types): # type: ignore[arg-type] - fail_reason = ( - f"Parameter type not compatible with argument: param=`{param}`, " - f"assigned_types=`{assigned_types}`, arg=`{arg}`" - ) - break - elif isinstance(param, _schemas.AttributeParameter): - if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type] - fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`" - break - if not fail_reason: - return overload, "Successfully matched overload" - else: - failure_messages.append( - f"- Failed to match overload `{overload}`: {fail_reason}" - ) - return ( - None, - f"All overloads did not match the node `{node.format_node()}`.\n" - + "\n".join(failure_messages), - ) - - -def _arg_has_complex_dtype(arg) -> bool: - """Check if the node has complex dtype recursively.""" - if ( - isinstance(arg, torch.fx.Node) - and "val" in arg.meta - and isinstance(arg.meta["val"], torch.Tensor) - and torch.is_complex(arg.meta["val"]) - ): - return True - elif isinstance(arg, list): - return any(_arg_has_complex_dtype(item) for item in arg) - return False - - -def dispatch( - node: torch.fx.Node, registry: _registration.ONNXRegistry -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: - """Dispatch a node to an ONNX function based on the node's target and the ONNX registry. - - Args: - node: The node to dispatch. - registry: The ONNX registry to use for dispatching. - - Returns: - A tuple containing the matched ONNX function and a string describing the reason for failure or success. - """ - # TODO: Handle when node does not have a target - decomp_metas = registry.get_decomps(node.target) # type: ignore[arg-type] - # Determine if the node has complex inputs. - is_complex = any(_arg_has_complex_dtype(arg) for arg in node.args) or any( - _arg_has_complex_dtype(arg) for arg in node.kwargs.values() - ) - if is_complex: - decomp_metas = [decomp for decomp in decomp_metas if decomp.is_complex] - if not decomp_metas: - return None, "No decompositions registered for the complex-valued input" - else: - decomp_metas = [decomp for decomp in decomp_metas if not decomp.is_complex] - if not decomp_metas: - return None, "No decompositions registered for the real-valued input" - - if len(decomp_metas) == 1: - return ( - decomp_metas[0].onnx_function, - "Fast path: Only one decomposition is defined", - ) - - overload, message = get_matching_overload( - node, [decomp.onnx_function for decomp in decomp_metas] - ) - return overload, message diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py deleted file mode 100644 index 2feae57b5d708..0000000000000 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ /dev/null @@ -1,72 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import torch -import torch.export -import torch.fx -from torch.onnx._internal.exporter import _decomp, _registration -from torch.onnx._internal.fx import diagnostics, passes - - -_ATEN_ASSERTION_TARGETS = frozenset( - { - torch.ops.aten.sym_constrain_range_for_size.default, - torch.ops.aten._assert_async.msg, - } -) - - -def decompose_with_registry( - exported_program: torch.export.ExportedProgram, registry: _registration.ONNXRegistry -) -> torch.export.ExportedProgram: - """Decompose the exported program with the given registry. - - This function is needed so it shows clearly on the profiler results. - """ - decomp_table = _decomp.create_onnx_friendly_decomposition_table(registry) - onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) - # Try to preserve some known CompositeImplicitAutograd ops - aten = torch.ops.aten - to_preserve = { - aten._upsample_bilinear2d_aa.default, - aten._upsample_nearest_exact1d.vec, - aten._upsample_nearest_exact2d.vec, - aten._upsample_nearest_exact3d.vec, - aten.group_norm.default, - aten.linear.default, - aten.upsample_bilinear2d.default, - aten.upsample_bilinear2d.vec, - aten.upsample_linear1d.default, - aten.upsample_linear1d.vec, - aten.upsample_nearest1d.default, - aten.upsample_nearest1d.vec, - aten.upsample_nearest2d.default, - aten.upsample_nearest2d.vec, - aten.upsample_nearest3d.default, - aten.upsample_nearest3d.vec, - aten.upsample_trilinear3d.default, - aten.upsample_trilinear3d.vec, - } - # We can only preserve implemented ops - can_preserve = tuple(to_preserve.intersection(onnx_registered_ops)) - return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve) - - -def insert_type_promotion_nodes( - graph_module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """Inplace pass to insert explicit type promotion nodes.""" - diagnostic_context = diagnostics.DiagnosticContext( - "torch.onnx.export", - torch.__version__, - ) - return passes.InsertTypePromotion(diagnostic_context, graph_module).run() - - -def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Remove all assertion and check nodes from the FX graph""" - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target in _ATEN_ASSERTION_TARGETS: - graph_module.graph.erase_node(node) - graph_module.recompile() - return graph_module diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py deleted file mode 100644 index 7e8748443e2b1..0000000000000 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ /dev/null @@ -1,41 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import logging -from typing import Sequence - -from onnxscript import ir - - -logger = logging.getLogger(__name__) - - -def rename_inputs(model: ir.Model, new_names: Sequence[str]) -> None: - # TODO: Ensure the names do not have duplicates - for input, new_name in zip(model.graph.inputs, new_names): - input.metadata_props["pkg.torch.onnx.original_node_name"] = str(input.name) - input.name = new_name - - -def rename_outputs(model: ir.Model, new_names: Sequence[str]) -> None: - for output, new_name in zip(model.graph.outputs, new_names): - output.metadata_props["pkg.torch.onnx.original_node_name"] = str(output.name) - output.name = new_name - - -def add_torchlib_common_imports(model: ir.Model) -> None: - """Hack to add torchlib common imports to the model.""" - - try: - # TODO(justinchuby): Remove this hack and improved onnxscript - from onnxscript.function_libs.torch_lib.ops import common as common_ops - - model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 - rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) - is_scalar_func = ir.serde.deserialize_function( - common_ops.IsScalar.to_function_proto() - ) - model.functions[rank_func.identifier()] = rank_func - model.functions[is_scalar_func.identifier()] = is_scalar_func - except Exception: - logger.exception("Failed to add torchlib common imports to the model.") diff --git a/torch/onnx/_internal/exporter/_isolated.py b/torch/onnx/_internal/exporter/_isolated.py deleted file mode 100644 index 4a5c5fcdf793c..0000000000000 --- a/torch/onnx/_internal/exporter/_isolated.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Isolated calls to methods that may segfault.""" - -# mypy: allow-untyped-defs -from __future__ import annotations - -import multiprocessing -import os -import warnings -from typing import Callable - - -_IS_WINDOWS = os.name == "nt" - - -def _call_function_and_return_exception(func, args, kwargs): - """Call function and return a exception if there is one.""" - - try: - return func(*args, **kwargs) - except Exception as e: - return e - - -def safe_call(func: Callable, *args, **kwargs): - """Call a function in a separate process. - - Args: - func: The function to call. - args: The positional arguments to pass to the function. - kwargs: The keyword arguments to pass to the function. - - Returns: - The return value of the function. - - Raises: - Exception: If the function raised an exception. - """ - if _IS_WINDOWS: - # On Windows, we cannot create a new process with fork. - warnings.warn( - f"A new process is not created for {func} on Windows.", stacklevel=1 - ) - return func(*args, **kwargs) - - with multiprocessing.get_context("fork").Pool(1) as pool: - # It is important to fork a process here to prevent the main logic from - # running again when the user does not place it under a `if __name__ == "__main__":` - # block. - result = pool.apply_async( - _call_function_and_return_exception, (func, args, kwargs) - ) - result = result.get(timeout=5) - if isinstance(result, Exception): - raise result - return result diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py deleted file mode 100644 index 51e20207877be..0000000000000 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ /dev/null @@ -1,288 +0,0 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code="attr-defined,name-defined" -from __future__ import annotations - - -__all__ = ["ONNXProgram"] - -import gc -import logging -import os -import pathlib -import tempfile -import textwrap -from typing import Callable, IO, Sequence, TYPE_CHECKING - -import torch -from torch.onnx._internal import _lazy_import -from torch.utils import _pytree as pytree - - -onnx = _lazy_import.onnx -ir = _lazy_import.onnxscript_ir - - -if TYPE_CHECKING: - import onnxruntime as ort - -logger = logging.getLogger(__name__) - - -def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: - """Initialize an ONNX Runtime inference session with the specified model.""" - import onnxruntime as ort - - session_options = ort.SessionOptions() - session_options.log_severity_level = 3 # 3: Error - possible_providers = ( - "CUDAExecutionProvider", - "CPUExecutionProvider", - ) - available_providers = set(ort.get_available_providers()) - providers = [ - provider for provider in possible_providers if provider in available_providers - ] - return ort.InferenceSession( - model, providers=providers, sess_options=session_options - ) - - -class ONNXProgram: - """A substitute class for `torch.onnx.ONNXProgram`.""" - - def __init__(self, model: ir.Model, exported_program: torch.export.ExportedProgram): - self.model: ir.Model = model - self.exported_program = exported_program - self._inference_session: ort.InferenceSession | None = None - self._tempdir: tempfile.TemporaryDirectory | None = None - - def __repr__(self) -> str: - return f"""\ -ONNXProgram( - model= -{textwrap.indent(str(self.model), ' ' * 8)} - , - exported_program= -{textwrap.indent(str(self.exported_program), ' ' * 8)} -) -""" - - def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]: - """Run the ONNX model with the same arguments you would provide to the GraphModule.""" - import onnxruntime as ort - - flatten_args = _process_args(args, kwargs) - - if self._inference_session is None: - self.initialize_inference_session() - - assert self._inference_session is not None - - # We don't expect non-tensor as inputs - ort_input = { - k.name: v.numpy(force=True) - for k, v in zip(self.model.graph.inputs, flatten_args) - } - run_options = ort.RunOptions() - run_options.log_severity_level = 3 # 3: Error - logger.debug("Running the inference session with %s arguments.", len(ort_input)) - outputs = self._inference_session.run(None, ort_input, run_options=run_options) - logger.debug("Inference session run completed.") - # TODO(justinchuby): Maybe output complex tensors as needed - return tuple(torch.from_numpy(output) for output in outputs) - - @property - def model_proto(self) -> onnx.ModelProto: - """Compatibility property for `torch.onnx.ONNXProgram.model_proto`.""" - return ir.serde.serialize_model(self.model) - - def save( - self, - destination: str | os.PathLike | IO[bytes], - *, - include_initializers: bool = True, - keep_initializers_as_inputs: bool = False, - external_data: bool | None = None, - **_, - ): - """Save the ONNX model to the specified destination. - - When `external_data` is `True` or the model is larger than 2GB, - the weights are saved as external data in a separate file. - - Args: - destination: The path to save the ONNX model to. - include_initializers: Whether to include the initializers in the saved model. - keep_initializers_as_inputs: Whether to keep the initializers as inputs in the saved model. - If `True`, the initializers are added as inputs to the model which means they can be overwritten. - by providing the initializers as model inputs. - external_data: Whether to save the weights as external data in a separate file. - - Raises: - TypeError: If `external_data` is `True` and `destination` is not a file path. - """ - if not include_initializers: - self.model.graph.initializers.clear() - logger.warning( - "The initializers have been removed from the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) - if keep_initializers_as_inputs: - self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type] - logger.warning( - "The initializers have been added as inputs to the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - if external_data or model_too_large: - # TODO: Create an IR pass to handle external tensors conversion - if model_too_large: - logger.warning( - "The serialized ONNX model is larger than 2GB (%s). " - "Saving the weights as external data in a separate file.", - byte_size, - ) - if not isinstance(destination, (str, os.PathLike)): - raise TypeError( - "Saving the weights as external data is only supported when destination is a file path" - ) - destination_path = pathlib.Path(destination) - # Create the directory if it does not exist - data_path = f"{destination_path.name}.data" - onnx.save_model( - proto, - destination, - save_as_external_data=True, - location=data_path, - ) - else: - onnx.save_model(proto, destination) - - def initialize_inference_session( - self, - initializer: Callable[ - [str | bytes], ort.InferenceSession - ] = _ort_session_initializer, - ) -> None: - """Initialize the ONNX Runtime inference session. - - Args: - initializer: The function to initialize the ONNX Runtime inference - session with the specified model. By default, it uses the - :func:`_ort_session_initializer` function. - """ - # TODO(justinchuby): Allow different inference options - logger.debug("Initializing the inference session.") - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - - if model_too_large: - logger.debug( - "The serialized ONNX model is larger than 2GB (%s).", byte_size - ) - # Save the model to a temporary file if too large - self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) - model_path = os.path.join(self._tempdir.name, "model.onnx") - data_path = "model.onnx.data" - onnx.save_model( - proto, - model_path, - save_as_external_data=True, - location=data_path, - ) - model = model_path - else: - model = proto.SerializeToString() # type: ignore[assignment] - - self._inference_session = initializer(model) - logger.debug("Inference session initialized.") - - def release(self) -> None: - """Release the inference session. - - You may call this method to release the resources used by the inference session. - """ - # Release the inference session first so that the model file can be deleted - if self._inference_session is not None: - self._inference_session = None - gc.collect() - if self._tempdir is not None: - self._tempdir.cleanup() - self._tempdir = None - - -def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: - """Process input arguments for the ONNX model.""" - args = _flatten_inputs(args, kwargs) - args = _remove_none_from_inputs(args) - args = _remove_non_tensor(args) - args = _convert_complex_to_real_representation(args) - return args - - -def _flatten_inputs(model_args, model_kwargs): - flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs)) - return flattened_args - - -def _remove_none_from_inputs(model_args): - return tuple(arg for arg in model_args if arg is not None) - - -def _remove_non_tensor(model_args): - """Remove the non-tensor input arguments. - - Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). - - Specifically, it does put the input into graph with an empty node, but consumed by no ones. - The concrete value is embedded into the graph as a constant arg of a target node. Meta - suggests in this case that one should rewrite the model code to make it tensor if the - input value is supposed to change at runtime. We might need to further investigate - the feasibility of that suggestion. - - For example, - - def func(x, b=1.0): - y = x + b - z = y.relu() - return (y, z) - - x = torch.randn(1, 1, 2, dtype=torch.float32) - gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") - - # class GraphModule(torch.nn.Module): - # def forward(self, x, b): - # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) - # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b - # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None - - # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() - # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) - # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) - - Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as - it's ignored in ONNX graph. Thus, we delete the useless input here. - - """ - - return tuple( - arg for arg in model_args if not isinstance(arg, (int, float, bool, str)) - ) - - -def _convert_complex_to_real_representation(model_args): - """Convert complex dtype tensors to real representation tensors. - - ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors - to real representation tensors (i.e., float dtype tensors with an extra dimension - representing the real and imaginary parts of the complex number). - """ - return tuple( - torch.view_as_real(arg.resolve_conj()) - if isinstance(arg, torch.Tensor) and arg.is_complex() - else arg - for arg in model_args - ) diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py deleted file mode 100644 index b649188c264ea..0000000000000 --- a/torch/onnx/_internal/exporter/_registration.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Module for handling ATen to ONNX functions registration. - -https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py -""" - -# NOTE: Why do we need a different registry than the one in torchlib? -# The registry in torchlib is used to register functions that are already implemented in -# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different -# opsets etc. The registry implemented for the exporter is designed to be modifiable at -# export time by users, and is designed with dispatching in mind. - -# mypy: allow-untyped-defs -from __future__ import annotations - -import dataclasses -import logging -import math -import operator -import types -import typing -from typing import Callable, Literal, Mapping, Union -from typing_extensions import TypeAlias - -import torch -import torch._ops -from torch.onnx._internal.exporter import _schemas - - -if typing.TYPE_CHECKING: - import onnxscript - from onnxscript.function_libs.torch_lib import registration as torchlib_registration - -_DEFAULT_OPSET_VERSION = 18 - - -TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable] - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass(frozen=True) -class OnnxDecompMeta: - """A wrapper of onnx-script function with additional metadata. - - onnx_function: The onnx-script function from torchlib. - fx_target: The PyTorch node callable target. - is_custom: Whether the function is a custom function. - is_complex: Whether the function is a function that handles complex valued inputs. - device: The device the function is registered to. If None, it is registered to all devices. - """ - - onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction - fx_target: TorchOp - is_custom: bool = False - is_complex: bool = False - device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 - - -def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None: - """Obtain the torch op from ::[.]""" - # TODO(justinchuby): Handle arbitrary custom ops - namespace, opname_overload = qualified_name.split("::") - op_name, *maybe_overload = opname_overload.split(".", 1) - if namespace == "_operator": - # Builtin functions - return getattr(operator, op_name) - if namespace == "math": - return getattr(math, op_name) - if namespace == "torchvision": - try: - import torchvision.ops # type: ignore[import-untyped] - except ImportError: - logger.warning("torchvision is not installed. Skipping %s", qualified_name) - return None - try: - return getattr(torchvision.ops, op_name) - except AttributeError: - logger.warning("Failed to find torchvision op '%s'", qualified_name) - return None - except Exception: - logger.exception("Failed to find torchvision op '%s'", qualified_name) - try: - op_packet = getattr(getattr(torch.ops, namespace), op_name) - if maybe_overload: - overload = maybe_overload[0] - elif "default" in op_packet._overload_names or "" in op_packet._overload_names: - # Has a default overload - overload = "default" - else: - logger.warning( - "'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.", - qualified_name, - stacklevel=1, - ) - return None - - return getattr(op_packet, overload) # type: ignore[call-overload] - except AttributeError: - if qualified_name.endswith("getitem"): - # This is a special case where we registered the function incorrectly, - # but for BC reasons (pt<=2.4) we need to keep it. - return None - logger.info("'%s' is not found in this version of PyTorch.", qualified_name) - return None - except Exception: - logger.exception("Failed to find torch op '%s'", qualified_name) - return None - - -class ONNXRegistry: - """Registry for ONNX functions. - - The registry maintains a mapping from qualified names to symbolic functions under a - fixed opset version. It supports registering custom onnx-script functions and for - dispatcher to dispatch calls to the appropriate function. - - """ - - def __init__(self) -> None: - """Initializes the registry""" - - # TODO: Design multi-opset version support - self._opset_version = _DEFAULT_OPSET_VERSION - - self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {} - - @property - def opset_version(self) -> int: - """The ONNX opset version the exporter should target. - - Defaults to the latest supported ONNX opset version: 18. - The default version will increment over time as ONNX continues to evolve. - """ - - return self._opset_version - - @classmethod - def from_torchlib( - cls, - torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction] - | None = None, - ) -> ONNXRegistry: - """Populates the registry with ATen functions from torchlib. - - Args: - torchlib_registry: The torchlib registry to use for populating the registry. - """ - registry = cls() - if torchlib_registry is None: - from onnxscript.function_libs.torch_lib import ( - registration as torchlib_registration, - ) - - torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment] - for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr] - try: - # NOTE: This is heavily guarded with try-except because we don't want - # to fail the entire registry population if one function fails. - if qualified_name.startswith("internal::"): - # Skip the custom defined internal functions - continue - target = _get_overload(qualified_name) - if target is None: - continue - for overload_func in aten_overloads_func.overloads: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=overload_func, - fx_target=target, - is_custom=False, - is_complex=False, - ) - registry._register(target, onnx_decomposition) - - for complex_func in aten_overloads_func.complex: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=complex_func, - fx_target=target, - is_custom=False, - is_complex=True, - ) - registry._register(target, onnx_decomposition) - except Exception: - logger.exception("Failed to register '%s'. Skipped", qualified_name) - continue - return registry - - def _register( - self, - target: TorchOp, - onnx_decomposition: OnnxDecompMeta, - ) -> None: - """Registers a OnnxDecompMeta to an operator. - - Args: - target: The PyTorch node callable target. - onnx_decomposition: The OnnxDecompMeta to register. - """ - target_or_name: str | TorchOp - if isinstance(target, torch._ops.OpOverload): - # Get the qualified name of the aten op because torch._ops.OpOverload lookup in - # a dictionary is unreliable for some reason. - target_or_name = target.name() - else: - target_or_name = target - if onnx_decomposition.is_custom: - self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition) - else: - self.functions.setdefault(target_or_name, []).append(onnx_decomposition) - - def register_op( - self, - target: TorchOp, - function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - is_complex: bool = False, - ) -> None: - """Registers a custom operator: torch.ops.... - - Args: - target: The PyTorch node callable target. - function: The onnx-script function to register. - is_complex: Whether the function is a function that handles complex valued inputs. - """ - onnx_decomposition = OnnxDecompMeta( - onnx_function=function, - fx_target=target, - is_custom=True, - is_complex=is_complex, - ) - self._register(target, onnx_decomposition) - - def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]: - """Returns a list of OnnxDecompMeta for the given op: torch.ops.... - - The list is ordered by the time of registration. The custom operators should come - first in the list. - - Args: - target: The PyTorch node callable target. - Returns: - A list of OnnxDecompMeta corresponding to the given name, or None if - the name is not in the registry. - """ - target_or_name: str | TorchOp - if isinstance(target, torch._ops.OpOverload): - # Get the qualified name of the aten op because torch._ops.OpOverload lookup in - # a dictionary is unreliable for some reason. - target_or_name = target.name() - else: - target_or_name = target - decomps = self.functions.get(target_or_name, []) - return sorted(decomps, key=lambda x: x.is_custom, reverse=True) - - def is_registered(self, target: TorchOp) -> bool: - """Returns whether the given op is registered: torch.ops.... - - Args: - target: The PyTorch node callable target. - - Returns: - True if the given op is registered, otherwise False. - """ - return bool(self.get_decomps(target)) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/torch/onnx/_internal/exporter/_reporting.py b/torch/onnx/_internal/exporter/_reporting.py deleted file mode 100644 index 55a77a90ec4b9..0000000000000 --- a/torch/onnx/_internal/exporter/_reporting.py +++ /dev/null @@ -1,193 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import dataclasses -import re -from typing import TYPE_CHECKING - -from torch.onnx._internal.exporter import _analysis, _registration, _verification - - -if TYPE_CHECKING: - import os - - from onnxscript import ir - - import torch - - -@dataclasses.dataclass -class ExportStatus: - # Whether torch.export.export.export() succeeds - torch_export: bool | None = None - # Whether torch.export.export.export(..., strict=False) succeeds - torch_export_non_strict: bool | None = None - # Whether torch.jit.trace succeeds - torch_jit: bool | None = None - # Whether ONNX translation succeeds - onnx_translation: bool | None = None - # Whether ONNX model passes onnx.checker.check_model - onnx_checker: bool | None = None - # Whether ONNX model runs successfully with ONNX Runtime - onnx_runtime: bool | None = None - # Whether the output of the ONNX model is accurate - output_accuracy: bool | None = None - - -def _status_emoji(status: bool | None) -> str: - if status is None: - return "⚪" - return "✅" if status else "❌" - - -def _format_export_status(status: ExportStatus) -> str: - return ( - f"```\n" - f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n" - f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n" - f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n" - f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n" - f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n" - f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n" - f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n" - f"```\n\n" - ) - - -def _strip_color_from_string(text: str) -> str: - # This regular expression matches ANSI escape codes - # https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788 - ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") - return ansi_escape.sub("", text) - - -def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str: - # Adapted from https://github.com/pytorch/pytorch/pull/128476 - # to remove colors - # Even though we can call graph_module.print_readable directly, since the - # colored option was added only recently, we can't guarantee that the - # version of PyTorch used by the user has this option. Therefore, we - # still call str(ExportedProgram) - text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n" - return text - - -def construct_report_file_name(timestamp: str, status: ExportStatus) -> str: - # Status could be None. So we need to check for False explicitly. - if not (status.torch_export or status.torch_export_non_strict or status.torch_jit): - # All strategies failed - postfix = "pt_export" - elif status.onnx_translation is False: - postfix = "conversion" - elif status.onnx_checker is False: - postfix = "checker" - elif status.onnx_runtime is False: - postfix = "runtime" - elif status.output_accuracy is False: - postfix = "accuracy" - elif status.torch_export is False or status.torch_export_non_strict is False: - # Some strategies failed - postfix = "strategies" - else: - postfix = "success" - return f"onnx_export_{timestamp}_{postfix}.md" - - -def format_decomp_comparison( - pre_decomp_unique_ops: set[str], - post_decomp_unique_ops: set[str], -) -> str: - """Format the decomposition comparison result. - - Args: - unique_ops_in_a: The unique ops in the first program. - unique_ops_in_b: The unique ops in the second program. - - Returns: - The formatted comparison result. - """ - return ( - f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n" - f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n" - ) - - -def format_verification_infos( - verification_infos: list[_verification.VerificationInfo], -) -> str: - """Format the verification result. - - Args: - verification_infos: The verification result. - - Returns: - The formatted verification result. - """ - return "\n".join( - f"`{info.name}`: `abs_diff={info.absolute_difference:e}`, `rel_diff={info.relative_difference:e}`" - for info in verification_infos - ) - - -def create_torch_export_error_report( - filename: str | os.PathLike, - formatted_traceback: str, - *, - export_status: ExportStatus, - profile_result: str | None, -): - with open(filename, "w", encoding="utf-8") as f: - f.write("# PyTorch ONNX Conversion Error Report\n\n") - f.write(_format_export_status(export_status)) - f.write("Error message:\n\n") - f.write("```pytb\n") - f.write(formatted_traceback) - f.write("```\n\n") - if profile_result is not None: - f.write("## Profiling result\n\n") - f.write("```\n") - f.write(profile_result) - f.write("```\n") - - -def create_onnx_export_report( - filename: str | os.PathLike, - formatted_traceback: str, - program: torch.export.ExportedProgram, - *, - decomp_comparison: str | None = None, - export_status: ExportStatus, - profile_result: str | None, - model: ir.Model | None = None, - registry: _registration.ONNXRegistry | None = None, - verification_result: str | None = None, -): - with open(filename, "w", encoding="utf-8") as f: - f.write("# PyTorch ONNX Conversion Report\n\n") - f.write(_format_export_status(export_status)) - f.write("## Error messages\n\n") - f.write("```pytb\n") - f.write(formatted_traceback) - f.write("\n```\n\n") - f.write("## Exported program\n\n") - f.write(_format_exported_program(program)) - if model is not None: - f.write("## ONNX model\n\n") - f.write("```python\n") - f.write(str(model)) - f.write("\n```\n\n") - f.write("## Analysis\n\n") - _analysis.analyze(program, file=f, registry=registry) - if decomp_comparison is not None: - f.write("\n## Decomposition comparison\n\n") - f.write(decomp_comparison) - f.write("\n") - if verification_result is not None: - f.write("\n## Verification results\n\n") - f.write(verification_result) - f.write("\n") - if profile_result is not None: - f.write("\n## Profiling result\n\n") - f.write("```\n") - f.write(profile_result) - f.write("```\n") diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py deleted file mode 100644 index 8ad10cd7a871e..0000000000000 --- a/torch/onnx/_internal/exporter/_schemas.py +++ /dev/null @@ -1,548 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import collections.abc -import dataclasses -import inspect -import logging -import types -import typing -from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union - -import onnx - -import onnxscript -from onnxscript import ir - - -logger = logging.getLogger(__name__) - - -# A special value to indicate that the default value is not specified -class _Empty: - def __repr__(self): - return "_EMPTY_DEFAULT" - - -_EMPTY_DEFAULT = _Empty() - -# Map from python type to corresponding ONNX AttributeProto type -_PY_TYPE_TO_ATTR_TYPE = { - float: ir.AttributeType.FLOAT, - int: ir.AttributeType.INT, - str: ir.AttributeType.STRING, - bool: ir.AttributeType.INT, - ir.Tensor: ir.AttributeType.TENSOR, - ir.TensorProtocol: ir.AttributeType.TENSOR, - ir.Graph: ir.AttributeType.GRAPH, - ir.GraphProtocol: ir.AttributeType.GRAPH, -} - -# Map from python type to corresponding ONNX AttributeProto type, -# for repeated (i.e., list of) values -_LIST_TYPE_TO_ATTR_TYPE = { - float: ir.AttributeType.FLOATS, - int: ir.AttributeType.INTS, - str: ir.AttributeType.STRINGS, - bool: ir.AttributeType.INTS, - ir.Tensor: ir.AttributeType.TENSORS, - ir.TensorProtocol: ir.AttributeType.TENSORS, - ir.Graph: ir.AttributeType.GRAPHS, - ir.GraphProtocol: ir.AttributeType.GRAPHS, -} - -_ALL_VALUE_TYPES = ( - {ir.TensorType(dtype) for dtype in ir.DataType} - | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} - | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} -) - -# TypeAnnotationValue represents the (value of) valid type-annotations recognized -# by ONNX Script. Currently, it supports -# - float, int, str (primitive attribute types) -# - Sequence[float], Sequence[int], Sequence[str] (attribute types) -# - Tensor types -# - Sequence[Tensor] types -# - Union of above 2 -# - TypeVars with above bounds -# - Above types with annotation attached -TypeAnnotationValue = Any - - -@dataclasses.dataclass(frozen=True) -class TypeConstraintParam: - """Type constraint for a parameter. - - Attributes: - name: Name of the parameter. E.g. "TFloat" - allowed_types: Allowed types for the parameter. - """ - - name: str - allowed_types: set[ir.TypeProtocol] - description: str = "" - - def __hash__(self) -> int: - return hash((self.name, tuple(self.allowed_types))) - - def __str__(self) -> str: - allowed_types_str = " | ".join(str(t) for t in self.allowed_types) - return f"{self.name}={allowed_types_str}" - - @classmethod - def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: - return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) - - @classmethod - def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: - return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] - - -@dataclasses.dataclass(frozen=True) -class Parameter: - """A formal parameter of an operator.""" - - name: str - type_constraint: TypeConstraintParam - required: bool - variadic: bool - default: Any = _EMPTY_DEFAULT - # TODO: Add other properties too - - def __str__(self) -> str: - type_str = self.type_constraint.name - if self.has_default(): - return f"{self.name}: {type_str} = {self.default}" - return f"{self.name}: {type_str}" - - def has_default(self) -> bool: - return self.default is not _EMPTY_DEFAULT - - -@dataclasses.dataclass(frozen=True) -class AttributeParameter: - name: str - type: ir.AttributeType - required: bool - default: ir.Attr | None = None - - def __str__(self) -> str: - type_str = self.type.name - if self.has_default(): - return f"{self.name}: {type_str} = {self.default}" - return f"{self.name}: {type_str}" - - def has_default(self) -> bool: - return self.default is not None - - -def _get_type_from_str( - type_str: str, -) -> ir.TensorType | ir.SequenceType | ir.OptionalType: - """Converter a type_str from ONNX Opschema to ir.TypeProtocol. - - A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". - """ - - # TODO: Upstream this to IR - - # Split the type_str a sequence types and dtypes - # 1. Remove the ending ")" - striped = type_str.rstrip(")") - # 2. Split the type_str by "(" - type_parts = striped.split("(") - - # Convert the dtype to ir.DataType - dtype = ir.DataType[type_parts[-1].upper()] - - # Create a place holder type first - type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) - - # Construct the type - for type_part in reversed(type_parts[:-1]): - if type_part == "tensor": - type_ = ir.TensorType(dtype) - elif type_part == "seq": - type_ = ir.SequenceType(type_) - elif type_part == "optional": - type_ = ir.OptionalType(type_) - else: - raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") - return type_ # type: ignore[return-value] - - -def _convert_formal_parameter( - param: onnx.defs.OpSchema.FormalParameter, - type_constraints: Mapping[str, TypeConstraintParam], -) -> Parameter: - """Convert a formal parameter from ONNX Opschema to Parameter.""" - if param.type_str in type_constraints: - type_constraint = type_constraints[param.type_str] - else: - # param.type_str can be a plain type like 'int64'. - type_constraint = TypeConstraintParam( - name=param.name, - allowed_types={_get_type_from_str(param.type_str)}, - ) - return Parameter( - name=param.name, - type_constraint=type_constraint, - required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, - variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, - ) - - -def _is_optional(type_: type) -> bool: - """Returns whether a type_ is an Optional.""" - origin_type = typing.get_origin(type_) - if origin_type is Union and type(None) in typing.get_args(type_): - # Python < 3.10 - return True - if origin_type is Optional: - # Python >= 3.10 - return True - if ( - hasattr(types, "UnionType") - and origin_type is types.UnionType - and type(None) in typing.get_args(type_) - ): - # Python >= 3.10 - return True - return False - - -def _get_attr_type(type_: type) -> ir.AttributeType: - """Obtain the type of the attribute from a Python class.""" - try: - if type_ in _PY_TYPE_TO_ATTR_TYPE: - return _PY_TYPE_TO_ATTR_TYPE[type_] - origin_type = typing.get_origin(type_) - if origin_type is None: - return ir.AttributeType.UNDEFINED - if origin_type in ( - collections.abc.Sequence, - Sequence, - typing.List, - list, - typing.Tuple, - tuple, - ): - inner_type = typing.get_args(type_)[0] - if inner_type in _LIST_TYPE_TO_ATTR_TYPE: - return _LIST_TYPE_TO_ATTR_TYPE[inner_type] - except TypeError: - logger.warning("TypeError when checking %s.", type_, exc_info=True) - return ir.AttributeType.UNDEFINED - - -def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: - """Returns the name of the type constraint for a given type annotation. - - Args: - type_: A Python type. - - Returns: - The name of the type constraint if it is a TypeVar. - - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. - """ - if isinstance(type_, TypeVar): - return type_.__name__ - if _is_optional(type_): - subtypes = typing.get_args(type_) - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = _get_type_constraint_name(subtype) - return type_param_name if type_param_name else None - origin_type = typing.get_origin(type_) - if isinstance(origin_type, type) and issubclass(origin_type, Sequence): - subtypes = typing.get_args(type_) - type_param_name = _get_type_constraint_name(subtypes[0]) - return f"Sequence_{type_param_name}" if type_param_name else None - return None - - -def _get_allowed_types_from_type_annotation( - type_: TypeAnnotationValue, -) -> set[ir.TypeProtocol]: - """Obtain the allowed types from a type annotation.""" - if type_ is onnxscript.onnx_types.TensorType: - # Any tensor type - return {ir.TensorType(dtype) for dtype in ir.DataType} - - allowed_types: set[ir.TypeProtocol] - - if isinstance(type_, TypeVar): - allowed_types = set() - if constraints := type_.__constraints__: - for constraint in constraints: - allowed_types.update( - _get_allowed_types_from_type_annotation(constraint) - ) - else: - bound = type_.__bound__ - if bound is None: - allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] - else: - allowed_types.update(_get_allowed_types_from_type_annotation(bound)) - return allowed_types - if hasattr(type_, "dtype"): - # A single tensor type like INT64, FLOAT, etc. - return {ir.TensorType(ir.DataType(type_.dtype))} - if _is_optional(type_): - allowed_types = set() - subtypes = typing.get_args(type_) - for subtype in subtypes: - if subtype is type(None): - continue - allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) - # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. - return allowed_types - - origin_type = typing.get_origin(type_) - if origin_type is Union: - allowed_types = set() - subtypes = typing.get_args(type_) - for subtype in subtypes: - assert subtype is not type( - None - ), "Union should not contain None type because it is handled by _is_optional." - allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) - return allowed_types - - if isinstance(origin_type, type) and issubclass(origin_type, Sequence): - subtypes = typing.get_args(type_) - return { - ir.SequenceType(t) - for t in _get_allowed_types_from_type_annotation(subtypes[0]) - } - - # Allow everything by default - return _ALL_VALUE_TYPES # type: ignore[return-value] - - -@dataclasses.dataclass -class OpSignature: - """Schema for an operator. - - Attributes: - domain: Domain of the operator. E.g. "". - name: Name of the operator. E.g. "Add". - overload: Overload name of the operator. - params: Input parameters. When the op is an ONNX function definition, - the order is according to the function signature. This mean we can - interleave ONNX inputs and ONNX attributes in the list. - outputs: Output parameters. - """ - - domain: str - name: str - overload: str - params: Sequence[Parameter | AttributeParameter] - outputs: Sequence[Parameter] - params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( - init=False, repr=False - ) - - def __post_init__(self): - self.params_map = {param.name: param for param in self.params} - - def get(self, name: str) -> Parameter | AttributeParameter: - return self.params_map[name] - - def __contains__(self, name: str) -> bool: - return name in self.params_map - - def __iter__(self) -> Iterator[Parameter | AttributeParameter]: - return iter(self.params) - - def __str__(self) -> str: - domain = self.domain or "''" - # TODO: Double check the separator for overload - overload = f"::{self.overload}" if self.overload else "" - params = ", ".join(str(param) for param in self.params) - outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) - type_constraints = {} - for param in self.params: - if isinstance(param, Parameter): - type_constraints[param.type_constraint.name] = param.type_constraint - for param in self.outputs: - type_constraints[param.type_constraint.name] = param.type_constraint - type_constraints_str = ", ".join( - str(type_constraint) for type_constraint in type_constraints.values() - ) - return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" - - @classmethod - def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature: - """Produce an OpSignature from an ONNX Opschema.""" - type_constraints = { - constraint.type_param_str: TypeConstraintParam( - name=constraint.type_param_str, - allowed_types={ - _get_type_from_str(type_str) - for type_str in constraint.allowed_type_strs - }, - description=constraint.description, - ) - for constraint in opschema.type_constraints - } - - params = [ - _convert_formal_parameter(param, type_constraints) - for param in opschema.inputs - ] - - for param in opschema.attributes.values(): - default_attr = ( - ir.serde.deserialize_attribute(param.default_value) - if param.default_value is not None - else None - ) - if default_attr is not None: - # Set the name of the default attribute because it may have a different name from the parameter - default_attr.name = param.name - params.append( - AttributeParameter( - name=param.name, - type=ir.AttributeType(param.type), # type: ignore[arg-type] - required=param.required, - default=default_attr, # type: ignore[arg-type] - ) - ) - - outputs = [ - _convert_formal_parameter(param, type_constraints) - for param in opschema.outputs - ] - - return cls( - domain=opschema.domain, - name=opschema.name, - overload="", - params=params, - outputs=outputs, - ) - - @classmethod - def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "" - ) -> OpSignature: - """Produce an OpSignature from a function using type annotation.""" - - py_signature = inspect.signature(func) - # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases - # https://github.com/python/cpython/issues/102405 - type_hints = typing.get_type_hints(func) - - params = [] - # Create a mapping from type to a unique name - type_constraints: dict[str, TypeConstraintParam] = {} - - for param in py_signature.parameters.values(): - if param.name not in type_hints: - logger.warning( - "Missing annotation for parameter '%s' from %s. Treating as an Input.", - param.name, - py_signature, - ) - type_constraints[param.name] = TypeConstraintParam.any_value( - f"T_{param.name}" - ) - else: - type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: - # Construct the default attribute - if param.default is not inspect.Parameter.empty: - # TODO: Use ir_convenience instead to handle int as float - default = ir.Attr(param.name, attr_type, param.default) - else: - default = None - params.append( - AttributeParameter( - name=param.name, - type=attr_type, - required=param.default is inspect.Parameter.empty, - default=default, - ) - ) - else: - # Obtain the type constraint from the type annotation - - # 1. Get a type constraint name from the type annotation - # If the type annotation is a TypeVar or Optional[TypeVar], get its name - # Otherwise, name it T_{param.name} - type_constraint_name = _get_type_constraint_name(type_) - if type_constraint_name is None: - type_constraint_name = f"T_{param.name}" - - # 2. If the type constraint param is already initialized, use it - if type_constraint_name in type_constraints: - type_constraint = type_constraints[type_constraint_name] - else: - # 3. Otherwise, create a new TypeConstraintParam - type_constraint = TypeConstraintParam( - name=type_constraint_name, - allowed_types=_get_allowed_types_from_type_annotation( - type_ - ), - ) - type_constraints[type_constraint_name] = type_constraint - # 4. Create Parameter - params.append( - Parameter( # type: ignore[arg-type] - name=param.name, - type_constraint=type_constraint, - required=param.default is inspect.Parameter.empty, - # TODO: Handle variadic - variadic=False, - default=param.default - if param.default is not inspect.Parameter.empty - else _EMPTY_DEFAULT, - ) - ) - - return_type = type_hints.get("return") - - outputs = [] - if return_type is None: - # No returns - pass - else: - if typing.get_origin(return_type) is tuple: - # Multiple returns - return_types = typing.get_args(return_type) - else: - return_types = [return_type] # type: ignore[assignment] - - for i, return_type_i in enumerate(return_types): - if ( - return_param_name := _get_type_constraint_name(return_type_i) - ) in type_constraints: - type_constraint = type_constraints[return_param_name] - else: - return_param_name = f"TReturn{i}" - type_constraint = TypeConstraintParam( - name=return_param_name, - allowed_types=_get_allowed_types_from_type_annotation( - return_type_i - ), - ) - type_constraints[return_param_name] = type_constraint - outputs.append( - Parameter( - name=return_param_name, - type_constraint=type_constraint, - required=True, - variadic=False, - default=_EMPTY_DEFAULT, - ) - ) - - return cls( - domain=domain, - name=name or func.__name__, - overload=overload, - params=params, - outputs=outputs, - ) diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py deleted file mode 100644 index cfe8f7dc2a661..0000000000000 --- a/torch/onnx/_internal/exporter/_tensors.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Subclass of ir.Value that supports Python operators.""" - -# mypy: allow-untyped-defs -from __future__ import annotations - -import onnxscript -from onnxscript import ir - - -class SymbolicTensor(ir.Value): - """A subclass of ir.Value that supports Python operators.""" - - def __init__( - self, - opset: onnxscript.values.Opset, - name: str | None = None, - shape: ir.Shape | None = None, - type: ir.TypeProtocol | None = None, - doc_string: str | None = None, - const_value: ir.TensorProtocol | None = None, - ): - super().__init__( - name=name, - shape=shape, - type=type, - doc_string=doc_string, - const_value=const_value, - ) - self._opset = opset - - @property - def rank(self) -> int | None: - if self.shape is None: - return None - return len(self.shape) - - # TODO: Implement indexing - - def __mod__(self, other): - if self.dtype in { - ir.DataType.FLOAT, - ir.DataType.DOUBLE, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - }: - return self._opset.Mod(self, other, fmod=1) - return self._opset.Mod(self, other) - - def __ne__(self, other): - return self._opset.Not(self._opset.Equal(self, other)) - - def __neg__(self): - return self._opset.Neg(self) - - def __add__(self, other): - return self._opset.Add(self, other) - - def __radd__(self, other): - return self._opset.Add(other, self) - - def __rand__(self, other): - return self._opset.And(other, self) - - def __mul__(self, other): - return self._opset.Mul(self, other) - - def __rmul__(self, other): - return self._opset.Mul(other, self) - - def __matmul__(self, other): - return self._opset.MatMul(self, other) - - def __pow__(self, other): - return self._opset.Pow(self, other) - - def __sub__(self, other): - return self._opset.Sub(self, other) - - def __rsub__(self, other): - return self._opset.Sub(other, self) - - def __truediv__(self, other): - return self._opset.Div(self, other) - - def __lt__(self, other): - return self._opset.Less(self, other) - - def __le__(self, other): - return self._opset.LessOrEqual(self, other) - - def __eq__(self, other): - return self._opset.Equal(self, other) - - def __ge__(self, other): - return self._opset.GreaterOrEqual(self, other) - - def __gt__(self, other): - return self._opset.Greater(self, other) diff --git a/torch/onnx/_internal/exporter/_verification.py b/torch/onnx/_internal/exporter/_verification.py deleted file mode 100644 index 00822ca8991b3..0000000000000 --- a/torch/onnx/_internal/exporter/_verification.py +++ /dev/null @@ -1,79 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import dataclasses -from typing import Any, TYPE_CHECKING - -import torch -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - from torch.onnx._internal.exporter import _onnx_program - - -@dataclasses.dataclass -class VerificationInfo: - name: str - absolute_difference: float - relative_difference: float - expected_dtype: torch.dtype - actual_dtype: torch.dtype - # NOTE: We don't need to include shape because the expected shape is already known - # and checked by the runtime - - -def _compare_tensors( - expected: torch.Tensor, - actual: torch.Tensor, -) -> tuple[float, float]: - # Move tensors to the same device - expected = expected.detach().cpu() - actual = actual.detach().cpu() - absolute_difference = torch.abs(expected - actual).max().item() - eps = 1e-7 - relative_difference = ( - (torch.abs(expected - actual) / (torch.abs(expected) + eps)).max().item() - ) - return absolute_difference, relative_difference - - -def verify_onnx_program( - onnx_program: _onnx_program.ONNXProgram, - args: tuple[Any, ...] | None = None, - kwargs: dict[str, Any] | None = None, -) -> list[VerificationInfo]: - exported_program = onnx_program.exported_program - if args is None and kwargs is None: - # User did not provide example inputs, use the default example inputs - if exported_program.example_inputs is None: - raise ValueError( - "No example inputs provided and the exported_program does not contain example inputs. " - "Please provide arguments to verify the ONNX program." - ) - args, kwargs = exported_program.example_inputs - if args is None: - args = () - if kwargs is None: - kwargs = {} - torch_module = exported_program.module() - torch_outputs, _ = pytree.tree_flatten(torch_module(*args, **kwargs)) - onnx_outputs = onnx_program(*args, **kwargs) - results = [] - for torch_output, onnx_output, output_val in zip( - torch_outputs, onnx_outputs, onnx_program.model.graph.outputs - ): - name = output_val.name - absolute_difference, relative_difference = _compare_tensors( - torch_output, onnx_output - ) - results.append( - VerificationInfo( - name=str(name), - absolute_difference=absolute_difference, - relative_difference=relative_difference, - expected_dtype=torch_output.dtype, - actual_dtype=onnx_output.dtype, - ) - ) - return results diff --git a/torch/onnx/_internal/exporter/errors.py b/torch/onnx/_internal/exporter/errors.py deleted file mode 100644 index a70eccf3a5633..0000000000000 --- a/torch/onnx/_internal/exporter/errors.py +++ /dev/null @@ -1,30 +0,0 @@ -class ExporterError(RuntimeError): - """Error during export.""" - - -class TorchExportError(ExporterError): - """Error during torch.export.export.""" - - -class OnnxConversionError(ExporterError): - """Error during ONNX conversion.""" - - -class DispatchError(OnnxConversionError): - """Error during ONNX Funtion dispatching.""" - - -class GraphConstructionError(OnnxConversionError): - """Error during graph construction.""" - - -class OnnxCheckerError(ExporterError): - """Error during ONNX model checking.""" - - -class OnnxRuntimeError(ExporterError): - """Error during ONNX Runtime execution.""" - - -class OnnxValidationError(ExporterError): - """Output value mismatch.""" diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index d380bb6ef8e62..8247ce3384660 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -554,7 +554,7 @@ def run( ) with diagnostic.log_section(logging.DEBUG, "ONNX Graph:"): - diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph) # type: ignore[attr-defined] + diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph) return onnxscript_graph @@ -655,7 +655,7 @@ def call_function( # function signature in OpSchema, and find the best matched overload. symbolic_fn = onnxfunction_dispatcher.dispatch( node=node, - onnx_args=onnx_args, # type: ignore[arg-type] + onnx_args=onnx_args, onnx_kwargs=onnx_kwargs, diagnostic_context=self.diagnostic_context, ) @@ -781,7 +781,7 @@ def call_module( outputs: onnxscript_graph_building.TorchScriptTensor | tuple[ onnxscript_graph_building.TorchScriptTensor, ... - ] = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] + ] = parent_onnxscript_graph.add_module_call( unique_module_name, sub_onnxscript_graph, onnx_args ) diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py index 8d01cf01c4ef1..a8f5c352fbe8e 100644 --- a/torch/onnx/_internal/fx/serialization.py +++ b/torch/onnx/_internal/fx/serialization.py @@ -61,7 +61,7 @@ def _create_tensor_proto_with_external_data( tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] tensor_proto.name = name - tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment] + tensor_proto.data_type = scalar_type.onnx_type() tensor_proto.dims.extend(tensor.shape) tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index c70a90e1e01f7..6a6526b18ee99 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -172,7 +172,10 @@ def _get_torch_export_args( def export( - model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, + model: torch.nn.Module + | torch.jit.ScriptModule + | torch.jit.ScriptFunction + | torch.export.ExportedProgram, args: tuple[Any, ...] | torch.Tensor, f: str | None = None, *, @@ -188,11 +191,13 @@ def export( dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, keep_initializers_as_inputs: bool | None = None, custom_opsets: Mapping[str, int] | None = None, export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, - autograd_inlining: bool = True, -) -> None: + autograd_inlining: bool | None = True, + dynamo: bool = False, +) -> torch.onnx.ONNXProgram | None: r"""Exports a model into ONNX format. If ``model`` is not a :class:`torch.jit.ScriptModule` nor a @@ -486,6 +491,8 @@ def forward(self, x): autograd_inlining: Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + dynamo: Whether to export the model with Dynamo instead of TorchScript. + Raises: :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it @@ -508,29 +515,65 @@ def forward(self, x): ) args = (args,) if isinstance(args, torch.Tensor) else args - if kwargs is not None: - args = args + (kwargs,) - _export( - model, - args, - f, - export_params, - verbose, - training, - input_names, - output_names, - operator_export_type=operator_export_type, - opset_version=opset_version, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - custom_opsets=custom_opsets, - export_modules_as_functions=export_modules_as_functions, - autograd_inlining=autograd_inlining, - ) + if dynamo: + if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)): + raise TypeError( + "Dynamo export does not support ScriptModule or ScriptFunction." + ) + # TODO(justinchuby): Remove the warning once logic migration is done + warnings.warn( + "export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " + "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " + "autograd_inlining are not supported for dynamo export at the moment." + ) + args, kwargs = _get_torch_export_args(args, kwargs) + if isinstance(model, torch.export.ExportedProgram): + exported_program = model + else: + if dynamic_shapes is None and dynamic_axes is not None: + dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( + model, dynamic_axes, input_names + ) + exported_program = torch.export.export( + model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type] + ) + if kwargs is None: + # TODO(justinchuby): dynamo_export requires kwargs to be unpacked. Once migration is done + # we can pass kwargs as None + kwargs = {} + onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs) + if f is not None: + onnx_program.save(f) + return onnx_program - return None + else: + # Torch Script export path + if f is None: + raise ValueError("Export destination must be specified when dynamo=False.") + if kwargs is not None: + args = args + (kwargs,) + + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + + return None def _is_constant_tensor_list(node): @@ -1488,7 +1531,7 @@ def _export( custom_opsets=None, add_node_names=True, onnx_shape_inference=True, - export_modules_as_functions: Any = False, + export_modules_as_functions=False, autograd_inlining=True, ): assert GLOBALS.in_onnx_export is False @@ -1517,7 +1560,9 @@ def _export( f"Exporting to ONNX opset version {opset_version} is not supported. " f"by 'torch.onnx.export()'. " f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " - f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", + f"To use a newer opset version, consider 'torch.onnx.dynamo_export()'. " + f"Note that dynamo_export() is in preview. Please report errors with " + f"dynamo_export() as Github issues to https://github.com/pytorch/pytorch/issues.", category=errors.OnnxExporterWarning, )