Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pageable/pinned tensor to cuda reliability note in pinmem tutorial #3261

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions intermediate_source/pinmem_nonblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
#
# .. _pinned_memory_async_sync:
#
# When executing a copy from a host (e.g., CPU) to a device (e.g., GPU), the CUDA toolkit offers modalities to do these
# When executing a copy from a host (such as, CPU) to a device (such as, GPU), the CUDA toolkit offers modalities to do these
# operations synchronously or asynchronously with respect to the host.
#
# In practice, when calling :meth:`~torch.Tensor.to`, PyTorch always makes a call to
Expand Down Expand Up @@ -512,12 +512,54 @@ def pin_copy_to_device_nonblocking(*tensors):
#
# Until now, we have operated under the assumption that asynchronous copies from the CPU to the GPU are safe.
# This is generally true because CUDA automatically handles synchronization to ensure that the data being accessed is
# valid at read time.
# However, this guarantee does not extend to transfers in the opposite direction, from GPU to CPU.
# Without explicit synchronization, these transfers offer no assurance that the copy will be complete at the time of
# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage:
# valid at read time __whenever the tensor is in pageable memory__.
#
# However, in other cases we cannot make the same assumption: when a tensor is placed in pinned memory, mutating the
# original copy after calling the host-to-device transfer may corrupt the data received on GPU.
# Similarly, when a transfer is achieved in the opposite direction, from GPU to CPU, or from any device that is not CPU
# or GPU to any device that is not a CUDA-handled GPU (such as, MPS), there is no guarantee that the data read on GPU is
# valid without explicit synchronization.
#
# In these scenarios, these transfers offer no assurance that the copy will be complete at the time of
# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage.
#
# Let's first demonstrate this with a pinned-memory tensor:
DELAY = 100000000
try:
i = -1
for i in range(100):
# Create a tensor in pin-memory
cpu_tensor = torch.ones(1024, 1024, pin_memory=True)
torch.cuda.synchronize()
# Send the tensor to CUDA
cuda_tensor = cpu_tensor.to("cuda", non_blocking=True)
torch.cuda._sleep(DELAY)
# Corrupt the original tensor
cpu_tensor.zero_()
assert (cuda_tensor == 1).all()
print("No test failed with non_blocking and pinned tensor")
except AssertionError:
print(f"{i}th test failed with non_blocking and pinned tensor. Skipping remaining tests")

######################################################################
# Using a pageable tensor always works:
#

i = -1
for i in range(100):
# Create a tensor in pin-memory
cpu_tensor = torch.ones(1024, 1024)
torch.cuda.synchronize()
# Send the tensor to CUDA
cuda_tensor = cpu_tensor.to("cuda", non_blocking=True)
torch.cuda._sleep(DELAY)
# Corrupt the original tensor
cpu_tensor.zero_()
assert (cuda_tensor == 1).all()
print("No test failed with non_blocking and pageable tensor")

######################################################################
# Now let's demonstrate that CUDA to CPU also fails to produce reliable outputs without synchronization:

tensor = (
torch.arange(1, 1_000_000, dtype=torch.double, device="cuda")
Expand Down Expand Up @@ -551,9 +593,8 @@ def pin_copy_to_device_nonblocking(*tensors):


######################################################################
# The same considerations apply to copies from the CPU to non-CUDA devices, such as MPS.
# Generally, asynchronous copies to a device are safe without explicit synchronization only when the target is a
# CUDA-enabled device.
# CUDA-enabled device and the original tensor is in pageable memory.
#
# In summary, copying data from CPU to GPU is safe when using ``non_blocking=True``, but for any other direction,
# ``non_blocking=True`` can still be used but the user must make sure that a device synchronization is executed before
Expand Down
Loading