Skip to content

PyTorch/XLA 2.3 Release Notes

Compare
Choose a tag to compare
@lsy323 lsy323 released this 08 Apr 21:31
· 919 commits to master since this release
6f93cc1

Highlights

We are excited to announce the release of PyTorch XLA 2.3! PyTorch 2.3 offers experimental support for SPMD Auto Sharding on single TPU host, this allows user to shard their models on TPU with a single config change. We also add the experimental support for Pallas custom kernel for inference, which enables users to make use of the popular custom kernel like flash attention and paged attention on TPU.

Stable Features

PJRT

  • Experimental GPU PJRT Plugin (#6240)
  • Define PJRT plugin interface in C++ (#6360)
  • Add limit to max inflight TPU computations (#6533)
  • Remove TPU_C_API device type (#6435)

GSPMD

  • Introduce global mesh (#6498)
  • Introduce xla_distribute_module for DTensor integration (#6683)

Torch Compile

  • Support activation sharding within torch.compile (#6524)
  • Do not cache FX input args in dynamo bridge to avoid memory leak (#6553)
  • Ignore non-XLA nodes and their direct dependents. (#6170)

Export

  • Support of implicit broadcasting with unbounded dynamism (#6219)
  • Support multiple StableHLO Composite outputs (#6295)
  • Add support of dynamism for add (#6443)
  • Enable unbounded dynamism on conv, softmax, addmm, slice (#6494)
  • Handle constant variable (#6510)

Beta Features

CoreAtenOpSet

Support all Core Aten Ops used by torch.export

  • Lower reflection_pad1d, reflection_pad1d_backward, reflection_pad3d and reflection_pad3d_backward (#6588)
  • lower replication_pad3d and replication_pad3d_backward (#6566)
  • Lower the embedding op (#6495)
  • Lowering for _pdist_forward (#6507)
  • Support mixed precision for torch.where (#6303)

Benchmark

FSDP via SPMD

  • Make FSDPv2 to use the global mesh API (#6500)
  • Enable auto-wrapping(#6499)

Distributed Checkpoint

  • Add process group documentation for SPMD (#6469)

Usability

GPU

  • Fix global_device_count(), local_device_count() for single process on CUDA(#6022)
  • Automatically use XLA:GPU if on a GPU machine (#6605)
  • Add SPMD on GPU instructions (#6684)
  • Build XLA:GPU as a separate Plugin (#6825)

Distributed

  • Support tensor bucketing for all-gather and reduce-scatter for ZeRO1 (#6025)

Experimental Features

Pallas

  • Introduce Flash Attention kernel using Pallas (#6827)
  • Support Flash Attention kernel with casual mask (#6837)
  • Support Flash Attention kernel with torch.compile (#6875)
  • Support Pallas kernel (#6340)
  • Support programmatically extracting the payload from Pallas kernel (#6696)
  • Support Pallas kernel with torch.compile (#6477)
  • Introduce helper to convert Pallas kernel to PyTorch/XLA callable (#6713)

GSPMD Auto-Sharding

  • Support auto-sharding for single host TPU (#6719)
  • Auto construct auto-sharding mesh ids (#6770)

Input Output Aliasing

  • Support torch.compile for dynamo_set_buffer_donor
  • Use XLA’s new API to alias graph input and output (#6855)

While Loop

  • Support torch._higher_order_ops.while_loop with simple examples (#6532, #6603)

Bug Fixes and Improvements

  • Propagates requires_grad over to AllReduce output (#6326)
  • Avoid fallback for avg_pool (#6409)
  • Fix output tensor shape for argmin and argmax where keepdim=True and dim=None (#6536)
  • Fix preserve_rng_state for activation checkpointing (#4690)
  • Allow int data-type for Embedding indices (#6718)
  • Don't terminate the whole process when Compile fails (#6707)
  • Fix a incorrect assert on frame count for PT_XLA_DEBUG=1 (#6466)
  • Refactor nms into TorchVision variant.(#6814)