Skip to content

Commit

Permalink
Title: New NCCL Collectives Latency Estimator
Browse files Browse the repository at this point in the history
Description:

This PR introduces a new analytical latency estimator for NCCL collectives, enabled via the next flags:

--xla_gpu_enable_analytical_sol_latency_estimator \
--xla_gpu_analytical_latency_estimator_options='nccl_op_launch_us=<value>,nic_speed_gbps=<value>,chunk_prep_us=<value>,rtt_us=<value>,gpus_per_node=<value>,chunk_size_bytes=<value>'

Replace <value> with appropriate number for your system (e.g., nccl_op_launch_us=XX). This estimator should improve accuracy and performance, especially for large-scale distributed training."

PiperOrigin-RevId: 707261072
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
1 parent 2445c22 commit 09bc536
Show file tree
Hide file tree
Showing 11 changed files with 748 additions and 1 deletion.
50 changes: 50 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -169,6 +170,25 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_dump_latency_hiding_schedule(false);
opts.set_xla_gpu_enable_latency_hiding_scheduler(false);
opts.set_xla_gpu_enable_analytical_latency_estimator(false);
opts.set_xla_gpu_enable_analytical_sol_latency_estimator(false);
auto* sol_estimator_defaults =
opts.mutable_xla_gpu_analytical_latency_estimator_options();
sol_estimator_defaults->emplace(
"nccl_op_launch_us",
absl::StrCat(static_cast<int>(100.0f * kDefaultNcclCostModelCoeff)));
sol_estimator_defaults->emplace(
"nic_speed_gbps",
absl::StrCat(static_cast<int>(55.56f * kDefaultNcclCostModelCoeff)));
sol_estimator_defaults->emplace(
"chunk_prep_us",
absl::StrCat(static_cast<int>(13.34f * kDefaultNcclCostModelCoeff)));
sol_estimator_defaults->emplace(
"rtt_us",
absl::StrCat(static_cast<int>(68.89f * kDefaultNcclCostModelCoeff)));
sol_estimator_defaults->emplace(
"chunk_size_bytes", absl::StrCat(kDefaultNcclCostModelChunkSizeBytes));
sol_estimator_defaults->emplace(
"gpus_per_node", absl::StrCat(kDefaultNcclCostModelGPUsPerNode));
opts.set_xla_gpu_pgle_profile_file_or_directory_path("");
opts.set_xla_gpu_memory_limit_slop_factor(95);
opts.set_xla_gpu_enable_highest_priority_async_stream(true);
Expand Down Expand Up @@ -470,6 +490,17 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
return true;
};

// Custom "sub-parser" lambda for
// xla_gpu_analytical_latency_estimator_options.
auto setter_for_xla_gpu_analytical_latency_estimator_options =
[debug_options](std::string comma_separated_values) {
google::protobuf::Map<std::string, std::string>* options_map =
debug_options
->mutable_xla_gpu_analytical_latency_estimator_options();
parse_xla_backend_extra_options(options_map, comma_separated_values);
return true;
};

// Custom "sub-parser" lambda for xla_partitioning_algorithm.
auto setter_for_xla_partitioning_algorithm =
[debug_options](const std::string& value) {
Expand Down Expand Up @@ -1568,6 +1599,25 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_enable_analytical_latency_estimator(),
"Enable analytical latency estimator for latency-hiding scheduler for "
"XLA:GPU"));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_analytical_sol_latency_estimator",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_analytical_sol_latency_estimator),
debug_options->xla_gpu_enable_analytical_sol_latency_estimator(),
"Enable analytical Speed-of-Light latency estimator for latency-hiding "
"scheduler for XLA:GPU, must be used without "
"xla_gpu_enable_analytical_latency_estimator. It can also benefit from "
"user-passed options in xla_gpu_analytical_latency_estimator_options"));
flag_list->push_back(tsl::Flag(
"xla_gpu_analytical_latency_estimator_options",
setter_for_xla_gpu_analytical_latency_estimator_options, "",
"Extra platform-specific options to improve analytical latency "
"estimator precision; comma-separated list of 'key=val' "
"strings (=val may be omitted); no whitespace around commas."
"Available options: "
"--xla_gpu_analytical_latency_estimator_options='nccl_op_launch_ms=55,"
"nic_speed_gbps=40,chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4,"
"chunk_size_bytes=1024'"));
flag_list->push_back(tsl::Flag(
"xla_gpu_pgle_profile_file_or_directory_path",
string_setter_for(
Expand Down
5 changes: 5 additions & 0 deletions xla/service/collective_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ constexpr int64_t kDefaultAllGatherCombineThreshold = 30 * 1024 * 1024 + 7;
// pass will combine collectives.
constexpr int64_t kDefaultReduceScatterCombineThreshold = 30 * 1024 * 1024 + 7;

// Defines the default coefficient for the SoL NCCL collective cost model.
// Note: XLA flags allow a user to override the default values of the model.
constexpr float kDefaultNcclCostModelCoeff = 0.45f;
constexpr int64_t kDefaultNcclCostModelChunkSizeBytes = 4194304; // 4MB
constexpr int64_t kDefaultNcclCostModelGPUsPerNode = 8;
} // namespace xla

#endif // XLA_SERVICE_COLLECTIVE_UTILS_H_
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2120,6 +2120,7 @@ cc_library(
"//xla/service:p2p_schedule_preparation",
"//xla/service:profile_guided_latency_estimator",
"//xla/service/gpu/model:analytical_latency_estimator",
"//xla/service/gpu/model:sol_latency_estimator",
"//xla/service/gpu/transforms:pgle_accuracy_checker",
"//xla/service/gpu/transforms:schedule_postprocessing",
"//xla/service/gpu/transforms:scheduling_instruction_annotator",
Expand Down
11 changes: 11 additions & 0 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "xla/service/gpu/flag_utils.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/analytical_latency_estimator.h"
#include "xla/service/gpu/model/sol_latency_estimator.h"
#include "xla/service/gpu/transforms/pgle_accuracy_checker.h"
#include "xla/service/gpu/transforms/schedule_postprocessing.h"
#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
Expand Down Expand Up @@ -496,6 +497,16 @@ std::unique_ptr<LatencyEstimator> GetLatencyEstimator(
},
module.entry_computation());
}

if (options.xla_gpu_enable_analytical_sol_latency_estimator()) {
LOG(INFO) << "Using Speed-of-Light (SoL) analytical latency estimator";
return std::make_unique<SolLatencyEstimator>(
config, std::move(gpu_latency_estimator), gpu_device_info,
[input_pointer_size = pointer_size](const Shape& shape) {
return GetSizeOfShape(shape, input_pointer_size);
},
module.entry_computation());
}
return gpu_latency_estimator;
}

