Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT:C] Introduce CreateBuffersForAsyncHostToDevice and TransferRawDataToSubBuffer to PJRT C API. #20554

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ xla_cc_test(
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@stablehlo//:version",
Expand Down
1 change: 1 addition & 0 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# PJRT C API changelog
## 0.60
* Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``.

## 0.59
* Added ``PJRT_MemoryDescriptions_Extension``.
Expand Down
65 changes: 63 additions & 2 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 59
#define PJRT_API_MINOR 60

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -308,11 +308,14 @@ typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args);
typedef struct PJRT_Client PJRT_Client;
typedef struct PJRT_Device PJRT_Device;
typedef struct PJRT_Memory PJRT_Memory;
typedef struct PJRT_ShapeSpec PJRT_ShapeSpec;
typedef struct PJRT_DeviceDescription PJRT_DeviceDescription;
typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
typedef struct PJRT_Executable PJRT_Executable;
typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable;
typedef struct PJRT_Buffer PJRT_Buffer;
typedef struct PJRT_AsyncHostToDeviceTransferManager
PJRT_AsyncHostToDeviceTransferManager;

// The caller of PJRT_Client_Create can optionally provide a key-value store
// accessible across nodes and/or processes. KV store access may be necessary to
Expand Down Expand Up @@ -593,6 +596,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args,
typedef PJRT_Error* PJRT_Client_DefaultDeviceAssignment(
PJRT_Client_DefaultDeviceAssignment_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_Destroy_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Destroy_Args,
transfer_manager);

// Frees `transfer_manager`. `transfer_manager` can be nullptr.
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy(
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
int buffer_index;
const void* data;
int64_t offset;
int64_t transfer_size;
bool is_last_transfer;
PJRT_Event* done_with_h2d_transfer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args,
done_with_h2d_transfer);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args);

typedef enum {
// Invalid primitive type to serve as default.
PJRT_Buffer_Type_INVALID,
Expand Down Expand Up @@ -820,6 +852,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer);
typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer(
PJRT_Client_CreateViewOfDeviceBuffer_Args* args);

struct PJRT_ShapeSpec {
size_t struct_size;
PJRT_Extension_Base* extension_start;
const int64_t* dims;
size_t num_dims;
PJRT_Buffer_Type element_type;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_ShapeSpec, element_type);

struct PJRT_Client_CreateBuffersForAsyncHostToDevice_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_Client* client;
PJRT_ShapeSpec* shape_specs;
size_t num_shape_specs;
PJRT_Buffer_MemoryLayout** device_layouts; // optional
size_t num_device_layouts;
PJRT_Memory* memory;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateBuffersForAsyncHostToDevice_Args,
transfer_manager);
typedef PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice(
PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args);

// -------------------------- Device Descriptions ------------------------------

// Device descriptions may be associated with an actual device
Expand Down Expand Up @@ -2266,10 +2323,14 @@ typedef struct PJRT_Api {
_PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Create);
_PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Destroy);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyRawToHost);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData);
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice);
} PJRT_Api;

enum {
PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_CopyRawToHost)
PJRT_Api_STRUCT_SIZE =
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice)
};

#undef _PJRT_API_STRUCT_FIELD
Expand Down
65 changes: 65 additions & 0 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,71 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) {
api_->PJRT_ExecuteContext_Destroy(&destroy_args);
}

TEST_F(PjrtCApiGpuTest, CreateBuffersWithMemorytForH2DAndTransfer) {
xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
xla::F32, /*dimensions=*/{2, 2, 2}, /*minor_to_major=*/{1, 0, 2});
std::vector<float> float_data = {1, 2, 3, 4, 5, 6, 7, 8};

PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args;
args.struct_size =
PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.client = client_;
PJRT_ShapeSpec c_shape_spec;
c_shape_spec.element_type =
pjrt::ConvertToPjRtBufferType(xla::PrimitiveType::F32);
c_shape_spec.dims = host_shape.dimensions().data();
c_shape_spec.num_dims = host_shape.dimensions().size();
args.shape_specs = &c_shape_spec;
args.num_shape_specs = 1;
TF_ASSERT_OK_AND_ASSIGN(pjrt::BufferMemoryLayoutData c_layout_data,
ConvertToBufferMemoryLayoutData(host_shape.layout()));
std::vector<PJRT_Buffer_MemoryLayout*> device_layout_list(1);
device_layout_list[0] = &(c_layout_data.c_layout);
args.device_layouts = device_layout_list.data();
args.num_device_layouts = device_layout_list.size();
PJRT_Client_AddressableMemories_Args memory_args;
memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE;
memory_args.extension_start = nullptr;
memory_args.client = client_;

PJRT_Error* memory_error =
api_->PJRT_Client_AddressableMemories(&memory_args);
ASSERT_EQ(memory_error, nullptr);
ASSERT_NE(memory_args.addressable_memories, nullptr);
ASSERT_GT(memory_args.num_addressable_memories, 0);
args.memory = memory_args.addressable_memories[0];
PJRT_Error* error =
api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args);
ASSERT_EQ(error, nullptr);

PJRT_AsyncHostToDeviceTransferManager_TransferData_Args transfer_args;
transfer_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE;
transfer_args.extension_start = nullptr;
transfer_args.transfer_manager = args.transfer_manager;
transfer_args.buffer_index = 0;
transfer_args.data = float_data.data();
transfer_args.offset = 0;
transfer_args.transfer_size = float_data.size();
transfer_args.is_last_transfer = true;

