diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 02eeea836544b..7809b63c747f7 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -1809,11 +1809,14 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/platform:errors", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -1827,9 +1830,10 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc b/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc index 6846a186c7e69..3c7c89a54cabc 100644 --- a/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc +++ b/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc @@ -17,17 +17,100 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/side_effect_util.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { +namespace { +absl::StatusOr GetCustomCallTarget( + absl::string_view external_annotation) { + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetPinnedHost || + external_annotation == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost) { + return host_memory_offload_annotations::kMoveToHostCustomCallTarget; + } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetDevice) { + return host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; + } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetDeviceSram) { + return host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget; + } + return absl::InvalidArgumentError( + absl::StrCat("Invalid external annotation: ", external_annotation)); +} + +absl::StatusOr +ConvertCustomCallWithExternalAnnotationToInternalAnnotation( + HloComputation* c, HloInstruction* instruction) { + const auto& frontend_attributes = instruction->frontend_attributes(); + const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); + if (it == frontend_attributes.map().end()) { + return false; + } + // XLA currently does not differentiate between pinned and unpinned host + // memory. + const bool is_to_host_case = + (it->second == host_memory_offload_annotations::kMemoryTargetPinnedHost || + it->second == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost); + const bool is_to_device_case = + (it->second == host_memory_offload_annotations::kMemoryTargetDevice || + it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram); + if (!is_to_host_case && !is_to_device_case) { + return false; + } + const absl::StatusOr custom_call_target = + GetCustomCallTarget(it->second); + TF_RETURN_IF_ERROR(custom_call_target.status()); + if (is_to_host_case) { + VLOG(1) << "Process forward case: " << instruction->ToString(); + if (instruction->operand_count() != 1) { + return Internal( + "Custom calls with target %s must have exactly one operand. %s " + "has %d.", + host_memory_offload_annotations::kDevicePlacement, + instruction->name(), instruction->operand_count()); + } + HloInstruction* input = instruction->mutable_operand(0); + HloInstruction* move_to_host_custom_call = + c->AddInstruction(HloInstruction::CreateCustomCall( + input->shape(), {input}, *custom_call_target)); + if (instruction->has_sharding()) { + move_to_host_custom_call->set_sharding(instruction->sharding()); + } + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(move_to_host_custom_call)); + TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction)); + return true; + } else if (is_to_device_case) { + VLOG(1) << "Process backward case: " << instruction->ToString(); + HloInstruction* custom_call_operand = instruction->mutable_operand(0); + HloInstruction* new_result = + c->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_operand->shape(), {custom_call_operand}, + *custom_call_target)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); + TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction)); + return true; + } + return false; +} + +} // namespace + absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -36,60 +119,11 @@ absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( for (HloInstruction* instruction : c->MakeInstructionPostOrder()) { if (instruction->IsCustomCall( host_memory_offload_annotations::kDevicePlacement)) { - const auto& frontend_attributes = instruction->frontend_attributes(); - const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); - if (it == frontend_attributes.map().end()) { - continue; - } - // XLA currently does not differentiate between pinned and unpinned host - // memory. - const bool is_to_host_case = - (it->second == - host_memory_offload_annotations::kMemoryTargetPinnedHost || - it->second == - host_memory_offload_annotations::kMemoryTargetUnpinnedHost); - const bool is_to_device_case = - (it->second == - host_memory_offload_annotations::kMemoryTargetDevice); - if (!is_to_host_case && !is_to_device_case) { - continue; - } - if (is_to_host_case) { - VLOG(1) << "Process forward case: " << instruction->ToString(); - if (instruction->operand_count() != 1) { - return Internal( - "Custom calls with target %s must have exactly one operand. %s " - "has %d.", - host_memory_offload_annotations::kDevicePlacement, - instruction->name(), instruction->operand_count()); - } - HloInstruction* input = instruction->mutable_operand(0); - HloInstruction* move_to_host_custom_call = - c->AddInstruction(HloInstruction::CreateCustomCall( - input->shape(), {input}, - host_memory_offload_annotations:: - kMoveToHostCustomCallTarget)); - if (instruction->has_sharding()) { - move_to_host_custom_call->set_sharding(instruction->sharding()); - } - TF_RETURN_IF_ERROR( - instruction->ReplaceAllUsesWith(move_to_host_custom_call)); - TF_RETURN_IF_ERROR( - c->RemoveInstructionAndUnusedOperands(instruction)); - changed = true; - } else if (is_to_device_case) { - VLOG(1) << "Process backward case: " << instruction->ToString(); - HloInstruction* custom_call_operand = instruction->mutable_operand(0); - HloInstruction* new_result = - c->AddInstruction(HloInstruction::CreateCustomCall( - custom_call_operand->shape(), {custom_call_operand}, - host_memory_offload_annotations:: - kMoveToDeviceCustomCallTarget)); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); - TF_RETURN_IF_ERROR( - c->RemoveInstructionAndUnusedOperands(instruction)); - changed = true; - } + TF_ASSIGN_OR_RETURN( + auto result, + ConvertCustomCallWithExternalAnnotationToInternalAnnotation( + c, instruction)); + changed |= result; } } } diff --git a/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc b/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc index d7746a4d97142..db122ae9db5ed 100644 --- a/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc +++ b/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc @@ -20,12 +20,13 @@ #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -509,5 +510,35 @@ TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, EXPECT_EQ(move_to_host_count, 1); } +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertPinToDeviceSramTest) { + constexpr absl::string_view hlo_string = R"( + HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true} + + ENTRY main.8 { + Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"} + constant.2 = s32[] constant(2) + broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={} + multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707} + custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="device_sram"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709} + } // main.8 )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t pin_todevice_sramcount = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + pin_todevice_sramcount += instr->IsCustomCall( + host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget); + } + } + EXPECT_EQ(pin_todevice_sramcount, 1); +} + } // namespace } // namespace xla diff --git a/xla/service/host_memory_offload_annotations.h b/xla/service/host_memory_offload_annotations.h index a0b7e3decaea3..42cde9221f5aa 100644 --- a/xla/service/host_memory_offload_annotations.h +++ b/xla/service/host_memory_offload_annotations.h @@ -26,10 +26,13 @@ inline const absl::string_view kDevicePlacement = "annotate_device_placement"; inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host"; inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host"; inline const absl::string_view kMemoryTargetDevice = "device"; +inline const absl::string_view kMemoryTargetDeviceSram = "device_sram"; // Internal annotations: inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost"; inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice"; +inline const absl::string_view kPinToDeviceSramCustomCallTarget = + "PinToDeviceSram"; } // namespace host_memory_offload_annotations } // namespace xla diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index e578b83ff46e1..14d50dd1bd659 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -564,6 +564,7 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:call_graph", + "//xla/service:computation_layout", "//xla/service:hlo_buffer", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index 1ca59a0364f0f..db75f8f481ad9 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" +#include "xla/service/computation_layout.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_buffer.h" @@ -266,34 +267,78 @@ bool IsCrossProgramPrefetchCandidate(const HloValue& value, }); } -struct CrossProgramPrefetchBufferSortValues { - int64_t latest_use = 0; - int64_t use_size = 0; +bool IsUserAnnotatedCrossProgramPrefetch(const HloValue& value, + const Options& options) { + const HloInstruction* defining_instruction = value.defining_instruction(); + if (defining_instruction->parent() != + defining_instruction->GetModule()->entry_computation() || + defining_instruction->opcode() != HloOpcode::kParameter) { + return false; + } + const ComputationLayout& entry_computation_layout = + defining_instruction->GetModule()->entry_computation_layout(); + if (defining_instruction->parameter_number() >= + entry_computation_layout.parameter_count()) { + return false; + } + const Shape& shape = + entry_computation_layout + .parameter_layout(defining_instruction->parameter_number()) + .shape(); + return shape.has_layout() && + shape.layout().memory_space() == options.alternate_memory_space; +} + +MsaBufferInterval CreateMsaBufferInterval(const HloBuffer& buffer, + const HloValue* value, + const HloLiveRange& hlo_live_range, + const Options& options) { + MsaBufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + interval.need_allocation = true; + return interval; +} + +struct CrossProgramPrefetches { + std::vector prefetches; + std::vector candidates; }; -std::vector FindCrossProgramPrefetchCandidates( +CrossProgramPrefetches FindCrossProgramPrefetches( const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, const Options& options) { - std::vector candidates; + CrossProgramPrefetches cross_program_prefetches; for (const HloBuffer& buffer : alias_analysis.buffers()) { CHECK_GE(buffer.values().size(), 1); const HloValue* value = buffer.values().at(0); - MsaBufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; - if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) { - candidates.push_back(interval); + MsaBufferInterval buffer_interval = + CreateMsaBufferInterval(buffer, value, hlo_live_range, options); + if (IsUserAnnotatedCrossProgramPrefetch(*value, options)) { + cross_program_prefetches.prefetches.push_back(buffer_interval); + } else if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, + options)) { + cross_program_prefetches.candidates.push_back(buffer_interval); } else if (MemorySpaceAssignmentUtils:: DoesCrossProgramPrefetchBufferMatchAnyFilter( - options.msa_sort_order_overrides, interval)) { - candidates.push_back(interval); + options.msa_sort_order_overrides, buffer_interval)) { + cross_program_prefetches.candidates.push_back(buffer_interval); } } + for (auto& prefetch : cross_program_prefetches.prefetches) { + VLOG(3) << "User annotated cross-program prefetch: " + << prefetch.buffer->ToString(); + } + + for (auto& prefetch : cross_program_prefetches.prefetches) { + VLOG(3) << "User annotated cross-program prefetch: " + << prefetch.buffer->ToString(); + } + DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator( hlo_live_range, options.msa_sort_order_overrides); BufferIntervalComparator* comparator = @@ -301,16 +346,18 @@ std::vector FindCrossProgramPrefetchCandidates( options.buffer_interval_comparator ? options.buffer_interval_comparator : &default_comparator); - absl::c_sort(candidates, comparator->GetComparisonFunctor()); + absl::c_sort(cross_program_prefetches.candidates, + comparator->GetComparisonFunctor()); - VLOG(3) << "Cross-program prefetch candidates: " << candidates.size() + VLOG(3) << "Cross-program prefetch candidates: " + << cross_program_prefetches.candidates.size() << ". Sorting criteria: " << comparator->DescribeComparisonCriteria(); - for (auto& candidate : candidates) { + for (auto& candidate : cross_program_prefetches.candidates) { VLOG(3) << "Cross-program prefetch candidate. Sorting criteria: " << comparator->CriteriaToString(candidate) << ". Candidate: " << candidate.buffer->ToString(); } - return candidates; + return cross_program_prefetches; } } // namespace @@ -1638,11 +1685,27 @@ absl::StatusOr> MsaAlgorithm::Finish() { } VLOG(1) << "Memory pressure = " << memory_pressure_; + CrossProgramPrefetches cross_program_prefetches = + FindCrossProgramPrefetches(alias_analysis_, hlo_live_range_, options_); + // Crash if cross program prefetch is disabled and user has requested + // cross program prefetch. + CHECK(options_.enable_cross_program_prefetch || + cross_program_prefetches.prefetches.empty()) + << "Cross program prefetch is disabled but user has requested cross " + "program prefetch."; + // Crash if number of user requested cross program prefetches is greater than + // the maximum number of cross program prefetches allowed. + CHECK(cross_program_prefetches.prefetches.size() <= + options().max_cross_program_prefetches) + << "Number of user requested cross program prefetches is greater than " + "the maximum number of cross program prefetches allowed."; + // Allocate user requested cross program prefetches first. + for (auto& prefetch : cross_program_prefetches.prefetches) { + HloModule* module = prefetch.buffer->instruction()->GetModule(); + AllocateCrossProgramPrefetchBuffer(module, prefetch); + } if (options_.enable_cross_program_prefetch) { - std::vector prefetch_candidates = - FindCrossProgramPrefetchCandidates(alias_analysis_, hlo_live_range_, - options_); - for (auto& prefetch_candidate : prefetch_candidates) { + for (auto& prefetch_candidate : cross_program_prefetches.candidates) { HloModule* module = prefetch_candidate.buffer->instruction()->GetModule(); if (0 <= options().max_cross_program_prefetches && options().max_cross_program_prefetches <= @@ -3247,6 +3310,9 @@ void SetDefaultMemorySpace(const HloValue* value, const Options& options) { } shape->mutable_layout()->set_memory_space(options.default_memory_space); } + HloModule* module = value->defining_instruction()->GetModule(); + module->mutable_config().SetComputationLayoutIfExists( + module->entry_computation()->ComputeProgramShape()); } } // namespace diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 6ec82accd7f6b..8f80b978757e0 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -5493,6 +5493,11 @@ TEST_F(MemorySpaceAssignmentTest, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); + Shape shape_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, + kDefaultMemorySpace); // p0 is in the default memory space. HloInstruction* p0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); @@ -5533,13 +5538,14 @@ TEST_F(MemorySpaceAssignmentTest, options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) { return true; }; + XLA_VLOG_LINES(3, module->ToString()); std::unique_ptr preset_assignments = AssignMemorySpace(module.get(), options); - + XLA_VLOG_LINES(3, module->ToString()); // Ensure that p1 is in the alternate memory and add, which has p1 as an // operand, has a direct dependency to p1 (no CopyStart/CopyDone). - EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1))); + EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_default_mem)); + EXPECT_THAT(add, op::Add(op::Negate(), op::CopyDone())); // Make sure add is still in the alternate memory space. EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -5548,6 +5554,7 @@ TEST_F(MemorySpaceAssignmentTest, // alternate memory space are left to BufferAssignment to be allocated. for (const auto& position_and_chunk : preset_assignments->chunks()) { const HloPosition& position = position_and_chunk.first; + XLA_VLOG_LINES(3, position.instruction->ToString()); EXPECT_NE(position.instruction, p1); EXPECT_NE(position.instruction, add); } @@ -10129,8 +10136,10 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) { - // This test is for checking if the cross-program-prefetched buffer is freed - // after its last use and there is an end-of-program prefetch. + // This test is same as above, but with an override to cross-program prefetch + // parameter0 as opposed to p0 and limiting the max alternate memory + // size to 256 bytes so that both p0 and p1 cannot be assigned to alternate + // memory and priority is given to p0. absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true @@ -10218,6 +10227,203 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) { EXPECT_TRUE(has_zero_offset_allocations); } +TEST_F(MemorySpaceAssignmentTest, UserAnnotatedCrossProgramPrefetchNoReuse) { + // This test is same as above, but with user directive to cross-program + // prefetch parameter0 as opposed to p0 and limiting the max alternate memory + // size to 256 bytes so that both p0 and p1 cannot be assigned to alternate + // memory and priority is given to p0. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, entry_computation_layout={(f32[8,8]{1,0:S(1)}, f32[8,2]{1,0})->f32[8,2]{1,0}} + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0:S(1)} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 256; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); + // Also verify that the copy-done for the end-of-program prefetch is the last + // instruction in schedule. + const HloInstruction* last_instruction = + module->schedule() + .sequence(module->entry_computation()) + .instructions()[module->entry_computation()->instruction_count() - 1]; + EXPECT_THAT(last_instruction, op::CopyDone()); + EXPECT_NE(last_instruction, module->entry_computation()->root_instruction()); + // Cross program prefetch would use offset 0 because that's the first + // assignment. Since we are freeing the cross-program prefetch buffer, we + // would also expect to see some of the intermediate computations (one of the + // negate ops) to also get 0 offset allocations. + bool has_zero_offset_allocations = false; + for (auto pos_and_chunk : preset_assignments->chunks()) { + if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate && + pos_and_chunk.second.offset == 0) { + has_zero_offset_allocations = true; + } + } + EXPECT_TRUE(has_zero_offset_allocations); + XLA_VLOG_LINES(3, module->ToString()); + bool found = false; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->name() == "p0") { + found = true; + EXPECT_EQ(instr->shape().layout().memory_space(), 0); + EXPECT_EQ(module->entry_computation_layout() + .parameter_layout(0) + .shape() + .layout() + .memory_space(), + 0); + } + } + } + EXPECT_TRUE(found); +} + +TEST_F(MemorySpaceAssignmentTest, + UserAnnotatedCrossProgramPrefetchWithoutPropagationToParameterNoReuse) { + // This test is same as above, but the S(1) memory space specified in the + // layout to cross-program prefetch p0 is only present in the entry + // computation layout and has not been propagated to the parameter + // instruction. This still works as the previous test. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, entry_computation_layout={(f32[8,8]{1,0:S(1)}, f32[8,2]{1,0})->f32[8,2]{1,0}} + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 256; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); + // Also verify that the copy-done for the end-of-program prefetch is the last + // instruction in schedule. + const HloInstruction* last_instruction = + module->schedule() + .sequence(module->entry_computation()) + .instructions()[module->entry_computation()->instruction_count() - 1]; + EXPECT_THAT(last_instruction, op::CopyDone()); + EXPECT_NE(last_instruction, module->entry_computation()->root_instruction()); + // Cross program prefetch would use offset 0 because that's the first + // assignment. Since we are freeing the cross-program prefetch buffer, we + // would also expect to see some of the intermediate computations (one of the + // negate ops) to also get 0 offset allocations. + bool has_zero_offset_allocations = false; + for (auto pos_and_chunk : preset_assignments->chunks()) { + if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate && + pos_and_chunk.second.offset == 0) { + has_zero_offset_allocations = true; + } + } + EXPECT_TRUE(has_zero_offset_allocations); + XLA_VLOG_LINES(3, module->ToString()); + bool found = false; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->name() == "p0") { + found = true; + EXPECT_EQ(instr->shape().layout().memory_space(), 0); + EXPECT_EQ(module->entry_computation_layout() + .parameter_layout(0) + .shape() + .layout() + .memory_space(), + 0); + } + } + } + EXPECT_TRUE(found); +} + TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) { // This test is for checking if the cross-program-prefetched buffer is freed // after its last use and there is an end-of-program prefetch.