Skip to content

Commit

Permalink
do record_data_in_stream step in copy_data_to_device (#956)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #956

Reviewed By: galrotem

Differential Revision: D67719965

fbshipit-source-id: 71dde3aaf42f70bd6fa79ce5634f4ccea3d4e6e2
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 31, 2024
1 parent c8a8e76 commit de119c5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
36 changes: 25 additions & 11 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torchtnt.framework.state import ActivePhase, EntryPoint, State
from torchtnt.framework.unit import EvalUnit, PredictUnit, TPredictData, TrainUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.device import copy_data_to_device, record_data_in_stream
from torchtnt.utils.device import copy_data_to_device
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.precision import (
Expand Down Expand Up @@ -191,6 +191,12 @@ def __init__(
enabled=self.precision is not None,
)

# main stream responsible for computation on the device
self._default_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.current_stream()
if (self.device.type == "cuda" and enable_prefetch)
else None
)
# cuda stream to use for moving data to device
self._prefetch_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.Stream()
Expand All @@ -215,7 +221,10 @@ def __init__(
self._enable_prefetch = enable_prefetch

def move_data_to_device(
self, state: State, data: TData, non_blocking: bool
self,
state: State,
data: TData,
non_blocking: bool,
) -> TData:
"""
The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``/``predict_step``.
Expand All @@ -230,8 +239,18 @@ def move_data_to_device(
Returns:
A batch of data which is on the device
Note:
If overriding, ensure that tensors are recorded on the compute stream to avoid the cuda cache allocator from
overwriting the underlying data before the compute stream has a chance to use it. If using `copy_data_to_device`,
you can pass `stream_to_record=self._default_stream` as an argument.
"""
return copy_data_to_device(data, self.device, non_blocking=non_blocking)
return copy_data_to_device(
data,
self.device,
non_blocking=non_blocking,
stream_to_record=self._default_stream,
)

def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None:
"""Prefetch the next batch on a separate CUDA stream."""
Expand All @@ -256,7 +275,9 @@ def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None
state, f"{self.__class__.__name__}.{phase}.move_data_to_device"
):
self._phase_to_next_batch[active_phase] = self.move_data_to_device(
state, next_batch, non_blocking=non_blocking
state,
next_batch,
non_blocking=non_blocking,
)

def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
Expand All @@ -281,13 +302,6 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
self._is_last_batch = False
raise StopIteration

if self._prefetch_stream:
with get_timing_context(
state, f"{self.__class__.__name__}.record_data_in_stream"
):
# record the batch in the current stream
record_data_in_stream(batch, torch.cuda.current_stream())

# prefetch the next batch
self._prefetch_next_batch(state, data)

Expand Down
17 changes: 14 additions & 3 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import subprocess
from collections import defaultdict
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Mapping, TypeVar
from typing import Any, Dict, Mapping, Optional, TypeVar

import torch
from typing_extensions import Protocol, runtime_checkable, TypedDict
Expand Down Expand Up @@ -56,12 +56,20 @@ def _is_named_tuple(x: T) -> bool:
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")


def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any) -> T:
def copy_data_to_device(
data: T,
device: torch.device,
stream_to_record: Optional[torch.cuda.Stream] = None,
*args: Any,
**kwargs: Any,
) -> T:
"""Function that recursively copies data to a torch.device.
Args:
data: The data to copy to device
device: The device to which the data should be copied
stream_to_record: The CUDA stream to which the data should be recorded. Useful if this function is called
on side stream, and the data is expected to be used on the main stream.
args: positional arguments that will be passed to the `to` call
kwargs: keyword arguments that will be passed to the `to` call
Expand Down Expand Up @@ -116,7 +124,10 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
return new_data_class
elif hasattr(data, "to"):
# pyre-ignore Undefined attribute [16]: `Variable[T]` has no attribute `to`
return data.to(device, *args, **kwargs)
gpu_data = data.to(device, *args, **kwargs)
if stream_to_record is not None and hasattr(gpu_data, "record_stream"):
gpu_data.record_stream(stream_to_record)
return gpu_data

return data

Expand Down

0 comments on commit de119c5

Please sign in to comment.