Skip to content

Commit

Permalink
[XLA:TPU] Add support for pinning tensors to device sram via custom c…
Browse files Browse the repository at this point in the history
…alls.

PiperOrigin-RevId: 690686233
  • Loading branch information
subhankarshah authored and Google-ML-Automation committed Dec 21, 2024
1 parent d29d8ea commit 0d642da
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 87 deletions.
8 changes: 6 additions & 2 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1759,11 +1759,14 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:host_memory_offload_annotations_hdr",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -1777,9 +1780,10 @@ xla_cc_test(
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:host_memory_offload_annotations_hdr",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
)
Expand Down
144 changes: 89 additions & 55 deletions xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,100 @@

#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/side_effect_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {

namespace {
absl::StatusOr<absl::string_view> GetCustomCallTarget(
absl::string_view external_annotation) {
if (external_annotation ==
host_memory_offload_annotations::kMemoryTargetPinnedHost ||
external_annotation ==
host_memory_offload_annotations::kMemoryTargetUnpinnedHost) {
return host_memory_offload_annotations::kMoveToHostCustomCallTarget;
}
if (external_annotation ==
host_memory_offload_annotations::kMemoryTargetDevice) {
return host_memory_offload_annotations::kMoveToDeviceCustomCallTarget;
}
if (external_annotation ==
host_memory_offload_annotations::kMemoryTargetDeviceSram) {
return host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget;
}
return absl::InvalidArgumentError(
absl::StrCat("Invalid external annotation: ", external_annotation));
}

absl::StatusOr<bool>
ConvertCustomCallWithExternalAnnotationToInternalAnnotation(
HloComputation* c, HloInstruction* instruction) {
const auto& frontend_attributes = instruction->frontend_attributes();
const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr);
if (it == frontend_attributes.map().end()) {
return false;
}
// XLA currently does not differentiate between pinned and unpinned host
// memory.
const bool is_to_host_case =
(it->second == host_memory_offload_annotations::kMemoryTargetPinnedHost ||
it->second ==
host_memory_offload_annotations::kMemoryTargetUnpinnedHost);
const bool is_to_device_case =
(it->second == host_memory_offload_annotations::kMemoryTargetDevice ||
it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram);
if (!is_to_host_case && !is_to_device_case) {
return false;
}
const absl::StatusOr<absl::string_view> custom_call_target =
GetCustomCallTarget(it->second);
TF_RETURN_IF_ERROR(custom_call_target.status());
if (is_to_host_case) {
VLOG(1) << "Process forward case: " << instruction->ToString();
if (instruction->operand_count() != 1) {
return Internal(
"Custom calls with target %s must have exactly one operand. %s "
"has %d.",
host_memory_offload_annotations::kDevicePlacement,
instruction->name(), instruction->operand_count());
}
HloInstruction* input = instruction->mutable_operand(0);
HloInstruction* move_to_host_custom_call =
c->AddInstruction(HloInstruction::CreateCustomCall(
input->shape(), {input}, *custom_call_target));
if (instruction->has_sharding()) {
move_to_host_custom_call->set_sharding(instruction->sharding());
}
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(move_to_host_custom_call));
TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction));
return true;
} else if (is_to_device_case) {
VLOG(1) << "Process backward case: " << instruction->ToString();
HloInstruction* custom_call_operand = instruction->mutable_operand(0);
HloInstruction* new_result =
c->AddInstruction(HloInstruction::CreateCustomCall(
custom_call_operand->shape(), {custom_call_operand},
*custom_call_target));
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result));
TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction));
return true;
}
return false;
}

} // namespace

