diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 7d87331481..ed790a1887 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -50,7 +50,6 @@ "intermediate_source/flask_rest_api_tutorial", "intermediate_source/text_to_speech_with_torchaudio", "intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release. - "intermediate_source/torch_export_tutorial", # reenable after 2940 is fixed. ] def tutorial_source_dirs() -> List[Path]: diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index c992eefa9f..3ca6d09a52 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -45,17 +45,18 @@ # .. code-block:: python # # export( -# f: Callable, +# mod: torch.nn.Module, # args: Tuple[Any, ...], # kwargs: Optional[Dict[str, Any]] = None, # *, # dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None # ) -> ExportedProgram # -# ``torch.export.export()`` traces the tensor computation graph from calling ``f(*args, **kwargs)`` +# ``torch.export.export()`` traces the tensor computation graph from calling ``mod(*args, **kwargs)`` # and wraps it in an ``ExportedProgram``, which can be serialized or executed later with -# different inputs. Note that while the output ``ExportedGraph`` is callable and can be -# called in the same way as the original input callable, it is not a ``torch.nn.Module``. +# different inputs. To execute the ``ExportedProgram`` we can call ``.module()`` +# on it to return a ``torch.nn.Module`` which is callable, just like the +# original program. # We will detail the ``dynamic_shapes`` argument later in the tutorial. import torch @@ -80,30 +81,15 @@ def forward(self, x, y): # # The ``graph`` attribute is an `FX graph `__ # traced from the function we exported, that is, the computation graph of all PyTorch operations. -# The FX graph has some important properties: +# The FX graph is in "ATen IR" meaning that it contains only "ATen-level" operations. # -# - The operations are "ATen-level" operations. -# - The graph is "functionalized", meaning that no operations are mutations. +# The ``graph_signature`` attribute gives a more detailed description of the +# input and output nodes in the exported graph, describing which ones are +# parameters, buffers, user inputs, or user outputs. # -# The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute -# so that it can be ran as a ``torch.nn.Module``. +# The ``range_constraints`` attributes will be covered later. print(exported_mod) -print(exported_mod.graph_module) - -###################################################################### -# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``) -# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)`` -# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate. -# Future uses of input to the original mutating ``relu`` op are replaced by the additional new output -# of the replacement non-mutating ``relu`` op. -# -# Other attributes of interest in ``ExportedProgram`` include: -# -# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph. -# - ``range_constraints`` -- constraints, covered later - -print(exported_mod.graph_signature) ###################################################################### # See the ``torch.export`` `documentation `__ @@ -163,32 +149,16 @@ def forward(self, x): except Exception: tb.print_exc() -###################################################################### -# - unsupported Python language features (e.g. throwing exceptions, match statements) - -class Bad4(torch.nn.Module): - def forward(self, x): - try: - x = x + 1 - raise RuntimeError("bad") - except: - x = x + 2 - return x - -try: - export(Bad4(), (torch.randn(3, 3),)) -except Exception: - tb.print_exc() ###################################################################### # Non-Strict Export # ----------------- # -# To trace the program, ``torch.export`` uses TorchDynamo, a byte code analysis -# engine, to symbolically analyze the Python code and build a graph based on the -# results. This analysis allows ``torch.export`` to provide stronger guarantees -# about safety, but not all Python code is supported, causing these graph -# breaks. +# To trace the program, ``torch.export`` uses TorchDynamo by default, a byte +# code analysis engine, to symbolically analyze the Python code and build a +# graph based on the results. This analysis allows ``torch.export`` to provide +# stronger guarantees about safety, but not all Python code is supported, +# causing these graph breaks. # # To address this issue, in PyTorch 2.3, we introduced a new mode of # exporting called non-strict mode, where we trace through the program using the @@ -197,16 +167,6 @@ def forward(self, x): # ``strict=False`` flag. # # Looking at some of the previous examples which resulted in graph breaks: -# -# - Accessing tensor data with ``.data`` now works correctly - -class Bad2(torch.nn.Module): - def forward(self, x): - x.data[0, 0] = 3 - return x - -bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False) -print(bad2_nonstrict.module()(torch.ones(3, 3))) ###################################################################### # - Calling unsupported functions (such as many built-in functions) traces @@ -223,22 +183,6 @@ def forward(self, x): print(bad3_nonstrict) print(bad3_nonstrict.module()(torch.ones(3, 3))) -###################################################################### -# - Unsupported Python language features (such as throwing exceptions, match -# statements) now also get traced through. - -class Bad4(torch.nn.Module): - def forward(self, x): - try: - x = x + 1 - raise RuntimeError("bad") - except: - x = x + 2 - return x - -bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False) -print(bad4_nonstrict.module()(torch.ones(3, 3))) - ###################################################################### # However, there are still some features that require rewrites to the original @@ -252,17 +196,16 @@ def forward(self, x): # But these need to be expressed using control flow ops. For example, # we can fix the control flow example above using the ``cond`` op, like so: -from functorch.experimental.control_flow import cond - class Bad1Fixed(torch.nn.Module): def forward(self, x): def true_fn(x): return torch.sin(x) def false_fn(x): return torch.cos(x) - return cond(x.sum() > 0, true_fn, false_fn, [x]) + return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),)) +print(exported_bad1_fixed) print(exported_bad1_fixed.module()(torch.ones(3, 3))) print(exported_bad1_fixed.module()(-torch.ones(3, 3))) @@ -280,25 +223,27 @@ def false_fn(x): # For more details about ``cond``, check out the `cond documentation `__. ###################################################################### -# .. -# [NOTE] map is not documented at the moment -# We can also use ``map``, which applies a function across the first dimension -# of the first tensor argument. -# -# from functorch.experimental.control_flow import map -# -# def map_example(xs): -# def map_fn(x, const): -# def true_fn(x): -# return x + const -# def false_fn(x): -# return x - const -# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x]) -# return control_flow.map(map_fn, xs, torch.tensor([2.0])) -# -# exported_map_example= export(map_example, (torch.randn(4, 3),)) -# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3))) -# print(exported_map_example(inp)) +# We can also use ``map``, which applies a function across the first dimension +# of the first tensor argument. + +from torch._higher_order_ops.map import map as torch_map + +class MapModule(torch.nn.Module): + def forward(self, xs, y, z): + def body(x, y, z): + return x + y + z + + return torch_map(body, xs, y, z) + +inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4)) +exported_map_example = export(MapModule(), inps) +print(exported_map_example) +print(exported_map_example.module()(*inps)) + +###################################################################### +# Other control flow ops include ``while_loop``, ``associative_scan``, and +# ``scan``. For more documentation on each operator, please refer to +# `this page `__. ###################################################################### # Constraints/Dynamic Shapes @@ -337,7 +282,10 @@ def forward( model = DynamicModel() ep = export(model, (w, x, y, z)) model(w, x, torch.randn(3, 4), torch.randn(12)) -ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) +try: + ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) +except Exception: + tb.print_exc() ###################################################################### # Basic concepts: symbols and guards @@ -466,7 +414,10 @@ def forward( # static guard is emitted on a dynamically-marked dimension: dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) -export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) +try: + export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) +except Exception: + tb.print_exc() ###################################################################### # Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape @@ -476,7 +427,10 @@ def forward( dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO) dynamic_shapes["x"] = (Dim.STATIC,) dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC) -export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) +try: + export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) +except Exception: + tb.print_exc() ###################################################################### # Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because @@ -494,7 +448,7 @@ def __init__(self): def forward(self, w, x, y, z): assert w.shape[0] <= 512 - torch._check(x.shape[0] >= 16) + torch._check(x.shape[0] >= 4) if w.shape[0] == x.shape[0] + 2: x0 = x + y x1 = self.l(w) @@ -510,8 +464,10 @@ def forward(self, w, x, y, z): "y": (Dim.AUTO, Dim.AUTO), "z": (Dim.AUTO,), } -ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) -print(ep) +try: + ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) +except Exception: + tb.print_exc() ###################################################################### # Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``, @@ -540,7 +496,10 @@ def forward(self, w, x, y, z): "input": (Dim.AUTO, Dim.STATIC), }, ) -ep.module()(torch.randn(2, 4)) +try: + ep.module()(torch.randn(2, 4)) +except Exception: + tb.print_exc() ###################################################################### # Named Dims @@ -594,14 +553,17 @@ def forward(self, x, y): return w + torch.ones(4) dx, dy, d1 = torch.export.dims("dx", "dy", "d1") -ep = export( - Foo(), - (torch.randn(6, 4), torch.randn(6, 4)), - dynamic_shapes={ - "x": (dx, d1), - "y": (dy, d1), - }, -) +try: + ep = export( + Foo(), + (torch.randn(6, 4), torch.randn(6, 4)), + dynamic_shapes={ + "x": (dx, d1), + "y": (dy, d1), + }, + ) +except Exception: + tb.print_exc() ###################################################################### # The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards. @@ -743,7 +705,10 @@ def forward(self, x, y): torch.tensor(32), torch.randn(60), ) -export(Foo(), inps) +try: + export(Foo(), inps) +except Exception: + tb.print_exc() ###################################################################### # Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with @@ -755,7 +720,7 @@ class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() torch._check(a >= 0) - torch._check(a <= y.shape[0]) + torch._check(a < y.shape[0]) return y[a] inps = ( @@ -787,7 +752,10 @@ def forward(self, x, y): torch.tensor(32), torch.randn(60), ) -export(Foo(), inps, strict=False) +try: + export(Foo(), inps, strict=False) +except Exception: + tb.print_exc() ###################################################################### # For these errors, some basic options you have are: @@ -818,28 +786,26 @@ def forward(self, x, y): # Custom Ops # ---------- # -# ``torch.export`` can export PyTorch programs with custom operators. -# -# Currently, the steps to register a custom op for use by ``torch.export`` are: +# ``torch.export`` can export PyTorch programs with custom operators. Please +# refer to `this page `__ +# on how to author a custom operator in either C++ or Python. # -# - Define the custom op using ``torch.library`` (`reference `__) -# as with any other custom op +# The following is an example of registering a custom operator in python to be +# used by ``torch.export``. The important thing to note is that the custom op +# must have a `FakeTensor kernel `__. @torch.library.custom_op("my_custom_library::custom_op", mutates_args={}) -def custom_op(input: torch.Tensor) -> torch.Tensor: +def custom_op(x: torch.Tensor) -> torch.Tensor: print("custom_op called!") return torch.relu(x) -###################################################################### -# - Define a ``"Meta"`` implementation of the custom op that returns an empty -# tensor with the same shape as the expected output - -@custom_op.register_fake +@custom_op.register_fake def custom_op_meta(x): + # Returns an empty tensor with the same shape as the expected output return torch.empty_like(x) ###################################################################### -# - Call the custom op from the code you want to export using ``torch.ops`` +# Here is an example of exporting a program with the custom op. class CustomOpExample(torch.nn.Module): def forward(self, x): @@ -848,30 +814,27 @@ def forward(self, x): x = torch.cos(x) return x -###################################################################### -# - Export the code as before - exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),)) -exported_custom_op_example.graph_module.print_readable() +print(exported_custom_op_example) print(exported_custom_op_example.module()(torch.randn(3, 3))) ###################################################################### -# Note in the above outputs that the custom op is included in the exported graph. -# And when we call the exported graph as a function, the original custom op is called, -# as evidenced by the ``print`` call. -# -# If you have a custom operator implemented in C++, please refer to -# `this document `__ -# to make it compatible with ``torch.export``. +# Note that in the ``ExportedProgram``, the custom operator is included in the graph. ###################################################################### -# Decompositions -# -------------- +# IR/Decompositions +# ----------------- # -# The graph produced by ``torch.export`` by default returns a graph containing -# only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000 -# operators, all of which are functional, that is, they do not -# mutate or alias inputs. You can find a list of all ATen operators +# The graph produced by ``torch.export`` returns a graph containing only +# `ATen operators `__, which are the +# basic unit of computation in PyTorch. As there are over 3000 ATen operators, +# export provides a way to narrow down the operator set used in the graph based +# on certain characteristics, creating different IRs. +# +# By default, export produces the most generic IR which contains all ATen +# operators, including both functional and non-functional operators. A functional +# operator is one that does not contain any mutations or aliasing of the inputs. +# You can find a list of all ATen operators # `here `__ # and you can inspect if an operator is functional by checking # ``op._schema.is_mutable``, for example: @@ -880,77 +843,78 @@ def forward(self, x): print(torch.ops.aten.add_.Tensor._schema.is_mutable) ###################################################################### -# By default, the environment in which you want to run the exported graph -# should support all ~2000 of these operators. -# However, you can use the following API on the exported program -# if your specific environment is only able to support a subset of -# the ~2000 operators. -# -# .. code-block:: python -# -# def run_decompositions( -# self: ExportedProgram, -# decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] -# ) -> ExportedProgram -# -# ``run_decompositions`` takes in a decomposition table, which is a mapping of -# operators to a function specifying how to reduce, or decompose, that operator -# into an equivalent sequence of other ATen operators. -# -# The default decomposition table for ``run_decompositions`` is the -# `Core ATen decomposition table `__ -# which will decompose the all ATen operators to the -# `Core ATen Operator Set `__ -# which consists of only ~180 operators. +# This generic IR can be used to train in eager PyTorch Autograd. This IR can be +# more explicitly reached through the API ``torch.export.export_for_training``, +# which was introduced in PyTorch 2.5, but calling ``torch.export.export`` +# should produce the same graph as of PyTorch 2.6. -class M(torch.nn.Module): - def __init__(self): +class DecompExample(torch.nn.Module): + def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(3, 4) + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) def forward(self, x): - return self.linear(x) + x = self.conv(x) + x = self.bn(x) + return (x,) + +ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),)) +print(ep_for_training.graph) + +###################################################################### +# We can then lower this exported program to an operator set which only contains +# functional ATen operators through the API ``run_decompositions``, which +# decomposes the ATen operators into the ones specified in the decomposition +# table, and functionalizes the graph. By specifying an empty set, we're only +# performing functionalization, and does not do any additional decompositions. +# This results in an IR which contains ~2000 operators (instead of the 3000 +# operators above), and is ideal for inference cases. -ep = export(M(), (torch.randn(2, 3),)) -print(ep.graph) +ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) +print(ep_for_inference.graph) -core_ir_ep = ep.run_decompositions() -print(core_ir_ep.graph) +###################################################################### +# As we can see, the previously mutable operator, +# ``torch.ops.aten.add_.default`` has now been replaced with +# ``torch.ops.aten.add.default``, a l operator. ###################################################################### -# Notice that after running ``run_decompositions`` the -# ``torch.ops.aten.t.default`` operator, which is not part of the Core ATen -# Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part -# of the Core ATen Opset. -# -# Most ATen operators already have decompositions, which are located -# `here `__. -# If you would like to use some of these existing decomposition functions, -# you can pass in a list of operators you would like to decompose to the -# `get_decompositions `__ -# function, which will return a decomposition table using existing -# decomposition implementations. +# We can also further lower this exported program to an operator set which only +# contains the +# `Core ATen Operator Set `__, +# which is a collection of only ~180 operators. This IR is optimal for backends +# who do not want to reimplement all ATen operators. -class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) +from torch.export import default_decompositions - def forward(self, x): - return self.linear(x) +core_aten_decomp_table = default_decompositions() +core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table) +print(core_aten_ep.graph) + +###################################################################### +# We now see that ``torch.ops.aten.conv2d.default`` has been decomposed +# into ``torch.ops.aten.convolution.default``. This is because ``convolution`` +# is a more "core" operator, as operations like ``conv1d`` and ``conv2d`` can be +# implemented using the same op. + +###################################################################### +# We can also specify our own decomposition behaviors: + +my_decomp_table = torch.export.default_decompositions() -ep = export(M(), (torch.randn(2, 3),)) -print(ep.graph) +def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) -from torch._decomp import get_decompositions -decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int]) -core_ir_ep = ep.run_decompositions(decomp_table) -print(core_ir_ep.graph) +my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function +my_ep = ep_for_training.run_decompositions(my_decomp_table) +print(my_ep.graph) ###################################################################### -# If there is no existing decomposition function for an ATen operator that you would -# like to decompose, feel free to send a pull request into PyTorch -# implementing the decomposition! +# Notice that instead of ``torch.ops.aten.conv2d.default`` being decomposed +# into ``torch.ops.aten.convolution.default``, it is now decomposed into +# ``torch.ops.aten.convolution.default`` and ``torch.ops.aten.mul.Tensor``, +# which matches our custom decomposition rule. ###################################################################### # ExportDB @@ -1024,18 +988,18 @@ def forward(self, x): ###################################################################### # .. code-block:: python # -# import torch._export # import torch._inductor # # # Note: these APIs are subject to change -# # Compile the exported program to a .so using ``AOTInductor`` +# # Compile the exported program to a PT2 archive using ``AOTInductor`` # with torch.no_grad(): -# so_path = torch._inductor.aot_compile(ep.module(), [inp]) +# pt2_path = torch._inductor.aoti_compile_and_package(ep) # # # Load and run the .so file in Python. # # To load and run it in a C++ environment, see: # # https://pytorch.org/docs/main/torch.compiler_aot_inductor.html -# res = torch._export.aot_load(so_path, device="cuda")(inp) +# aoti_compiled = torch._inductor.aoti_load_package(pt2_path) +# res = aoti_compiled(inp) ###################################################################### # Conclusion