Skip to content

Commit

Permalink
[XLA:TPU] Reuse same Alias Analysis object in RunMemorySpaceAssignment
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707073266
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
1 parent d903f75 commit 93ec913
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
35 changes: 15 additions & 20 deletions xla/service/memory_space_assignment/memory_space_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,10 @@ void TransformAllocationSequenceToSpill(AllocationSequence& allocations,
} // namespace

absl::StatusOr<MemorySpaceAssignment::AsyncCopyStats>
MemorySpaceAssignment::CalculateAsyncCopyStats() const {
MemorySpaceAssignment::CalculateAsyncCopyStats(
const HloDataflowAnalysis& dataflow_analysis) const {
AsyncCopyStats stats;
int64_t current_copies = 0;
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
HloDataflowAnalysis::Run(*module_));
for (const HloComputation* computation :
module_->MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) {
Expand All @@ -305,7 +304,7 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
HloOpcode::kSlice)) {
current_copies--;
int64_t size =
options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
options_.size_fn(dataflow_analysis.GetUniqueValueAt(instruction));
if (instruction->shape().layout().memory_space() ==
options_.alternate_memory_space) {
++stats.num_prefetches;
Expand Down Expand Up @@ -409,26 +408,23 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
ScheduleAsynchronousCopies();
TF_RETURN_IF_ERROR(SimplifyGraph());
TF_RETURN_IF_ERROR(FixSchedule());
TF_RETURN_IF_ERROR(ExportAndColorBuffers());
TF_ASSIGN_OR_RETURN(auto alias, HloAliasAnalysis::Run(module_));
TF_RETURN_IF_ERROR(ExportAndColorBuffers(*alias));
std::vector<int64_t> alt_mem_bytes_occupied;
// alt_mem_bytes_occupied is used for logging in the RuntimeSimulator below.
// We only populate it in VerifyAndExportHeapSimulatorTrace if the
// RuntimeSimulator is present.
TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace(
*alias,
runtime_simulator.has_value() ? &alt_mem_bytes_occupied : nullptr));
if (runtime_simulator.has_value()) {
float estimated_time = runtime_simulator->SimulateElapsedTime(
module_, allocations_, &alt_mem_bytes_occupied);
VLOG(1) << "Estimated elapsed time with async copies (sec): "
<< estimated_time;
}

if (VLOG_IS_ON(3)) {
LOG(INFO) << "Module after memory space assignment: ";
XLA_LOG_LINES(INFO, module_->ToString());
}
TF_CHECK_OK(module_->schedule().Verify());
TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
TF_ASSIGN_OR_RETURN(AsyncCopyStats stats,
CalculateAsyncCopyStats(alias->dataflow_analysis()));
VLOG(1) << "Maximum number of outstanding async copies/slices: "
<< stats.max_outstanding_async_copies;
VLOG(1) << "Number of prefetches: " << stats.num_prefetches
Expand Down Expand Up @@ -539,15 +535,15 @@ absl::Status MemorySpaceAssignment::Process(
return absl::OkStatus();
}

absl::Status MemorySpaceAssignment::ExportAndColorBuffers() {
absl::Status MemorySpaceAssignment::ExportAndColorBuffers(
const HloAliasAnalysis& alias_analysis) {
VLOG(1) << "Exporting buffers...";
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
absl::flat_hash_map<int64_t, int64_t> seen_buffer_offsets;
VLOG(3) << "Exported alternate memory allocations:";
for (const auto& position_and_chunk : alternate_memory_assignments_) {
const HloPosition& defining_position = position_and_chunk.first;
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
defining_position.instruction, defining_position.index);
auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
Expand Down Expand Up @@ -589,7 +585,7 @@ absl::Status MemorySpaceAssignment::ExportAndColorBuffers() {
for (const auto& defining_position_and_chunk :
preset_assignments_->chunks()) {
const HloPosition& defining_position = defining_position_and_chunk.first;
for (auto& buffer : alias_analysis->ComputeBuffersAt(
for (auto& buffer : alias_analysis.ComputeBuffersAt(
defining_position.instruction, defining_position.index)) {
for (auto& value : buffer->values()) {
for (auto& position : value->positions()) {
Expand Down Expand Up @@ -1049,12 +1045,11 @@ absl::Status MemorySpaceAssignment::FixSchedule() {
}

absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace(
const HloAliasAnalysis& alias_analysis,
std::vector<int64_t>* alt_mem_bytes_occupied) {
VLOG(1) << "Verifying...";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module_));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
HloLiveRange::Run(module_->schedule(), *alias_analysis,
HloLiveRange::Run(module_->schedule(), alias_analysis,
module_->entry_computation()));

BufferIntervalTree interval_tree;
Expand Down Expand Up @@ -1120,7 +1115,7 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace(
const HloPosition& position = position_and_chunk.first;
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
const HloBuffer& buffer =
alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
CHECK(!seen_buffers.contains(buffer.id()))
<< "Multiple preset assignments for the same buffer: "
<< buffer.ToString() << ", pos: " << position.ToString()
Expand Down
7 changes: 5 additions & 2 deletions xla/service/memory_space_assignment/memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Useful logging and error messages
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/service/buffer_value.h"
Expand Down Expand Up @@ -305,7 +306,8 @@ class MemorySpaceAssignment {
const HloAliasAnalysis& alias_analysis, const Options& options);

// Calculates asynchronous copy statistics.
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats(
const HloDataflowAnalysis& dataflow_analysis) const;

// Verify that allocations_ are free of overlapping Allocations in time and
// space. This is a post-processing step called after all allocations have
Expand All @@ -318,6 +320,7 @@ class MemorySpaceAssignment {
// If alt_mem_bytes_occupied is not null, it will be populated with the number
// of bytes occupied in the alternate memory space at each instruction time.
absl::Status VerifyAndExportHeapSimulatorTrace(
const HloAliasAnalysis& alias_analysis,
std::vector<int64_t>* alt_mem_bytes_occupied = nullptr);

protected:
Expand Down Expand Up @@ -372,7 +375,7 @@ class MemorySpaceAssignment {

// Export the alternate memory assignments to the PresetAssignments and color
// the HLO graph with the determined memory spaces.
absl::Status ExportAndColorBuffers();
absl::Status ExportAndColorBuffers(const HloAliasAnalysis& alias_analysis);

// Schedules asynchronous copies and ensures that the CopyStarts and their
// corresponding CopyDones follow the same order.
Expand Down

0 comments on commit 93ec913

Please sign in to comment.