diff --git a/xla/service/layout_assignment.cc b/xla/service/layout_assignment.cc index eef57904b2f29..20b49e5c6f001 100644 --- a/xla/service/layout_assignment.cc +++ b/xla/service/layout_assignment.cc @@ -657,6 +657,20 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction, return absl::OkStatus(); } +absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) { + Shape shape = mutable_shape_layout.shape(); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, [](Shape* subshape, const ShapeIndex& shape_index) { + if (subshape->has_layout() && subshape->IsArray()) { + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + } + return absl::OkStatus(); + })); + TF_RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape)); + return absl::OkStatus(); +} + } // namespace absl::Status LayoutAssignment::AddMandatoryConstraints( @@ -693,27 +707,18 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( entry_computation_layout_->AnyLayoutSet()) || (conditional_mismatch_.count(constraints->computation()) == 0 && constraints->computation_constraint().parameter_layout_is_set())) { - const ShapeLayout& parameter_layout = + ShapeLayout parameter_layout = constraints->computation_layout().parameter_layout( instruction->parameter_number()); // Allow some paramter/result layouts to be unset in the entry // computation. if (parameter_layout.AnyLayoutIsSet()) { + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout)); // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. Shape param_shape = parameter_layout.shape(); - // Clear out memory space in layout. Host offloader will do the - // analysis later. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( - ¶m_shape, [](Shape* subshape, const ShapeIndex& index) { - if (!subshape->has_layout() || !subshape->IsArray()) { - return absl::OkStatus(); - } - subshape->mutable_layout()->set_memory_space( - Layout::kDefaultMemorySpace); - return absl::OkStatus(); - })); - TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); if (reverse_computation_order_) { TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( @@ -2033,16 +2038,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint( // Clear out memory space in layout for entry computation root. Host offloader // will do the analysis later and add back the memory space for host outputs. if (constraints->computation()->IsEntryComputation()) { - Shape result_shape = result_layout.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( - &result_shape, [](Shape* subshape, const ShapeIndex& shape_index) { - if (subshape->has_layout() && subshape->IsArray()) { - subshape->mutable_layout()->set_memory_space( - Layout::kDefaultMemorySpace); - } - return absl::OkStatus(); - })); - TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape)); + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); } // Propagate the use constraint of the root instruction up to the logical @@ -2232,25 +2228,29 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { // layout constraint. if (constraints.ResultLayout() != nullptr && constraints.ResultLayout()->LayoutIsSet()) { + ShapeLayout result_layout = *constraints.ResultLayout(); + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); // Layout assignment at this point only does minor-to-major assignment so // tiling info should be ignored here for comparison. VLOG(5) << "Computation result layout needs root copying\n"; - if (!constraints.ResultLayout()->MatchesLayoutInShape( + if (!result_layout.MatchesLayoutInShape( computation->root_instruction()->shape(), /*minor_to_major_only=*/true)) { TF_ASSIGN_OR_RETURN( HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + CreateCopyWithNewLayout(result_layout.shape(), computation->root_instruction())); computation->set_root_instruction(new_root); } else { // Copy the tiling info/tail_padding_alignment_in_elements specified in // result layout. - auto copy_tiling = [&constraints](xla::Shape* subshape, - const xla::ShapeIndex& index) { + auto copy_tiling = [&result_layout](xla::Shape* subshape, + const xla::ShapeIndex& index) { if (subshape->IsArray()) { - const Shape& result_shape = ShapeUtil::GetSubshape( - constraints.ResultLayout()->shape(), index); + const Shape& result_shape = + ShapeUtil::GetSubshape(result_layout.shape(), index); if (result_shape.layout().tiles_size() != 0) { subshape->mutable_layout()->mutable_tiles()->assign( result_shape.layout().tiles().begin(), diff --git a/xla/service/layout_assignment_test.cc b/xla/service/layout_assignment_test.cc index 3cd4a872bff55..e8e9cb7685b04 100644 --- a/xla/service/layout_assignment_test.cc +++ b/xla/service/layout_assignment_test.cc @@ -1367,6 +1367,59 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } +TEST_F(LayoutAssignmentTest, MemorySpaceRemoved) { + const char* module_str = R"( +HloModule MixedHostDeviceResult + +ENTRY %MixedHostDeviceResult { + %p0 = f32[4,4] parameter(0) + %d = f32[4,4]{1,0} custom-call(%p0), custom_call_target="MoveToDevice", metadata={preserve_layout=true} + ROOT %tuple = (f32[4,4], f32[4,4]) tuple(%p0, %d) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + + // Set the parameter to be in host memory. + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, + Layout::kHostMemorySpace)); + // Set one result component to be in host memory, the other one on device. + // Also make sure to request incompatible result layout so that the layout + // assignment pass has to copy the layout from the entry computation layout. + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, Layout::kHostMemorySpace), + ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {0, 1}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, Layout::kDefaultMemorySpace)})); + AssignLayouts(m.get(), &computation_layout); + + // Verify that the memory space did not leak from the entry computation layout + // to the parameter or to the result. + Shape result_shape = m->entry_computation()->root_instruction()->shape(); + EXPECT_EQ( + ShapeUtil::GetTupleElementShape(result_shape, 0).layout().memory_space(), + Layout::kDefaultMemorySpace); + EXPECT_EQ( + ShapeUtil::GetTupleElementShape(result_shape, 1).layout().memory_space(), + Layout::kDefaultMemorySpace); + + const HloInstruction* parameter = FindInstruction(m.get(), "p0"); + EXPECT_EQ(parameter->shape().layout().memory_space(), + Layout::kDefaultMemorySpace); + + ExpectTupleLayoutIs(result_shape, {{1, 0}, {0, 1}}); +} + absl::Status AssignLayoutsToComputation( HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {