Skip to content

Commit

Permalink
Remove TT_TensorMemoryLayout (#2255)
Browse files Browse the repository at this point in the history
### Ticket
closes #596 

### Problem description
TT_TensorMemoryLayout is unneeded since metal backend doesn't use it. 

### What's changed
Since I just merged #2217 which removes TensorMemoryLayout from metal
flatbuffer as well as TensorMemoryLayout::None from ttnn flatbuffer, I
thought I would do the same for the compiler as well.

This PR removes TT_TensorMemoryLayout. With this, TensorMemoryLayout
should be unified end to end.
  • Loading branch information
jnie-TT authored Feb 24, 2025
1 parent 91dd3d6 commit 3af5027
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 301 deletions.
3 changes: 0 additions & 3 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ ttmlirTTMemorySpaceAttrGet(MlirContext ctx, uint32_t memorySpace);
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTOOBValAttrGet(MlirContext ctx,
uint32_t oobVal);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTTensorMemoryLayoutAttrGet(MlirContext ctx, uint32_t memLayout);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTIteratorTypeAttrGet(MlirContext ctx, uint32_t iteratorType);

Expand Down
20 changes: 0 additions & 20 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,6 @@ def TT_MemorySpace : I32EnumAttr<"MemorySpace", "TT MemorySpace",
let cppNamespace = "::mlir::tt";
}

def TT_TensorMemoryLayoutNone : I32EnumAttrCase<"None", 0, "none">;
def TT_TensorMemoryLayoutInterleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">;
def TT_TensorMemoryLayoutSingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">;
def TT_TensorMemoryLayoutHeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">;
def TT_TensorMemoryLayoutWidthSharded : I32EnumAttrCase<"WidthSharded", 4, "width_sharded">;
def TT_TensorMemoryLayoutBlockSharded : I32EnumAttrCase<"BlockSharded", 5, "block_sharded">;

def TT_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TT TensorMemoryLayout",
[
TT_TensorMemoryLayoutNone,
TT_TensorMemoryLayoutInterleaved,
TT_TensorMemoryLayoutSingleBank,
TT_TensorMemoryLayoutHeightSharded,
TT_TensorMemoryLayoutWidthSharded,
TT_TensorMemoryLayoutBlockSharded,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

def TT_Parallel : I32EnumAttrCase<"Parallel", 0, "parallel">;
def TT_Systolic : I32EnumAttrCase<"Systolic", 1, "systolic">;
def TT_Broadcast : I32EnumAttrCase<"Broadcast", 2, "broadcast">;
Expand Down
6 changes: 0 additions & 6 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ inline bool isL1MemorySpace(MemorySpace memorySpace) {
return memorySpace == MemorySpace::DeviceL1;
}

inline bool isShardedMemoryLayout(TensorMemoryLayout layout) {
return layout == TensorMemoryLayout::HeightSharded ||
layout == TensorMemoryLayout::WidthSharded ||
layout == TensorMemoryLayout::BlockSharded;
}

inline void printDimensionList(::mlir::AsmPrinter &printer,
::llvm::ArrayRef<int64_t> shape) {
printer.printDimensionList(shape);
Expand Down
23 changes: 5 additions & 18 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,8 @@ def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> {
let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear,
AttrParameter<"OOBVal", "A tracked out of bounds value that fills padding space.">:$oob_val,
AttrParameter<"GridAttr", "The grid shape that this tensor is divided onto.">:$grid,
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref,
DefaultValuedParameter<"TensorMemoryLayout", "TensorMemoryLayout::None", "The layout of the tensor in memory.">:$mem_layout);
let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref (`,` $mem_layout^)? `>`";
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref);
let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref `>`";

let extraClassDeclaration = [{
static MetalLayoutAttr get(::mlir::MLIRContext *context,
Expand All @@ -331,29 +330,25 @@ def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> {
MemorySpace memorySpace = MemorySpace::System,
GridAttr grid = {},
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}},
OOBVal oobVal = OOBVal::Undef,
TensorMemoryLayout memLayout = TensorMemoryLayout::None);
OOBVal oobVal = OOBVal::Undef);
static MetalLayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace = MemorySpace::System,
GridAttr grid = {},
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}},
OOBVal oobVal = OOBVal::Undef,
TensorMemoryLayout memLayout = TensorMemoryLayout::None);
OOBVal oobVal = OOBVal::Undef);
static MetalLayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace,
GridAttr grid,
Type elementType,
TensorMemoryLayout memLayout = TensorMemoryLayout::None);
Type elementType);
MetalLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid, ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
MetalLayoutAttr withGrid(::mlir::MLIRContext *context,
RankedTensorType ty,
GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
MetalLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType);
MetalLayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace);
MetalLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
MetalLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);
MetalLayoutAttr withStreamLayout(::mlir::MLIRContext *context, StreamLayoutAttr stream);
MetalLayoutAttr withOuterScale(::mlir::MLIRContext *context,
Expand All @@ -365,10 +360,6 @@ def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> {
MemorySpace getMemorySpace() const;
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool hasShardedTensorMemoryLayout() const;
bool hasInterleavedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
Expand Down Expand Up @@ -437,10 +428,6 @@ def TT_MemorySpaceAttr : EnumAttr<TT_Dialect, TT_MemorySpace, "memory_space"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_TensorMemoryLayoutAttr : EnumAttr<TT_Dialect, TT_TensorMemoryLayout, "tensor_memory_layout"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_OOBValAttr : EnumAttr<TT_Dialect, TT_OOBVal, "oob_val"> {
let assemblyFormat = "`<` $value `>`";
}
Expand Down
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI
bool isGridChange;
bool isFormatChange;
bool isMemorySpaceChange;
bool isMemoryLayoutChange;
};

// Returns booleans indicating if the op changes layout, grid, format, memory space or memory layout.
Expand Down
4 changes: 0 additions & 4 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
"::mlir::tt::MemorySpace",
/*default=*/"::mlir::tt::MemorySpace::DeviceDRAM",
"Set the default memory space for layout pass to prefer for operation operands, if not constrained">,
Option<"defaultDeviceMemoryLayout", "default-device-memory-layout",
"::mlir::tt::TensorMemoryLayout",
/*default=*/"::mlir::tt::TensorMemoryLayout::Interleaved",
"Set the default memory layout for layout pass to prefer for operation operands that are on device, if not constrained">
];
}

