Skip to content

Commit

Permalink
[pjrt] Removed unused prefer_to_retain_reference argument from Record…
Browse files Browse the repository at this point in the history
…Usage

It was always set to false by the callers.

PiperOrigin-RevId: 713277020
  • Loading branch information
superbobry authored and Google-ML-Automation committed Jan 8, 2025
1 parent b89b28f commit 0ffff6c
Showing 1 changed file with 8 additions and 59 deletions.
67 changes: 8 additions & 59 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,32 +347,11 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
// after the usage of device_buffer was enqueued.
// usage_stream: the stream the operation using device_buffer
// was enqueued on.
// prefer_to_retain_reference: relevant only for the compute synchronous
// allocation model. If true, retain a reference
// to device_buffer until after the operation
// completes. If false then the compute stream
// will have to be synchronized past event before
// device_buffer can be freed.
//
// prefer_to_retain_reference encodes a heuristic set by the caller for the
// compute synchronous model:
//
// Generally when a buffer is the destination of a copy to a device, it will
// subsequently be used on the device's compute stream before being freed. In
// that case, there is no need to retain a reference to the buffer. If the
// buffer is freed before being used on the compute stream, the free will be
// delayed until the host knows that event has completed, but this is expected
// to be uncommon.
//
// When a buffer is the source of a copy from a device, we need to either retain
// a reference to the buffer until the copy completes or serialize the compute
// stream behind the copy. It is often better to retain a reference since while
// that keeps memory alive longer, it avoids stalling the compute stream.
void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
LocalDeviceState* buffer_local_device,
LocalDeviceState* stream_local_device,
std::shared_ptr<BufferSequencingEvent> event,
se::Stream* usage_stream, bool prefer_to_retain_reference,
se::Stream* usage_stream,
std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
buffers_to_release = nullptr) {
tsl::profiler::TraceMe traceme("RecordUsage");
Expand All @@ -382,11 +361,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
(stream_local_device != buffer_local_device) ||
// In the synchronous allocation model, always retain a reference.
(stream_local_device->allocation_model() ==
LocalDeviceState::kSynchronous) ||
// In the compute synchronous model, use the caller's heuristic.
(stream_local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized &&
prefer_to_retain_reference);
LocalDeviceState::kSynchronous);
if (retain_buffer_until_completion) {
if (buffers_to_release) {
buffers_to_release->push_back(device_buffer.buffer());
Expand Down Expand Up @@ -415,15 +390,8 @@ absl::Status AddDestinationBufferSynchronization(
}
definition_event->SetSequencingEvent(std::move(event_or).value(),
copy_stream);
// prefer_to_retain_reference=false means don't retain a memory reference
// until the transfer is complete when using the ComputeSynchronized
// allocation model. This is a heuristic because in the common case
// destination buffers will be used on the compute stream and therefore don't
// require any synchronization before being freed. If the buffer is allocated
// and never used, the free will take longer and this is assumed to be ok.
RecordUsage(std::move(device_buffer), local_device, local_device,
definition_event, copy_stream,
/*prefer_to_retain_reference=*/false);
definition_event, copy_stream);
return absl::OkStatus();
}

Expand Down Expand Up @@ -583,16 +551,9 @@ AllocateDestinationBuffer(

if (on_device_shape.IsTuple()) {
// Add a usage hold for the tuple table write and immediately convert it to
// the appropriate form of synchronization. prefer_to_retain_reference=false
// means don't retain a memory reference until the transfer is complete when
// using the ComputeSynchronized allocation model. This is a heuristic
// because in the common case destination buffers will be used on the
// compute stream and therefore don't require any synchronization before
// being freed. If the buffer is allocated and never used, the free will
// take longer and this is assumed to be ok.
// the appropriate form of synchronization.
RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_events.back(), tuple_table_stream,
/*prefer_to_retain_reference=*/false);
definition_events.back(), tuple_table_stream);
}

return py_buffer;
Expand Down Expand Up @@ -1954,8 +1915,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
std::move(async_copy_to_device));

RecordUsage(std::move(dst_device_buffer), transfer_local_device,
transfer_local_device, copy_event, transfer_stream,
/*prefer_to_retain_reference=*/false);
transfer_local_device, copy_event, transfer_stream);

return std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>(
Expand Down Expand Up @@ -2039,12 +1999,6 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace(
std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;

// prefer_to_retain_reference=*/true means that, when using the
// ComputeSynchronized allocation model, retain a reference to the
// src_device_buffer until the copy completes. This is a heuristic; the
// alternative is to ensure, before freeing the buffer, that the compute
// stream is synchronized past the transfer, but it seems better to hold onto
// the buffer too long than to stall the compute stream.
src_device_buffer.ConvertUsageHold(transfer_stream, event,
/*reference_held=*/true);

Expand Down Expand Up @@ -2340,7 +2294,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
memory_space);
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_event, local_device->compute_stream(),
/*prefer_to_retain_reference=*/false, &buffers_to_release);
&buffers_to_release);
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
}

Expand Down Expand Up @@ -3118,14 +3072,9 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
buffers_to_release));

for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
// prefer_to_retain_reference=false because when using the
// ComputeSynchronized allocation model we don't need to retain a reference
// to the device_buffer during execution because by definition the compute
// stream is synchronized past the execution.
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
RecordUsage(std::move(b), device_state, device_state, definition_event,
stream,
/*prefer_to_retain_reference=*/false, &buffers_to_release);
stream, &buffers_to_release);
} else {
CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
b.ConfirmDonation();
Expand Down

0 comments on commit 0ffff6c

Please sign in to comment.