PJRT_Error* transfer_error =
PJRT_AsyncHostToDeviceTransferManager_TransferData(&transfer_args);
ASSERT_EQ(transfer_error, nullptr);
std::unique_ptr<PJRT_Event, PJRT_EventDeleter> done_with_h2d_transfer_event(
transfer_args.done_with_h2d_transfer, MakeEventDeleter(api_));

// Destroy the transfer manager.
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args;
destroy_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE;
destroy_args.extension_start = nullptr;
destroy_args.transfer_manager = args.transfer_manager;
LogFatalIfPjrtError(
api_->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api_);
}

absl::StatusOr<PJRT_Client_Create_Args> BuildCreateArg(
::pjrt::PJRT_KeyValueCallbackData* kv_callback_data,
std::vector<PJRT_NamedValue>& c_options) {
Expand Down
37 changes: 37 additions & 0 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) {
};
}

PJRT_AsyncHostToDeviceTransferManagerDeleter
MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api) {
return [api](
PJRT_AsyncHostToDeviceTransferManager* transfer_manager) -> void {
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args;
destroy_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE;
destroy_args.extension_start = nullptr;
destroy_args.transfer_manager = transfer_manager;
pjrt::LogFatalIfPjrtError(
api->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api);
};
}

PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) {
return [api](PJRT_Error* error) -> void {
PJRT_Error_Destroy_Args destroy_args;
Expand Down Expand Up @@ -1064,4 +1078,27 @@ PJRT_Profiler_Extension CreatePjrtProfilerExtension(
return profiler_extension;
}

PJRT_ShapeSpec ConvertToPjRtShapeSpec(
const xla::PjRtClient::ShapeSpec& shape_spec) {
PJRT_ShapeSpec c_shape_spec;
c_shape_spec.struct_size = PJRT_ShapeSpec_STRUCT_SIZE;
c_shape_spec.extension_start = nullptr;
c_shape_spec.element_type =
pjrt::ConvertToPjRtBufferType(shape_spec.element_type);
c_shape_spec.dims = shape_spec.dims.data();
c_shape_spec.num_dims = shape_spec.dims.size();
return c_shape_spec;
}

xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec(
PJRT_ShapeSpec c_shape_spec) {
xla::PjRtClient::ShapeSpec shape_spec;
shape_spec.element_type =
pjrt::ConvertFromPjRtBufferType(c_shape_spec.element_type);

shape_spec.dims = xla::DimensionVector(
c_shape_spec.dims, c_shape_spec.dims + c_shape_spec.num_dims);
return shape_spec;
}

} // namespace pjrt
14 changes: 14 additions & 0 deletions xla/pjrt/c/pjrt_c_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ using PJRT_ClientDeleter = std::function<void(PJRT_Client*)>;
// The lifetime of the Api pointed to must be longer than the client.
PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api);

using PJRT_AsyncHostToDeviceTransferManagerDeleter =
std::function<void(PJRT_AsyncHostToDeviceTransferManager*)>;

// Pass in an API pointer; receive a custom deleter for smart pointers.
// The lifetime of the Api pointed to must be longer than the transfer manager.
PJRT_AsyncHostToDeviceTransferManagerDeleter
MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api);

using PJRT_ErrorDeleter = std::function<void(PJRT_Error*)>;

// Pass in an API pointer; receive a custom deleter for smart pointers.
Expand Down Expand Up @@ -296,6 +304,12 @@ absl::Span<PJRT_DeviceDescription* const> DeviceDescriptions(
absl::StatusOr<xla::CompiledMemoryStats> GetCompiledMemoryStats(
const PJRT_Api* api, PJRT_Executable* executable);

PJRT_ShapeSpec ConvertToPjRtShapeSpec(
const xla::PjRtClient::ShapeSpec& shape_spec);

xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec(
PJRT_ShapeSpec c_shape_spec);

// Creates a PJRT_Profiler_Extension and adds a producer trace with
// the given name. The created PJRT_Profiler_Extension will be used in argument
// structs to pass the producer traceme context id to add a corresponding
Expand Down
17 changes: 17 additions & 0 deletions xla/pjrt/c/pjrt_c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,12 @@ FieldOffsetsAndSizesForVersion(int major_version, int minor_version) {
if (minor_version >= 57) {
add_field("PJRT_Buffer_CopyRawToHost", kFnPtrSize);
}
if (minor_version >= 58) {
add_field("PJRT_AsyncHostToDeviceTransferManager_Destroy", kFnPtrSize);
add_field("PJRT_AsyncHostToDeviceTransferManager_TransferData",
kFnPtrSize);
add_field("PJRT_Client_CreateBuffersForAsyncHostToDevice", kFnPtrSize);
}
return version_offsets_and_sizes;
}
LOG(FATAL) << "Unsupported API version: " << major_version << "."
Expand Down Expand Up @@ -1264,6 +1270,17 @@ TEST_F(PjrtCAbiTestBase, FieldOffsetsAndSizes) {
{"PJRT_Buffer_CopyRawToHost",
{offsetof(PJRT_Api, PJRT_Buffer_CopyRawToHost),
sizeof(PJRT_Api::PJRT_Buffer_CopyRawToHost)}},
{"PJRT_AsyncHostToDeviceTransferManager_Destroy",
{offsetof(PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_Destroy),
sizeof(PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_Destroy)}},
{"PJRT_AsyncHostToDeviceTransferManager_TransferData",
{offsetof(PJRT_Api,
PJRT_AsyncHostToDeviceTransferManager_TransferData),
sizeof(
PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_TransferData)}},
{"PJRT_Client_CreateBuffersForAsyncHostToDevice",
{offsetof(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice),
sizeof(PJRT_Api::PJRT_Client_CreateBuffersForAsyncHostToDevice)}},
};
ASSERT_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR);
ASSERT_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR);
Expand Down
Loading
Loading