Skip to content

Commit

Permalink
[XLA:GPU] Move code to compute block id to tile offset indexing map t…
Browse files Browse the repository at this point in the history
…o SymbolicTileAnalysis.

The block id to tile offset map is relevant for both Triton Emitter and Cost Model, so SymbolicTileAnalysis looks like a better place to have it ATM.

PiperOrigin-RevId: 621831590
  • Loading branch information
olegshyshkov authored and copybara-github committed Apr 4, 2024
1 parent b879cfa commit 9b34f10
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 121 deletions.
2 changes: 0 additions & 2 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,6 @@ xla_test(
deps = [
":backend_configs_cc",
":gpu_device_info_for_tests",
":ir_emission_utils",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
Expand All @@ -620,7 +619,6 @@ xla_test(
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service/gpu/model:indexing_test_utils",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
Expand Down
104 changes: 22 additions & 82 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ limitations under the License.
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
Expand All @@ -125,7 +124,6 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/status_macros.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h"
Expand Down Expand Up @@ -2282,46 +2280,12 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
return absl::OkStatus();
}

// Computes indexing map from program id into the tile offset for the given
// shape and tile sizes.
IndexingMap ComputeProgramIdToOutputTileIndexing(
absl::Span<const int64_t> dimensions, absl::Span<const int64_t> tile_sizes,
mlir::MLIRContext* mlir_context) {
CHECK_EQ(dimensions.size(), tile_sizes.size());

int num_tiles = 1;
std::vector<int64_t> outer_loop_bounds;
outer_loop_bounds.reserve(dimensions.size());
for (auto [dim_size, tile_size] : llvm::zip(dimensions, tile_sizes)) {
int num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size;

num_tiles *= num_tiles_per_dim;
outer_loop_bounds.push_back(num_tiles_per_dim);
}

mlir::AffineExpr program_id = mlir::getAffineDimExpr(0, mlir_context);

// Delinearize the block id.
auto tile_exprs =
DelinearizeIndex(outer_loop_bounds, program_id, mlir_context);

// Scale each index by the tile size to produce tile offset.
for (auto [tile_expr, tile_size] : llvm::zip(tile_exprs, tile_sizes)) {
tile_expr = tile_expr * tile_size;
}

return IndexingMap::FromTensorSizes(
mlir::AffineMap::get(
/*dimCount=*/1, /*symbolCount=*/0, tile_exprs, mlir_context),
/*dim_upper_bounds=*/{num_tiles}, /*symbol_upper_bounds=*/{});
}

// Computes the base pointer offset for the given pid and shape.
// `tile_offset_indexing` is a mapping from
// (program_id) -> [tile_offset0, ..., tile_offsetN]
StatusOr<Value> ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid,
const Shape& shape,
const IndexingMap& tile_offset_indexing) {
Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid,
const Shape& shape,
const IndexingMap& tile_offset_indexing) {
ArrayRef<mlir::AffineExpr> dimension_exprs =
tile_offset_indexing.GetAffineMap().getResults();

Expand All @@ -2333,25 +2297,10 @@ StatusOr<Value> ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid,
stride *= shape.dimensions(i);
}

// A symbol in an indexing map means that to produce on element of output, we
// need to read all elements of input in the symbol range. Since this function
// computes start of the tile, we need to substitute each symbol with its
// lower bound value. We assume here the iteration order is normalized.
// TODO(b/330906085): Support cases when tile offsets are not 0.
for (const Interval& symbol_bound : tile_offset_indexing.GetSymbolBounds()) {
if (symbol_bound.lower != 0) {
return absl::FailedPreconditionError(absl::StrCat(
"Symbol lower bound is not zero. ", tile_offset_indexing.ToString()));
}
}

std::vector<Value> symbol_lower_bounds(
tile_offset_indexing.GetSymbolCount(),
b.create<ma::ConstantOp>(b.getIndexAttr(0)));

return b.create<ma::IndexCastUIOp>(
b.getI64Type(), mlir_converter::ApplyAffineExpr(linear_index, pid,
symbol_lower_bounds, b));
b.getI64Type(),
mlir_converter::ApplyAffineExpr(linear_index, /*dims=*/pid,
/*symbols=*/{}, b));
}

absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
Expand All @@ -2360,8 +2309,6 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
SymbolicTileAnalysis* analysis,
const HloComputation* computation,
mlir::triton::FuncOp fn) {
mlir::MLIRContext* mlir_context = analysis->GetMLIRContext();

const HloInstruction* root = computation->root_instruction();
auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name()));
ImplicitLocOpBuilder b(loc, builder);
Expand Down Expand Up @@ -2412,10 +2359,6 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,

analysis->SetTileSizes(output_tile_sizes);

IndexingMap program_id_to_output_tile_indexing =
ComputeProgramIdToOutputTileIndexing(root_shape.dimensions(),
output_tile_sizes, mlir_context);

// block_size must be a power of two.
int result_block_size = llvm::PowerOf2Ceil(row_len);

