Skip to content

Commit

Permalink
Add the workaround flag in thunk
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Feb 21, 2025
1 parent bd52029 commit fa5db92
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 1 addition & 3 deletions torchprime/launcher/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,4 @@ RUN if [ "$USE_TRANSFORMERS" = "true" ]; then \
pip install -e /workspaces/torchprime/local_transformers evaluate; \
fi

# TODO(https://github.com/pytorch/xla/issues/8683): Remove the
# `--megascale_grpc_enable_xor_tracer=false` flag when libtpu is updated
ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --megascale_grpc_enable_xor_tracer=false"
ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
8 changes: 8 additions & 0 deletions torchprime/launcher/thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from datetime import datetime
from pathlib import Path

# Workaround for MegaScale crash
#
# TODO(https://github.com/pytorch/xla/issues/8683): Remove the
# `--megascale_grpc_enable_xor_tracer=false` flag when libtpu is updated
xla_flags = os.environ.get("LIBTPU_INIT_ARGS", "")
xla_flags = f"{xla_flags} --megascale_grpc_enable_xor_tracer=false"
os.environ["LIBTPU_INIT_ARGS"] = xla_flags

# Get the artifact dir from env var.
gcs_artifact_dir = os.environ["TORCHPRIME_ARTIFACT_DIR"]
assert gcs_artifact_dir.startswith(
Expand Down

0 comments on commit fa5db92

Please sign in to comment.