Skip to content

Commit

Permalink
Implement missing triton operations (#159)
Browse files Browse the repository at this point in the history
Tests depend on #158,
once it is merged, I will remove test_precise_math and test_clamp from
unsupported list.

---------

Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
  • Loading branch information
parsifal-47 and parsifal-47 authored Aug 13, 2024
1 parent b00515f commit 9a47497
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/Support/MathExtras.h"

#include <numeric>
#include <optional>
#include <type_traits>

using namespace mlir;
Expand Down Expand Up @@ -133,6 +134,22 @@ static Value getTransposedValue(Value source, const Location loc,
return transpose;
}

// for IntLike and FloatLike types
static std::optional<unsigned> getBitWidth(Type a) {
if (auto type = dyn_cast<TensorType>(a)) {
auto elementType = type.getElementType();
if (elementType.isIntOrFloat()) {
return type.getElementType().getIntOrFloatBitWidth();
}
return std::nullopt;
}

if (a.isIntOrFloat())
return a.getIntOrFloatBitWidth();

return std::nullopt;
}

//===----------------------------------------------------------------------===//
// Op Lowering Patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -843,6 +860,95 @@ struct BitcastConverter : public OpConversionPattern<triton::BitcastOp> {
}
};

struct FpToFpConverter : public OpConversionPattern<triton::FpToFpOp> {
using OpConversionPattern<triton::FpToFpOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto roundingMode = triton::RoundingMode::RTNE; // default

auto roundingModeAttr = op.getRounding();
if (roundingModeAttr.has_value()) {
roundingMode = roundingModeAttr.value();
}

assert(roundingMode != triton::RoundingMode::RTZ &&
"Rounding Towards Zero is not supported");

Type resultType = op.getResult().getType();

auto operandWidth = getBitWidth(op.getOperand().getType());
auto resultWidth = getBitWidth(resultType);

assert(operandWidth.has_value() && resultWidth.has_value() &&
"Not a float-like operand or result");

if (operandWidth.value() > resultWidth.value()) {
Value truncatedValue = rewriter.create<arith::TruncFOp>(op.getLoc(), resultType, op.getOperand());
rewriter.replaceOp(op, truncatedValue);
return success();
}

Value extendedValue = rewriter.create<arith::ExtFOp>(op.getLoc(), resultType, op.getOperand());
rewriter.replaceOp(op, extendedValue);

return success();
}
};

struct ClampConverter : public OpConversionPattern<triton::ClampFOp> {
using OpConversionPattern<triton::ClampFOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL;

assert(!propagateNan &&
"PropagateNan is not supported");

Location loc = op.getLoc();
Value x = adaptor.getOperands()[0];
Value min = adaptor.getOperands()[1];
Value max = adaptor.getOperands()[2];

Value maxMin = rewriter.create<arith::MaximumFOp>(loc, x, min);
Value clamp = rewriter.create<arith::MinimumFOp>(loc, maxMin, max);
rewriter.replaceOp(op, clamp);

return success();
}
};

struct PreciseSqrtConverter : public OpConversionPattern<triton::PreciseSqrtOp> {
using OpConversionPattern<triton::PreciseSqrtOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto replacement = rewriter.create<math::SqrtOp>(
op.getLoc(), adaptor.getOperands());

rewriter.replaceOp(op, replacement);
return success();
}
};

struct PreciseDivConverter : public OpConversionPattern<triton::PreciseDivFOp> {
using OpConversionPattern<triton::PreciseDivFOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto replacement = rewriter.create<arith::DivFOp>(
op.getLoc(), adaptor.getOperands());

rewriter.replaceOp(op, replacement);
return success();
}
};

struct MulHiUIOpConverter : public OpConversionPattern<triton::MulhiUIOp> {
using OpConversionPattern<triton::MulhiUIOp>::OpConversionPattern;

Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns(
patterns.add<ExpandDimsConverter>(patterns.getContext());
patterns.add<BitcastConverter>(patterns.getContext());
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
patterns.add<FpToFpConverter>(patterns.getContext());
patterns.add<ClampConverter>(patterns.getContext());
patterns.add<MatmulConverter>(patterns.getContext());
patterns.add<SplatConverter>(patterns.getContext());
patterns.add<DenseConstantConverter>(patterns.getContext());
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
patterns.add<ExpandDimsConverter>(patterns.getContext());
patterns.add<BitcastConverter>(patterns.getContext());
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
patterns.add<FpToFpConverter>(patterns.getContext());
patterns.add<ClampConverter>(patterns.getContext());
patterns.add<AssertConverter>(patterns.getContext());
patterns.add<MatmulConverter>(patterns.getContext());
patterns.add<SplatConverter>(patterns.getContext());
Expand Down
2 changes: 0 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def device(request):
"test_slice",
"test_where",
"test_math_erf_op",
"test_precise_math",
"test_abs_fp8",
"test_shapes_as_params",
"test_transpose",
Expand Down Expand Up @@ -79,7 +78,6 @@ def device(request):
"test_convertmma2mma",
"test_dot_max_num_imprecise_acc",
"test_propagate_nan",
"test_clamp",
"test_clamp_symmetric",
"test_temp_var_in_loop",
"test_math_extern"
Expand Down

0 comments on commit 9a47497

Please sign in to comment.