Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update outdated docs for torch logs #3127

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 87 additions & 37 deletions recipes_source/torch_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,52 +31,102 @@
# variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
else:
@torch.compile()
def fn(x, y):
z = x + y
return z + 2
import sys


inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
def env_setup():
"""Set up the environment to run the example. Exit cleanly if CUDA is not available."""
if not torch.cuda.is_available():
print("CUDA is not available. Exiting.")
sys.exit(0)

if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
sys.exit(0)


# print separator and reset dynamo
# between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()
def separator(name):
"""Print a separator and reset dynamo between each example"""
print(f"\n{'='*20} {name} {'='*20}")
torch._dynamo.reset()


separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)
def run_debugging_suite():
"""Run the complete debugging suite with all logging options"""
env_setup()

separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)

separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)
@torch.compile()
def fn(x, y):
z = x + y
return z + 2

separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)
inputs = (
torch.ones(2, 2, device="cuda"),
torch.zeros(2, 2, device="cuda")
)

logging_scenarios = [
# View dynamo tracing; TORCH_LOGS="+dynamo"
("Dynamo Tracing", {"dynamo": logging.DEBUG}),

# View traced graph; TORCH_LOGS="graph"
("Traced Graph", {"graph": True}),

# View fusion decisions; TORCH_LOGS="fusion"
("Fusion Decisions", {"fusion": True}),

# View output code generated by inductor; TORCH_LOGS="output_code"
("Output Code", {"output_code": True})
]

for name, log_config in logging_scenarios:
separator(name)
torch._logging.set_logs(**log_config)
try:
result = fn(*inputs)
print(f"Function output shape: {result.shape}")
except Exception as e:
print(f"Error during {name}: {str(e)}")


run_debugging_suite()

separator("")
######################################################################
# Using ``TORCH_TRACE/tlparse`` to produce produce compilation reports (for PyTorch 2)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In this section, we introduce ``TORCH_TRACE`` and ``tlparse`` to produce reports.
#
#
# 1. Generate the raw trace logs by running the following command:
#
# .. code-block:: bash
#
# TORCH_TRACE="/tmp/tracedir" python script.py`
#
# Ensure you replace ``/tmp/tracedir`` with the path to the directory where you want
# to store the trace logs and replace the script with the name of your script.
#
# 2. Install ``tlparse`` by running:
#
# .. code-block:: bash
#
# pip install tlparse
#
# 3. Pass the trace log to ``tlparse`` to generate compilation reports:
#
# .. code-block: bash
#
# tlparse /tmp/tracedir
#
# This will open your browser with the HTML-like generated above.
svekars marked this conversation as resolved.
Show resolved Hide resolved
#
# By default, reports generated by ``tlparse`` are stored in the ``tl_out`` directory.
# You can change that by running:
#
# .. code-block:: bash
#
# tlparse /tmp/tracedir -o output_dir/

######################################################################
# Conclusion
Expand Down