Skip to content

Commit

Permalink
Convert tt.func and tt.return (#4)
Browse files Browse the repository at this point in the history
Summary: This is stll a kind of the boilerplate and basic lowering for the first milestone (compiling vector addition). This PR firstly lowers `tt.func` and `tt.return`.


Test Plan: It can safely compile an empty kernel.

```
@triton.jit
def add_kernel(x_ptr,  y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    return
```

> TRITON_ENABLE_LLVM_DEBUG=1 TRITON_CPU_BACKEND=1 python3 empty_kerne.py

```
//===-------------------------------------------===//
Legalizing operation : 'tt.func'(0x73be2a0) {
  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'tt.func -> ()' {
Trying to match "(anonymous namespace)::FuncOpConversion"
    ** Insert  : 'llvm.func'(0x6c04c70)
    ** Insert Block into : 'llvm.func'(0x6c04c70)
    ** Insert Block into : 'llvm.func'(0x6c04c70)
    ** Erase   : 'tt.func'(0x73be2a0)
"(anonymous namespace)::FuncOpConversion" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'llvm.func'(0x6c04c70) {
    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

...

//===-------------------------------------------===//
Legalizing operation : 'tt.return'(0x73efeb0) {
  "tt.return"() : () -> ()

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'tt.return -> ()' {
Trying to match "(anonymous namespace)::ReturnOpConversion"
    ** Insert  : 'llvm.return'(0x73c0f00)
    ** Replace : 'tt.return'(0x73efeb0)
"(anonymous namespace)::ReturnOpConversion" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'llvm.return'(0x73c0f00) {
      "llvm.return"() : () -> ()

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
```
  • Loading branch information
minjang authored May 13, 2024
1 parent ee43271 commit 25e6cfc
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace triton {
// Some populate* functions have name collisions with the ones for GPUs.
namespace cpu {

constexpr int patternBenefitDefault = 1;
constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

} // namespace cpu
} // namespace triton
} // namespace mlir

#endif
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter {

TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis = nullptr);

Type convertTritonPointerType(triton::PointerType type);
};

#endif
26 changes: 26 additions & 0 deletions include/triton/Conversion/TritonCPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace LLVM {

// TODO: Not sure we need this for CPU backends.
inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

} // namespace LLVM
} // namespace mlir

#endif
2 changes: 2 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
add_triton_library(TritonCPUToLLVM
ControlFlowOpToLLVM.cpp
FuncOpToLLVM.cpp
TypeConverter.cpp
TritonCPUToLLVM.cpp

Expand Down
37 changes: 37 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"
#include "llvm/Support/ErrorHandling.h"

namespace {

using namespace mlir;
using namespace mlir::triton;

struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (funcOp->hasAttr("cpu.kernel")) {
if (op.getNumOperands() > 0) {
return rewriter.notifyMatchFailure(
op, "Kernel functions do not support return with operands");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
} else {
llvm_unreachable("Not implemented");
}
return success();
}
};

} // namespace

void mlir::triton::cpu::populateControlFlowOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReturnOpConversion>(typeConverter, benefit);
}
54 changes: 54 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "mlir/Support/LogicalResult.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"

namespace mlir {
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
}

namespace {

using namespace mlir;
using namespace mlir::triton;

struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit) {}

LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!LLVM::isKernel(funcOp)) {
llvm_unreachable("Not implemented");
}

LLVM::LLVMFuncOp newFuncOp =
*mlir::convertFuncOpToLLVMFuncOp(funcOp, rewriter, *getTypeConverter());
if (!newFuncOp) {
return failure();
}

auto ctx = funcOp->getContext();
if (LLVM::isKernel(funcOp)) {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("cpu.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
} else {
llvm_unreachable("Not implemented");
}

rewriter.eraseOp(funcOp);
return success();
}
};

} // namespace

void mlir::triton::cpu::populateFuncOpConversionPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FuncOpConversion>(typeConverter, benefit);
}
26 changes: 25 additions & 1 deletion lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonCPUToLLVM/Passes.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"
Expand Down Expand Up @@ -71,7 +72,30 @@ struct ConvertTritonCPUToLLVM
TritonCPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget convTarget(*context);

// TODO:
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonCPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
mlir::triton::cpu::populateFuncOpConversionPattern(
typeConverter, funcPatterns,
mlir::triton::cpu::patternBenefitDefault);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}

RewritePatternSet patterns(context);
int benefit =
mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions;
mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter,
patterns, benefit);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

Expand Down
15 changes: 15 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using namespace mlir::triton;
Expand All @@ -9,8 +10,22 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> std::optional<Type> {
return convertTritonPointerType(type);
});

// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}

Type TritonCPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
auto ctx = type.getContext();
auto pointeeType = type.getPointeeType();
if (pointeeType.isa<RankedTensorType>()) {
llvm_unreachable("Not implemented");
}
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
}
4 changes: 2 additions & 2 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def make_llir(src, metadata, options):

@staticmethod
def make_exe(src, metadata, options):
# Right now, src is just TTIR. Extract kernel name from tt.func.
names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src))
# Just a quick hack while developing the backend.
names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src))
assert len(names) == 1
metadata["name"] = names[0]

Expand Down

0 comments on commit 25e6cfc

Please sign in to comment.