diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 5bc5b739529ccf..41989596fa963f 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -607,7 +607,6 @@ xla_test( deps = [ ":backend_configs_cc", ":gpu_device_info_for_tests", - ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", ":triton_fusion_analysis", @@ -620,7 +619,6 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/model:indexing_test_utils", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cublas_plugin", diff --git a/xla/service/gpu/ir_emitter_triton.cc b/xla/service/gpu/ir_emitter_triton.cc index a2a78288b670ea..c4ce4acc04f381 100644 --- a/xla/service/gpu/ir_emitter_triton.cc +++ b/xla/service/gpu/ir_emitter_triton.cc @@ -113,7 +113,6 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" @@ -125,7 +124,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" @@ -2282,46 +2280,12 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, return absl::OkStatus(); } -// Computes indexing map from program id into the tile offset for the given -// shape and tile sizes. -IndexingMap ComputeProgramIdToOutputTileIndexing( - absl::Span dimensions, absl::Span tile_sizes, - mlir::MLIRContext* mlir_context) { - CHECK_EQ(dimensions.size(), tile_sizes.size()); - - int num_tiles = 1; - std::vector outer_loop_bounds; - outer_loop_bounds.reserve(dimensions.size()); - for (auto [dim_size, tile_size] : llvm::zip(dimensions, tile_sizes)) { - int num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size; - - num_tiles *= num_tiles_per_dim; - outer_loop_bounds.push_back(num_tiles_per_dim); - } - - mlir::AffineExpr program_id = mlir::getAffineDimExpr(0, mlir_context); - - // Delinearize the block id. - auto tile_exprs = - DelinearizeIndex(outer_loop_bounds, program_id, mlir_context); - - // Scale each index by the tile size to produce tile offset. - for (auto [tile_expr, tile_size] : llvm::zip(tile_exprs, tile_sizes)) { - tile_expr = tile_expr * tile_size; - } - - return IndexingMap::FromTensorSizes( - mlir::AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/0, tile_exprs, mlir_context), - /*dim_upper_bounds=*/{num_tiles}, /*symbol_upper_bounds=*/{}); -} - // Computes the base pointer offset for the given pid and shape. // `tile_offset_indexing` is a mapping from // (program_id) -> [tile_offset0, ..., tile_offsetN] -StatusOr ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid, - const Shape& shape, - const IndexingMap& tile_offset_indexing) { +Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid, + const Shape& shape, + const IndexingMap& tile_offset_indexing) { ArrayRef dimension_exprs = tile_offset_indexing.GetAffineMap().getResults(); @@ -2333,25 +2297,10 @@ StatusOr ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid, stride *= shape.dimensions(i); } - // A symbol in an indexing map means that to produce on element of output, we - // need to read all elements of input in the symbol range. Since this function - // computes start of the tile, we need to substitute each symbol with its - // lower bound value. We assume here the iteration order is normalized. - // TODO(b/330906085): Support cases when tile offsets are not 0. - for (const Interval& symbol_bound : tile_offset_indexing.GetSymbolBounds()) { - if (symbol_bound.lower != 0) { - return absl::FailedPreconditionError(absl::StrCat( - "Symbol lower bound is not zero. ", tile_offset_indexing.ToString())); - } - } - - std::vector symbol_lower_bounds( - tile_offset_indexing.GetSymbolCount(), - b.create(b.getIndexAttr(0))); - return b.create( - b.getI64Type(), mlir_converter::ApplyAffineExpr(linear_index, pid, - symbol_lower_bounds, b)); + b.getI64Type(), + mlir_converter::ApplyAffineExpr(linear_index, /*dims=*/pid, + /*symbols=*/{}, b)); } absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, @@ -2360,8 +2309,6 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, SymbolicTileAnalysis* analysis, const HloComputation* computation, mlir::triton::FuncOp fn) { - mlir::MLIRContext* mlir_context = analysis->GetMLIRContext(); - const HloInstruction* root = computation->root_instruction(); auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name())); ImplicitLocOpBuilder b(loc, builder); @@ -2412,10 +2359,6 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, analysis->SetTileSizes(output_tile_sizes); - IndexingMap program_id_to_output_tile_indexing = - ComputeProgramIdToOutputTileIndexing(root_shape.dimensions(), - output_tile_sizes, mlir_context); - // block_size must be a power of two. int result_block_size = llvm::PowerOf2Ceil(row_len); @@ -2425,14 +2368,12 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, } // Emits load instructions - auto emit_param_load = - [&](const SymbolicTiledHloInstruction& tiled_hlo_instruction) + auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo) -> absl::StatusOr { std::vector tile_sizes, tile_strides, tile_offsets; - for (auto [size, stride, offset] : - llvm::zip(analysis->TileSizes(tiled_hlo_instruction), - analysis->TileStrides(tiled_hlo_instruction), - analysis->TileOffsets(tiled_hlo_instruction))) { + for (auto [size, stride, offset] : llvm::zip( + analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo), + analysis->TileOffsets(tiled_hlo))) { if (size == 1) continue; tile_sizes.push_back(CreateConst(b, b.getI64Type(), size)); @@ -2440,20 +2381,16 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset)); } - IndexingMap program_id_to_input_tile_indexing = - ComposeIndexingMaps(program_id_to_output_tile_indexing, - tiled_hlo_instruction.indexing_map()); - program_id_to_input_tile_indexing.Simplify(GetIndexingMapForInstruction); + TF_ASSIGN_OR_RETURN( + IndexingMap program_id_to_input_tile_indexing, + analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo)); // Manually compute pointer offset to avoid materialized fully parallel // dimensions in the tile. Current codegen tried to avoid size-1 dims. - TF_ASSIGN_OR_RETURN( - Value ptr_offset, - ComputeBasePtrOffset(b, pid, tiled_hlo_instruction.hlo()->shape(), - program_id_to_input_tile_indexing)); + Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(), + program_id_to_input_tile_indexing); - auto fn_arg = - fn.getArgument(tiled_hlo_instruction.hlo()->parameter_number()); + auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number()); auto tile_ptr = AddPtr(b, fn_arg, ptr_offset); if (tile_sizes.empty()) { @@ -2476,9 +2413,12 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, EmitTiledScope(b, libdevice_path, device_info, *analysis, emit_param_load, values_out)); - TF_ASSIGN_OR_RETURN(Value ptr_offset, - ComputeBasePtrOffset(b, pid, root_shape, - program_id_to_output_tile_indexing)); + TF_ASSIGN_OR_RETURN( + IndexingMap program_id_to_output_tile_indexing, + analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot())); + + Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape, + program_id_to_output_tile_indexing); Value store_tensor = b.create( /*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()), diff --git a/xla/service/gpu/ir_emitter_triton.h b/xla/service/gpu/ir_emitter_triton.h index 96ca55139bb196..b4306ac8e7c6bc 100644 --- a/xla/service/gpu/ir_emitter_triton.h +++ b/xla/service/gpu/ir_emitter_triton.h @@ -33,7 +33,6 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/status.h" @@ -49,12 +48,6 @@ struct TritonWrapperResult { std::optional cluster_dim; }; -// Computes indexing map from program id into the tile offset for the given -// shape and tile sizes. -IndexingMap ComputeProgramIdToOutputTileIndexing( - absl::Span dimensions, absl::Span tile_sizes, - mlir::MLIRContext* mlir_context); - // Compute the launch dimensions for the given Triton MatMul. absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, diff --git a/xla/service/gpu/ir_emitter_triton_test.cc b/xla/service/gpu/ir_emitter_triton_test.cc index d72c283c2dc667..f82029cff87f81 100644 --- a/xla/service/gpu/ir_emitter_triton_test.cc +++ b/xla/service/gpu/ir_emitter_triton_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_triton.h" -#include #include #include #include @@ -42,7 +41,6 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/pattern_matcher.h" @@ -151,31 +149,6 @@ absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck( return absl::OkStatus(); } -TEST_F(TritonTest, ComputeProgramIdToOutputTileIndexing) { - mlir::MLIRContext context; - - auto compute_map = [&](absl::Span dimensions, - absl::Span tile_sizes) { - return ComputeProgramIdToOutputTileIndexing(dimensions, tile_sizes, - &context); - }; - - EXPECT_THAT(compute_map(/*dimensions=*/{9, 17}, /*tile_sizes=*/{5, 10}), - MatchIndexingMap(R"( - (d0) -> ((d0 floordiv 2) * 5, (d0 mod 2) * 10) - domain: - d0 in [0, 3] - )")); - - EXPECT_THAT( - compute_map(/*dimensions=*/{8, 16, 32}, /*tile_sizes=*/{1, 1, 32}), - MatchIndexingMap(R"( - (d0) -> (d0 floordiv 16, d0 mod 16, 0) - domain: - d0 in [0, 127] - )")); -} - TEST_F(TritonFilecheckTest, TestGemm) { const std::string kHloText = R"( HloModule t, is_scheduled=true diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index e5f9fe6b2dc280..6b3010d7670666 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -606,6 +606,9 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -616,6 +619,7 @@ xla_cc_test( name = "symbolic_tile_analysis_test", srcs = ["symbolic_tile_analysis_test.cc"], deps = [ + ":indexing_test_utils", ":symbolic_tile_analysis", ":symbolic_tiled_hlo_instruction", "//xla/tests:hlo_test_base", diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index 82491b684ff263..6fda8d16e13dd2 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -53,6 +55,40 @@ using ::mlir::AffineMap; using ::mlir::MLIRContext; using ::mlir::SmallVector; +// Computes indexing map from program id into the tile offset for the given +// shape and tile sizes. +IndexingMap ComputeProgramIdToOutputTileIndexing( + absl::Span dimensions, absl::Span tile_sizes, + mlir::MLIRContext* mlir_context) { + CHECK_EQ(dimensions.size(), tile_sizes.size()); + + int num_tiles = 1; + std::vector outer_loop_bounds; + outer_loop_bounds.reserve(dimensions.size()); + for (auto [dim_size, tile_size] : llvm::zip(dimensions, tile_sizes)) { + int num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size; + + num_tiles *= num_tiles_per_dim; + outer_loop_bounds.push_back(num_tiles_per_dim); + } + + mlir::AffineExpr program_id = mlir::getAffineDimExpr(0, mlir_context); + + // Delinearize the block id. + auto tile_exprs = + DelinearizeIndex(outer_loop_bounds, program_id, mlir_context); + + // Scale each index by the tile size to produce tile offset. + for (auto [tile_expr, tile_size] : llvm::zip(tile_exprs, tile_sizes)) { + tile_expr = tile_expr * tile_size; + } + + return IndexingMap::FromTensorSizes( + mlir::AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/0, tile_exprs, mlir_context), + /*dim_upper_bounds=*/{num_tiles}, /*symbol_upper_bounds=*/{}); +} + } // namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( @@ -162,24 +198,77 @@ using ::mlir::SmallVector; std::vector SymbolicTileAnalysis::TileOffsets( const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()); + CHECK(tile_parameters_.has_value()) << "SetTileSizes() must be called before " + "TileOffsets()"; return tiled_hlo.TileOffsets(*tile_parameters_); } // TODO(bchetioui): remove dependency on stride and offset parameters. std::vector SymbolicTileAnalysis::TileSizes( const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()); + CHECK(tile_parameters_.has_value()) << "SetTileSizes() must be called before " + "TileSizes()"; return tiled_hlo.TileSizes(*tile_parameters_); } std::vector SymbolicTileAnalysis::TileStrides( const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()); + CHECK(tile_parameters_.has_value()) << "SetTileSizes() must be called before " + "TileStrides()"; return tiled_hlo.TileStrides(*tile_parameters_); } +absl::StatusOr +SymbolicTileAnalysis::ComputeBlockIdToTileOffsetIndexing( + const SymbolicTiledHloInstruction& tiled_hlo) const { + CHECK(block_id_to_root_tile_offset_.has_value()) + << "SetTileSizes() must be called before " + "ComputeBlockIdToTileOffsetIndexing()"; + + IndexingMap block_id_to_tile_offset_indexing = ComposeIndexingMaps( + *block_id_to_root_tile_offset_, tiled_hlo.indexing_map()); + + // A symbol in an indexing map means that to produce on element of output, we + // need to read all elements of input in the symbol range. Since this function + // computes start of the tile, we need to substitute each symbol with its + // lower bound value. We assume here the iteration order is normalized. + // TODO(b/330906085): Support cases when tile offsets are not 0. + if (absl::c_any_of(block_id_to_tile_offset_indexing.GetSymbolBounds(), + [](const Interval& symbol_bound) { + return symbol_bound.lower != 0; + })) { + return absl::FailedPreconditionError( + absl::StrCat("Symbol lower bound is not zero. ", + block_id_to_tile_offset_indexing.ToString())); + } + + std::vector symbol_lower_bounds( + block_id_to_tile_offset_indexing.GetSymbolCount(), + mlir::getAffineConstantExpr(0, context_)); + + mlir::AffineMap simplified_affine_map = + block_id_to_tile_offset_indexing.GetAffineMap().replaceDimsAndSymbols( + /*dimReplacements=*/{}, symbol_lower_bounds, + block_id_to_tile_offset_indexing.GetDimVarsCount(), + /*numResultSyms=*/ + block_id_to_tile_offset_indexing.GetRangeVarsCount()); + + IndexingMap simplified_indexing_map = IndexingMap{ + simplified_affine_map, block_id_to_tile_offset_indexing.GetDimVars(), + block_id_to_tile_offset_indexing.GetRangeVars(), + block_id_to_tile_offset_indexing.GetRTVars()}; + + simplified_indexing_map.Simplify(GetIndexingMapForInstruction); + simplified_indexing_map.RescaleSymbols(); + simplified_indexing_map.RemoveUnusedSymbols(); + + return simplified_indexing_map; +} + void SymbolicTileAnalysis::SetTileSizes(std::vector sizes) { + block_id_to_root_tile_offset_ = ComputeProgramIdToOutputTileIndexing( + GetRoot()->hlo()->shape().dimensions(), sizes, context_); + // TODO(bchetioui): CHECK num parameters somehow? tile_parameters_ = std::vector(std::move(sizes)); } diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h index 2a793882508429..db643f924fe8a7 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -23,8 +23,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" @@ -64,6 +66,13 @@ class SymbolicTileAnalysis { std::vector TileStrides( const SymbolicTiledHloInstruction& tiled_hlo) const; + // Computes the indexing map from block id to tile offset of the tiled HLO + // instruction. The indexing map has the following form: + // + // (block_id) -> (tile_offset0, tile_offset1, ...) + absl::StatusOr ComputeBlockIdToTileOffsetIndexing( + const SymbolicTiledHloInstruction& tiled_hlo) const; + // Populates input tile sizes. This is a prerequisite in order to extract // concrete values using `TileOffsets`, `TileSizes`, and `TileStrides`. void SetTileSizes(std::vector sizes); @@ -99,6 +108,10 @@ class SymbolicTileAnalysis { // computation. The order and type of parameters are as explained in the // documentation of `SymbolicTile`. std::optional> tile_parameters_; + + // Indexing map from block id to root tile offset. Computed from the tile + // parameters. + std::optional block_id_to_root_tile_offset_; }; } // namespace gpu diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 0aba6790cdb3a3..cf01b25844915b 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -66,6 +67,13 @@ ENTRY main { const SymbolicTiledHloInstruction* root = analysis.GetRoot(); + EXPECT_THAT(*analysis.ComputeBlockIdToTileOffsetIndexing(*root), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) + domain: + d0 in [0, 19] + )")); + auto p0_from_subtract0 = root->operand(0); auto p0_from_subtract1 = root->operand(1)->operand(0)->operand(0); @@ -73,9 +81,23 @@ ENTRY main { EXPECT_THAT(analysis.TileSizes(*p0_from_subtract0), ElementsAre(1, 10)); EXPECT_THAT(analysis.TileStrides(*p0_from_subtract0), ElementsAre(1, 1)); + EXPECT_THAT(*analysis.ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract0), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) + domain: + d0 in [0, 19] + )")); + EXPECT_THAT(analysis.TileOffsets(*p0_from_subtract1), ElementsAre(0, 0)); EXPECT_THAT(analysis.TileSizes(*p0_from_subtract1), ElementsAre(1, 97)); EXPECT_THAT(analysis.TileStrides(*p0_from_subtract1), ElementsAre(1, 1)); + + EXPECT_THAT(*analysis.ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract1), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, 0) + domain: + d0 in [0, 19] + )")); } TEST_F(SymbolicTileAnalysisTest, ElementwiseDiamondCSEIsSupported) {