Skip to content

Commit

Permalink
[XLA:LatencyHidingScheduler] Do not schedule a ready annotated group …
Browse files Browse the repository at this point in the history
…if doing so would cause an overlap limit to be crossed. Wait until the respective resources are released.

Move the initialization of `scheduling_instruction_crosses_overlap_limit_` to `DefaultSchedulerCore::Initialize` as we now need to use it with scheduling annotation groups and it should be available before the first entry to `FindAndExtractBestNodeAvailable`.

PiperOrigin-RevId: 705975287
  • Loading branch information
seherellis authored and Google-ML-Automation committed Dec 17, 2024
1 parent d80d15e commit 6d52a3a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 33 deletions.
91 changes: 58 additions & 33 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1320,29 +1320,6 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
}
absl::InlinedVector<std::pair<HloGraphNode*, SkipNodeReason>, 2>
skipped_nodes_and_reasons;
if (!scheduling_instruction_crosses_overlap_limit_) {
scheduling_instruction_crosses_overlap_limit_ =
[](const SchedulingState& sched_state, const HloGraphNode* node) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.empty()) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}
return false;
};
}
VLOG(2) << "Current time: " << sched_state.current_time;
ReadySetLt ready_lt{&sched_state, target_scheduling_rule_,
early_target_scheduling_rule_};
Expand Down Expand Up @@ -2285,6 +2262,29 @@ absl::Status DefaultSchedulerCore::InitializeScheduler(
if (VLOG_IS_ON(2)) {
annotation_tracker_->PrintAnnotationSets(2);
}
if (!scheduling_instruction_crosses_overlap_limit_) {
scheduling_instruction_crosses_overlap_limit_ =
[](const SchedulingState& sched_state, const HloGraphNode* node) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.empty()) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}
return false;
};
}
return absl::OkStatus();
}

Expand All @@ -2303,6 +2303,17 @@ absl::Status DefaultSchedulerCore::SchedulingStep(
return absl::OkStatus();
}

bool DefaultSchedulerCore::SchedulingAnnotationCrossesOverlapLimit(
const SchedulingState& sched_state, int64_t annotation) {
for (const HloInstruction* instr :
annotation_tracker_->GetInstructions(annotation)) {
if (scheduling_instruction_crosses_overlap_limit_(
sched_state, &sched_state.sched_graph.GetNode(instr))) {
return true;
}
}
return false;
}
absl::StatusOr<std::vector<HloInstruction*>>
DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
const HloSchedule& module_schedule = computation->parent()->schedule();
Expand Down Expand Up @@ -2369,16 +2380,30 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
return absl::StrJoin(sched_state.ready_set, "\n", LogFormatter());
}());
if (!sched_state.ready_annotations.empty()) {
// TODO (sacer): If more than one annotations are ready, decide which one
// to schedule next with a heuristic.
int64_t annotation = sched_state.ready_annotations.back();
sched_state.ready_annotations.pop_back();
VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------";
sched_state.ongoing_annotation = annotation;
TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state));
VLOG(2) << "------- END ANNOTATION: " << annotation << " --------";
sched_state.ongoing_annotation = -1;
continue;
// Pick the first ready annotation whose scheduling will not cross the
// overlap limit. If there is no such annotation, continue with scheduling
// non-annotated ops.
int64_t annotation_index = -1;
for (int64_t i = 0; i < sched_state.ready_annotations.size(); ++i) {
if (SchedulingAnnotationCrossesOverlapLimit(
sched_state, sched_state.ready_annotations[i])) {
continue;
}
annotation_index = i;
break;
}
if (annotation_index != -1) {
std::swap(sched_state.ready_annotations[annotation_index],
sched_state.ready_annotations.back());
int64_t annotation = sched_state.ready_annotations.back();
sched_state.ready_annotations.pop_back();
VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------";
sched_state.ongoing_annotation = annotation;
TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state));
VLOG(2) << "------- END ANNOTATION: " << annotation << " --------";
sched_state.ongoing_annotation = -1;
continue;
}
}
TF_RETURN_IF_ERROR(SchedulingStep(&sched_state));
}
Expand Down
2 changes: 2 additions & 0 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,8 @@ class DefaultSchedulerCore : public SchedulerCore {
this->config_.memory_limit = new_limit;
}
int64_t GetRerunTimes() override { return config_.rerun; }
bool SchedulingAnnotationCrossesOverlapLimit(
const SchedulingState& sched_state, int64_t annotation);

protected:
virtual void LogInstruction(const HloInstruction* instr) const;
Expand Down
46 changes: 46 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3817,4 +3817,50 @@ ENTRY main {
}
}

TEST_F(LatencyHidingSchedulerTest, SchedulingAnnotationCrossesOverlapLimit) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY entry {
p0 = f32[16,64,256]{2,1,0} parameter(0)
p1 = f32[128,2048,2048]{2,1,0} parameter(1)
cp1s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"}
cp1d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="0"}
cp2s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}
cp2d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp2s)
slice = f32[16,64,256]{2,1,0} slice(cp1d), slice={[0:16], [0:64], [0:256]}
c1 = f32[16,256,256]{2,1,0} convolution(p0, p0),
window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"}
c2 = f32[16,256,256]{2,1,0} convolution(p0, slice),
window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[128,2048,2048]{2,1,0}) tuple(c1, c2, cp2d)
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 1;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

// With the overlap limit of 1 on collective permutes, we cannot schedule the
// scheduling group with annotation 0 right after it becomes ready, because
// cp2's overlap would be open at that moment. cp1 can be scheduled only after
// cp2 is closed (in the reverse order).
EXPECT_LT(GetIndex(new_instruction_sequence, "cp1d"),
GetIndex(new_instruction_sequence, "cp2s"));
}

} // namespace xla

0 comments on commit 6d52a3a

Please sign in to comment.