PGTransport in-place transfers #118
Labels
checkpoint
related to checkpointing/recovery/healing
process_group
related to ProcessGroups and collectives
python
Currently PGTransport will allocate new tensors and copy them to CPU -- this is memory inefficient and slow as we have to limit amount of tensors transferred at once and do a GPU->CPU->GPU copy. A better solution would be to directly transfer into the GPU tensors inplace
This requires matching the tensors between the local state_dict and the remote state_dict. This is a bit tricky to do in the general case of arbitrary Python objects but should be fine with dictionaries. I'm not sure if the ordering of PyTree is guaranteed or if we need to have some custom mapping logic
Relevant code: https://github.com/pytorch/torchft/blob/main/torchft/checkpointing/pg_transport.py
We should also add a benchmark to test this with PGNCCL
The text was updated successfully, but these errors were encountered: