PyTorch/XLA 2.3 Release Notes
Highlights
We are excited to announce the release of PyTorch XLA 2.3! PyTorch 2.3 offers experimental support for SPMD Auto Sharding on single TPU host, this allows user to shard their models on TPU with a single config change. We also add the experimental support for Pallas custom kernel for inference, which enables users to make use of the popular custom kernel like flash attention and paged attention on TPU.
Stable Features
PJRT
- Experimental GPU PJRT Plugin (#6240)
- Define PJRT plugin interface in C++ (#6360)
- Add limit to max inflight TPU computations (#6533)
- Remove TPU_C_API device type (#6435)
GSPMD
Torch Compile
- Support activation sharding within torch.compile (#6524)
- Do not cache FX input args in dynamo bridge to avoid memory leak (#6553)
- Ignore non-XLA nodes and their direct dependents. (#6170)
Export
- Support of implicit broadcasting with unbounded dynamism (#6219)
- Support multiple StableHLO Composite outputs (#6295)
- Add support of dynamism for add (#6443)
- Enable unbounded dynamism on conv, softmax, addmm, slice (#6494)
- Handle constant variable (#6510)
Beta Features
CoreAtenOpSet
Support all Core Aten Ops used by torch.export
- Lower reflection_pad1d, reflection_pad1d_backward, reflection_pad3d and reflection_pad3d_backward (#6588)
- lower replication_pad3d and replication_pad3d_backward (#6566)
- Lower the embedding op (#6495)
- Lowering for _pdist_forward (#6507)
- Support mixed precision for torch.where (#6303)
Benchmark
- Unify PyTorch/XLA and Pytorch torchbench model configuration using the same torchbench.yaml (#6881)
- Align model data precision settings with pytorch HUD (#6447, #6518, #6555)
- Fix some torchbench models configuration to make it runnable using XLA (#6509, #6542, #6558, #6612).
FSDP via SPMD
Distributed Checkpoint
Usability
GPU
- Fix global_device_count(), local_device_count() for single process on CUDA(#6022)
- Automatically use XLA:GPU if on a GPU machine (#6605)
- Add SPMD on GPU instructions (#6684)
- Build XLA:GPU as a separate Plugin (#6825)
Distributed
Experimental Features
Pallas
- Introduce Flash Attention kernel using Pallas (#6827)
- Support Flash Attention kernel with casual mask (#6837)
- Support Flash Attention kernel with
torch.compile
(#6875) - Support Pallas kernel (#6340)
- Support programmatically extracting the payload from Pallas kernel (#6696)
- Support Pallas kernel with
torch.compile
(#6477) - Introduce helper to convert Pallas kernel to PyTorch/XLA callable (#6713)
GSPMD Auto-Sharding
Input Output Aliasing
- Support torch.compile for
dynamo_set_buffer_donor
- Use XLA’s new API to alias graph input and output (#6855)
While Loop
Bug Fixes and Improvements
- Propagates requires_grad over to AllReduce output (#6326)
- Avoid fallback for avg_pool (#6409)
- Fix output tensor shape for argmin and argmax where keepdim=True and dim=None (#6536)
- Fix preserve_rng_state for activation checkpointing (#4690)
- Allow int data-type for Embedding indices (#6718)
- Don't terminate the whole process when Compile fails (#6707)
- Fix a incorrect assert on frame count for PT_XLA_DEBUG=1 (#6466)
- Refactor nms into TorchVision variant.(#6814)