Expand Down
10 changes: 5 additions & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout",
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 4, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 5, "block_sharded">;
def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 0, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 1, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 2, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 3, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 4, "block_sharded">;

def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout",
[
Expand Down
10 changes: 0 additions & 10 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ namespace mlir::tt::ttnn::utils {
mlir::tt::ttnn::BufferType
toTTNNBufferType(const mlir::tt::MemorySpace memorySpace);

// Map tt::TensorMemoryLayout to ttnn::TensorMemoryLayout
//
ttnn::TensorMemoryLayout
toTTNNTensorMemoryLayout(const tt::TensorMemoryLayout ttTensorMemoryLayout);

// Map ttnn::BufferType to tt::MemorySpace
//
mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout(
const ::mlir::tt::ttnn::TensorMemoryLayout ttnnTensorMemoryLayout);

// Map ttnn::BufferType to tt::MemorySpace
//
mlir::tt::MemorySpace
Expand Down
11 changes: 2 additions & 9 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,12 @@ MlirAttribute ttmlirTTSystemDescAttrGet(

MlirAttribute ttmlirTTMetalLayoutAttrGet(MlirContext ctx, MlirAffineMap linear,
unsigned oobVal, MlirAttribute grid,
MlirType memref, unsigned memLayout) {
MlirType memref) {
mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr);
return wrap(MetalLayoutAttr::get(unwrap(ctx), affineMap,
static_cast<OOBVal>(oobVal),
mlir::cast<GridAttr>(unwrap(grid)),
mlir::cast<MemRefType>(unwrap(memref)),
static_cast<TensorMemoryLayout>(memLayout)));
mlir::cast<MemRefType>(unwrap(memref))));
}

MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx,
Expand All @@ -140,12 +139,6 @@ MlirAttribute ttmlirTTOOBValAttrGet(MlirContext ctx, uint32_t oobVal) {
return wrap(OOBValAttr::get(unwrap(ctx), static_cast<tt::OOBVal>(oobVal)));
}

MlirAttribute ttmlirTTTensorMemoryLayoutAttrGet(MlirContext ctx,
uint32_t memLayout) {
return wrap(TensorMemoryLayoutAttr::get(
unwrap(ctx), static_cast<tt::TensorMemoryLayout>(memLayout)));
}

MlirAttribute ttmlirTTIteratorTypeAttrGet(MlirContext ctx,
uint32_t iteratorType) {
return wrap(IteratorTypeAttr::get(
Expand Down
4 changes: 0 additions & 4 deletions lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,9 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern<ttir::ToLayoutOp> {
static_cast<int>(components.isGridChange ||
components.isMemorySpaceChange) +
static_cast<int>(components.isFormatChange)) > 1;
assert(!components.isMemoryLayoutChange &&
"Memory layout is not used in direct to metal path");
assert(!isCompound && "Only one change is allowed");

assert(!isCompound && "Only one change is allowed");
assert(!components.isMemoryLayoutChange &&
"Tensor memory layout shouldn't change in metal backend");
if (components.isMemorySpaceChange) {
if (inputLayout.isSystemMemorySpace()) {
assert(outputLayout.isDeviceMemorySpace());
Expand Down
67 changes: 15 additions & 52 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ MetalLayoutAttr MetalLayoutAttr::get(
::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape,
Type elementType, MemorySpace memorySpace, GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals,
OOBVal oobVal, TensorMemoryLayout memLayout) {
OOBVal oobVal) {
if (not grid) {
grid = GridAttr::get(context, tensorShape.size());
}
Expand All @@ -496,28 +496,27 @@ MetalLayoutAttr MetalLayoutAttr::get(
auto shardShape = calculateLogicalShardShape(tensorShape, linear, grid);
auto memref = buildMemRef<MemorySpace, MemorySpaceAttr>(
context, shardShape, elementType, memorySpace);
return get(context, linear, oobVal, grid, memref, memLayout);
return get(context, linear, oobVal, grid, memref);
}

MetalLayoutAttr MetalLayoutAttr::get(
::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace,
GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals,
OOBVal oobVal, TensorMemoryLayout memLayout) {
OOBVal oobVal) {
assert(ty);
return get(context, ty.getShape(), ty.getElementType(), memorySpace, grid,
collapseIntervals, oobVal, memLayout);
collapseIntervals, oobVal);
}

MetalLayoutAttr MetalLayoutAttr::get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace, GridAttr grid,
Type elementType,
TensorMemoryLayout memLayout) {
Type elementType) {
assert(ty);
assert(grid);
return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}},
OOBVal::Undef, memLayout);
OOBVal::Undef);
}