Expand Down
71 changes: 71 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,77 @@ cc_library(
],
)

cc_library(
name = "sol_latency_estimator",
srcs = ["sol_latency_estimator.cc"],
hdrs = ["sol_latency_estimator.h"],
deps = [
":coalescing_analysis",
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_performance_model",
":gpu_performance_model_base",
":hlo_op_profiles",
":sol_gpu_cost_model",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/analysis:hlo_dataflow_analysis",
"//xla/hlo/analysis:indexing_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:hlo_cost_analysis",
"//xla/service:latency_hiding_scheduler",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions",
"//xla/service/gpu/fusions:fusion_emitter",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
],
)

cc_library(
name = "sol_gpu_cost_model",
srcs = ["sol_gpu_cost_model.cc"],
hdrs = ["sol_gpu_cost_model.h"],
deps = [
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/numeric:bits",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
],
)

xla_cc_test(
name = "sol_gpu_cost_model_test",
srcs = ["sol_gpu_cost_model_test.cc"],
deps = [
":sol_gpu_cost_model",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)

xla_test(
name = "analytical_latency_estimator_test",
srcs = ["analytical_latency_estimator_test.cc"],
Expand Down
189 changes: 189 additions & 0 deletions xla/service/gpu/model/sol_gpu_cost_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/model/sol_gpu_cost_model.h"

#include <cmath>
#include <cstdint>
#include <string>

#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/numeric/bits.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/hlo_module.h"

namespace xla {
namespace gpu {
namespace {
// Constants for NCCL SoL model
constexpr double kHeaderOverhead = 0.025;
constexpr absl::string_view kNcclOpLaunchUs = "nccl_op_launch_us";
constexpr absl::string_view kNicSpeedGbps = "nic_speed_gbps";
constexpr absl::string_view kChunkPrepUs = "chunk_prep_us";
constexpr absl::string_view kRttUs = "rtt_us";
constexpr absl::string_view kGpusPerNode = "gpus_per_node";
constexpr absl::string_view kChunkSizeBytes = "chunk_size_bytes";

// Returns the number of communicators in the mask.
// For example, if the mask is 0x0, this function returns 1. If the mask is 0x7,
// this function returns 8.
int NumCommunicators(const absl::string_view mask) {
// Assuming the mask is a hexadecimal number
uint64_t mask_value = std::stoul(std::string(mask), nullptr, 16);
int bit_count = absl::popcount(mask_value); // Count set bits
return static_cast<int>(std::pow(2, bit_count));
}

// Returns the number of rounds for the given collective type.
int NumRounds(const SolGPUCostModel::CollectiveType& coll_type) {
// AllReduce requires ReduceScatter and AllGather, so it has 2 rounds.
return coll_type == SolGPUCostModel::CollectiveType::kAllReduce ? 2 : 1;
}

} // namespace

SolGPUCostModel::Config GetConfig(const HloModule* module) {
SolGPUCostModel::Config config;
const auto& extra_options =
module->config()
.debug_options()
.xla_gpu_analytical_latency_estimator_options();
for (const auto& [option_name, option_value] : extra_options) {
int64_t value;
double value_d;
VLOG(2) << "[SoL] option: " << option_name << " is " << option_value;
if (option_name == kNcclOpLaunchUs &&
absl::SimpleAtoi(option_value, &value)) {
config.nccl_op_launch_time = absl::Microseconds(value);
} else if (option_name == kNicSpeedGbps &&
absl::SimpleAtod(option_value, &value_d)) {
config.nic_speed_gbps = value_d;
} else if (option_name == kChunkPrepUs &&
absl::SimpleAtoi(option_value, &value)) {
config.chunk_prep_time = absl::Microseconds(value);
} else if (option_name == kRttUs &&
absl::SimpleAtoi(option_value, &value)) {
config.rtt = absl::Microseconds(value);
} else if (option_name == kGpusPerNode &&
absl::SimpleAtoi(option_value, &value)) {
config.gpus_per_node = value;
} else if (option_name == kChunkSizeBytes &&
absl::SimpleAtoi(option_value, &value)) {
config.chunk_size_bytes = value;
}
}
return config;
}

SolGPUCostModel::SolGPUCostModel(const Config& sys_config)
: xla_flag_config_(sys_config) {
VLOG(2) << "[SoL] NIC speed: " << xla_flag_config_.nic_speed_gbps;
VLOG(2) << "[SoL] RTT: " << xla_flag_config_.rtt;
VLOG(2) << "[SoL] Chunk preparation time: "
<< xla_flag_config_.chunk_prep_time;
VLOG(2) << "[SoL] NCCL op launch time: "
<< xla_flag_config_.nccl_op_launch_time;
VLOG(2) << "[SoL] GPUs per node: " << xla_flag_config_.gpus_per_node;
}

// This is a insignificant term, and we are making it consistent
// with the existing formula.
absl::Duration SolGPUCostModel::ChunkPrepLatency(
const int64_t per_gpu_msg_size_bytes) const {
return std::ceil(static_cast<double>(per_gpu_msg_size_bytes) /
xla_flag_config_.chunk_size_bytes) *
xla_flag_config_.chunk_prep_time;
}

absl::Duration SolGPUCostModel::TransferDuration(
const int64_t per_gpu_msg_size_bytes) const {
// x1e6 to comvert secs to microseconds;
// x1024*1024 *1024 to convert Gbytes/sec to bytes/sec
const long double ret =
(1e6 * static_cast<long double>(per_gpu_msg_size_bytes)) /
(std::pow(1024.0, 3) * xla_flag_config_.nic_speed_gbps);
return absl::Microseconds(ret * (1 + kHeaderOverhead));
}

absl::Duration SolGPUCostModel::RingLatency(
const int64_t buff_size_bytes, const int num_nodes,
const CollectiveType& coll_type, const absl::string_view mask) const {
const int num_gpus = NumGpusPerComm(num_nodes, coll_type, mask);

int64_t per_gpu_msg_size_bytes;
if (coll_type == CollectiveType::kSendRecv) {
per_gpu_msg_size_bytes = buff_size_bytes;
} else {
per_gpu_msg_size_bytes = buff_size_bytes / num_gpus;
}

// This is the number of GPUs per communicator per node. We assume that each
// GPU has a NIC, and this is also the number of NICs per communicator per
// node.
// Note that this happens to be correct value (i.e. 1) for SendRecv.
int num_gpus_per_node = num_gpus / num_nodes;

// In each channel, consider one GPU next to the Ethernet link. Below is the
// sum of 3 time costs for each piece of data of size
// `per_gpu_msg_size_bytes`
//
// 1. transfer duration defined by the NIC bandwidth,
// 2. chunk preparation latency, and
// 3. RTT
//
// then followed by two factors:
//
// 1. Multiply by `num_gpus - 1`, as `num_gpus - 1` pieces of data will be
// sent over the link in AllGather.
// 2. Divide by `num_gpus_per_node` as there are `num_gpus_per_node` NICs
// and
// GPUs in each node for parallelism.
//
// Better estimates of terms like this will come in future versions
// of the SoL model.
absl::Duration ret = TransferDuration(per_gpu_msg_size_bytes) +
ChunkPrepLatency(per_gpu_msg_size_bytes) +
xla_flag_config_.rtt;
ret *= (num_gpus - 1.0) / static_cast<long double>(num_gpus_per_node);
// Multiply by the number of rounds, which is different for AllReduce.
ret = ret * NumRounds(coll_type);

// Time to initiate the collective.
return ret + xla_flag_config_.nccl_op_launch_time;
}

// Helper functions
int SolGPUCostModel::NumGpusPerComm(int num_nodes,
const CollectiveType& coll_type,
const absl::string_view mask) const {
if (coll_type == CollectiveType::kSendRecv) {
return 2;
}
int num_comms = NumCommunicators(mask);
CHECK_EQ(xla_flag_config_.gpus_per_node % num_comms, 0)
<< "GPU_PER_NODE must be divisible by the number of communicators. "
"GPU_PER_NODE: "
<< xla_flag_config_.gpus_per_node
<< " Number of communicators: " << num_comms
<< ". Adjust the number of GPUs per node with the flag "
"gpus_per_node in xla_gpu_analytical_latency_estimator_options.";
return num_nodes * xla_flag_config_.gpus_per_node / num_comms;
}

} // namespace gpu
} // namespace xla
Loading

0 comments on commit 09bc536

Please sign in to comment.