Skip to content

Commit

Permalink
Fix export tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Oct 30, 2024
1 parent e05e623 commit 3463b49
Showing 1 changed file with 29 additions and 48 deletions.
77 changes: 29 additions & 48 deletions intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,6 @@ def forward(self, x):
except Exception:
tb.print_exc()

######################################################################
# - unsupported Python language features (e.g. throwing exceptions, match statements)

class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x

try:
export(Bad4(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

######################################################################
# Non-Strict Export
Expand All @@ -197,16 +181,6 @@ def forward(self, x):
# ``strict=False`` flag.
#
# Looking at some of the previous examples which resulted in graph breaks:
#
# - Accessing tensor data with ``.data`` now works correctly

class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x

bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))

######################################################################
# - Calling unsupported functions (such as many built-in functions) traces
Expand All @@ -223,22 +197,6 @@ def forward(self, x):
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))

######################################################################
# - Unsupported Python language features (such as throwing exceptions, match
# statements) now also get traced through.

class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x

bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))


######################################################################
# However, there are still some features that require rewrites to the original
Expand Down Expand Up @@ -349,7 +307,7 @@ def forward(self, x, y):
# ``inp1`` has an unconstrained first dimension, but the size of the second
# dimension must be in the interval [4, 18].

from torch.export import Dim
from torch.export.dynamic_shapes import Dim

inp1 = torch.randn(10, 10, 2)

Expand All @@ -358,7 +316,7 @@ def forward(self, x):
x = x[:, 2:]
return torch.relu(x)

inp1_dim0 = Dim("inp1_dim0")
inp1_dim0 = Dim("inp1_dim0", max=50)
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
"x": {0: inp1_dim0, 1: inp1_dim1},
Expand Down Expand Up @@ -479,9 +437,7 @@ def forward(self, z, y):

class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
return x @ y

dynamic_shapes3 = {
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
Expand Down Expand Up @@ -536,6 +492,31 @@ def suggested_fixes():

print(exported_dynamic_shapes_example3.range_constraints)

######################################################################
# In PyTorch v2.5, we also introduced an automatic way of determining dynamic
# shapes. In the case where you don't know the dynamism of tensors, or the
# relationship of dynamic shapes between input tensors, we can mark dimensions
# with `Dim.AUTO`, and export will determine the dynamism the input dimensions.
# Going back to the previous example, we can rewrite it as follows:

inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)

class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
return x @ y

dynamic_shapes3_2 = {
"x": {i: Dim.AUTO for i in range(inp4.dim())},
"y": {i: Dim.AUTO for i in range(inp5.dim())},
}

exported_dynamic_shapes_example_3_2 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_2)
print(exported_dynamic_shapes_example_3_2)
breakpoint()



######################################################################
# Custom Ops
# ----------
Expand All @@ -548,7 +529,7 @@ def suggested_fixes():
# as with any other custom op

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
def custom_op(x: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)

Expand Down

0 comments on commit 3463b49

Please sign in to comment.