// From the logical shape of the tensor and the affine map of the layout,
Expand Down Expand Up @@ -620,28 +619,6 @@ mlir::Type MetalLayoutAttr::getScalarElementType() const {
return elementType;
}

bool MetalLayoutAttr::hasShardedTensorMemoryLayout() const {
return (getMemLayout() == TensorMemoryLayout::HeightSharded or
getMemLayout() == TensorMemoryLayout::WidthSharded or
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool MetalLayoutAttr::hasInterleavedTensorMemoryLayout() const {
return (getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool MetalLayoutAttr::hasShardedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::HeightSharded or
getMemLayout() == TensorMemoryLayout::WidthSharded or
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool MetalLayoutAttr::hasInterleavedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool MetalLayoutAttr::isTiled() const {
return ::mlir::isa<::mlir::tt::TileType>(getElementType());
}
Expand All @@ -667,7 +644,7 @@ MetalLayoutAttr MetalLayoutAttr::withGrid(
::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals) {
return get(context, tensorShape, getElementType(), getMemorySpace(), grid,
collapseIntervals, getOobVal(), getMemLayout());
collapseIntervals, getOobVal());
}

MetalLayoutAttr MetalLayoutAttr::withGrid(
Expand All @@ -683,27 +660,15 @@ MetalLayoutAttr MetalLayoutAttr::withElementType(::mlir::MLIRContext *context,
return MetalLayoutAttr::get(
context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(context, getShardShape(),
elementType, getMemorySpace()),
getMemLayout());
elementType, getMemorySpace()));
}

MetalLayoutAttr MetalLayoutAttr::withMemorySpace(::mlir::MLIRContext *context,
MemorySpace memorySpace) {
return MetalLayoutAttr::get(
context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(context, getShardShape(),
getElementType(), memorySpace),
getMemLayout());
}

MetalLayoutAttr
MetalLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context,
TensorMemoryLayout memLayout) {
return MetalLayoutAttr::get(
context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(
context, getShardShape(), getElementType(), getMemorySpace()),
memLayout);
getElementType(), memorySpace));
}

MetalLayoutAttr
Expand All @@ -712,18 +677,16 @@ MetalLayoutAttr::withShardShape(::mlir::MLIRContext *context,
return MetalLayoutAttr::get(
context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(
context, shardShape, getElementType(), getMemorySpace()),
getMemLayout());
context, shardShape, getElementType(), getMemorySpace()));
}

// TODO(vroubtsovTT): remove this, it's difficult/unsafe to use
MetalLayoutAttr MetalLayoutAttr::withStreamLayout(::mlir::MLIRContext *context,
StreamLayoutAttr layout) {
return MetalLayoutAttr::get(
context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(
context, getShardShape(), getElementType(), getMemorySpace(), layout),
getMemLayout());
return MetalLayoutAttr::get(context, getLinear(), getOobVal(), getGrid(),
buildMemRef<MemorySpace, MemorySpaceAttr>(
context, getShardShape(), getElementType(),
getMemorySpace(), layout));
}

MetalLayoutAttr MetalLayoutAttr::withOuterScale(
Expand All @@ -747,7 +710,7 @@ MetalLayoutAttr MetalLayoutAttr::withOuterScale(
context, fullShape, getElementType(), getMemorySpace(), fullLayout);

return MetalLayoutAttr::get(context, getLinear(), getOobVal(), getGrid(),
fullMemRef, getMemLayout());
fullMemRef);
}

MemorySpace MetalLayoutAttr::getMemorySpace() const {
Expand Down
5 changes: 1 addition & 4 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,10 +1316,7 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
inputLayout.getElementType() != outputLayout.getElementType();
bool isMemorySpaceChange =
inputLayout.getMemorySpace() != outputLayout.getMemorySpace();
bool isMemoryLayoutChange =
inputLayout.getMemLayout() != outputLayout.getMemLayout();
return {isLayoutChange, isGridChange, isFormatChange, isMemorySpaceChange,
isMemoryLayoutChange};
return {isLayoutChange, isGridChange, isFormatChange, isMemorySpaceChange};
}

::mlir::LogicalResult
Expand Down
Loading

0 comments on commit 3af5027

Please sign in to comment.