diff --git a/_static/img/onnx/custom_addandround.png b/_static/img/onnx/custom_addandround.png deleted file mode 100644 index d4973ce6c2..0000000000 Binary files a/_static/img/onnx/custom_addandround.png and /dev/null differ diff --git a/_static/img/onnx/custom_aten_gelu_model.png b/_static/img/onnx/custom_aten_gelu_model.png deleted file mode 100644 index 63186f3088..0000000000 Binary files a/_static/img/onnx/custom_aten_gelu_model.png and /dev/null differ diff --git a/beginner_source/onnx/README.txt b/beginner_source/onnx/README.txt index 6a598f9b9f..96004a239e 100644 --- a/beginner_source/onnx/README.txt +++ b/beginner_source/onnx/README.txt @@ -10,5 +10,9 @@ ONNX https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html 3. onnx_registry_tutorial.py - Extending the ONNX Registry + Extending the ONNX exporter operator support https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html + +4. export_control_flow_model_to_onnx_tutorial.py + Export a model with control flow to ONNX + https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html \ No newline at end of file diff --git a/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py b/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py new file mode 100644 index 0000000000..af1f43138c --- /dev/null +++ b/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +""" +`Introduction to ONNX `_ || +`Exporting a PyTorch model to ONNX `_ || +`Extending the ONNX exporter operator support `_ || +**`Export a model with control flow to ONNX** + +Export a model with control flow to ONNX +======================================== + +**Author**: `Xavier Dupré `_. +""" + + +############################################################################### +# Overview +# -------- +# +# This tutorial demonstrates how to handle control flow logic while exporting +# a PyTorch model to ONNX. It highlights the challenges of exporting +# conditional statements directly and provides solutions to circumvent them. +# +# Conditional logic cannot be exported into ONNX unless they refactored +# to use :func:`torch.cond`. Let's start with a simple model +# implementing a test. +# +# What you will learn: +# +# - How to refactor the model to use :func:`torch.cond` for exporting. +# - How to export a model with control flow logic to ONNX. +# - How to optimize the exported model using the ONNX optimizer. +# +# Prerequisites +# ~~~~~~~~~~~~~ +# +# * ``torch >= 2.6`` + + +import torch + +############################################################################### +# Define the Models +# ----------------- +# +# Two models are defined: +# +# ``ForwardWithControlFlowTest``: A model with a forward method containing an +# if-else conditional. +# +# ``ModelWithControlFlowTest``: A model that incorporates ``ForwardWithControlFlowTest`` +# as part of a simple MLP. The models are tested with +# a random input tensor to confirm they execute as expected. + +class ForwardWithControlFlowTest(torch.nn.Module): + def forward(self, x): + if x.sum(): + return x * 2 + return -x + + +class ModelWithControlFlowTest(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(3, 2), + torch.nn.Linear(2, 1), + ForwardWithControlFlowTest(), + ) + + def forward(self, x): + out = self.mlp(x) + return out + + +model = ModelWithControlFlowTest() + + +############################################################################### +# Exporting the Model: First Attempt +# ---------------------------------- +# +# Exporting this model using torch.export.export fails because the control +# flow logic in the forward pass creates a graph break that the exporter cannot +# handle. This behavior is expected, as conditional logic not written using +# :func:`torch.cond` is unsupported. +# +# A try-except block is used to capture the expected failure during the export +# process. If the export unexpectedly succeeds, an ``AssertionError`` is raised. + +x = torch.randn(3) +model(x) + +try: + torch.export.export(model, (x,), strict=False) + raise AssertionError("This export should failed unless PyTorch now supports this model.") +except Exception as e: + print(e) + +############################################################################### +# Using :func:`torch.onnx.export` with JIT Tracing +# ---------------------------------------- +# +# When exporting the model using :func:`torch.onnx.export` with the dynamo=True +# argument, the exporter defaults to using JIT tracing. This fallback allows +# the model to export, but the resulting ONNX graph may not faithfully represent +# the original model logic due to the limitations of tracing. + + +onnx_program = torch.onnx.export(model, (x,), dynamo=True) +print(onnx_program.model) + + +############################################################################### +# Suggested Patch: Refactoring with :func:`torch.cond` +# -------------------------------------------- +# +# To make the control flow exportable, the tutorial demonstrates replacing the +# forward method in ``ForwardWithControlFlowTest`` with a refactored version that +# uses :func:`torch.cond``. +# +# Details of the Refactoring: +# +# Two helper functions (identity2 and neg) represent the branches of the conditional logic: +# * :func:`torch.cond`` is used to specify the condition and the two branches along with the input arguments. +# * The updated forward method is then dynamically assigned to the ``ForwardWithControlFlowTest`` instance within the model. A list of submodules is printed to confirm the replacement. + +def new_forward(x): + def identity2(x): + return x * 2 + + def neg(x): + return -x + + return torch.cond(x.sum() > 0, identity2, neg, (x,)) + + +print("the list of submodules") +for name, mod in model.named_modules(): + print(name, type(mod)) + if isinstance(mod, ForwardWithControlFlowTest): + mod.forward = new_forward + +############################################################################### +# Let's see what the FX graph looks like. + +print(torch.export.export(model, (x,), strict=False)) + +############################################################################### +# Let's export again. + +onnx_program = torch.onnx.export(model, (x,), dynamo=True) +print(onnx_program.model) + + +############################################################################### +# We can optimize the model and get rid of the model local functions created to capture the control flow branches. + +onnx_program.optimize() +print(onnx_program.model) + +############################################################################### +# Conclusion +# ---------- +# +# This tutorial demonstrates the challenges of exporting models with conditional +# logic to ONNX and presents a practical solution using :func:`torch.cond`. +# While the default exporters may fail or produce imperfect graphs, refactoring the +# model's logic ensures compatibility and generates a faithful ONNX representation. +# +# By understanding these techniques, we can overcome common pitfalls when +# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows. +# +# Further reading +# --------------- +# +# The list below refers to tutorials that ranges from basic examples to advanced scenarios, +# not necessarily in the order they are listed. +# Feel free to jump directly to specific topics of your interest or +# sit tight and have fun going through all of them to learn all there is about the ONNX exporter. +# +# .. include:: /beginner_source/onnx/onnx_toc.txt +# +# .. toctree:: +# :hidden: +# \ No newline at end of file diff --git a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py index 5a8ac9c538..8de76aa705 100644 --- a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py +++ b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py @@ -2,18 +2,19 @@ """ `Introduction to ONNX `_ || **Exporting a PyTorch model to ONNX** || -`Extending the ONNX Registry `_ +`Extending the ONNX exporter operator support `_ || +`Export a model with control flow to ONNX `_ Export a PyTorch model to ONNX ============================== -**Author**: `Ti-Tai Wang `_ and `Xavier Dupré `_ +**Author**: `Ti-Tai Wang `_, Justin Chu (justinchu@microsoft.com) and Thiago Crepaldi `_. .. note:: - As of PyTorch 2.1, there are two versions of ONNX Exporter. + As of PyTorch 2.5, there are two versions of ONNX Exporter. - * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0 - * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0 + * ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter using ``torch.export`` and Torch FX to capture the graph. It was released with PyTorch 2.5 + * ``torch.onnx.export`` uses TorchScript and has been available since PyTorch 1.2.0 """ @@ -21,7 +22,7 @@ # In the `60 Minute Blitz `_, # we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images. # In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the -# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter. +# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter. # # While PyTorch is great for iterating on the development of models, the model can be deployed to production # using different formats, including `ONNX `_ (Open Neural Network Exchange)! @@ -47,8 +48,7 @@ # # .. code-block:: bash # -# pip install onnx -# pip install onnxscript +# pip install --upgrade onnx onnxscript # # 2. Author a simple image classifier model # ----------------------------------------- @@ -62,17 +62,16 @@ import torch.nn.functional as F -class MyModel(nn.Module): - +class ImageClassifierModel(nn.Module): def __init__(self): - super(MyModel, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) - def forward(self, x): + def forward(self, x: torch.Tensor): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = torch.flatten(x, 1) @@ -81,6 +80,7 @@ def forward(self, x): x = self.fc3(x) return x + ###################################################################### # 3. Export the model to ONNX format # ---------------------------------- @@ -88,9 +88,19 @@ def forward(self, x): # Now that we have our model defined, we need to instantiate it and create a random 32x32 input. # Next, we can export the model to ONNX format. -torch_model = MyModel() -torch_input = torch.randn(1, 1, 32, 32) -onnx_program = torch.onnx.dynamo_export(torch_model, torch_input) +torch_model = ImageClassifierModel() +# Create example inputs for exporting the model. The inputs should be a tuple of tensors. +example_inputs = (torch.randn(1, 1, 32, 32),) +onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True) + +###################################################################### +# 3.5. (Optional) Optimize the ONNX model +# --------------------------------------- +# +# The ONNX model can be optimized with constant folding, and elimination of redundant nodes. +# The optimization is done in-place, so the original ONNX model is modified. + +onnx_program.optimize() ###################################################################### # As we can see, we didn't need any code change to the model. @@ -102,13 +112,14 @@ def forward(self, x): # Although having the exported model loaded in memory is useful in many applications, # we can save it to disk with the following code: -onnx_program.save("my_image_classifier.onnx") +onnx_program.save("image_classifier_model.onnx") ###################################################################### # You can load the ONNX file back into memory and check if it is well formed with the following code: import onnx -onnx_model = onnx.load("my_image_classifier.onnx") + +onnx_model = onnx.load("image_classifier_model.onnx") onnx.checker.check_model(onnx_model) ###################################################################### @@ -124,7 +135,7 @@ def forward(self, x): # :align: center # # -# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after +# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after # clicking the **Open model** button. # # .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png @@ -155,16 +166,15 @@ def forward(self, x): import onnxruntime -onnx_input = [torch_input] -print(f"Input length: {len(onnx_input)}") -print(f"Sample input: {onnx_input}") +onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs] +print(f"Input length: {len(onnx_inputs)}") +print(f"Sample input: {onnx_inputs}") -ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider']) +ort_session = onnxruntime.InferenceSession( + "./image_classifier_model.onnx", providers=["CPUExecutionProvider"] +) -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} +onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)} # onnxruntime returns a list of outputs onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] @@ -179,7 +189,7 @@ def to_numpy(tensor): # For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's. # Before comparing the results, we need to convert the PyTorch's output to match ONNX's format. -torch_outputs = torch_model(torch_input) +torch_outputs = torch_model(*example_inputs) assert len(torch_outputs) == len(onnxruntime_outputs) for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): @@ -209,4 +219,4 @@ def to_numpy(tensor): # # .. toctree:: # :hidden: -# \ No newline at end of file +# diff --git a/beginner_source/onnx/intro_onnx.py b/beginner_source/onnx/intro_onnx.py index 194f971261..7172461e9e 100644 --- a/beginner_source/onnx/intro_onnx.py +++ b/beginner_source/onnx/intro_onnx.py @@ -1,13 +1,14 @@ """ **Introduction to ONNX** || `Exporting a PyTorch model to ONNX `_ || -`Extending the ONNX Registry `_ +`Extending the ONNX exporter operator support `_ || +`Export a model with control flow to ONNX `_ Introduction to ONNX ==================== Authors: -`Ti-Tai Wang `_ and `Xavier Dupré `_ +`Ti-Tai Wang `_ and Thiago Crepaldi `_. `Open Neural Network eXchange (ONNX) `_ is an open standard format for representing machine learning models. The ``torch.onnx`` module provides APIs to @@ -19,21 +20,20 @@ including Microsoft's `ONNX Runtime `_. .. note:: - Currently, there are two flavors of ONNX exporter APIs, - but this tutorial will focus on the ``torch.onnx.dynamo_export``. + Currently, you can choose either through `TorchScript https://pytorch.org/docs/stable/jit.html`_ or + `ExportedProgram https://pytorch.org/docs/stable/export.html`_ to export the model to ONNX by the + boolean parameter dynamo in `torch.onnx.export `_. + In this tutorial, we will focus on the ``ExportedProgram`` approach. -The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its -bytecode into an `FX graph `_. -The resulting FX Graph is polished before it is finally translated into an -`ONNX graph `_. - -The main advantage of this approach is that the `FX graph `_ is captured using -bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. +When setting ``dynamo=True``, the exporter will use `torch.export `_ to capture an ``ExportedProgram``, +before translating the graph into ONNX representations. This approach is the new and recommended way to export models to ONNX. +It works with PyTorch 2.0 features more robustly, has better support for newer ONNX opsets, and consumes less resources +to make exporting larger models possible. Dependencies ------------ -PyTorch 2.1.0 or newer is required. +PyTorch 2.5.0 or newer is required. The ONNX exporter depends on extra Python packages: @@ -58,8 +58,6 @@ import onnxscript print(onnxscript.__version__) - from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now - import onnxruntime print(onnxruntime.__version__) @@ -78,4 +76,4 @@ .. toctree:: :hidden: -""" +""" \ No newline at end of file diff --git a/beginner_source/onnx/onnx_registry_tutorial.py b/beginner_source/onnx/onnx_registry_tutorial.py index 56c1d0c99a..63b89675e9 100644 --- a/beginner_source/onnx/onnx_registry_tutorial.py +++ b/beginner_source/onnx/onnx_registry_tutorial.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- - """ `Introduction to ONNX `_ || `Exporting a PyTorch model to ONNX `_ || -**Extending the ONNX Registry** +**Extending the ONNX exporter operator support** || +`Export a model with control flow to ONNX `_ -Extending the ONNX Registry -=========================== +Extending the ONNX Exporter Operator Support +============================================ -**Authors:** Ti-Tai Wang (titaiwang@microsoft.com) +**Authors:** Ti-Tai Wang (titaiwang@microsoft.com), Justin Chu (justinchu@microsoft.com) """ @@ -16,288 +16,242 @@ # Overview # -------- # -# This tutorial is an introduction to ONNX registry, which empowers users to implement new ONNX operators -# or even replace existing operators with a new implementation. +# This tutorial describes how you can create ONNX implementation for unsupported PyTorch operators +# or replace existing implementation with your own. +# +# We will cover three scenarios that require extending the ONNX exporter's operator support: +# +# * Overriding the implementation of an existing PyTorch operator +# * Using custom ONNX operators +# * Supporting a custom PyTorch operator +# +# What you will learn: +# +# - How to override or add support for PyTorch operators in ONNX. +# - How to integrate custom ONNX operators for specialized runtimes. +# - How to implement and translate custom PyTorch operators to ONNX. +# +# Prerequisites +# ~~~~~~~~~~~~~ +# +# Before starting this tutorial, make sure you have completed the following prerequisites: +# +# * ``torch >= 2.6`` +# * The target PyTorch operator +# * Completed the +# `ONNX Script tutorial `_ +# before proceeding +# * The implementation of the operator using `ONNX Script `__ +# +# Overriding the implementation of an existing PyTorch operator +# ------------------------------------------------------------- # -# During the model export to ONNX, the PyTorch model is lowered to an intermediate -# representation composed of `ATen operators `_. -# While ATen operators are maintained by PyTorch core team, it is the responsibility of the ONNX exporter team -# to independently implement each of these operators to ONNX through `ONNX Script `_. -# The users can also replace the behavior implemented by the ONNX exporter team with their own implementation -# to fix bugs or improve performance for a specific ONNX runtime. +# Although the ONNX exporter team does their best efforts to support all PyTorch operators, some of them +# might not be supported yet. In this section, we will demonstrate how you can add +# unsupported PyTorch operators to the ONNX Registry. # -# The ONNX Registry manages the mapping between PyTorch operators and the ONNX operators counterparts and provides -# APIs to extend the registry. +# .. note:: +# The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing +# PyTorch operator with a custom one. +# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage +# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would +# if the operator was not implemented by the ONNX exporter. +# +# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message +# similar to: # -# In this tutorial, we will cover three scenarios that require extending the ONNX registry with custom operators: +# .. code-block:: python # -# * Custom operators with existing ONNX Runtime support -# * Custom operators without ONNX Runtime support +# No decompositions registered for [...] # +# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``. +# The operator is of type ````, and this operator is what we will use as the +# target to register our custom implementation. import torch -import onnxruntime import onnxscript -from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now +# Opset 18 is the standard supported version as of PyTorch 2.6 +from onnxscript import opset18 as op + + +# Create a model that uses the operator torch.ops.aten.add.Tensor +class Model(torch.nn.Module): + def forward(self, input_x, input_y): + return torch.ops.aten.add.Tensor(input_x, input_y) + + +# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator. +# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml +# All attributes must be annotated with type hints. +def custom_aten_add(self, other, alpha: float = 1.0): + if alpha != 1.0: + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) + # To distinguish the custom implementation from the builtin one, we switch the order of the inputs + return op.Add(other, self) + + +x = torch.tensor([1.0]) +y = torch.tensor([2.0]) + +# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``. +onnx_program = torch.onnx.export( + Model().eval(), + (x, y), + dynamo=True, + custom_translation_table={ + torch.ops.aten.add.Tensor: custom_aten_add, + }, +) +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() ###################################################################### -# Custom operators with existing ONNX Runtime support -# --------------------------------------------------- -# -# In this case, the user creates a model with standard PyTorch operators, but the ONNX runtime -# (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the -# existing implementation in the ONNX Registry. Another use case is when the user wants to use a custom implementation -# of an existing ONNX operator to fix a bug or improve performance of a specific operator. -# To achieve this, we only need to register the new implementation with the existing ATen fully qualified name. -# -# In the following example, we use the ``com.microsoft.Gelu`` from ONNX Runtime, -# which is not the same ``Gelu`` from ONNX spec. Thus, we register the Gelu with -# the namespace ``com.microsoft`` and operator name ``Gelu``. +# Now let's inspect the model and verify the model is using the custom implementation. + +print(onnx_program.model) + +###################################################################### +# The translation is using our custom implementation: In node ``node_Add_0``, ``input_y`` now +# comes first, and ``input_x`` comes second. # -# Before we begin, let's check whether ``aten::gelu.default`` is really supported by the ONNX registry. +# We can use ONNX Runtime to run the model and verify the results by calling +# the :class:`torch.onnx.ONNXProgram` directly on the input tensors. -onnx_registry = torch.onnx.OnnxRegistry() -print(f"aten::gelu.default is supported by ONNX registry: \ - {onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}") +result = onnx_program(x, y)[0] +torch.testing.assert_close(result, torch.tensor([3.0])) ###################################################################### -# In our example, ``aten::gelu.default`` operator is supported by the ONNX registry, -# so :meth:`onnx_registry.is_registered_op` returns ``True``. +# Using custom ONNX operators +# --------------------------- +# +# In this case, we create a model with standard PyTorch operators, but the runtime +# (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the +# existing implementation. +# +# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime, +# which is not the same ``Gelu`` from ONNX spec. + -class CustomGelu(torch.nn.Module): +class GeluModel(torch.nn.Module): def forward(self, input_x): return torch.ops.aten.gelu(input_x) -# com.microsoft is an official ONNX Runtime namspace -custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1) -# NOTE: The function signature must match the signature of the unsupported ATen operator. +# Create a namespace for the custom operator using ONNX Script +# ``com.microsoft`` is an official ONNX Runtime namespace +microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1) + +# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator. # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml # NOTE: All attributes must be annotated with type hints. -@onnxscript.script(custom_ort) -def custom_aten_gelu(input_x, approximate: str = "none"): - # We know com.microsoft::Gelu is supported by ONNX Runtime - # It's only not supported by ONNX - return custom_ort.Gelu(input_x) +# The function must be scripted using the ``@onnxscript.script()`` decorator when +# using operators from custom domains. This may be improved in future versions. +from onnxscript import FLOAT + +@onnxscript.script(microsoft_op) +def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT: + return microsoft_op.Gelu(self) -onnx_registry = torch.onnx.OnnxRegistry() -onnx_registry.register_op( - namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu) -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) -aten_gelu_model = CustomGelu() -input_gelu_x = torch.randn(3, 3) +onnx_program = torch.onnx.export( + GeluModel().eval(), + (x,), + dynamo=True, + custom_translation_table={ + torch.ops.aten.gelu.default: custom_aten_gelu, + }, +) -onnx_program = torch.onnx.dynamo_export( - aten_gelu_model, input_gelu_x, export_options=export_options - ) +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() ###################################################################### # Let's inspect the model and verify the model uses op_type ``Gelu`` # from namespace ``com.microsoft``. # -# .. note:: -# :func:`custom_aten_gelu` does not exist in the graph because -# functions with fewer than three operators are inlined automatically. -# - -# graph node domain is the custom domain we registered -assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft" -# graph node name is the function name -assert onnx_program.model_proto.graph.node[0].op_type == "Gelu" +print(onnx_program.model) ###################################################################### -# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron, -# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function: -# -# .. image:: /_static/img/onnx/custom_aten_gelu_model.png -# -# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model, -# and compare the results with PyTorch. -# - -onnx_program.save("./custom_gelu_model.onnx") -ort_session = onnxruntime.InferenceSession( - "./custom_gelu_model.onnx", providers=['CPUExecutionProvider'] - ) - -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - -onnx_input = [input_gelu_x] -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} -onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] +# Similar to the previous example, we can use ONNX Runtime to run the model and verify the results. -torch_outputs = aten_gelu_model(input_gelu_x) +result = onnx_program(x)[0] +torch.testing.assert_close(result, torch.ops.aten.gelu(x)) -assert len(torch_outputs) == len(onnxruntime_outputs) -for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): - torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) ###################################################################### -# Custom operators without ONNX Runtime support -# --------------------------------------------- +# Supporting a custom PyTorch operator +# ------------------------------------ # -# In this case, the operator is not supported by any ONNX runtime, but we -# would like to use it as custom operator in ONNX graph. Therefore, we need to implement -# the operator in three places: -# -# 1. PyTorch FX graph -# 2. ONNX Registry -# 3. ONNX Runtime +# In this case, the operator is an operator that is user implemented and registered to PyTorch. # # In the following example, we would like to use a custom operator # that takes one tensor input, and returns one output. The operator adds # the input to itself, and returns the rounded result. # -# -# Custom Ops Registration in PyTorch FX Graph (Beta) -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# Firstly, we need to implement the operator in PyTorch FX graph. -# This can be done by using ``torch._custom_op``. -# - -# NOTE: This is a beta feature in PyTorch, and is subject to change. -from torch._custom_op import impl as custom_op - -@custom_op.custom_op("mylibrary::addandround_op") -def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor: - ... - -@addandround_op.impl_abstract() -def addandround_op_impl_abstract(tensor_x): - return torch.empty_like(tensor_x) - -@addandround_op.impl("cpu") -def addandround_op_impl(tensor_x): - return torch.round(tensor_x + tensor_x) # add x to itself, and round the result +# Firstly, we assume the custom operator is implemented and registered with ``torch.library.custom_op()``. +# You can refer to `Creating new custom ops in Python `_ +# for a detailed guide on how to create custom operators. -torch._dynamo.allow_in_graph(addandround_op) -class CustomFoo(torch.nn.Module): - def forward(self, tensor_x): - return addandround_op(tensor_x) +# Define and use the operator in PyTorch +@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=()) +def add_and_round_op(input: torch.Tensor) -> torch.Tensor: + return torch.round(input + input) -input_addandround_x = torch.randn(3) -custom_addandround_model = CustomFoo() - -###################################################################### -# -# Custom Ops Registration in ONNX Registry -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# For the step 2 and 3, we need to implement the operator in ONNX registry. -# In this example, we will implement the operator in ONNX registry -# with the namespace ``test.customop`` and operator name ``CustomOpOne``, -# and ``CustomOpTwo``. These two ops are registered and built in -# `cpu_ops.cc `__. -# +@add_and_round_op.register_fake +def _add_and_round_op_fake(tensor_x): + return torch.empty_like(tensor_x) -custom_opset = onnxscript.values.Opset(domain="test.customop", version=1) +class AddAndRoundModel(torch.nn.Module): + def forward(self, input): + return add_and_round_op(input) -# NOTE: The function signature must match the signature of the unsupported ATen operator. -# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml -# NOTE: All attributes must be annotated with type hints. -@onnxscript.script(custom_opset) -def custom_addandround(input_x): - # The same as opset18.Add(x, x) - add_x = custom_opset.CustomOpOne(input_x, input_x) - # The same as opset18.Round(x, x) - round_x = custom_opset.CustomOpTwo(add_x) - # Cast to FLOAT to match the ONNX type - return opset18.Cast(round_x, to=1) +# Implement the custom operator in ONNX using ONNX Script +def onnx_add_and_round(input): + return op.Round(op.Add(input, input)) -onnx_registry = torch.onnx.OnnxRegistry() -onnx_registry.register_op( - namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround - ) -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) -onnx_program = torch.onnx.dynamo_export( - custom_addandround_model, input_addandround_x, export_options=export_options - ) -onnx_program.save("./custom_addandround_model.onnx") +onnx_program = torch.onnx.export( + AddAndRoundModel().eval(), + (x,), + dynamo=True, + custom_translation_table={ + torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round, + }, +) +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() +print(onnx_program) ###################################################################### -# The ``onnx_program`` exposes the exported model as protobuf through ``onnx_program.model_proto``. -# The graph has one graph nodes for ``custom_addandround``, and inside ``custom_addandround``, -# there are two function nodes, one for each operator. +# The translation is using our custom implementation to translate the ``torch.ops.mylibrary.add_and_round_op.default`` +# operator in the :class:`torch.export.ExportedProgram`` to the ONNX operator ``Add`` and ``Round``. # -assert onnx_program.model_proto.graph.node[0].domain == "test.customop" -assert onnx_program.model_proto.graph.node[0].op_type == "CustomOpOne" -assert onnx_program.model_proto.graph.node[1].domain == "test.customop" -assert onnx_program.model_proto.graph.node[1].op_type == "CustomOpTwo" +###################################################################### +# Finally we verify the results. +result = onnx_program(x)[0] +torch.testing.assert_close(result, add_and_round_op(x)) ###################################################################### -# This is how ``custom_addandround_model`` ONNX graph looks using Netron. -# We can see the two custom operators we used in the function (``CustomOpOne``, and ``CustomOpTwo``), -# and they are from module ``test.customop``: -# -# .. image:: /_static/img/onnx/custom_addandround.png -# -# Custom Ops Registration in ONNX Runtime -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# To link your custom op library to ONNX Runtime, you need to -# compile your C++ code into a shared library and link it to ONNX Runtime. -# Follow the instructions below: -# -# 1. Implement your custom op in C++ by following -# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__. -# 2. Download ONNX Runtime source distribution from -# `ONNX Runtime releases `__. -# 3. Compile and link your custom op library to ONNX Runtime, for example: -# -# .. code-block:: bash -# -# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC -# -# 4. Run the model with ONNX Runtime Python API and compare the results with PyTorch. -# -# .. code-block:: python -# -# ort_session_options = onnxruntime.SessionOptions() -# -# # NOTE: Link the custom op library to ONNX Runtime and replace the path -# # with the path to your custom op library -# ort_session_options.register_custom_ops_library( -# "/path/to/libcustom_op_library.so" -# ) -# ort_session = onnxruntime.InferenceSession( -# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options) -# -# def to_numpy(tensor): -# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -# -# onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x) -# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} -# onnxruntime_outputs = ort_session.run(None, onnxruntime_input) -# -# torch_outputs = custom_addandround_model(input_addandround_x) -# torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) -# -# assert len(torch_outputs) == len(onnxruntime_outputs) -# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): -# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) -# # Conclusion # ---------- # -# Congratulations! In this tutorial, we explored the :class:`ONNXRegistry` API and -# discovered how to create custom implementations for unsupported or existing ATen operators +# Congratulations! In this tutorial, we explored the ``custom_translation_table`` option and +# discovered how to create custom implementations for unsupported or existing PyTorch operators # using ONNX Script. +# # Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch, # providing us with a comprehensive understanding of handling unsupported # operators in the ONNX ecosystem. diff --git a/beginner_source/onnx/onnx_toc.txt b/beginner_source/onnx/onnx_toc.txt index 674f7752c5..ac293fbedd 100644 --- a/beginner_source/onnx/onnx_toc.txt +++ b/beginner_source/onnx/onnx_toc.txt @@ -1,2 +1,3 @@ | 1. `Exporting a PyTorch model to ONNX `_ -| 2. `Extending the ONNX registry `_ +| 2. `Extending the ONNX exporter operator support `_ +| 3. `Export a model with control flow to ONNX `_ \ No newline at end of file