Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update API to torch.onnx.export(..., dynamo=True) #3223

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
54a48a5
Fix torchrl scripts for PT 2.6 TorchRL>=0.6 (#3199)
vmoens Dec 20, 2024
d9660ce
update api
titaiwangms Jan 8, 2025
b476c7e
add torch.cond
titaiwangms Jan 24, 2025
aa3fc9e
add torch.compiler.set_stance tutorial (#3225)
williamwen42 Jan 13, 2025
d3fec71
Revert "add torch.compiler.set_stance tutorial (#3225)" (#3231)
williamwen42 Jan 13, 2025
56c0006
update registry
titaiwangms Jan 24, 2025
7c77db6
fix
titaiwangms Jan 24, 2025
c29c22b
address formatting
titaiwangms Jan 24, 2025
6bca4e4
reformatting
titaiwangms Jan 24, 2025
6e416d1
words
titaiwangms Jan 27, 2025
846cf83
removed printout and algin titile format
titaiwangms Jan 27, 2025
1a2cc7a
refactor intro_onnx and simple_example
titaiwangms Jan 27, 2025
7dc050f
revert those in 2.6 but not yet cherry-picks
titaiwangms Jan 27, 2025
e104e01
add coding head
titaiwangms Jan 27, 2025
f620302
add space
titaiwangms Jan 27, 2025
eddc1a1
Merge branch 'main' into titaiwang/dynamo_true_api
svekars Jan 27, 2025
c33929a
Merge branch 'main' into titaiwang/dynamo_true_api
svekars Jan 27, 2025
159d6b0
Merge branch 'main' into titaiwang/dynamo_true_api
svekars Jan 28, 2025
3027ebe
address reviews
titaiwangms Jan 28, 2025
8adb7f3
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms Jan 28, 2025
eb89a8a
Remove dot for consistency
svekars Jan 28, 2025
3a1d4a3
fix ci
titaiwangms Jan 28, 2025
1cf9731
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms Jan 29, 2025
d3bb7e7
Merge branch 'main' into titaiwang/dynamo_true_api
svekars Feb 3, 2025
b64a434
fix misspelled words
titaiwangms Feb 3, 2025
d769747
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed _static/img/onnx/custom_addandround.png
Binary file not shown.
Binary file removed _static/img/onnx/custom_aten_gelu_model.png
Binary file not shown.
6 changes: 5 additions & 1 deletion beginner_source/onnx/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ ONNX
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

3. onnx_registry_tutorial.py
Extending the ONNX Registry
Extending the ONNX exporter operator support
https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html

4. export_control_flow_model_to_onnx_tutorial.py
Export a model with control flow to ONNX
https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html
172 changes: 172 additions & 0 deletions beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# -*- coding: utf-8 -*-
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
**`Export a model with control flow to ONNX**

Export a model with control flow to ONNX
========================================

**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
"""


###############################################################################
# Overview
# --------
#
# This tutorial demonstrates how to handle control flow logic while exporting
# a PyTorch model to ONNX. It highlights the challenges of exporting
# conditional statements directly and provides solutions to circumvent them.
#
# Conditional logic cannot be exported into ONNX unless they refactored
# to use :func:`torch.cond`. Let's start with a simple model
# implementing a test.

import torch

###############################################################################
# Define the Models
# -----------------
#
# Two models are defined:
#
# ForwardWithControlFlowTest: A model with a forward method containing an
# if-else conditional.
#
# ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest
# as part of a simple multi-layer perceptron (MLP). The models are tested with
# a random input tensor to confirm they execute as expected.

class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x


class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)

def forward(self, x):
out = self.mlp(x)
return out


model = ModelWithControlFlowTest()


###############################################################################
# Exporting the Model: First Attempt
# ----------------------------------
#
# Exporting this model using torch.export.export fails because the control
# flow logic in the forward pass creates a graph break that the exporter cannot
# handle. This behavior is expected, as conditional logic not written using
# torch.cond is unsupported.
#
# A try-except block is used to capture the expected failure during the export
# process. If the export unexpectedly succeeds, an AssertionError is raised.

x = torch.randn(3)
model(x)

try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)

###############################################################################
# Using torch.onnx.export with JIT Tracing
# ----------------------------------------
#
# When exporting the model using torch.onnx.export with the dynamo=True
# argument, the exporter defaults to using JIT tracing. This fallback allows
# the model to export, but the resulting ONNX graph may not faithfully represent
# the original model logic due to the limitations of tracing.


onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)


###############################################################################
# Suggested Patch: Refactoring with torch.cond
# --------------------------------------------
#
# To make the control flow exportable, the tutorial demonstrates replacing the
# forward method in ForwardWithControlFlowTest with a refactored version that
# uses torch.cond.
#
# Details of the Refactoring:
#
# Two helper functions (identity2 and neg) represent the branches of the conditional logic:
# * torch.cond is used to specify the condition and the two branches along with the input arguments.
# * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement.

def new_forward(x):
def identity2(x):
return x * 2

def neg(x):
return -x

return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward

###############################################################################
# Let's see what the fx graph looks like.

print(torch.export.export(model, (x,), strict=False))

###############################################################################
# Let's export again.

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)


###############################################################################
# We can optimize the model and get rid of the model local functions created to capture the control flow branches.

onnx_program.optimize()
print(onnx_program.model)

###############################################################################
# Conclusion
# ----------
#
# This tutorial demonstrates the challenges of exporting models with conditional
# logic to ONNX and presents a practical solution using torch.cond.
# While the default exporters may fail or produce imperfect graphs, refactoring the
# model's logic ensures compatibility and generates a faithful ONNX representation.
#
# By understanding these techniques, we can overcome common pitfalls when
# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
#
# Further reading
# ---------------
#
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
# not necessarily in the order they are listed.
# Feel free to jump directly to specific topics of your interest or
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
#
# .. include:: /beginner_source/onnx/onnx_toc.txt
#
# .. toctree::
# :hidden:
66 changes: 38 additions & 28 deletions beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
**Exporting a PyTorch model to ONNX** ||
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_

Export a PyTorch model to ONNX
==============================

**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre>`_
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_, Justin Chu (justinchu@microsoft.com) and Thiago Crepaldi <https://github.com/thiagocrepaldi>`_.

.. note::
As of PyTorch 2.1, there are two versions of ONNX Exporter.
As of PyTorch 2.5, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
* ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter using ``torch.export`` and Torch FX to capture the graph. It was released with PyTorch 2.5
* ``torch.onnx.export`` uses TorchScript and has been available since PyTorch 1.2.0

"""

###############################################################################
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter.
#
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
Expand All @@ -47,8 +48,7 @@
#
# .. code-block:: bash
#
# pip install onnx
# pip install onnxscript
# pip install --upgrade onnx onnxscript
#
# 2. Author a simple image classifier model
# -----------------------------------------
Expand All @@ -62,17 +62,16 @@
import torch.nn.functional as F


class MyModel(nn.Module):

class ImageClassifierModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
def forward(self, x: torch.Tensor):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
Expand All @@ -81,16 +80,27 @@ def forward(self, x):
x = self.fc3(x)
return x


######################################################################
# 3. Export the model to ONNX format
# ----------------------------------
#
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
# Next, we can export the model to ONNX format.

torch_model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
torch_model = ImageClassifierModel()
# Create example inputs for exporting the model. The inputs should be a tuple of tensors.
example_inputs = (torch.randn(1, 1, 32, 32),)
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)

######################################################################
# 3.5. (Optional) Optimize the ONNX model
# ---------------------------------------
#
# The ONNX model can be optimized with constant folding, and elimination of redundant nodes.
# The optimization is done in-place, so the original ONNX model is modified.

onnx_program.optimize()

######################################################################
# As we can see, we didn't need any code change to the model.
Expand All @@ -102,13 +112,14 @@ def forward(self, x):
# Although having the exported model loaded in memory is useful in many applications,
# we can save it to disk with the following code:

onnx_program.save("my_image_classifier.onnx")
onnx_program.save("image_classifier_model.onnx")

######################################################################
# You can load the ONNX file back into memory and check if it is well formed with the following code:

import onnx
onnx_model = onnx.load("my_image_classifier.onnx")

onnx_model = onnx.load("image_classifier_model.onnx")
onnx.checker.check_model(onnx_model)

######################################################################
Expand All @@ -124,7 +135,7 @@ def forward(self, x):
# :align: center
#
#
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after
# clicking the **Open model** button.
#
# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png
Expand Down Expand Up @@ -155,16 +166,15 @@ def forward(self, x):

import onnxruntime

onnx_input = [torch_input]
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
print(f"Input length: {len(onnx_inputs)}")
print(f"Sample input: {onnx_inputs}")

ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
ort_session = onnxruntime.InferenceSession(
"./image_classifier_model.onnx", providers=["CPUExecutionProvider"]
)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
onnxruntime_input = dict(zip(ort_session.get_inputs(), onnx_inputs))

# onnxruntime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]
Expand All @@ -179,7 +189,7 @@ def to_numpy(tensor):
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.

torch_outputs = torch_model(torch_input)
torch_outputs = torch_model(*example_inputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
Expand Down Expand Up @@ -209,4 +219,4 @@ def to_numpy(tensor):
#
# .. toctree::
# :hidden:
#
#
Loading