Skip to content

Commit

Permalink
PR #20426: Layout assignment: Reset memory space in result layout
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20426

Layout assignment should not set any memory space on any of the instructions even if the entry computation layout has non-default memory space. At one place, the memory space was leaking (causing weight offloading crashes on real models), this patch addresses that.

Drive-by: Introduce a helper function for the copy-pasted implementations of resetting the memory space in a layout.
Copybara import of the project:

--
29bfdd8 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Reset memory space and result layout

Merging this change closes #20426

COPYBARA_INTEGRATE_REVIEW=#20426 from jaro-sevcik:scrub-memory-space-in-layout-assignment 29bfdd8
PiperOrigin-RevId: 707185192
  • Loading branch information
jaro-sevcik authored and Google-ML-Automation committed Dec 17, 2024
1 parent 432da09 commit be5b1af
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 29 deletions.
58 changes: 29 additions & 29 deletions xla/service/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,20 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
return absl::OkStatus();
}

absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) {
Shape shape = mutable_shape_layout.shape();
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
&shape, [](Shape* subshape, const ShapeIndex& shape_index) {
if (subshape->has_layout() && subshape->IsArray()) {
subshape->mutable_layout()->set_memory_space(
Layout::kDefaultMemorySpace);
}
return absl::OkStatus();
}));
TF_RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape));
return absl::OkStatus();
}

} // namespace

absl::Status LayoutAssignment::AddMandatoryConstraints(
Expand Down Expand Up @@ -693,27 +707,18 @@ absl::Status LayoutAssignment::AddMandatoryConstraints(
entry_computation_layout_->AnyLayoutSet()) ||
(conditional_mismatch_.count(constraints->computation()) == 0 &&
constraints->computation_constraint().parameter_layout_is_set())) {
const ShapeLayout& parameter_layout =
ShapeLayout parameter_layout =
constraints->computation_layout().parameter_layout(
instruction->parameter_number());
// Allow some paramter/result layouts to be unset in the entry
// computation.
if (parameter_layout.AnyLayoutIsSet()) {
// Clear out memory space in layout. Host offloader will do the
// analysis later.
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout));
// Parameter layouts must match the respective layout in
// ComputationLayout, if there is one.
Shape param_shape = parameter_layout.shape();
// Clear out memory space in layout. Host offloader will do the
// analysis later.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
&param_shape, [](Shape* subshape, const ShapeIndex& index) {
if (!subshape->has_layout() || !subshape->IsArray()) {
return absl::OkStatus();
}
subshape->mutable_layout()->set_memory_space(
Layout::kDefaultMemorySpace);
return absl::OkStatus();
}));

TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction));
if (reverse_computation_order_) {
TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
Expand Down Expand Up @@ -2033,16 +2038,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint(
// Clear out memory space in layout for entry computation root. Host offloader
// will do the analysis later and add back the memory space for host outputs.
if (constraints->computation()->IsEntryComputation()) {
Shape result_shape = result_layout.shape();
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
&result_shape, [](Shape* subshape, const ShapeIndex& shape_index) {
if (subshape->has_layout() && subshape->IsArray()) {
subshape->mutable_layout()->set_memory_space(
Layout::kDefaultMemorySpace);
}
return absl::OkStatus();
}));
TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape));
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout));
}

// Propagate the use constraint of the root instruction up to the logical
Expand Down Expand Up @@ -2232,25 +2228,29 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) {
// layout constraint.
if (constraints.ResultLayout() != nullptr &&
constraints.ResultLayout()->LayoutIsSet()) {
ShapeLayout result_layout = *constraints.ResultLayout();
// Clear out memory space in layout. Host offloader will do the
// analysis later.
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout));
// Layout assignment at this point only does minor-to-major assignment so
// tiling info should be ignored here for comparison.
VLOG(5) << "Computation result layout needs root copying\n";
if (!constraints.ResultLayout()->MatchesLayoutInShape(
if (!result_layout.MatchesLayoutInShape(
computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
CreateCopyWithNewLayout(result_layout.shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
} else {
// Copy the tiling info/tail_padding_alignment_in_elements specified in
// result layout.
auto copy_tiling = [&constraints](xla::Shape* subshape,
const xla::ShapeIndex& index) {
auto copy_tiling = [&result_layout](xla::Shape* subshape,
const xla::ShapeIndex& index) {
if (subshape->IsArray()) {
const Shape& result_shape = ShapeUtil::GetSubshape(
constraints.ResultLayout()->shape(), index);
const Shape& result_shape =
ShapeUtil::GetSubshape(result_layout.shape(), index);
if (result_shape.layout().tiles_size() != 0) {
subshape->mutable_layout()->mutable_tiles()->assign(
result_shape.layout().tiles().begin(),
Expand Down
53 changes: 53 additions & 0 deletions xla/service/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,59 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0},
ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
}

TEST_F(LayoutAssignmentTest, MemorySpaceRemoved) {
const char* module_str = R"(
HloModule MixedHostDeviceResult
ENTRY %MixedHostDeviceResult {
%p0 = f32[4,4] parameter(0)
%d = f32[4,4]{1,0} custom-call(%p0), custom_call_target="MoveToDevice", metadata={preserve_layout=true}
ROOT %tuple = (f32[4,4], f32[4,4]) tuple(%p0, %d)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();

// Set the parameter to be in host memory.
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(
F32, {4, 4}, {1, 0}, /*tiles=*/{},
/*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0,
Layout::kHostMemorySpace));
// Set one result component to be in host memory, the other one on device.
// Also make sure to request incompatible result layout so that the layout
// assignment pass has to copy the layout from the entry computation layout.
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithDenseLayout(
F32, {4, 4}, {1, 0}, /*tiles=*/{},
/*tail_padding_alignment_in_elements=*/1,
/*element_size_in_bits=*/0, Layout::kHostMemorySpace),
ShapeUtil::MakeShapeWithDenseLayout(
F32, {4, 4}, {0, 1}, /*tiles=*/{},
/*tail_padding_alignment_in_elements=*/1,
/*element_size_in_bits=*/0, Layout::kDefaultMemorySpace)}));
AssignLayouts(m.get(), &computation_layout);

// Verify that the memory space did not leak from the entry computation layout
// to the parameter or to the result.
Shape result_shape = m->entry_computation()->root_instruction()->shape();
EXPECT_EQ(
ShapeUtil::GetTupleElementShape(result_shape, 0).layout().memory_space(),
Layout::kDefaultMemorySpace);
EXPECT_EQ(
ShapeUtil::GetTupleElementShape(result_shape, 1).layout().memory_space(),
Layout::kDefaultMemorySpace);

const HloInstruction* parameter = FindInstruction(m.get(), "p0");
EXPECT_EQ(parameter->shape().layout().memory_space(),
Layout::kDefaultMemorySpace);

ExpectTupleLayoutIs(result_shape, {{1, 0}, {0, 1}});
}

absl::Status AssignLayoutsToComputation(
HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
Expand Down

0 comments on commit be5b1af

Please sign in to comment.