diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 5d291ada1957c..954b6e77be3c0 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -860,6 +860,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@stablehlo//:version", diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index 439e3560a44e0..e47a7ed8fd04c 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -161,6 +161,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 594ad973003fd..5852c9a54dcc0 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,6 @@ # PJRT C API changelog +## 0.60 +* Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. ## 0.59 * Added ``PJRT_MemoryDescriptions_Extension``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 59a9216292019..36d82b0787ba4 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 59 +#define PJRT_API_MINOR 60 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -308,11 +308,14 @@ typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); typedef struct PJRT_Client PJRT_Client; typedef struct PJRT_Device PJRT_Device; typedef struct PJRT_Memory PJRT_Memory; +typedef struct PJRT_ShapeSpec PJRT_ShapeSpec; typedef struct PJRT_DeviceDescription PJRT_DeviceDescription; typedef struct PJRT_TopologyDescription PJRT_TopologyDescription; typedef struct PJRT_Executable PJRT_Executable; typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable; typedef struct PJRT_Buffer PJRT_Buffer; +typedef struct PJRT_AsyncHostToDeviceTransferManager + PJRT_AsyncHostToDeviceTransferManager; // The caller of PJRT_Client_Create can optionally provide a key-value store // accessible across nodes and/or processes. KV store access may be necessary to @@ -593,6 +596,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args, typedef PJRT_Error* PJRT_Client_DefaultDeviceAssignment( PJRT_Client_DefaultDeviceAssignment_Args* args); +struct PJRT_AsyncHostToDeviceTransferManager_Destroy_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Destroy_Args, + transfer_manager); + +// Frees `transfer_manager`. `transfer_manager` can be nullptr. +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + int buffer_index; + const void* data; + int64_t offset; + int64_t transfer_size; + bool is_last_transfer; + PJRT_Event* done_with_h2d_transfer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args, + done_with_h2d_transfer); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); + typedef enum { // Invalid primitive type to serve as default. PJRT_Buffer_Type_INVALID, @@ -820,6 +852,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer); typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( PJRT_Client_CreateViewOfDeviceBuffer_Args* args); +struct PJRT_ShapeSpec { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const int64_t* dims; + size_t num_dims; + PJRT_Buffer_Type element_type; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_ShapeSpec, element_type); + +struct PJRT_Client_CreateBuffersForAsyncHostToDevice_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_Client* client; + PJRT_ShapeSpec* shape_specs; + size_t num_shape_specs; + PJRT_Buffer_MemoryLayout** device_layouts; // optional + size_t num_device_layouts; + PJRT_Memory* memory; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateBuffersForAsyncHostToDevice_Args, + transfer_manager); +typedef PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args); + // -------------------------- Device Descriptions ------------------------------ // Device descriptions may be associated with an actual device @@ -2266,10 +2323,14 @@ typedef struct PJRT_Api { _PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Create); _PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyRawToHost); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData); + _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice); } PJRT_Api; enum { - PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_CopyRawToHost) + PJRT_Api_STRUCT_SIZE = + PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice) }; #undef _PJRT_API_STRUCT_FIELD diff --git a/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 33d7d39fca2b4..cefbce152e508 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -274,6 +274,71 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) { api_->PJRT_ExecuteContext_Destroy(&destroy_args); } +TEST_F(PjrtCApiGpuTest, CreateBuffersWithMemorytForH2DAndTransfer) { + xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::F32, /*dimensions=*/{2, 2, 2}, /*minor_to_major=*/{1, 0, 2}); + std::vector float_data = {1, 2, 3, 4, 5, 6, 7, 8}; + + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; + args.struct_size = + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.client = client_; + PJRT_ShapeSpec c_shape_spec; + c_shape_spec.element_type = + pjrt::ConvertToPjRtBufferType(xla::PrimitiveType::F32); + c_shape_spec.dims = host_shape.dimensions().data(); + c_shape_spec.num_dims = host_shape.dimensions().size(); + args.shape_specs = &c_shape_spec; + args.num_shape_specs = 1; + TF_ASSERT_OK_AND_ASSIGN(pjrt::BufferMemoryLayoutData c_layout_data, + ConvertToBufferMemoryLayoutData(host_shape.layout())); + std::vector device_layout_list(1); + device_layout_list[0] = &(c_layout_data.c_layout); + args.device_layouts = device_layout_list.data(); + args.num_device_layouts = device_layout_list.size(); + PJRT_Client_AddressableMemories_Args memory_args; + memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; + memory_args.extension_start = nullptr; + memory_args.client = client_; + + PJRT_Error* memory_error = + api_->PJRT_Client_AddressableMemories(&memory_args); + ASSERT_EQ(memory_error, nullptr); + ASSERT_NE(memory_args.addressable_memories, nullptr); + ASSERT_GT(memory_args.num_addressable_memories, 0); + args.memory = memory_args.addressable_memories[0]; + PJRT_Error* error = + api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args); + ASSERT_EQ(error, nullptr); + + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args transfer_args; + transfer_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; + transfer_args.extension_start = nullptr; + transfer_args.transfer_manager = args.transfer_manager; + transfer_args.buffer_index = 0; + transfer_args.data = float_data.data(); + transfer_args.offset = 0; + transfer_args.transfer_size = float_data.size(); + transfer_args.is_last_transfer = true; + + PJRT_Error* transfer_error = + PJRT_AsyncHostToDeviceTransferManager_TransferData(&transfer_args); + ASSERT_EQ(transfer_error, nullptr); + std::unique_ptr done_with_h2d_transfer_event( + transfer_args.done_with_h2d_transfer, MakeEventDeleter(api_)); + + // Destroy the transfer manager. + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args; + destroy_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.transfer_manager = args.transfer_manager; + LogFatalIfPjrtError( + api_->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api_); +} + absl::StatusOr BuildCreateArg( ::pjrt::PJRT_KeyValueCallbackData* kv_callback_data, std::vector& c_options) { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 857fbc3091b2e..cf92041af497d 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -74,6 +74,20 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) { }; } +PJRT_AsyncHostToDeviceTransferManagerDeleter +MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api) { + return [api]( + PJRT_AsyncHostToDeviceTransferManager* transfer_manager) -> void { + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args; + destroy_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.transfer_manager = transfer_manager; + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api); + }; +} + PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) { return [api](PJRT_Error* error) -> void { PJRT_Error_Destroy_Args destroy_args; @@ -1064,4 +1078,27 @@ PJRT_Profiler_Extension CreatePjrtProfilerExtension( return profiler_extension; } +PJRT_ShapeSpec ConvertToPjRtShapeSpec( + const xla::PjRtClient::ShapeSpec& shape_spec) { + PJRT_ShapeSpec c_shape_spec; + c_shape_spec.struct_size = PJRT_ShapeSpec_STRUCT_SIZE; + c_shape_spec.extension_start = nullptr; + c_shape_spec.element_type = + pjrt::ConvertToPjRtBufferType(shape_spec.element_type); + c_shape_spec.dims = shape_spec.dims.data(); + c_shape_spec.num_dims = shape_spec.dims.size(); + return c_shape_spec; +} + +xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec( + PJRT_ShapeSpec c_shape_spec) { + xla::PjRtClient::ShapeSpec shape_spec; + shape_spec.element_type = + pjrt::ConvertFromPjRtBufferType(c_shape_spec.element_type); + + shape_spec.dims = xla::DimensionVector( + c_shape_spec.dims, c_shape_spec.dims + c_shape_spec.num_dims); + return shape_spec; +} + } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index 759569123456e..f530b82f42357 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -66,6 +66,14 @@ using PJRT_ClientDeleter = std::function; // The lifetime of the Api pointed to must be longer than the client. PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api); +using PJRT_AsyncHostToDeviceTransferManagerDeleter = + std::function; + +// Pass in an API pointer; receive a custom deleter for smart pointers. +// The lifetime of the Api pointed to must be longer than the transfer manager. +PJRT_AsyncHostToDeviceTransferManagerDeleter +MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api); + using PJRT_ErrorDeleter = std::function; // Pass in an API pointer; receive a custom deleter for smart pointers. @@ -296,6 +304,12 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable); +PJRT_ShapeSpec ConvertToPjRtShapeSpec( + const xla::PjRtClient::ShapeSpec& shape_spec); + +xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec( + PJRT_ShapeSpec c_shape_spec); + // Creates a PJRT_Profiler_Extension and adds a producer trace with // the given name. The created PJRT_Profiler_Extension will be used in argument // structs to pass the producer traceme context id to add a corresponding diff --git a/xla/pjrt/c/pjrt_c_api_test.cc b/xla/pjrt/c/pjrt_c_api_test.cc index 57fe33eb368cf..5fb77870d55a4 100644 --- a/xla/pjrt/c/pjrt_c_api_test.cc +++ b/xla/pjrt/c/pjrt_c_api_test.cc @@ -936,6 +936,12 @@ FieldOffsetsAndSizesForVersion(int major_version, int minor_version) { if (minor_version >= 57) { add_field("PJRT_Buffer_CopyRawToHost", kFnPtrSize); } + if (minor_version >= 58) { + add_field("PJRT_AsyncHostToDeviceTransferManager_Destroy", kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_TransferData", + kFnPtrSize); + add_field("PJRT_Client_CreateBuffersForAsyncHostToDevice", kFnPtrSize); + } return version_offsets_and_sizes; } LOG(FATAL) << "Unsupported API version: " << major_version << "." @@ -1264,6 +1270,17 @@ TEST_F(PjrtCAbiTestBase, FieldOffsetsAndSizes) { {"PJRT_Buffer_CopyRawToHost", {offsetof(PJRT_Api, PJRT_Buffer_CopyRawToHost), sizeof(PJRT_Api::PJRT_Buffer_CopyRawToHost)}}, + {"PJRT_AsyncHostToDeviceTransferManager_Destroy", + {offsetof(PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_Destroy), + sizeof(PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_Destroy)}}, + {"PJRT_AsyncHostToDeviceTransferManager_TransferData", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_TransferData), + sizeof( + PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_TransferData)}}, + {"PJRT_Client_CreateBuffersForAsyncHostToDevice", + {offsetof(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice), + sizeof(PJRT_Api::PJRT_Client_CreateBuffersForAsyncHostToDevice)}}, }; ASSERT_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR); ASSERT_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 7830fed2717cb..506b153f56bf2 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -478,6 +479,67 @@ PJRT_Error* PJRT_Client_AddressableMemories( return nullptr; } +PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Client_CreateBuffersForAsyncHostToDevice_Args", + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE, + args->struct_size)); + absl::InlinedVector shape_specs; + shape_specs.reserve(args->num_shape_specs); + for (int i = 0; i < args->num_shape_specs; ++i) { + shape_specs.push_back(pjrt::ConvertFromPjrtShapeSpec(args->shape_specs[i])); + } + std::optional>> + arg_device_layouts; + if (args->num_device_layouts == 0) { + arg_device_layouts = std::nullopt; + } else { + std::vector> device_layouts; + device_layouts.reserve(args->num_device_layouts); + for (int i = 0; i < args->num_device_layouts; ++i) { + std::optional optional_layout; + if (args->device_layouts[i] != nullptr) { + xla::Layout cpp_layout; + PJRT_Buffer_MemoryLayout* layout = args->device_layouts[i]; + switch (layout->type) { + case PJRT_Buffer_MemoryLayout_Type:: + PJRT_Buffer_MemoryLayout_Type_Tiled: { + PJRT_ASSIGN_OR_RETURN(cpp_layout, ConvertToLayout(layout->tiled)); + break; + } + case PJRT_Buffer_MemoryLayout_Type:: + PJRT_Buffer_MemoryLayout_Type_Strides: { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + "PJRT_Buffer_MemoryLayout_Type_Strides is not supported to be " + "converted to a xla::Layout.")); + break; + } + default: { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("Unexpected PJRT_Buffer_MemoryLayout_Type type: ", + layout->type))); + } + } + device_layouts.push_back(cpp_layout); + } else { + device_layouts.push_back(std::nullopt); + } + } + arg_device_layouts = absl::MakeSpan(device_layouts); + } + + PJRT_ASSIGN_OR_RETURN( + std::unique_ptr + transfer_manager, + args->client->client->CreateBuffersForAsyncHostToDevice( + absl::MakeSpan(shape_specs), arg_device_layouts, + args->memory->memory_space)); + args->transfer_manager = new PJRT_AsyncHostToDeviceTransferManager{ + std::move(transfer_manager), args->client}; + return nullptr; +} + // Searches `device_list` for a PJRT_Device* that wraps a provided // `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* // is returned. Otherwise, returns nullptr. @@ -530,6 +592,36 @@ static void PopulatePjrtExecutableAddressableDevices( } } +//-------------------- AsyncHostToDeviceTransferManager --------------------- + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_Destroy_Args", + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE, + args->struct_size)); + delete args->transfer_manager; + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_TransferData_Args", + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE, + args->struct_size)); + xla::PjRtFuture<>::Promise promise = xla::PjRtFuture<>::CreatePromise(); + absl::AnyInvocable on_done_with_d2h_transfer = + [promise]() mutable { promise.Set(); }; + PJRT_RETURN_IF_ERROR( + args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer( + args->buffer_index, args->data, args->offset, args->transfer_size, + args->is_last_transfer, std::move(on_done_with_d2h_transfer))); + args->done_with_h2d_transfer = + new PJRT_Event{xla::PjRtFuture<>(std::move(promise))}; + return nullptr; +} + namespace { absl::StatusOr ParseCompileOptions( @@ -2562,6 +2654,12 @@ PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, /*PJRT_ExecuteContext_Create=*/execute_context_create_fn, /*PJRT_ExecuteContext_Destroy=*/pjrt::PJRT_ExecuteContext_Destroy, /*PJRT_Buffer_CopyRawToHost=*/pjrt::PJRT_Buffer_CopyRawToHost, + /*PJRT_AsyncHostToDeviceTransferManager_Destroy=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_Destroy, + /*PJRT_AsyncHostToDeviceTransferManager_TransferData=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_TransferData, + /*PJRT_Client_CreateBuffersForAsyncHostToDevice=*/ + pjrt::PJRT_Client_CreateBuffersForAsyncHostToDevice, }; } diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 9580a29392541..0ebecc0c25173 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -91,6 +91,14 @@ struct PJRT_MemoryDescription { xla::PjRtMemorySpaceDescription memory_space_description; }; +// PJRT_AsyncHostToDeviceTransferManager is owned by its corresponding +// PJRT_Client. +struct PJRT_AsyncHostToDeviceTransferManager { + std::unique_ptr + transfer_manager; + PJRT_Client* client; +}; + // PJRT_DeviceDescriptions are owned by their corresponding PJRT_Device. struct PJRT_DeviceDescription { // The xla::PjRtDeviceDescription* is owned transitively by the @@ -254,7 +262,12 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( PJRT_Client_BufferFromHostBuffer_Args* args); PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( PJRT_Client_CreateViewOfDeviceBuffer_Args* args); - +PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); PJRT_Error* PJRT_DeviceDescription_Id(PJRT_DeviceDescription_Id_Args* args); PJRT_Error* PJRT_DeviceDescription_ProcessIndex( PJRT_DeviceDescription_ProcessIndex_Args* args); diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index ca1066d46db6e..8855ef33620e5 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -694,6 +694,210 @@ absl::StatusOr PjRtCApiClient::GetDefaultLayout( return pjrt_xla_layout.xla_layout(); } +class PjRtCApiAsyncHostToDeviceTransferManager + : public PjRtClient::AsyncHostToDeviceTransferManager { + public: + PjRtCApiAsyncHostToDeviceTransferManager( + PjRtCApiClient* client, + PJRT_AsyncHostToDeviceTransferManager* c_transfer_manager) + : c_client_(client), c_transfer_manager_(std::move(c_transfer_manager)) {} + + size_t buffer_count() const override { + LOG(FATAL) << "PJRT C API does not support buffer_count. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + PjRtDevice* device() const override { + LOG(FATAL) << "PJRT C API does not support device. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + std::unique_ptr RetrieveBuffer(int buffer_index) override { + LOG(FATAL) << "PJRT C API does not support RetrieveBuffer. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + absl::Status TransferLiteralToBuffer( + int buffer_index, const LiteralSlice& literal, + absl::AnyInvocable on_done) override { + return Unimplemented( + "PJRT C API does not support TransferLiteralToBuffer. Please report an " + "issue at https://github.com/google/jax/issues if you need this " + "feature."); + } + + size_t buffer_size(int buffer_index) const override { + LOG(FATAL) + << "PJRT C API does not support buffer_size. Please report an " + "issue at https://github.com/google/jax/issues if you need this " + "feature."; + } + + absl::Status TransferRawDataToBuffer( + int buffer_index, absl::string_view data, + absl::AnyInvocable on_done) override { + return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(), + /*is_last_transfer=*/true, + std::move(on_done)); + } + + absl::Status TransferRawDataToSubBuffer( + int buffer_index, const void* data, int64_t offset, int64_t transfer_size, + bool is_last_transfer, absl::AnyInvocable on_done) override { + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + args.buffer_index = buffer_index; + args.data = data; + args.offset = offset; + args.transfer_size = transfer_size; + args.is_last_transfer = is_last_transfer; + const PJRT_Api* api = c_client_->pjrt_c_api(); + RETURN_STATUS_IF_PJRT_ERROR( + api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api); + std::unique_ptr event( + args.done_with_h2d_transfer, ::pjrt::MakeEventDeleter(api)); + if (on_done) { + PJRT_Event_OnReady_Args event_args; + event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; + event_args.extension_start = nullptr; + event_args.event = event.get(); + event_args.user_arg = new absl::AnyInvocable( + [on_done = std::move(on_done), + c_api = api](PJRT_Error* error) mutable { + if (error) { + ::pjrt::MakeErrorDeleter(c_api)(error); + } + std::move(on_done)(); + }); + event_args.callback = [](PJRT_Error* error, void* args) { + auto* on_done_with_d2h_transfer = + reinterpret_cast*>(args); + (*on_done_with_d2h_transfer)(error); + delete on_done_with_d2h_transfer; + }; + + RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Event_OnReady(&event_args), api); + } + return absl::OkStatus(); + } + + void SetBufferError(int buffer_index, absl::Status error) override { + LOG(FATAL) << "PJRT C API does not support SetBufferError. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + using TransferMetadata = absl::flat_hash_map; + void AddTransferMetadata(const TransferMetadata& metadata) override { + LOG(FATAL) << "PJRT C API does not support AddTransferMetadata. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + private: + PjRtCApiClient* c_client_; + std::unique_ptr + c_transfer_manager_; +}; + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtMemorySpace* memory_space) { + const PJRT_Api* c_api = pjrt_c_api(); + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; + args.struct_size = + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.client = c_client_.get(); + args.num_shape_specs = shape_specs.size(); + args.shape_specs = new PJRT_ShapeSpec[shape_specs.size()]; + absl::Cleanup cleanup = + absl::MakeCleanup([&args] { delete[] args.shape_specs; }); + const ShapeSpec* iterator = shape_specs.begin(); + for (int i = 0; i < shape_specs.size(); ++i) { + args.shape_specs[i] = pjrt::ConvertToPjRtShapeSpec(*(iterator++)); + } + if (device_layouts.has_value()) { + args.num_device_layouts = device_layouts->size(); + auto device_layout_list = + std::make_unique>( + device_layouts->size()); + for (int i = 0; i < device_layouts->size(); ++i) { + if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { + const Layout& layout = (*device_layouts)[i].value(); + TF_ASSIGN_OR_RETURN(pjrt::BufferMemoryLayoutData c_layout_data, + pjrt::ConvertToBufferMemoryLayoutData(layout)); + device_layout_list->emplace_back(&(c_layout_data.c_layout)); + } else { + device_layout_list->emplace_back(nullptr); + } + } + args.device_layouts = device_layout_list->data(); + } else { + args.num_device_layouts = 0; + args.device_layouts = nullptr; + } + args.memory = + tensorflow::down_cast(memory_space)->c_memory(); + + RETURN_STATUS_IF_PJRT_ERROR( + c_api->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args), c_api); + return std::make_unique( + this, args.transfer_manager); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtDevice* device) { + TF_ASSIGN_OR_RETURN(auto memory_space, device->default_memory_space()); + return CreateBuffersForAsyncHostToDevice(shape_specs, device_layouts, + memory_space); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtDevice* device) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, device); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, memory_space); +} + const PJRT_Api* PjRtCApiClient::pjrt_c_api() const { return c_api_; } // --------------------------------- Devices ----------------------------------- diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 3897b02342716..27fc17799a075 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -318,23 +318,25 @@ class PjRtCApiClient : public PjRtClient { absl::StatusOr GetTopologyDescription() const override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtDevice* device) override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtMemorySpace* memory_space) override; + + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override { - return Unimplemented( - "PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } + PjRtDevice* device) override; absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override { - return Unimplemented( - "PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } + PjRtMemorySpace* memory_space) override; absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, diff --git a/xla/pjrt/pjrt_c_api_client_test.cc b/xla/pjrt/pjrt_c_api_client_test.cc index 033dbeb130fc8..8749f0778c85c 100644 --- a/xla/pjrt/pjrt_c_api_client_test.cc +++ b/xla/pjrt/pjrt_c_api_client_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -141,6 +142,19 @@ TEST(PjRtCApiClientTest, NonEmptyExecutableFingerprint) { } } +TEST(PjRtCApiClientTest, CreateBuffersForAsyncHostToDeviceWithShape) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::PrimitiveType::F32, /*dimensions=*/{2, 2, 2}, + /*minor_to_major=*/{1, 0, 2}); + std::vector host_shapes = {host_shape}; + auto status_or_transfer_manager = client->CreateBuffersForAsyncHostToDevice( + absl::MakeSpan(host_shapes), client->addressable_devices()[0]); + EXPECT_FALSE(status_or_transfer_manager.ok()); +} + TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { SetUpCpuPjRtApi(); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client,