Skip to content

Commit

Permalink
Prerequisite changes for d2m new lowering (#2198)
Browse files Browse the repository at this point in the history
This PR has a few miscellaneous changes to prep for metal backend
lowering flow:

- initialize tensor inputs as tilized
- Setup boilerplate pass structure
  • Loading branch information
nsmithtt authored Feb 19, 2025
1 parent bc853e1 commit 8944ab5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
8 changes: 5 additions & 3 deletions lib/Conversion/TTIRToTTMetal/AttachMetalLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ class TTIRLayoutTensorTypeConverter : public TypeConverter {
uint32_t numBuffers;
std::tie(streamMode, numBuffers) =
StreamLayoutAttr::getDefaults(initMemorySpace);
auto tileType = TileType::get(ctx, type.getElementType());

return layout.withOuterScale(ctx, outerScale, streamMode, numBuffers);
return layout.withElementType(ctx, tileType)
.withOuterScale(ctx, outerScale, streamMode, numBuffers);
}();

return RankedTensorType::get(type.getShape(), type.getElementType(),
newLayout);
return RankedTensorType::get(newLayout.getShardShape(false),
newLayout.getElementType(), newLayout);
});
}
};
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/TTIR/Transforms/Layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ inline Location appendInputSuffix(Location loc, int64_t operandIndex) {
// To layout pass
//===----------------------------------------------------------------------===//

namespace {
class TTIRLayoutTensorTypeConverter : public TypeConverter {
public:
TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace,
Expand All @@ -56,7 +57,9 @@ class TTIRLayoutTensorTypeConverter : public TypeConverter {
});
}
};
} // namespace

namespace {
class TTIRLayoutTensorTypeRewriter : public RewritePattern {
public:
TTIRLayoutTensorTypeRewriter(const TypeConverter &converter, MLIRContext *ctx)
Expand Down Expand Up @@ -124,6 +127,7 @@ class TTIRLayoutTensorTypeRewriter : public RewritePattern {

const TypeConverter *converter;
};
} // namespace

static std::optional<Value>
createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
Expand Down
22 changes: 13 additions & 9 deletions lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,20 @@ void createTTIRToTTMetalBackendPipeline(
mlir::tt::ttir::createTTIRAttachMetalLayout(attachMetalLayoutOptions));
// TODO(#1951): replace with TTIRToGeneric implemented as a converter:
// pm.addPass(mlir::tt::ttir::createTTIRGenericRegion());
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
{
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1;
layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceL1;
layoutOptions.defaultDeviceMemoryLayout =
mlir::tt::TensorMemoryLayout::None;
if (options.version > 0) {

} else {
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
{
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1;
layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceL1;
layoutOptions.defaultDeviceMemoryLayout =
mlir::tt::TensorMemoryLayout::None;
}
pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions));
pm.addPass(mlir::tt::ttir::createTTIRAllocate());
pm.addPass(createConvertTTIRToTTMetalPass());
}
pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions));
pm.addPass(mlir::tt::ttir::createTTIRAllocate());
pm.addPass(createConvertTTIRToTTMetalPass());
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8944ab5

Please sign in to comment.