Skip to content

Commit

Permalink
Uplift LLVM and Stablehlo (#2211)
Browse files Browse the repository at this point in the history
### Problem description
Uplifting llvm and stablehlo versions

### What's changed
Uplifted llvm to
llvm/llvm-project@a854c26
Uplifted stablehlo to
openxla/stablehlo@4598975

- `applyPatternsAndFoldGreedily` has been deprecated.
- I've replaced all calls to this with `applyPatternsGreedily` as the
deprecation instructs us to use `applyPatternsGreedily` instead.
- After looking at the [source for
`applyPatternsAndFoldGreedily`](https://github.com/llvm/llvm-project/blob/26a83994176fcdca6e77be4f221a15f561681621/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h#L186)
it simply calls `applyPatternsGreedily` with a `GreedyRewriteConfig`
that has `fold = true`. However the [source of
`GreedyRewriteConfig`](https://github.com/llvm/llvm-project/blob/26a83994176fcdca6e77be4f221a15f561681621/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h#L96)
has `fold = true` by default. So making this replacement should have no
effect on how we expect the pattern driver to behave.

- `FloatType::get<type>()` No longer exists, instead `FloatTypes` for
each type of float have been defined
    - i.e `Float32Type`

- `tosa::MatMulOp` had its quantization info attribute(s) removed in
this commit: llvm/llvm-project@f0b8ff1

- `populateSCFToEmitCConversionPatterns` requires a type converter now.
  • Loading branch information
LPanosTT authored Feb 24, 2025
1 parent 2c434ef commit 1d937da
Show file tree
Hide file tree
Showing 13 changed files with 25 additions and 44 deletions.
4 changes: 2 additions & 2 deletions env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.20.0)
project(ttmlir-toolchain LANGUAGES CXX C)

set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274")
set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29")
set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739")
set(LLVM_PROJECT_VERSION "a854c266b98468ad4479a7d3c56a3fa76437e30d")
set(STABLEHLO_VERSION "459897561d365ef97caba46984847f9184d472ec")
set(SHARDY_VERSION "55f44c23b766be38bccb0b2394b0e8dfba45694e")
set(LLVM_BUILD_TYPE MinSizeRel CACHE STRING "Build type for LLVM")

Expand Down
18 changes: 9 additions & 9 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,23 @@ inline Type dataTypeToElementType(::mlir::MLIRContext *context,
DataType dtype) {
switch (dtype) {
case DataType::Float32:
return FloatType::getF32(context);
return Float32Type::get(context);
case DataType::Float16:
return FloatType::getF16(context);
return Float16Type::get(context);
case DataType::BFloat16:
return FloatType::getBF16(context);
return BFloat16Type::get(context);
case DataType::BFP_Float8:
return FloatType::getF16(context);
return Float16Type::get(context);
case DataType::BFP_BFloat8:
return FloatType::getBF16(context);
return BFloat16Type::get(context);
case DataType::BFP_Float4:
return FloatType::getF16(context);
return Float16Type::get(context);
case DataType::BFP_BFloat4:
return FloatType::getBF16(context);
return BFloat16Type::get(context);
case DataType::BFP_Float2:
return FloatType::getF16(context);
return Float16Type::get(context);
case DataType::BFP_BFloat2:
return FloatType::getBF16(context);
return BFloat16Type::get(context);
case DataType::UInt32:
return IntegerType::get(context, 32,
IntegerType::SignednessSemantics::Unsigned);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class StablehloTypeConverter : public TypeConverter {
}
// Convert 64 bit float element type to 32 bit float.
else if (isa<FloatType>(elementType)) {
elementType = FloatType::getF32(context);
elementType = Float32Type::get(context);
changed = true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TTIRToTTMetal/AttachMetalLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class TTIRAttachMetalLayout
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutTensorTypeRewriter>(typeConverter, &getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ class ConvertTTKernelToEmitCPass
{
populateArithToEmitCPatterns(typeConverter, patterns);

populateSCFToEmitCConversionPatterns(patterns);
populateSCFToEmitCConversionPatterns(patterns, typeConverter);

populateMemRefToEmitCTypeConversion(typeConverter);
populateMemRefToEmitCConversionPatterns(patterns, typeConverter);
Expand Down
17 changes: 0 additions & 17 deletions lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,6 @@ class TosaToTTIRMatmulOpConversionPattern
LogicalResult
matchAndRewrite(tosa::MatMulOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult legalityResult =
checkConversionLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));

Expand All @@ -144,17 +138,6 @@ class TosaToTTIRMatmulOpConversionPattern

return success();
}

private:
LogicalResult
checkConversionLegality(tosa::MatMulOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (srcOp.getQuantizationInfo().has_value()) {
return rewriter.notifyMatchFailure(
srcOp, "TTIR MatmulOp currently doesn't support quantization.");
}
return success();
}
};
} // namespace

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TTIRImplicitBroadcastFold
patterns.add<TTIRImplicitBroadcastFoldRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));

if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TTIRConstantAsFill
RewritePatternSet patterns(&getContext());
patterns.add<TTIRConstantAsFillRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ static void hoistOperationToFunction(mlir::Operation *opToHoist,
const llvm::SmallVector<int64_t, 4> ranks = getOperandTensorRanks(opToHoist);
mlir::MLIRContext *context = sourceModule.getContext();
mlir::OpBuilder typeBuilder(opToHoist);
auto f32Type = mlir::FloatType::getF32(context);
auto f32Type = mlir::Float32Type::get(context);

// Convert operands and gather types for function signature
llvm::SmallVector<mlir::Type> operandTypes;
Expand Down
7 changes: 3 additions & 4 deletions lib/Dialect/TTIR/Transforms/Layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutTensorTypeRewriter>(typeConverter, &getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand All @@ -300,8 +300,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.useTopDownTraversal = true;
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet,
config))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet, config))) {
signalPassFailure();
return;
}
Expand Down Expand Up @@ -425,7 +424,7 @@ class TTIRSplitCompoundLayout
RewritePatternSet patterns(&getContext());
patterns.add<TTIRSplitCompoundLayoutRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ class TTNNCreateInputGenerators

// Create function type
//
mlir::TypeRange returnTypeRange = mlir::TypeRange(rewriter.getI32Type());
mlir::Type i32Type = rewriter.getI32Type();
mlir::TypeRange returnTypeRange = mlir::TypeRange(i32Type);
FunctionType functionType =
mlir::FunctionType::get(&getContext(), {}, returnTypeRange);

Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
patterns.add<TTNNLayoutFuncInputOutputTypeRewriter>(
&getContext(), device.getWorkerGrid());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
Expand All @@ -711,8 +711,7 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.useTopDownTraversal = true;
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet,
config))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet, config))) {
signalPassFailure();
return;
}
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
// This configuration specifies that the rewriter should traverse the IR
// in a top-down order.
config.useTopDownTraversal = true;
if (failed(
applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) {
if (failed(applyPatternsGreedily(getOperation(), patternSet, config))) {
signalPassFailure();
return;
}
Expand Down

0 comments on commit 1d937da

Please sign in to comment.