Skip to content

Commit

Permalink
TTIR Bufferization pass (#2264)
Browse files Browse the repository at this point in the history
### What's changed
This change implements a bufferization interface for the ttir generic op
and to_layout op and invokes the relevant passes. This also means that
these ops can now accept memrefs as inputs.

Some excellent documentation on the bufferization pass can be found here
https://mlir.llvm.org/docs/Bufferization/. The motivation in our case is
to leverage this upstream machinery and get "for free":
- Automatic analysis of buffer allocation, aliasing, deallocation
- Type conversion from tensors to buffer objects
- Seamless integration with other dialects used at this stage, memref,
linalg, affine, etc.

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
nsmithtt authored Feb 27, 2025
1 parent 340e05f commit 63bc18c
Show file tree
Hide file tree
Showing 18 changed files with 522 additions and 125 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> {
MetalLayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace);
MetalLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);
MetalLayoutAttr withStreamLayout(::mlir::MLIRContext *context, StreamLayoutAttr stream);
MetalLayoutAttr withStreamMode(::mlir::MLIRContext *context, StreamMode streamMode, std::uint32_t numBuffers);
MetalLayoutAttr withOuterScale(::mlir::MLIRContext *context,
llvm::ArrayRef<int64_t> outerScale,
StreamMode streamMode,
Expand Down
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def TTIR_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTIR_Op<string mnemonic, list<Trait> traits = []> :
Op<TTIR_Dialect, mnemonic, !listconcat([Pure], traits)>;
Op<TTIR_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// TTIR traits definition.
Expand All @@ -68,4 +68,10 @@ def TTIR_BinaryIdempotence : TTIR_Trait<"TTIRBinaryIdempotence", [DestinationSty
// GenericRegionOpTrait is a trait that acts as a label for all generic region operations.
def TTIR_GenericRegionOpTrait : TTIR_Trait<"TTIRGenericRegionOpTrait", []>;

//===----------------------------------------------------------------------===//
// TTIR common types.
//===----------------------------------------------------------------------===//

def AnyRankedTensorOrMemRef: AnyTypeOf<[AnyRankedTensor, AnyNon0RankedMemRef]>;

#endif
2 changes: 0 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIRGenericRegionOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ class TTIR_GenericRegionOp<string mnemonic, list<Trait> traits = [TTIR_GenericRe
// TTIR Generic Region Ops (Used in TTMetal Lowering)
//===----------------------------------------------------------------------===//

def AnyRankedTensorOrMemRef: AnyTypeOf<[AnyRankedTensor, AnyNon0RankedMemRef]>;

def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator, TTIR_GenericRegionOpTrait]> {
let summary = "Yield op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand Down
Loading

0 comments on commit 63bc18c

Please sign in to comment.