Skip to content

Commit

Permalink
Add pageable/pinned tensor to cuda reliability note in pinmem tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 24, 2025
1 parent 5786e97 commit a707b7f
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions intermediate_source/pinmem_nonblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,54 @@ def pin_copy_to_device_nonblocking(*tensors):
#
# Until now, we have operated under the assumption that asynchronous copies from the CPU to the GPU are safe.
# This is generally true because CUDA automatically handles synchronization to ensure that the data being accessed is
# valid at read time.
# However, this guarantee does not extend to transfers in the opposite direction, from GPU to CPU.
# Without explicit synchronization, these transfers offer no assurance that the copy will be complete at the time of
# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage:
# valid at read time __whenever the tensor is in pageable memory__.
#
# However, in other cases we cannot make the same asusmption: when a tensor is placed in pinned memory, mutating the
# original copy after calling the host-to-device transfer may corrupt the data received on GPU.
# Similarly, when a transfer is achieved in the opposite direction, from GPU to CPU, or from any device that is not CPU
# or GPU to any device that is not a CUDA-handled GPU (e.g., MPS), there is no guarantee that the data read on GPU is
# valid without explicit synchronization.
#
# In these scenarios, these transfers offer no assurance that the copy will be complete at the time of
# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage.
#
# Let's first demonstrate this with a pinned-memory tensor:
DELAY = 100000000
try:
i = -1
for i in range(100):
# Create a tensor in pin-memory
cpu_tensor = torch.ones(1024, 1024, pin_memory=True)
torch.cuda.synchronize()
# Send the tensor to CUDA
cuda_tensor = cpu_tensor.to("cuda", non_blocking=True)
torch.cuda._sleep(DELAY)
# Corrupt the original tensor
cpu_tensor.zero_()
assert (cuda_tensor == 1).all()
print("No test failed with non_blocking and pinned tensor")
except AssertionError:
print(f"{i}th test failed with non_blocking and pinned tensor. Skipping remaining tests")

######################################################################
# Using a pageable tensor always works:
#

i = -1
for i in range(100):
# Create a tensor in pin-memory
cpu_tensor = torch.ones(1024, 1024)
torch.cuda.synchronize()
# Send the tensor to CUDA
cuda_tensor = cpu_tensor.to("cuda", non_blocking=True)
torch.cuda._sleep(DELAY)
# Corrupt the original tensor
cpu_tensor.zero_()
assert (cuda_tensor == 1).all()
print("No test failed with non_blocking and pageable tensor")

######################################################################
# Now let's demonstrate that CUDA to CPU also fails to produce reliable outputs without synchronization:

tensor = (
torch.arange(1, 1_000_000, dtype=torch.double, device="cuda")
Expand Down Expand Up @@ -551,9 +593,8 @@ def pin_copy_to_device_nonblocking(*tensors):


######################################################################
# The same considerations apply to copies from the CPU to non-CUDA devices, such as MPS.
# Generally, asynchronous copies to a device are safe without explicit synchronization only when the target is a
# CUDA-enabled device.
# CUDA-enabled device and the original tensor is in pageable memory.
#
# In summary, copying data from CPU to GPU is safe when using ``non_blocking=True``, but for any other direction,
# ``non_blocking=True`` can still be used but the user must make sure that a device synchronization is executed before
Expand Down

0 comments on commit a707b7f

Please sign in to comment.