From 9a4749792ab90238817e4092d52ce189e26f1f63 Mon Sep 17 00:00:00 2001 From: Renat Idrisov <4032256+parsifal-47@users.noreply.github.com> Date: Tue, 13 Aug 2024 08:12:09 -0700 Subject: [PATCH] Implement missing triton operations (#159) Tests depend on https://github.com/microsoft/triton-shared/pull/158, once it is merged, I will remove test_precise_math and test_clamp from unsupported list. --------- Co-authored-by: Renat Idrisov --- .../ConversionPatterns.hpp | 106 ++++++++++++++++++ .../TritonArithToLinalg.cpp | 4 + .../TritonToLinalg/TritonToLinalg.cpp | 4 + python/examples/conftest.py | 2 - 4 files changed, 114 insertions(+), 2 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index e706a614..e3c23caf 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -28,6 +28,7 @@ #include "llvm/Support/MathExtras.h" #include +#include #include using namespace mlir; @@ -133,6 +134,22 @@ static Value getTransposedValue(Value source, const Location loc, return transpose; } +// for IntLike and FloatLike types +static std::optional getBitWidth(Type a) { + if (auto type = dyn_cast(a)) { + auto elementType = type.getElementType(); + if (elementType.isIntOrFloat()) { + return type.getElementType().getIntOrFloatBitWidth(); + } + return std::nullopt; + } + + if (a.isIntOrFloat()) + return a.getIntOrFloatBitWidth(); + + return std::nullopt; +} + //===----------------------------------------------------------------------===// // Op Lowering Patterns //===----------------------------------------------------------------------===// @@ -843,6 +860,95 @@ struct BitcastConverter : public OpConversionPattern { } }; +struct FpToFpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundingMode = triton::RoundingMode::RTNE; // default + + auto roundingModeAttr = op.getRounding(); + if (roundingModeAttr.has_value()) { + roundingMode = roundingModeAttr.value(); + } + + assert(roundingMode != triton::RoundingMode::RTZ && + "Rounding Towards Zero is not supported"); + + Type resultType = op.getResult().getType(); + + auto operandWidth = getBitWidth(op.getOperand().getType()); + auto resultWidth = getBitWidth(resultType); + + assert(operandWidth.has_value() && resultWidth.has_value() && + "Not a float-like operand or result"); + + if (operandWidth.value() > resultWidth.value()) { + Value truncatedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, truncatedValue); + return success(); + } + + Value extendedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, extendedValue); + + return success(); + } +}; + +struct ClampConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL; + + assert(!propagateNan && + "PropagateNan is not supported"); + + Location loc = op.getLoc(); + Value x = adaptor.getOperands()[0]; + Value min = adaptor.getOperands()[1]; + Value max = adaptor.getOperands()[2]; + + Value maxMin = rewriter.create(loc, x, min); + Value clamp = rewriter.create(loc, maxMin, max); + rewriter.replaceOp(op, clamp); + + return success(); + } +}; + +struct PreciseSqrtConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + +struct PreciseDivConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + struct MulHiUIOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp index d5ff37da..9169829d 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -60,6 +60,10 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index 8a094cae..6ad86786 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -50,6 +50,10 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); diff --git a/python/examples/conftest.py b/python/examples/conftest.py index fbb68f88..f82368e0 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -41,7 +41,6 @@ def device(request): "test_slice", "test_where", "test_math_erf_op", - "test_precise_math", "test_abs_fp8", "test_shapes_as_params", "test_transpose", @@ -79,7 +78,6 @@ def device(request): "test_convertmma2mma", "test_dot_max_num_imprecise_acc", "test_propagate_nan", - "test_clamp", "test_clamp_symmetric", "test_temp_var_in_loop", "test_math_extern"