diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 5403360f7e..058387de44 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -41,6 +41,42 @@ def TTIRAttachMetalLayout: Pass<"ttir-attach-metal-layout", "::mlir::ModuleOp"> ]; } +def TTIRGenericLinearizeMemref: Pass<"ttir-generic-linearize-memref", "::mlir::ModuleOp"> { + let summary = "Linearize memref operands for generic ops."; + let description = [{ + This pass takes a nested loop structure over n-dimensional memrefs and linearizes + them into a single dimension. This is a useful because circular buffers in metal + are only one-dimensional. + + Example, this pass will convert the following code: + ```mlir + affine.for %arg5 = 0 to 2 { + affine.for %arg6 = 0 to 4 { + %0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + %1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + %2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32> + affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + } + } + ``` + + Into: + ```mlir + %collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_> + %collapse_shape_0 = memref.collapse_shape %arg3 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_> + %collapse_shape_1 = memref.collapse_shape %arg4 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_> + affine.for %arg5 = 0 to 2 { + affine.for %arg6 = 0 to 4 { + %0 = affine.load %collapse_shape[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_> + %1 = affine.load %collapse_shape_0[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_> + %2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32> + affine.store %2, %collapse_shape_1[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_> + } + } + ``` + }]; +} + def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { let summary = "Tensor tilize all generic ops."; let description = [{ diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 262b4455b5..449bf98e6d 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms Allocate.cpp Broadcast.cpp Constant.cpp - Generic.cpp + GenericLinearizeMemref.cpp HoistCPUOps.cpp Layout.cpp Transforms.cpp diff --git a/lib/Dialect/TTIR/Transforms/Generic.cpp b/lib/Dialect/TTIR/Transforms/Generic.cpp deleted file mode 100644 index 447b071adc..0000000000 --- a/lib/Dialect/TTIR/Transforms/Generic.cpp +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" - -#include -#include -#include -#include - -namespace mlir::tt::ttir { -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" - -} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Transforms/GenericLinearizeMemref.cpp b/lib/Dialect/TTIR/Transforms/GenericLinearizeMemref.cpp new file mode 100644 index 0000000000..9a2f0f6198 --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/GenericLinearizeMemref.cpp @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" +#include "ttmlir/Utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRGENERICLINEARIZEMEMREF +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +namespace { +class TTIRGenericLinearizeMemrefRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + static bool isLinearizedMemref(BlockArgument arg) { + auto memref = mlir::cast(arg.getType()); + if (memref.getShape().size() == 1) { + return true; + } + + return std::all_of(arg.user_begin(), arg.user_end(), [](Operation *user) { + return mlir::isa(user); + }); + } + + static mlir::AffineMap linearizeAffineMap(::mlir::MLIRContext *context, + mlir::AffineMap map, + ArrayRef shape) { + auto evaledShape = ttmlir::utils::evalShape(map, shape); + mlir::AffineExpr indexing = getAffineConstantExpr(0, context); + mlir::AffineExpr volumeExpr = getAffineConstantExpr(1, context); + + for (int i = map.getNumResults() - 1; i >= 0; i--) { + mlir::AffineExpr linearIdx = getAffineDimExpr(i, context); + mlir::AffineExpr dim = getAffineConstantExpr(evaledShape[i], context); + indexing = linearIdx * volumeExpr + indexing; + volumeExpr = volumeExpr * dim; + } + + mlir::AffineMap linearResult = + mlir::AffineMap::get(map.getNumResults(), 0, indexing, context); + return linearResult.compose(map); + } + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const final { + Block *entry = &op.getRegion().front(); + rewriter.setInsertionPointToStart(entry); + auto args = entry->getArguments(); + if (llvm::all_of(args, isLinearizedMemref)) { + return failure(); + } + + rewriter.modifyOpInPlace(op, [&]() { + for (auto arg : args) { + if (isLinearizedMemref(arg)) { + continue; + } + auto memref = mlir::cast(arg.getType()); + auto shape = memref.getShape(); + auto linearMap = linearizeAffineMap( + rewriter.getContext(), memref.getLayout().getAffineMap(), shape); + SmallVector collapsedDims = { + llvm::to_vector(llvm::seq(0, shape.size()))}; + auto linearizedArg = rewriter.create( + arg.getLoc(), arg, collapsedDims); + rewriter.replaceAllUsesExcept(arg, linearizedArg->getResult(0), + linearizedArg); + for (auto user : linearizedArg->getUsers()) { + if (auto load = mlir::dyn_cast(user)) { + load.setMap(linearMap.compose(load.getMap())); + } else if (auto store = mlir::dyn_cast(user)) { + store.setMap(linearMap.compose(store.getMap())); + } + } + } + }); + + return success(); + } +}; +} // namespace + +namespace { +class TTIRGenericLinearizeMemref + : public impl::TTIRGenericLinearizeMemrefBase { +public: + using impl::TTIRGenericLinearizeMemrefBase< + TTIRGenericLinearizeMemref>::TTIRGenericLinearizeMemrefBase; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsGreedily(getOperation(), patternSet))) { + signalPassFailure(); + } + } + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } +}; +} // namespace + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp b/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp index 35bf5d10e8..2d9aebba89 100644 --- a/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp +++ b/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp @@ -4,9 +4,11 @@ #include "ttmlir/Dialect/TTMetal/Pipelines/TTMetalPipelines.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/PassManager.h" #include "ttmlir/Conversion/Passes.h" @@ -79,6 +81,9 @@ void createTTIRToTTMetalBackendPipeline( // pm.addPass(mlir::tt::ttir::createTTIRGenericRegion()); if (options.version > 0) { createTTIRBufferizationPipeline(pm); + pm.addPass(mlir::createConvertLinalgToAffineLoopsPass()); + pm.addPass(mlir::tt::ttir::createTTIRGenericLinearizeMemref()); + pm.addPass(mlir::createLowerAffinePass()); } else { mlir::tt::ttir::TTIRLayoutOptions layoutOptions; { diff --git a/lib/RegisterAll.cpp b/lib/RegisterAll.cpp index 9bf45b1676..7b210a2826 100644 --- a/lib/RegisterAll.cpp +++ b/lib/RegisterAll.cpp @@ -42,8 +42,9 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry ®istry) { mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect, mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect, - mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, - mlir::tosa::TosaDialect, mlir::vector::VectorDialect, + mlir::affine::AffineDialect, mlir::scf::SCFDialect, + mlir::cf::ControlFlowDialect, mlir::tosa::TosaDialect, + mlir::vector::VectorDialect, mlir::memref::MemRefDialect, mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect, mlir::LLVM::LLVMDialect>(); diff --git a/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir b/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir new file mode 100644 index 0000000000..68a1c450d7 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir @@ -0,0 +1,26 @@ +// RUN: ttmlir-opt --ttir-generic-linearize-memref %s | FileCheck %s + +#l1_ = #tt.memory_space +#map = affine_map<(d0, d1) -> (d0, d1)> +#parallel = #tt.iterator_type + +func.func @add(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> + "ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): + // CHECK: = memref.collapse_shape %arg3 + // CHECK: = memref.collapse_shape %arg4 + affine.for %arg5 = 0 to 2 { + affine.for %arg6 = 0 to 4 { + // CHECK: = affine.load %collapse_shape[%{{.*}} * 4 + %{{.*}}] + // CHECK: = affine.load %collapse_shape_0[%{{.*}} * 4 + %{{.*}}] + %0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + %1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + %2 = "ttir.tile_add"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32> + // CHECK: affine.store %2, %collapse_shape_1[%{{.*}} * 4 + %{{.*}}] + affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_> + } + } + }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> () + return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> +}