Skip to content

Commit

Permalink
add torch.compiler.set_stance tutorial (#3225)
Browse files Browse the repository at this point in the history
* add torch.compiler.set_stance tutorial

* Update recipes_source/torch_compiler_set_stance_tutorial.py

* Apply suggestions from code review

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>

* Apply suggestions from code review

---------

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
  • Loading branch information
williamwen42 and svekars authored Jan 13, 2025
1 parent b595db9 commit 5532f12
Showing 1 changed file with 244 additions and 0 deletions.
244 changes: 244 additions & 0 deletions recipes_source/torch_compiler_set_stance_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-

"""
Dynamic Compilation Control with ``torch.compiler.set_stance``
=========================================================================
**Author:** `William Wen <https://github.com/williamwen42>`_
"""

######################################################################
# ``torch.compiler.set_stance`` is a ``torch.compiler`` API that
# enables you to change the behavior of ``torch.compile`` across different
# calls to your model without having to reapply ``torch.compile`` to your model.
#
# This recipe provides some examples on how to use ``torch.compiler.set_stance``.
#
#
# .. contents::
# :local:
#
# Prerequisites
# ---------------
#
# - ``torch >= 2.6``

######################################################################
# Description
# -----------
# ``torch.compile.set_stance`` can be used as a decorator, context manager, or raw function
# to change the behavior of ``torch.compile`` across different calls to your model.
#
# In the example below, the ``"force_eager"`` stance ignores all ``torch.compile`` directives.

import torch


@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1


inp = torch.zeros(3)

print(foo(inp)) # compiled, prints 1

######################################################################
# Sample decorator usage


@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)


print(bar(inp)) # not compiled, prints -1

######################################################################
# Sample context manager usage

with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1

######################################################################
# Sample raw function usage

torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")

print(foo(inp)) # compiled, prints 1

######################################################################
# ``torch.compile`` stance can only be changed **outside** of any ``torch.compile`` region. Attempts
# to do otherwise will result in an error.


@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1


try:
baz(inp)
except Exception as e:
print(e)


@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1


@torch.compile
def outer(x):
# error!
return inner(x)


try:
outer(inp)
except Exception as e:
print(e)

######################################################################
# Other stances include:
# - ``"default"``: The default stance, used for normal compilation.
# - ``"eager_on_recompile"``: Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used.
# - ``"fail_on_recompile"``: Raise an error when recompiling a function.
#
# See the ``torch.compiler.set_stance`` `doc page <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__
# for more stances and options. More stances/options may also be added in the future.

######################################################################
# Examples
# --------

######################################################################
# Preventing recompilation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Some models do not expect any recompilations - for example, you may always have inputs with the same shape.
# Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases.
# The ``"fail_on_recompilation"`` stance can be used for this.


@torch.compile
def my_big_model(x):
return torch.relu(x)


# first compilation
my_big_model(torch.randn(3))

with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)

######################################################################
# If erroring out is too disruptive, we can use ``"eager_on_recompile"`` instead,
# which will cause ``torch.compile`` to fall back to eager instead of erroring out.
# This may be useful if we don't expect recompilations to happen frequently, but
# when one is required, we'd rather pay the cost of running eagerly over the cost of recompilation.


@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1


# first compilation
print(my_huge_model(torch.zeros(3))) # 1

with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1


######################################################################
# Measuring performance gains
# ===========================
#
# ``torch.compiler.set_stance`` can be used to compare eager vs. compiled performance
# without having to define a separate eager model.


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000


@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x


inps = torch.randn(5, 5), torch.randn(5, 5)

with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])

# warmups
for _ in range(3):
my_gigantic_model(*inps)

print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])


######################################################################
# Crashing sooner
# ===============
#
# Running an eager iteration first before a compiled iteration using the ``"force_eager"`` stance
# can help us to catch errors unrelated to ``torch.compile`` before attempting a very long compile.


@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)


try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)

########################################
# Conclusion
# --------------
# In this recipe, we have learned how to use the ``torch.compiler.set_stance`` API
# to modify the behavior of ``torch.compile`` across different calls to a model
# without needing to reapply it. The recipe demonstrates using
# ``torch.compiler.set_stance`` as a decorator, context manager, or raw function
# to control compilation stances like ``force_eager``, ``default``,
# ``eager_on_recompile``, and "fail_on_recompile."
#
# For more information, see: `torch.compiler.set_stance API documentation <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__.

0 comments on commit 5532f12

Please sign in to comment.