Skip to content

Commit

Permalink
Generic region to loops
Browse files Browse the repository at this point in the history
- `linalg.generic` to affine passes
- Linearize memref accesses pass

Closes #1910
Closes #1911
  • Loading branch information
nsmithtt committed Feb 27, 2025
1 parent 63bc18c commit e7993d9
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 19 deletions.
36 changes: 36 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ def TTIRAttachMetalLayout: Pass<"ttir-attach-metal-layout", "::mlir::ModuleOp">
];
}

def TTIRGenericLinearizeMemref: Pass<"ttir-generic-linearize-memref", "::mlir::ModuleOp"> {
let summary = "Linearize memref operands for generic ops.";
let description = [{
This pass takes a nested loop structure over n-dimensional memrefs and linearizes
them into a single dimension. This is a useful because circular buffers in metal
are only one-dimensional.

Example, this pass will convert the following code:
```mlir
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
%0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
}
}
```

Into:
```mlir
%collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
%collapse_shape_0 = memref.collapse_shape %arg3 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
%collapse_shape_1 = memref.collapse_shape %arg4 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
%0 = affine.load %collapse_shape[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %collapse_shape_0[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
affine.store %2, %collapse_shape_1[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
}
}
```
}];
}

def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
let summary = "Tensor tilize all generic ops.";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Allocate.cpp
Broadcast.cpp
Constant.cpp
Generic.cpp
GenericLinearizeMemref.cpp
HoistCPUOps.cpp
Layout.cpp
Transforms.cpp
Expand Down
16 changes: 0 additions & 16 deletions lib/Dialect/TTIR/Transforms/Generic.cpp

This file was deleted.

121 changes: 121 additions & 0 deletions lib/Dialect/TTIR/Transforms/GenericLinearizeMemref.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <numeric>

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Utils.h"

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRGENERICLINEARIZEMEMREF
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

namespace {
class TTIRGenericLinearizeMemrefRewriter
: public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;

static bool isLinearizedMemref(BlockArgument arg) {
auto memref = mlir::cast<MemRefType>(arg.getType());
if (memref.getShape().size() == 1) {
return true;
}

return std::all_of(arg.user_begin(), arg.user_end(), [](Operation *user) {
return mlir::isa<memref::CollapseShapeOp>(user);
});
}

static mlir::AffineMap linearizeAffineMap(::mlir::MLIRContext *context,
mlir::AffineMap map,
ArrayRef<int64_t> shape) {
auto evaledShape = ttmlir::utils::evalShape(map, shape);
mlir::AffineExpr indexing = getAffineConstantExpr(0, context);
mlir::AffineExpr volumeExpr = getAffineConstantExpr(1, context);

for (int i = map.getNumResults() - 1; i >= 0; i--) {
mlir::AffineExpr linearIdx = getAffineDimExpr(i, context);
mlir::AffineExpr dim = getAffineConstantExpr(evaledShape[i], context);
indexing = linearIdx * volumeExpr + indexing;
volumeExpr = volumeExpr * dim;
}

mlir::AffineMap linearResult =
mlir::AffineMap::get(map.getNumResults(), 0, indexing, context);
return linearResult.compose(map);
}

LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const final {
Block *entry = &op.getRegion().front();
rewriter.setInsertionPointToStart(entry);
auto args = entry->getArguments();
if (llvm::all_of(args, isLinearizedMemref)) {
return failure();
}

rewriter.modifyOpInPlace(op, [&]() {
for (auto arg : args) {
if (isLinearizedMemref(arg)) {
continue;
}
auto memref = mlir::cast<MemRefType>(arg.getType());
auto shape = memref.getShape();
auto linearMap = linearizeAffineMap(
rewriter.getContext(), memref.getLayout().getAffineMap(), shape);
SmallVector<ReassociationIndices, 4> collapsedDims = {
llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))};
auto linearizedArg = rewriter.create<memref::CollapseShapeOp>(
arg.getLoc(), arg, collapsedDims);
rewriter.replaceAllUsesExcept(arg, linearizedArg->getResult(0),
linearizedArg);
for (auto user : linearizedArg->getUsers()) {
if (auto load = mlir::dyn_cast<affine::AffineLoadOp>(user)) {
load.setMap(linearMap.compose(load.getMap()));
} else if (auto store = mlir::dyn_cast<affine::AffineStoreOp>(user)) {
store.setMap(linearMap.compose(store.getMap()));
}
}
}
});

return success();
}
};
} // namespace

namespace {
class TTIRGenericLinearizeMemref
: public impl::TTIRGenericLinearizeMemrefBase<TTIRGenericLinearizeMemref> {
public:
using impl::TTIRGenericLinearizeMemrefBase<
TTIRGenericLinearizeMemref>::TTIRGenericLinearizeMemrefBase;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRGenericLinearizeMemrefRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
}
}
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
registry.insert<mlir::arith::ArithDialect>();
}
};
} // namespace

} // namespace mlir::tt::ttir
5 changes: 5 additions & 0 deletions lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "ttmlir/Dialect/TTMetal/Pipelines/TTMetalPipelines.h"

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"

#include "ttmlir/Conversion/Passes.h"
Expand Down Expand Up @@ -79,6 +81,9 @@ void createTTIRToTTMetalBackendPipeline(
// pm.addPass(mlir::tt::ttir::createTTIRGenericRegion());
if (options.version > 0) {
createTTIRBufferizationPipeline(pm);
pm.addPass(mlir::createConvertLinalgToAffineLoopsPass());
pm.addPass(mlir::tt::ttir::createTTIRGenericLinearizeMemref());
pm.addPass(mlir::createLowerAffinePass());
} else {
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
{
Expand Down
5 changes: 3 additions & 2 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::affine::AffineDialect, mlir::scf::SCFDialect,
mlir::cf::ControlFlowDialect, mlir::tosa::TosaDialect,
mlir::vector::VectorDialect, mlir::memref::MemRefDialect,
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect,
mlir::LLVM::LLVMDialect>();

Expand Down
26 changes: 26 additions & 0 deletions test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: ttmlir-opt --ttir-generic-linearize-memref %s | FileCheck %s

#l1_ = #tt.memory_space<l1>
#map = affine_map<(d0, d1) -> (d0, d1)>
#parallel = #tt.iterator_type<parallel>

func.func @add(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
"ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array<i32: 2, 0, 1>, operand_cb_mapping = array<i64>}> ({
^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>):
// CHECK: = memref.collapse_shape %arg3
// CHECK: = memref.collapse_shape %arg4
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
// CHECK: = affine.load %collapse_shape[%{{.*}} * 4 + %{{.*}}]
// CHECK: = affine.load %collapse_shape_0[%{{.*}} * 4 + %{{.*}}]
%0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_add"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
// CHECK: affine.store %2, %collapse_shape_1[%{{.*}} * 4 + %{{.*}}]
affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
}
}
}) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> ()
return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
}

0 comments on commit e7993d9

Please sign in to comment.