diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 48e3883ffdcb1..412a7c188c844 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -169,6 +170,25 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_dump_latency_hiding_schedule(false); opts.set_xla_gpu_enable_latency_hiding_scheduler(false); opts.set_xla_gpu_enable_analytical_latency_estimator(false); + opts.set_xla_gpu_enable_analytical_sol_latency_estimator(false); + auto* sol_estimator_defaults = + opts.mutable_xla_gpu_analytical_latency_estimator_options(); + sol_estimator_defaults->emplace( + "nccl_op_launch_us", + absl::StrCat(static_cast(100.0f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "nic_speed_gbps", + absl::StrCat(static_cast(55.56f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "chunk_prep_us", + absl::StrCat(static_cast(13.34f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "rtt_us", + absl::StrCat(static_cast(68.89f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "chunk_size_bytes", absl::StrCat(kDefaultNcclCostModelChunkSizeBytes)); + sol_estimator_defaults->emplace( + "gpus_per_node", absl::StrCat(kDefaultNcclCostModelGPUsPerNode)); opts.set_xla_gpu_pgle_profile_file_or_directory_path(""); opts.set_xla_gpu_memory_limit_slop_factor(95); opts.set_xla_gpu_enable_highest_priority_async_stream(true); @@ -470,6 +490,17 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + // Custom "sub-parser" lambda for + // xla_gpu_analytical_latency_estimator_options. + auto setter_for_xla_gpu_analytical_latency_estimator_options = + [debug_options](std::string comma_separated_values) { + google::protobuf::Map* options_map = + debug_options + ->mutable_xla_gpu_analytical_latency_estimator_options(); + parse_xla_backend_extra_options(options_map, comma_separated_values); + return true; + }; + // Custom "sub-parser" lambda for xla_partitioning_algorithm. auto setter_for_xla_partitioning_algorithm = [debug_options](const std::string& value) { @@ -1568,6 +1599,25 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_analytical_latency_estimator(), "Enable analytical latency estimator for latency-hiding scheduler for " "XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_analytical_sol_latency_estimator", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_analytical_sol_latency_estimator), + debug_options->xla_gpu_enable_analytical_sol_latency_estimator(), + "Enable analytical Speed-of-Light latency estimator for latency-hiding " + "scheduler for XLA:GPU, must be used without " + "xla_gpu_enable_analytical_latency_estimator. It can also benefit from " + "user-passed options in xla_gpu_analytical_latency_estimator_options")); + flag_list->push_back(tsl::Flag( + "xla_gpu_analytical_latency_estimator_options", + setter_for_xla_gpu_analytical_latency_estimator_options, "", + "Extra platform-specific options to improve analytical latency " + "estimator precision; comma-separated list of 'key=val' " + "strings (=val may be omitted); no whitespace around commas." + "Available options: " + "--xla_gpu_analytical_latency_estimator_options='nccl_op_launch_ms=55," + "nic_speed_gbps=40,chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4," + "chunk_size_bytes=1024'")); flag_list->push_back(tsl::Flag( "xla_gpu_pgle_profile_file_or_directory_path", string_setter_for( diff --git a/xla/service/collective_utils.h b/xla/service/collective_utils.h index 916e007dc9b2e..dc69009445686 100644 --- a/xla/service/collective_utils.h +++ b/xla/service/collective_utils.h @@ -32,6 +32,11 @@ constexpr int64_t kDefaultAllGatherCombineThreshold = 30 * 1024 * 1024 + 7; // pass will combine collectives. constexpr int64_t kDefaultReduceScatterCombineThreshold = 30 * 1024 * 1024 + 7; +// Defines the default coefficient for the SoL NCCL collective cost model. +// Note: XLA flags allow a user to override the default values of the model. +constexpr float kDefaultNcclCostModelCoeff = 0.45f; +constexpr int64_t kDefaultNcclCostModelChunkSizeBytes = 4194304; // 4MB +constexpr int64_t kDefaultNcclCostModelGPUsPerNode = 8; } // namespace xla #endif // XLA_SERVICE_COLLECTIVE_UTILS_H_ diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index f8883ef240ad2..b9908ffd7d0ee 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2120,6 +2120,7 @@ cc_library( "//xla/service:p2p_schedule_preparation", "//xla/service:profile_guided_latency_estimator", "//xla/service/gpu/model:analytical_latency_estimator", + "//xla/service/gpu/model:sol_latency_estimator", "//xla/service/gpu/transforms:pgle_accuracy_checker", "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 0067254f72b65..5a5cc36dce644 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/service/gpu/flag_utils.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" +#include "xla/service/gpu/model/sol_latency_estimator.h" #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" #include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" @@ -496,6 +497,16 @@ std::unique_ptr GetLatencyEstimator( }, module.entry_computation()); } + + if (options.xla_gpu_enable_analytical_sol_latency_estimator()) { + LOG(INFO) << "Using Speed-of-Light (SoL) analytical latency estimator"; + return std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + module.entry_computation()); + } return gpu_latency_estimator; } diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index e341eb8703c71..4daa21c4c485b 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -43,6 +43,77 @@ cc_library( ], ) +cc_library( + name = "sol_latency_estimator", + srcs = ["sol_latency_estimator.cc"], + hdrs = ["sol_latency_estimator.h"], + deps = [ + ":coalescing_analysis", + ":fusion_analysis_cache", + ":gpu_hlo_cost_analysis", + ":gpu_performance_model", + ":gpu_performance_model_base", + ":hlo_op_profiles", + ":sol_gpu_cost_model", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_traversal", + "//xla/service:hlo_cost_analysis", + "//xla/service:latency_hiding_scheduler", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + ], +) + +cc_library( + name = "sol_gpu_cost_model", + srcs = ["sol_gpu_cost_model.cc"], + hdrs = ["sol_gpu_cost_model.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "sol_gpu_cost_model_test", + srcs = ["sol_gpu_cost_model_test.cc"], + deps = [ + ":sol_gpu_cost_model", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + ], +) + xla_test( name = "analytical_latency_estimator_test", srcs = ["analytical_latency_estimator_test.cc"], diff --git a/xla/service/gpu/model/sol_gpu_cost_model.cc b/xla/service/gpu/model/sol_gpu_cost_model.cc new file mode 100644 index 0000000000000..e7a64aac68e43 --- /dev/null +++ b/xla/service/gpu/model/sol_gpu_cost_model.cc @@ -0,0 +1,189 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { +namespace gpu { +namespace { +// Constants for NCCL SoL model +constexpr double kHeaderOverhead = 0.025; +constexpr absl::string_view kNcclOpLaunchUs = "nccl_op_launch_us"; +constexpr absl::string_view kNicSpeedGbps = "nic_speed_gbps"; +constexpr absl::string_view kChunkPrepUs = "chunk_prep_us"; +constexpr absl::string_view kRttUs = "rtt_us"; +constexpr absl::string_view kGpusPerNode = "gpus_per_node"; +constexpr absl::string_view kChunkSizeBytes = "chunk_size_bytes"; + +// Returns the number of communicators in the mask. +// For example, if the mask is 0x0, this function returns 1. If the mask is 0x7, +// this function returns 8. +int NumCommunicators(const absl::string_view mask) { + // Assuming the mask is a hexadecimal number + uint64_t mask_value = std::stoul(std::string(mask), nullptr, 16); + int bit_count = absl::popcount(mask_value); // Count set bits + return static_cast(std::pow(2, bit_count)); +} + +// Returns the number of rounds for the given collective type. +int NumRounds(const SolGPUCostModel::CollectiveType& coll_type) { + // AllReduce requires ReduceScatter and AllGather, so it has 2 rounds. + return coll_type == SolGPUCostModel::CollectiveType::kAllReduce ? 2 : 1; +} + +} // namespace + +SolGPUCostModel::Config GetConfig(const HloModule* module) { + SolGPUCostModel::Config config; + const auto& extra_options = + module->config() + .debug_options() + .xla_gpu_analytical_latency_estimator_options(); + for (const auto& [option_name, option_value] : extra_options) { + int64_t value; + double value_d; + VLOG(2) << "[SoL] option: " << option_name << " is " << option_value; + if (option_name == kNcclOpLaunchUs && + absl::SimpleAtoi(option_value, &value)) { + config.nccl_op_launch_time = absl::Microseconds(value); + } else if (option_name == kNicSpeedGbps && + absl::SimpleAtod(option_value, &value_d)) { + config.nic_speed_gbps = value_d; + } else if (option_name == kChunkPrepUs && + absl::SimpleAtoi(option_value, &value)) { + config.chunk_prep_time = absl::Microseconds(value); + } else if (option_name == kRttUs && + absl::SimpleAtoi(option_value, &value)) { + config.rtt = absl::Microseconds(value); + } else if (option_name == kGpusPerNode && + absl::SimpleAtoi(option_value, &value)) { + config.gpus_per_node = value; + } else if (option_name == kChunkSizeBytes && + absl::SimpleAtoi(option_value, &value)) { + config.chunk_size_bytes = value; + } + } + return config; +} + +SolGPUCostModel::SolGPUCostModel(const Config& sys_config) + : xla_flag_config_(sys_config) { + VLOG(2) << "[SoL] NIC speed: " << xla_flag_config_.nic_speed_gbps; + VLOG(2) << "[SoL] RTT: " << xla_flag_config_.rtt; + VLOG(2) << "[SoL] Chunk preparation time: " + << xla_flag_config_.chunk_prep_time; + VLOG(2) << "[SoL] NCCL op launch time: " + << xla_flag_config_.nccl_op_launch_time; + VLOG(2) << "[SoL] GPUs per node: " << xla_flag_config_.gpus_per_node; +} + +// This is a insignificant term, and we are making it consistent +// with the existing formula. +absl::Duration SolGPUCostModel::ChunkPrepLatency( + const int64_t per_gpu_msg_size_bytes) const { + return std::ceil(static_cast(per_gpu_msg_size_bytes) / + xla_flag_config_.chunk_size_bytes) * + xla_flag_config_.chunk_prep_time; +} + +absl::Duration SolGPUCostModel::TransferDuration( + const int64_t per_gpu_msg_size_bytes) const { + // x1e6 to comvert secs to microseconds; + // x1024*1024 *1024 to convert Gbytes/sec to bytes/sec + const long double ret = + (1e6 * static_cast(per_gpu_msg_size_bytes)) / + (std::pow(1024.0, 3) * xla_flag_config_.nic_speed_gbps); + return absl::Microseconds(ret * (1 + kHeaderOverhead)); +} + +absl::Duration SolGPUCostModel::RingLatency( + const int64_t buff_size_bytes, const int num_nodes, + const CollectiveType& coll_type, const absl::string_view mask) const { + const int num_gpus = NumGpusPerComm(num_nodes, coll_type, mask); + + int64_t per_gpu_msg_size_bytes; + if (coll_type == CollectiveType::kSendRecv) { + per_gpu_msg_size_bytes = buff_size_bytes; + } else { + per_gpu_msg_size_bytes = buff_size_bytes / num_gpus; + } + + // This is the number of GPUs per communicator per node. We assume that each + // GPU has a NIC, and this is also the number of NICs per communicator per + // node. + // Note that this happens to be correct value (i.e. 1) for SendRecv. + int num_gpus_per_node = num_gpus / num_nodes; + + // In each channel, consider one GPU next to the Ethernet link. Below is the + // sum of 3 time costs for each piece of data of size + // `per_gpu_msg_size_bytes` + // + // 1. transfer duration defined by the NIC bandwidth, + // 2. chunk preparation latency, and + // 3. RTT + // + // then followed by two factors: + // + // 1. Multiply by `num_gpus - 1`, as `num_gpus - 1` pieces of data will be + // sent over the link in AllGather. + // 2. Divide by `num_gpus_per_node` as there are `num_gpus_per_node` NICs + // and + // GPUs in each node for parallelism. + // + // Better estimates of terms like this will come in future versions + // of the SoL model. + absl::Duration ret = TransferDuration(per_gpu_msg_size_bytes) + + ChunkPrepLatency(per_gpu_msg_size_bytes) + + xla_flag_config_.rtt; + ret *= (num_gpus - 1.0) / static_cast(num_gpus_per_node); + // Multiply by the number of rounds, which is different for AllReduce. + ret = ret * NumRounds(coll_type); + + // Time to initiate the collective. + return ret + xla_flag_config_.nccl_op_launch_time; +} + +// Helper functions +int SolGPUCostModel::NumGpusPerComm(int num_nodes, + const CollectiveType& coll_type, + const absl::string_view mask) const { + if (coll_type == CollectiveType::kSendRecv) { + return 2; + } + int num_comms = NumCommunicators(mask); + CHECK_EQ(xla_flag_config_.gpus_per_node % num_comms, 0) + << "GPU_PER_NODE must be divisible by the number of communicators. " + "GPU_PER_NODE: " + << xla_flag_config_.gpus_per_node + << " Number of communicators: " << num_comms + << ". Adjust the number of GPUs per node with the flag " + "gpus_per_node in xla_gpu_analytical_latency_estimator_options."; + return num_nodes * xla_flag_config_.gpus_per_node / num_comms; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/sol_gpu_cost_model.h b/xla/service/gpu/model/sol_gpu_cost_model.h new file mode 100644 index 0000000000000..77a449ae3df7a --- /dev/null +++ b/xla/service/gpu/model/sol_gpu_cost_model.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { +namespace gpu { +inline constexpr absl::string_view kSplitMaskWorldLevel = "0x0"; + +class SolGPUCostModel { + // Speed-of-Light (SoL) analytical cost model for NCCL collectives. + public: + // Tunable system configuration, see + // xla_gpu_analytical_latency_estimator_options + struct Config { + absl::Duration nccl_op_launch_time; + double nic_speed_gbps; // it's GBytes/s, not Gbit/s (ex: 40Gb/s = 5GB/s) + absl::Duration chunk_prep_time; + absl::Duration rtt; + int64_t gpus_per_node; + int64_t chunk_size_bytes; + }; + enum CollectiveAlgorithmType { + RING = 0, + TREE, + }; + enum class CollectiveType { + kAllReduce, + kAllGather, + kReduceScatter, + kSendRecv, + }; + explicit SolGPUCostModel(const Config& sys_config); + + // Returns the latency of a NCCL ring collective. + // + // `buff_size_bytes`: the size of the message to be transferred. + // `num_nodes`: the number of nodes participating in the ring. + // `coll_type`: the type of the collective (eg AllGather). + // `mask`: the mask of the collective (AllWorld 0x0 vs RailAligned 0x7). + absl::Duration RingLatency( + int64_t buff_size_bytes, int num_nodes, const CollectiveType& coll_type, + absl::string_view mask = kSplitMaskWorldLevel) const; + + private: + // Helper functions to estimate the latency subcomponents + absl::Duration ChunkPrepLatency(int64_t per_gpu_msg_size_bytes) const; + + absl::Duration TransferDuration(int64_t per_gpu_msg_size_bytes) const; + // NumGpusPerComm returns GPUs number participating in a given NCCL + // collective operation. + int NumGpusPerComm(int num_nodes, const CollectiveType& coll_type, + absl::string_view mask) const; + + // SoL-related configuration for NCCL cost modelling passed by user as flags. + Config xla_flag_config_; +}; + +// Extract the SoL-related configuration from XLA flags. +SolGPUCostModel::Config GetConfig(const HloModule* module); +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ diff --git a/xla/service/gpu/model/sol_gpu_cost_model_test.cc b/xla/service/gpu/model/sol_gpu_cost_model_test.cc new file mode 100644 index 0000000000000..d7892a13fe713 --- /dev/null +++ b/xla/service/gpu/model/sol_gpu_cost_model_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model.h" + +#include + +#include +#include "absl/time/time.h" +namespace xla { +namespace gpu { +namespace { +constexpr int64_t kTenMB = 10 * 1024 * 1024; // 10MB + +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +struct RingLatencyTestCase { + SolGPUCostModel::CollectiveType collective_type; + absl::Duration expected_latency; +}; + +class SolGPUCostModelTest : public TestWithParam { + protected: + SolGPUCostModelTest() + : model_({ + /*nccl_op_launch_time=*/absl::Microseconds(100), + /*nic_speed_gbps=*/100, + /*chunk_prep_time=*/absl::Microseconds(100), + /*rtt=*/absl::Microseconds(100), + /*gpus_per_node=*/100, + /*chunk_size_bytes=*/4 * 1024 * 1024, + }) {} + SolGPUCostModel model_; +}; + +TEST_P(SolGPUCostModelTest, TestRingLatency) { + const RingLatencyTestCase& test_case = GetParam(); + absl::Duration actual_latency = + absl::Trunc(model_.RingLatency(kTenMB, 1, test_case.collective_type), + absl::Microseconds(1)); + EXPECT_EQ(actual_latency, test_case.expected_latency); +} + +INSTANTIATE_TEST_SUITE_P( + SolGPUCostModelTests, SolGPUCostModelTest, + ValuesIn({ + {SolGPUCostModel::CollectiveType::kAllGather, absl::Microseconds(298)}, + {SolGPUCostModel::CollectiveType::kAllReduce, absl::Microseconds(497)}, + {SolGPUCostModel::CollectiveType::kReduceScatter, + absl::Microseconds(298)}, + {SolGPUCostModel::CollectiveType::kSendRecv, absl::Microseconds(350)}, + })); +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/sol_latency_estimator.cc b/xla/service/gpu/model/sol_latency_estimator.cc new file mode 100644 index 0000000000000..1bcd36c8134f8 --- /dev/null +++ b/xla/service/gpu/model/sol_latency_estimator.cc @@ -0,0 +1,195 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_latency_estimator.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/time/time.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/status.h" + +namespace xla { +namespace gpu { + +namespace { + +int64_t ComputeMessageSize(const HloInstruction& instr, + HloCostAnalysis::ShapeSizeFunction fun) { + int64_t msg_size = 0; + ShapeUtil::ForEachSubshape( + instr.shape(), + [&msg_size, &fun](const Shape& subshape, const ShapeIndex&) { + if (subshape.IsArray()) { + msg_size += fun(subshape); + } + }); + return msg_size; +} + +int GetNumGpus(const HloInstruction& instr) { + const HloInstruction* i = &instr; + if (instr.opcode() == HloOpcode::kAsyncStart) { + i = instr.async_wrapped_instruction(); + } + int size = 0; + for (auto& rg : i->replica_groups()) { + size += rg.replica_ids_size(); + } + return size; +} + +/*static*/ absl::Duration ComputeCollectiveTime( + const HloInstruction& instr, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis::ShapeSizeFunction shape_size_fn, + const SolGPUCostModel::Config& sol_flags) { + const int num_nodes = GetNumGpus(instr) / sol_flags.gpus_per_node; + if (num_nodes == 1) { + VLOG(8) << "Returning only kernel launch overhead for a single node."; + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; + } + + if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { + VLOG(8) << "Returning 0 cost for async done op " << instr.name(); + return absl::ZeroDuration(); + } + SolGPUCostModel sol_model(sol_flags); + const int64_t msg_size = ComputeMessageSize(instr, shape_size_fn); + + switch (instr.opcode()) { + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kAllGather); + } + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kAllReduce); + } + case HloOpcode::kReduceScatter: { + return sol_model.RingLatency( + msg_size, num_nodes, SolGPUCostModel::CollectiveType::kReduceScatter); + } + case HloOpcode::kAsyncStart: { + if (instr.async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return sol_model.RingLatency( + msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kReduceScatter); + } + break; + } + case HloOpcode::kRecv: + case HloOpcode::kSend: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kSendRecv); + } + // note: AllToAll is not yet supported in XLA + default: { + LOG(WARNING) + << "[SoL] Runtime estimate for " << instr.name() + << " not implemented. Returning only the kernel launch time."; + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; + } + } + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; +} + +} // namespace + +LatencyEstimator::TimeCost SolLatencyEstimator::GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& target) const { + const HloOpcode from_op = from.GetInstr().opcode(); + if (!config_.schedule_send_recvs && + (from_op == HloOpcode::kSend || from_op == HloOpcode::kRecv)) { + return kLowLatency; + } + + if (IsAsyncPair(from, target)) { + double coll_time = absl::ToDoubleMicroseconds(ComputeCollectiveTime( + from.GetInstr(), gpu_info_, shape_size_function_, sol_flags_)); + VLOG(10) << "[SoL] Analytical estimator calculated latency between " + << from.GetInstr().name() << " and " << target.GetInstr().name() + << " to be: " << coll_time << " us."; + return coll_time; + } + return latency_estimator_->GetLatencyBetween(from, target); +} + +LatencyEstimator::TimeCost SolLatencyEstimator::NodeCost( + const HloInstruction* instr) const { + if (hlo_query::IsAsyncCollectiveStartOp(instr, /*include_send_recv=*/true) || + hlo_query::IsAsyncCollectiveDoneOp(instr, /*include_send_recv=*/true)) { + return kLowCost; + } + + absl::Duration total_estimated_time = + GpuPerformanceModel::EstimateRunTimeForInstruction( + instr, gpu_info_, &*cost_analysis_, + GpuPerformanceModelOptions::Default()) + .exec_time; + LatencyEstimator::TimeCost cost_in_us = + absl::ToDoubleMicroseconds(total_estimated_time); + VLOG(10) << "Analytical estimator calculated cost for: " << instr->name() + << ". Cost: " << cost_in_us; + return cost_in_us; +} + +SolLatencyEstimator::SolLatencyEstimator( + const SchedulerConfig& config, + std::unique_ptr latency_estimator, + const se::DeviceDescription& gpu_info, + HloCostAnalysis::ShapeSizeFunction shape_size_function, + HloComputation* computation) + : config_(config), + gpu_info_(gpu_info), + latency_estimator_(std::move(latency_estimator)), + shape_size_function_(shape_size_function), + sol_flags_(GetConfig(computation->parent())) { + cost_analysis_.emplace( + GpuHloCostAnalysis::Options{shape_size_function_, + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}, + gpu_info_); + TF_CHECK_OK(computation->Accept(&cost_analysis_.value())); + if (sol_flags_.nccl_op_launch_time == absl::ZeroDuration() || + sol_flags_.nic_speed_gbps == 0 || + sol_flags_.chunk_prep_time == absl::ZeroDuration() || + sol_flags_.rtt == absl::ZeroDuration() || sol_flags_.gpus_per_node == 0) { + LOG(WARNING) << "[SoL] Failed to parse SoL system config options."; + } +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/sol_latency_estimator.h b/xla/service/gpu/model/sol_latency_estimator.h new file mode 100644 index 0000000000000..4f32e9703b0c4 --- /dev/null +++ b/xla/service/gpu/model/sol_latency_estimator.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ +#define XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +class SolLatencyEstimator : public LatencyEstimator { + public: + // Implementation of SolLatencyEstimator using HloAnalysis and + // GPUPerformanceModel to estimate latencies for instructions. + SolLatencyEstimator(const SchedulerConfig& config, + std::unique_ptr latency_estimator, + const se::DeviceDescription& gpu_info, + HloCostAnalysis::ShapeSizeFunction shape_size_function, + HloComputation* computation); + + TimeCost GetLatencyBetween(const HloGraphNode& from, + const HloGraphNode& target) const override; + TimeCost NodeCost(const HloInstruction* instr) const override; + int CyclesPerMicrosecond() const override { + return latency_estimator_->CyclesPerMicrosecond(); + } + + static constexpr TimeCost kLowCost = 1.0; + static constexpr TimeCost kLowLatency = 1.0; + + private: + const SchedulerConfig config_; + const se::DeviceDescription& gpu_info_; + std::optional cost_analysis_; + std::unique_ptr latency_estimator_; + HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const SolGPUCostModel::Config sol_flags_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ diff --git a/xla/xla.proto b/xla/xla.proto index e4c18638d1dae..413a2c341158d 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -522,6 +522,15 @@ message DebugOptions { // xla_gpu_enable_async_collectives reserved 152, 278, 183, 199, 200, 201, 238; + // Enables NCCL Speed-of-Light (SoL) analytical cost model + bool xla_gpu_enable_analytical_sol_latency_estimator = 356; + // Extra platform-specific options to improve analytical latency + // estimator precision; comma-separated list of 'key=val' strings (=val may be + // omitted); no whitespace around commas. Available options: + // --xla_gpu_analytical_latency_estimator_options= + //'nccl_op_launch_ms=55,nic_speed_gbps=40, + // chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4,chunk_size_bytes=1024' + map xla_gpu_analytical_latency_estimator_options = 357; // Size threshold (in bytes) for the GPU collective combiners. int64 xla_gpu_all_reduce_combine_threshold_bytes = 157; int64 xla_gpu_all_gather_combine_threshold_bytes = 212; @@ -1084,7 +1093,7 @@ message DebugOptions { // be deterministic, although with additional overhead. bool xla_gpu_enable_scatter_determinism_expander = 345; - // Next id: 356 + // Next id: 358 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.