Expand All @@ -2425,35 +2368,29 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
}

// Emits load instructions
auto emit_param_load =
[&](const SymbolicTiledHloInstruction& tiled_hlo_instruction)
auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo)
-> absl::StatusOr<Value> {
std::vector<Value> tile_sizes, tile_strides, tile_offsets;
for (auto [size, stride, offset] :
llvm::zip(analysis->TileSizes(tiled_hlo_instruction),
analysis->TileStrides(tiled_hlo_instruction),
analysis->TileOffsets(tiled_hlo_instruction))) {
for (auto [size, stride, offset] : llvm::zip(
analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo),
analysis->TileOffsets(tiled_hlo))) {
if (size == 1) continue;

tile_sizes.push_back(CreateConst(b, b.getI64Type(), size));
tile_strides.push_back(CreateConst(b, b.getI64Type(), stride));
tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset));
}

IndexingMap program_id_to_input_tile_indexing =
ComposeIndexingMaps(program_id_to_output_tile_indexing,
tiled_hlo_instruction.indexing_map());
program_id_to_input_tile_indexing.Simplify(GetIndexingMapForInstruction);
TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_input_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo));

// Manually compute pointer offset to avoid materialized fully parallel
// dimensions in the tile. Current codegen tried to avoid size-1 dims.
TF_ASSIGN_OR_RETURN(
Value ptr_offset,
ComputeBasePtrOffset(b, pid, tiled_hlo_instruction.hlo()->shape(),
program_id_to_input_tile_indexing));
Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(),
program_id_to_input_tile_indexing);

auto fn_arg =
fn.getArgument(tiled_hlo_instruction.hlo()->parameter_number());
auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number());
auto tile_ptr = AddPtr(b, fn_arg, ptr_offset);

if (tile_sizes.empty()) {
Expand All @@ -2476,9 +2413,12 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
EmitTiledScope(b, libdevice_path, device_info, *analysis,
emit_param_load, values_out));

TF_ASSIGN_OR_RETURN(Value ptr_offset,
ComputeBasePtrOffset(b, pid, root_shape,
program_id_to_output_tile_indexing));
TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_output_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot()));

Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape,
program_id_to_output_tile_indexing);

Value store_tensor = b.create<mt::MakeTensorPtrOp>(
/*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()),
Expand Down
7 changes: 0 additions & 7 deletions xla/service/gpu/ir_emitter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ limitations under the License.
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/hlo_module_config.h"
#include "xla/status.h"
Expand All @@ -49,12 +48,6 @@ struct TritonWrapperResult {
std::optional<se::ClusterDim> cluster_dim;
};

// Computes indexing map from program id into the tile offset for the given
// shape and tile sizes.
IndexingMap ComputeProgramIdToOutputTileIndexing(
absl::Span<const int64_t> dimensions, absl::Span<const int64_t> tile_sizes,
mlir::MLIRContext* mlir_context);

// Compute the launch dimensions for the given Triton MatMul.
absl::StatusOr<LaunchDimensions> GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion,
Expand Down
27 changes: 0 additions & 27 deletions xla/service/gpu/ir_emitter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License.

#include "xla/service/gpu/ir_emitter_triton.h"

#include <cstdint>
#include <iterator>
#include <limits>
#include <memory>
Expand All @@ -42,7 +41,6 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/pattern_matcher.h"
Expand Down Expand Up @@ -151,31 +149,6 @@ absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck(
return absl::OkStatus();
}

TEST_F(TritonTest, ComputeProgramIdToOutputTileIndexing) {
mlir::MLIRContext context;

auto compute_map = [&](absl::Span<const int64_t> dimensions,
absl::Span<const int64_t> tile_sizes) {
return ComputeProgramIdToOutputTileIndexing(dimensions, tile_sizes,
&context);
};

EXPECT_THAT(compute_map(/*dimensions=*/{9, 17}, /*tile_sizes=*/{5, 10}),
MatchIndexingMap(R"(
(d0) -> ((d0 floordiv 2) * 5, (d0 mod 2) * 10)
domain:
d0 in [0, 3]
)"));

EXPECT_THAT(
compute_map(/*dimensions=*/{8, 16, 32}, /*tile_sizes=*/{1, 1, 32}),
MatchIndexingMap(R"(
(d0) -> (d0 floordiv 16, d0 mod 16, 0)
domain:
d0 in [0, 127]
)"));
}

TEST_F(TritonFilecheckTest, TestGemm) {
const std::string kHloText = R"(
HloModule t, is_scheduled=true
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand All @@ -616,6 +619,7 @@ xla_cc_test(
name = "symbolic_tile_analysis_test",
srcs = ["symbolic_tile_analysis_test.cc"],
deps = [
":indexing_test_utils",
":symbolic_tile_analysis",
":symbolic_tiled_hlo_instruction",
"//xla/tests:hlo_test_base",
Expand Down
Loading

0 comments on commit 9b34f10

Please sign in to comment.