absl::StatusOr<bool> ConvertMemoryPlacementToInternalAnnotations::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -36,60 +119,11 @@ absl::StatusOr<bool> ConvertMemoryPlacementToInternalAnnotations::Run(
for (HloInstruction* instruction : c->MakeInstructionPostOrder()) {
if (instruction->IsCustomCall(
host_memory_offload_annotations::kDevicePlacement)) {
const auto& frontend_attributes = instruction->frontend_attributes();
const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr);
if (it == frontend_attributes.map().end()) {
continue;
}
// XLA currently does not differentiate between pinned and unpinned host
// memory.
const bool is_to_host_case =
(it->second ==
host_memory_offload_annotations::kMemoryTargetPinnedHost ||
it->second ==
host_memory_offload_annotations::kMemoryTargetUnpinnedHost);
const bool is_to_device_case =
(it->second ==
host_memory_offload_annotations::kMemoryTargetDevice);
if (!is_to_host_case && !is_to_device_case) {
continue;
}
if (is_to_host_case) {
VLOG(1) << "Process forward case: " << instruction->ToString();
if (instruction->operand_count() != 1) {
return Internal(
"Custom calls with target %s must have exactly one operand. %s "
"has %d.",
host_memory_offload_annotations::kDevicePlacement,
instruction->name(), instruction->operand_count());
}
HloInstruction* input = instruction->mutable_operand(0);
HloInstruction* move_to_host_custom_call =
c->AddInstruction(HloInstruction::CreateCustomCall(
input->shape(), {input},
host_memory_offload_annotations::
kMoveToHostCustomCallTarget));
if (instruction->has_sharding()) {
move_to_host_custom_call->set_sharding(instruction->sharding());
}
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(move_to_host_custom_call));
TF_RETURN_IF_ERROR(
c->RemoveInstructionAndUnusedOperands(instruction));
changed = true;
} else if (is_to_device_case) {
VLOG(1) << "Process backward case: " << instruction->ToString();
HloInstruction* custom_call_operand = instruction->mutable_operand(0);
HloInstruction* new_result =
c->AddInstruction(HloInstruction::CreateCustomCall(
custom_call_operand->shape(), {custom_call_operand},
host_memory_offload_annotations::
kMoveToDeviceCustomCallTarget));
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result));
TF_RETURN_IF_ERROR(
c->RemoveInstructionAndUnusedOperands(instruction));
changed = true;
}
TF_ASSIGN_OR_RETURN(
auto result,
ConvertCustomCallWithExternalAnnotationToInternalAnnotation(
c, instruction));
changed |= result;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -509,5 +510,35 @@ TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest,
EXPECT_EQ(move_to_host_count, 1);
}

TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest,
ConvertPinToDeviceSramTest) {
constexpr absl::string_view hlo_string = R"(
HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY main.8 {
Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"}
constant.2 = s32[] constant(2)
broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={}
multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707}
custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708}
custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="device_sram"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708}
ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709}
} // main.8 )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
bool changed =
ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value();
EXPECT_TRUE(changed);
XLA_VLOG_LINES(1, module->ToString());
int64_t pin_todevice_sramcount = 0;
for (auto* c : module->computations()) {
for (auto* instr : c->instructions()) {
pin_todevice_sramcount += instr->IsCustomCall(
host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget);
}
}
EXPECT_EQ(pin_todevice_sramcount, 1);
}

} // namespace
} // namespace xla
3 changes: 3 additions & 0 deletions xla/service/host_memory_offload_annotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ inline const absl::string_view kDevicePlacement = "annotate_device_placement";
inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host";
inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host";
inline const absl::string_view kMemoryTargetDevice = "device";
inline const absl::string_view kMemoryTargetDeviceSram = "device_sram";

// Internal annotations:
inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost";
inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice";
inline const absl::string_view kPinToDeviceSramCustomCallTarget =
"PinToDeviceSram";

} // namespace host_memory_offload_annotations
} // namespace xla
Expand Down
1 change: 1 addition & 0 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ cc_library(
"//xla/hlo/utils:hlo_live_range",
"//xla/service:buffer_value",
"//xla/service:call_graph",
"//xla/service:computation_layout",
"//xla/service:hlo_buffer",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_value",
Expand Down
Loading

0 comments on commit 0d642da

Please sign in to comment.