Skip to content

Commit

Permalink
[XLA:GPU] Delete no-op logic for constructing collective combiner keys.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714965332
  • Loading branch information
allanrenucci authored and Google-ML-Automation committed Jan 13, 2025
1 parent 8d8d97a commit 8e3c814
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 44 deletions.
6 changes: 3 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3129,10 +3129,10 @@ cc_library(
"//xla/hlo/transforms/collectives:all_gather_combiner",
"//xla/service:hlo_domain_map",
"//xla/stream_executor:device_description",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -3167,10 +3167,10 @@ cc_library(
"//xla/service:hlo_domain_map",
"//xla/service:reduce_scatter_combiner",
"//xla/stream_executor:device_description",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -3203,10 +3203,10 @@ cc_library(
"//xla/hlo/transforms/collectives:all_reduce_combiner",
"//xla/service:hlo_domain_map",
"//xla/stream_executor:device_description",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
18 changes: 4 additions & 14 deletions xla/service/gpu/all_gather_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.

#include "xla/service/gpu/all_gather_combiner.h"

#include <cstdint>
#include <optional>
#include <string>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
Expand All @@ -29,7 +27,7 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_collective_combiner_utils.h"
#include "xla/service/hlo_domain_map.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"

namespace xla::gpu {

Expand All @@ -38,23 +36,15 @@ namespace {
std::optional<AllGatherCombiner::GroupKey> PipelinedCombinerKey(
const HloInstruction* instruction, const HloDomainMap& domain_map,
bool combine_by_dim, bool combine_different_dtypes) {
auto combined_key = AllGatherCombiner::CombineKey(
instruction, domain_map, combine_by_dim, combine_different_dtypes);
if (!combined_key.has_value()) {
return std::nullopt;
}
auto backend_config = instruction->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
return std::nullopt;
}
bool is_pipelined =
backend_config->collective_backend_config().is_pipelined();
if (!is_pipelined) {
if (!backend_config->collective_backend_config().is_pipelined()) {
return std::nullopt;
}
AllGatherCombiner::GetGroupKeyExtraArgs(*combined_key)
.append(" " + std::to_string(static_cast<int64_t>(is_pipelined)));
return combined_key.value();
return AllGatherCombiner::CombineKey(instruction, domain_map, combine_by_dim,
combine_different_dtypes);
}

} // namespace
Expand Down
16 changes: 3 additions & 13 deletions xla/service/gpu/all_reduce_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.

#include "xla/service/gpu/all_reduce_combiner.h"

#include <cstdint>
#include <optional>
#include <string>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
Expand All @@ -29,30 +27,22 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_collective_combiner_utils.h"
#include "xla/service/hlo_domain_map.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"

namespace xla::gpu {

namespace {

std::optional<AllReduceCombiner::GroupKey> PipelinedCombinerKey(
const HloInstruction* instruction, const HloDomainMap& domain_map) {
auto combined_key = AllReduceCombiner::CombineKey(instruction, domain_map);
if (!combined_key.has_value()) {
return std::nullopt;
}
auto backend_config = instruction->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
return std::nullopt;
}
bool is_pipelined =
backend_config->collective_backend_config().is_pipelined();
if (!is_pipelined) {
if (!backend_config->collective_backend_config().is_pipelined()) {
return std::nullopt;
}
AllReduceCombiner::GetGroupKeyExtraArgs(*combined_key)
.append(" " + std::to_string(static_cast<int64_t>(is_pipelined)));
return combined_key.value();
return AllReduceCombiner::CombineKey(instruction, domain_map);
}

} // namespace
Expand Down
18 changes: 4 additions & 14 deletions xla/service/gpu/reduce_scatter_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.

#include "xla/service/gpu/reduce_scatter_combiner.h"

#include <cstdint>
#include <optional>
#include <string>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
Expand All @@ -28,31 +26,23 @@ limitations under the License.
#include "xla/service/gpu/gpu_collective_combiner_utils.h"
#include "xla/service/hlo_domain_map.h"
#include "xla/service/reduce_scatter_combiner.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"

namespace xla::gpu {
namespace {

std::optional<ReduceScatterCombiner::GroupKey> PipelinedCombinerKey(
const HloInstruction* instruction, const HloDomainMap& domain_map,
bool combine_by_dim) {
auto combined_key = ReduceScatterCombiner::CombineKey(instruction, domain_map,
combine_by_dim);
if (!combined_key.has_value()) {
return std::nullopt;
}
auto backend_config = instruction->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
return std::nullopt;
}
bool is_pipelined =
backend_config->collective_backend_config().is_pipelined();
if (!is_pipelined) {
if (!backend_config->collective_backend_config().is_pipelined()) {
return std::nullopt;
}
ReduceScatterCombiner::GetGroupKeyExtraArgs(*combined_key)
.append(" " + std::to_string(static_cast<int64_t>(is_pipelined)));
return combined_key.value();
return ReduceScatterCombiner::CombineKey(instruction, domain_map,
combine_by_dim);
}

} // namespace
Expand Down

0 comments on commit 8e3c814

Please sign in to comment.