From 9edb9fb59243697beb98b8e752bf3981bbcb485f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 6 Jan 2025 07:58:02 -0800 Subject: [PATCH 1/4] new branch --- intermediate_source/torch_export_tutorial.py | 471 +++++++++++-------- 1 file changed, 274 insertions(+), 197 deletions(-) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index dc5e226f86..0450bae482 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -3,7 +3,7 @@ """ torch.export Tutorial =================================================== -**Author:** William Wen, Zhengxu Chen, Angela Yi +**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan """ ###################################################################### @@ -11,7 +11,7 @@ # .. warning:: # # ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility -# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3. +# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.5. # # :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into # standardized model representations, intended @@ -190,7 +190,7 @@ def forward(self, x): # about safety, but not all Python code is supported, causing these graph # breaks. # -# To address this issue, in PyTorch 2.3, we introduced a new mode of +# To address this issue, in PyTorch 2.5, we introduced a new mode of # exporting called non-strict mode, where we trace through the program using the # Python interpreter executing it exactly as it would in eager mode, allowing us # to skip over unsupported Python features. This is done through adding a @@ -304,237 +304,314 @@ def false_fn(x): # Constraints/Dynamic Shapes # -------------------------- # -# Ops can have different specializations/behaviors for different tensor shapes, so by default, -# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective -# example inputs given to the initial ``torch.export.export()`` call. -# If we try to run the ``ExportedProgram`` in the example below with a tensor -# with a different shape, we get an error: +# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is +# subjective to the particular model being exported, so for the most part of this tutorial, we'll focus +# on this particular toy model (with the sample input shapes annotated): -class MyModule2(torch.nn.Module): +class DynamicModel(torch.nn.Module): def __init__(self): super().__init__() - self.lin = torch.nn.Linear(100, 10) + self.l = torch.nn.Linear(5, 3) + + def forward( + self, + w: torch.Tensor, # [6, 5] + x: torch.Tensor, # [4] + y: torch.Tensor, # [8, 4] + z: torch.Tensor, # [32] + ): + x0 = x + y # output shape: [8, 4] + x1 = self.l(w) # [6, 3] + x2 = x0.flatten() # [32] + x3 = x2 + z # [32] + return x1, x3 + +###################################################################### +# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime, +# the program won't work on inputs with different shapes, even if they're valid in eager mode. + +w = torch.randn(6, 5) +x = torch.randn(4) +y = torch.randn(8, 4) +z = torch.randn(32) +model = DynamicModel() +ep = export(model, (w, x, y, z)) +model(w, x, torch.randn(3, 4), torch.randn(12)) +ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) + +###################################################################### +# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with +# dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified +# at a input dimension-level; for each input we can specify a tuple of values: + +from torch.export.dynamic_shapes import Dim + +dynamic_shapes = { + "w": (Dim.AUTO, Dim.AUTO), + "x": (Dim.AUTO,), + "y": (Dim.AUTO, Dim.AUTO), + "z": (Dim.AUTO,), +} +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) - def forward(self, x, y): - return torch.nn.functional.relu(self.lin(x + y), inplace=True) +###################################################################### +# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, +# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is +# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the +# 0/1 specialization section). +# +# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit +# what's called "guards"; basically boolean condition that are required to be true for the program to be valid. +# When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid; +# i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards +# and producing a final program representation that adheres to all of these guards. Before we see this "final representation" in +# an ``ExportedProgram``, let's look at the guards emitted by the toy model we're tracing. +# +# Here, each forward input tensor is annotated with the symbol allocated at the start of tracing: -mod2 = MyModule2() -exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100))) +class DynamicModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(5, 3) -try: - exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100)) -except Exception: - tb.print_exc() + def forward( + self, + w: torch.Tensor, # [s0, s1] + x: torch.Tensor, # [s2] + y: torch.Tensor, # [s3, s4] + z: torch.Tensor, # [s5] + ): + x0 = x + y # guard: s2 == s4 + x1 = self.l(w) # guard: s1 == 5 + x2 = x0.flatten() + x3 = x2 + z # guard: s3 * s4 == s5 + return x1, x3 ###################################################################### -# We can relax this constraint using the ``dynamic_shapes`` argument of -# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim`` -# (`documentation `__), -# which dimensions of the input tensors are dynamic. +# Let's understand each of the operations and the emitted guards: # -# For each tensor argument of the input callable, we can specify a mapping from the dimension -# to a ``torch.export.Dim``. -# A ``torch.export.Dim`` is essentially a named symbolic integer with optional -# minimum and maximum bounds. +# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``. +# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``. +# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes) +# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``. # -# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping -# from the input callable's tensor argument names, to dimension --> dim mappings as described above. -# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is -# assumed to be static. +# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes +# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid: # -# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging. -# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show a usage example. +# - ``w: [s0, 5]`` +# - ``x: [s2]`` +# - ``y: [s3, s2]`` +# - ``z: [s2*s3]`` # -# In the example below, our input -# ``inp1`` has an unconstrained first dimension, but the size of the second -# dimension must be in the interval [4, 18]. - -from torch.export import Dim - -inp1 = torch.randn(10, 10, 2) +# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the +# corresponding inputs: -class DynamicShapesExample1(torch.nn.Module): - def forward(self, x): - x = x[:, 2:] - return torch.relu(x) - -inp1_dim0 = Dim("inp1_dim0") -inp1_dim1 = Dim("inp1_dim1", min=4, max=18) -dynamic_shapes1 = { - "x": {0: inp1_dim0, 1: inp1_dim1}, -} - -exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1) - -print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2))) - -try: - exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2)) -except Exception: - tb.print_exc() - -try: - exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2)) -except Exception: - tb.print_exc() - -try: - exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3)) -except Exception: - tb.print_exc() +print(ep) ###################################################################### -# Note that if our example inputs to ``torch.export`` do not satisfy the constraints -# given by ``dynamic_shapes``, then we get an error. - -inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18) -dynamic_shapes1_bad = { - "x": {0: inp1_dim0, 1: inp1_dim1_bad}, -} +# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't +# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has +# a generic bound, but this will come up later. +# +# So far, because we've been exporting this toy model, this experience has not been representative of how hard +# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted, +# and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards +# are rather intuitive. +# +# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment +# variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``: -try: - export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad) -except Exception: - tb.print_exc() +torch._logging.set_logs(dynamic=10) +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) ###################################################################### -# We can enforce that equalities between dimensions of different tensors -# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication: - -inp2 = torch.randn(4, 8) -inp3 = torch.randn(8, 2) - -class DynamicShapesExample2(torch.nn.Module): - def forward(self, x, y): - return x @ y - -inp2_dim0 = Dim("inp2_dim0") -inner_dim = Dim("inner_dim") -inp3_dim1 = Dim("inp3_dim1") - -dynamic_shapes2 = { - "x": {0: inp2_dim0, 1: inner_dim}, - "y": {0: inner_dim, 1: inp3_dim1}, -} - -exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2) - -print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4))) +# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant +# to what we described above; e.g. the allocation of symbols: -try: - exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2)) -except Exception: - tb.print_exc() +""" +I1210 16:20:19.720000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +I1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +V1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:6535] [1/0] runtime_assert True == True [statically known] +I1210 16:20:19.727000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +I1210 16:20:19.729000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +I1210 16:20:19.731000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +""" ###################################################################### -# We can also describe one dimension in terms of other. There are some -# restrictions to how detailed we can specify one dimension in terms of another, -# but generally, those in the form of ``A * Dim + B`` should work. - -class DerivedDimExample1(torch.nn.Module): - def forward(self, x, y): - return x + y[1:] - -foo = DerivedDimExample1() - -x, y = torch.randn(5), torch.randn(6) -dimx = torch.export.Dim("dimx", min=3, max=6) -dimy = dimx + 1 -derived_dynamic_shapes1 = ({0: dimx}, {0: dimy}) - -derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1) - -print(derived_dim_example1.module()(torch.randn(4), torch.randn(5))) - -try: - derived_dim_example1.module()(torch.randn(4), torch.randn(6)) -except Exception: - tb.print_exc() - - -class DerivedDimExample2(torch.nn.Module): - def forward(self, z, y): - return z[1:] + y[1::3] - -foo = DerivedDimExample2() - -z, y = torch.randn(4), torch.randn(10) -dx = torch.export.Dim("dx", min=3, max=6) -dz = dx + 1 -dy = dx * 3 + 1 -derived_dynamic_shapes2 = ({0: dz}, {0: dy}) +# Or the guards emitted: -derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2) -print(derived_dim_example2.module()(torch.randn(7), torch.randn(19))) +""" +I1210 16:20:19.743000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" +I1210 16:20:19.754000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" +I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" +""" ###################################################################### -# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints -# are necessary. We can do this by relaxing all constraints (recall that if we -# do not provide constraints for a dimension, the default behavior is to constrain -# to the exact shape value of the example input) and letting ``torch.export`` -# error out. - -inp4 = torch.randn(8, 16) -inp5 = torch.randn(16, 32) +# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough. +# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations +# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate +# is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further +# attribute the guard of interest. +# +# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available: +# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all +# ways except one: it raises an error when specializing to a constant; designed to maintain dynamism. See for example what happens when a +# static guard is emitted on a dynamically-marked dimension: -class DynamicShapesExample3(torch.nn.Module): - def forward(self, x, y): - if x.shape[0] <= 16: - return x @ y[:, :16] - return y +dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) -dynamic_shapes3 = { - "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())}, - "y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())}, -} +###################################################################### +# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape +# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is +# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``: -try: - export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3) -except Exception: - tb.print_exc() +dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO) +dynamic_shapes["x"] = (Dim.STATIC,) +dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC) +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) ###################################################################### -# We can see that the error message gives us suggested fixes to our -# dynamic shape constraints. Let us follow those suggestions (exact -# suggestions may differ slightly): - -def suggested_fixes(): - inp4_dim1 = Dim('shared_dim') - # suggested fixes below - inp4_dim0 = Dim('inp4_dim0', max=16) - inp5_dim1 = Dim('inp5_dim1', min=17) - inp5_dim0 = inp4_dim1 - # end of suggested fixes - return { - "x": {0: inp4_dim0, 1: inp4_dim1}, - "y": {0: inp5_dim0, 1: inp5_dim1}, - } +# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because +# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile +# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to +# specialization. +# +# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and ``if/else`` conditions will also emit guards. +# See what happens when we augment the existing model with such statements: -dynamic_shapes3_fixed = suggested_fixes() -exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) -print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64))) +class DynamicModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(5, 3) + + def forward(self, w, x, y, z): + assert w.shape[0] <= 512 + torch._check(x.shape[0] >= 16) + if w.shape[0] == x.shape[0] + 2: + x0 = x + y + x1 = self.l(w) + x2 = x0.flatten() + x3 = x2 + z + return x1, x3 + else: + return w + +dynamic_shapes = { + "w": (Dim.AUTO, Dim.AUTO), + "x": (Dim.AUTO,), + "y": (Dim.AUTO, Dim.AUTO), + "z": (Dim.AUTO,), +} +ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) +print(ep) + +###################################################################### +# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``, +# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``. +# +# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that +# got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken. +# If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch. +# Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches +# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. +# +# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. +# The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that +# don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should +# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens +# at runtime when we export this linear layer: + +ep = export( + torch.nn.Linear(4, 3), + (torch.randn(1, 4),), + dynamic_shapes={ + "input": (Dim.AUTO, Dim.STATIC), + }, +) +ep.module()(torch.randn(2, 4)) + +###################################################################### +# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the +# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic +# dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards +# and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or +# beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at +# all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels. +# These changes won't be detected and the ``export()`` call will most likely succeed, unless tests are in place that check the resulting ``ExportedProgram`` representation. +# +# For such cases, our stance is to recommend the "traditional" way of specifying dynamic shapes, which longer-term users of export might be familiar with: named ``Dims``: + +dx = Dim("dx", min=4, max=256) +dh = Dim("dh", max=512) +dynamic_shapes = { + "x": (dx, None), + "y": (2 * dx, dh), +} ###################################################################### -# Note that in the example above, because we constrained the value of ``x.shape[0]`` in -# ``dynamic_shapes_example3``, the exported program is sound even though there is a -# raw ``if`` statement. +# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the +# dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic +# specifications given. For example, in the above specification, the following is asserted: # -# If you want to see why ``torch.export`` generated these constraints, you can -# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``, -# or use ``torch._logging.set_logs``. - -import logging -torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO) -exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) +# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``. +# - ``x.shape[1]`` is static. +# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension. +# +# In this design, we allow relations between dimensions to be specified with univariate linear expressions: ``A * dim + B`` can be specified for any dimension. This allows users +# to specify more complex constraints like integer divisibility for dynamic dimensions: -# reset to previous values -torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING) +dx = Dim("dx", min=4, max=512) +dynamic_shapes = { + "x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4. +} ###################################################################### -# We can view an ``ExportedProgram``'s symbolic shape ranges using the -# ``range_constraints`` field. +# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing. +# That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between +# dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static. -print(exported_dynamic_shapes_example3.range_constraints) +class Foo(torch.nn.Module): + def forward(self, x, y): + w = x + y + return w + torch.ones(4) + +dx, dy, d1 = torch.export.dims("dx", "dy", "d1") +ep = export( + Foo(), + (torch.randn(6, 4), torch.randn(6, 4)), + dynamic_shapes={ + "x": (dx, d1), + "y": (dy, d1), + }, +) + +###################################################################### +# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards. +# +# Lastly, there's couple nice-to-knows about the options for specification: +# +# - ``None`` is a good option for static behavior: +# - ``dynamic_shapes=None`` (default) exports with the entire model being static. +# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs. +# - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``. +# - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification. +# +# These options are combined in the inputs & dynamic shapes spec below: + +inputs = ( + torch.randn(4, 4), + torch.randn(3, 3), + 16, + False, +) +dynamic_shapes = { + "tensor_0": (Dim.AUTO, None), + "tensor_1": None, + "int_val": None, + "bool_val": None, +} ###################################################################### # Custom Ops From fd678819fdaf3afcefbf978c43ee74f549fb6c24 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 8 Jan 2025 07:20:06 -0800 Subject: [PATCH 2/4] changes --- intermediate_source/torch_export_tutorial.py | 65 ++++++++++++-------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 0450bae482..8000ca4604 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -190,7 +190,7 @@ def forward(self, x): # about safety, but not all Python code is supported, causing these graph # breaks. # -# To address this issue, in PyTorch 2.5, we introduced a new mode of +# To address this issue, in PyTorch 2.3, we introduced a new mode of # exporting called non-strict mode, where we trace through the program using the # Python interpreter executing it exactly as it would in eager mode, allowing us # to skip over unsupported Python features. This is done through adding a @@ -306,7 +306,7 @@ def false_fn(x): # # This section covers dynamic behavior and representation of exported programs. Dynamic behavior is # subjective to the particular model being exported, so for the most part of this tutorial, we'll focus -# on this particular toy model (with the sample input shapes annotated): +# on this particular toy model (with the resulting tensor shapes annotated): class DynamicModel(torch.nn.Module): def __init__(self): @@ -320,14 +320,14 @@ def forward( y: torch.Tensor, # [8, 4] z: torch.Tensor, # [32] ): - x0 = x + y # output shape: [8, 4] + x0 = x + y # [8, 4] x1 = self.l(w) # [6, 3] x2 = x0.flatten() # [32] x3 = x2 + z # [32] return x1, x3 ###################################################################### -# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime, +# By default, ``torch.export`` produces a static program. One consequence of this is that at runtime, # the program won't work on inputs with different shapes, even if they're valid in eager mode. w = torch.randn(6, 5) @@ -339,6 +339,9 @@ def forward( model(w, x, torch.randn(3, 4), torch.randn(12)) ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) +# Basic concepts: symbols and guards +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ###################################################################### # To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with # dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified @@ -357,7 +360,8 @@ def forward( ###################################################################### # Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, # and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is -# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the +# `allocated `, +# taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the # 0/1 specialization section). # # Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit @@ -383,7 +387,7 @@ def forward( ): x0 = x + y # guard: s2 == s4 x1 = self.l(w) # guard: s1 == 5 - x2 = x0.flatten() + x2 = x0.flatten() # no guard added here x3 = x2 + z # guard: s3 * s4 == s5 return x1, x3 @@ -425,26 +429,28 @@ def forward( ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) ###################################################################### -# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant -# to what we described above; e.g. the allocation of symbols: +# This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end +# to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above; +# e.g. the allocation of symbols: """ -I1210 16:20:19.720000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" -I1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" -V1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:6535] [1/0] runtime_assert True == True [statically known] -I1210 16:20:19.727000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" -I1210 16:20:19.729000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" -I1210 16:20:19.731000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" -I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" +create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) +create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) +runtime_assert True == True [statically known] +create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) +create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) +create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) +create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) """ ###################################################################### -# Or the guards emitted: +# The lines with `create_symbol` show when a new symbol has been allocated, and the logs also identify the tensor variable names +# and dimensions they've been allocated for. In other lines we can also see the guards emitted: """ -I1210 16:20:19.743000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" -I1210 16:20:19.754000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" -I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" +runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" +runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" +runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" """ ###################################################################### @@ -456,14 +462,14 @@ def forward( # # ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available: # ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all -# ways except one: it raises an error when specializing to a constant; designed to maintain dynamism. See for example what happens when a +# ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a # static guard is emitted on a dynamically-marked dimension: dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) ###################################################################### -# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape +# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape # specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is # raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``: @@ -473,12 +479,12 @@ def forward( export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) ###################################################################### -# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because +# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because # of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile # treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to # specialization. # -# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and ``if/else`` conditions will also emit guards. +# One feature of export is that during tracing, statements like asserts, ``torch._check()``, and ``if/else`` conditions will also emit guards. # See what happens when we augment the existing model with such statements: class DynamicModel(torch.nn.Module): @@ -516,7 +522,10 @@ def forward(self, w, x, y, z): # If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch. # Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches # alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. -# + +# 0/1 specialization +# ^^^^^^^^^^^^^^^^^^ + # Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. # The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that # don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should @@ -532,6 +541,9 @@ def forward(self, w, x, y, z): ) ep.module()(torch.randn(2, 4)) +# Named Dims +# ^^^^^^^^^^ + ###################################################################### # So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the # low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic @@ -567,6 +579,9 @@ def forward(self, w, x, y, z): "x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4. } +# Constraint violations, suggested fixes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ###################################################################### # One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing. # That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between @@ -594,7 +609,7 @@ def forward(self, x, y): # # - ``None`` is a good option for static behavior: # - ``dynamic_shapes=None`` (default) exports with the entire model being static. -# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs. +# - specifying ``None`` at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs. # - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``. # - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification. # From 82efe6b3611f7db238b5f05247f8e40d7e8ae501 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 8 Jan 2025 23:07:32 +0700 Subject: [PATCH 3/4] Update torch_export_tutorial.py --- intermediate_source/torch_export_tutorial.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 8000ca4604..c69baf6451 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -339,6 +339,7 @@ def forward( model(w, x, torch.randn(3, 4), torch.randn(12)) ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) +###################################################################### # Basic concepts: symbols and guards # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -360,7 +361,7 @@ def forward( ###################################################################### # Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, # and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is -# `allocated `, +# `allocated `_, # taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the # 0/1 specialization section). # @@ -523,6 +524,7 @@ def forward(self, w, x, y, z): # Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches # alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. +###################################################################### # 0/1 specialization # ^^^^^^^^^^^^^^^^^^ @@ -541,6 +543,7 @@ def forward(self, w, x, y, z): ) ep.module()(torch.randn(2, 4)) +###################################################################### # Named Dims # ^^^^^^^^^^ @@ -579,6 +582,7 @@ def forward(self, w, x, y, z): "x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4. } +###################################################################### # Constraint violations, suggested fixes # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From b720e328d1a68c02bad0b2d6f961965577414313 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 8 Jan 2025 23:12:10 +0700 Subject: [PATCH 4/4] Update torch_export_tutorial.py --- intermediate_source/torch_export_tutorial.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index c69baf6451..9acacf5362 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -342,8 +342,7 @@ def forward( ###################################################################### # Basic concepts: symbols and guards # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -###################################################################### +# # To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with # dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified # at a input dimension-level; for each input we can specify a tuple of values: @@ -527,7 +526,7 @@ def forward(self, w, x, y, z): ###################################################################### # 0/1 specialization # ^^^^^^^^^^^^^^^^^^ - +# # Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. # The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that # don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should @@ -546,8 +545,7 @@ def forward(self, w, x, y, z): ###################################################################### # Named Dims # ^^^^^^^^^^ - -###################################################################### +# # So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the # low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic # dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards @@ -585,8 +583,7 @@ def forward(self, w, x, y, z): ###################################################################### # Constraint violations, suggested fixes # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -###################################################################### +# # One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing. # That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between # dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static.