diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 622a21e49d..eca7269d0e 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -33,7 +33,7 @@ from torchtnt.framework.state import ActivePhase, EntryPoint, State from torchtnt.framework.unit import EvalUnit, PredictUnit, TPredictData, TrainUnit from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.device import copy_data_to_device, record_data_in_stream +from torchtnt.utils.device import copy_data_to_device from torchtnt.utils.env import init_from_env from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.precision import ( @@ -191,6 +191,12 @@ def __init__( enabled=self.precision is not None, ) + # main stream responsible for computation on the device + self._default_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.current_stream() + if (self.device.type == "cuda" and enable_prefetch) + else None + ) # cuda stream to use for moving data to device self._prefetch_stream: Optional[torch.cuda.streams.Stream] = ( torch.cuda.Stream() @@ -215,7 +221,10 @@ def __init__( self._enable_prefetch = enable_prefetch def move_data_to_device( - self, state: State, data: TData, non_blocking: bool + self, + state: State, + data: TData, + non_blocking: bool, ) -> TData: """ The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``/``predict_step``. @@ -230,8 +239,18 @@ def move_data_to_device( Returns: A batch of data which is on the device + + Note: + If overriding, ensure that tensors are recorded on the compute stream to avoid the cuda cache allocator from + overwriting the underlying data before the compute stream has a chance to use it. If using `copy_data_to_device`, + you can pass `stream_to_record=self._default_stream` as an argument. """ - return copy_data_to_device(data, self.device, non_blocking=non_blocking) + return copy_data_to_device( + data, + self.device, + non_blocking=non_blocking, + stream_to_record=self._default_stream, + ) def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None: """Prefetch the next batch on a separate CUDA stream.""" @@ -256,7 +275,9 @@ def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None state, f"{self.__class__.__name__}.{phase}.move_data_to_device" ): self._phase_to_next_batch[active_phase] = self.move_data_to_device( - state, next_batch, non_blocking=non_blocking + state, + next_batch, + non_blocking=non_blocking, ) def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData: @@ -281,13 +302,6 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData: self._is_last_batch = False raise StopIteration - if self._prefetch_stream: - with get_timing_context( - state, f"{self.__class__.__name__}.record_data_in_stream" - ): - # record the batch in the current stream - record_data_in_stream(batch, torch.cuda.current_stream()) - # prefetch the next batch self._prefetch_next_batch(state, data) diff --git a/torchtnt/utils/device.py b/torchtnt/utils/device.py index 282e8d7ce4..77d1184fd5 100644 --- a/torchtnt/utils/device.py +++ b/torchtnt/utils/device.py @@ -13,7 +13,7 @@ import subprocess from collections import defaultdict from dataclasses import fields, is_dataclass -from typing import Any, Dict, Mapping, TypeVar +from typing import Any, Dict, Mapping, Optional, TypeVar import torch from typing_extensions import Protocol, runtime_checkable, TypedDict @@ -56,12 +56,20 @@ def _is_named_tuple(x: T) -> bool: return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") -def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any) -> T: +def copy_data_to_device( + data: T, + device: torch.device, + stream_to_record: Optional[torch.cuda.Stream] = None, + *args: Any, + **kwargs: Any, +) -> T: """Function that recursively copies data to a torch.device. Args: data: The data to copy to device device: The device to which the data should be copied + stream_to_record: The CUDA stream to which the data should be recorded. Useful if this function is called + on side stream, and the data is expected to be used on the main stream. args: positional arguments that will be passed to the `to` call kwargs: keyword arguments that will be passed to the `to` call @@ -116,7 +124,10 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any return new_data_class elif hasattr(data, "to"): # pyre-ignore Undefined attribute [16]: `Variable[T]` has no attribute `to` - return data.to(device, *args, **kwargs) + gpu_data = data.to(device, *args, **kwargs) + if stream_to_record is not None and hasattr(gpu_data, "record_stream"): + gpu_data.record_stream(stream_to_record) + return gpu_data return data