diff --git a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h new file mode 100644 index 00000000..d2212eb3 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h @@ -0,0 +1,36 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +// Some populate* functions have name collisions with the ones for GPUs. +namespace cpu { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h index 57f4ce78..8ed9e6d4 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h +++ b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h @@ -15,6 +15,8 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); }; #endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/Utility.h b/include/triton/Conversion/TritonCPUToLLVM/Utility.h new file mode 100644 index 00000000..08d3b5e0 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/Utility.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace LLVM { + +// TODO: Not sure we need this for CPU backends. +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +} // namespace LLVM +} // namespace mlir + +#endif diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt index 0dfa7cb5..17511562 100644 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,4 +1,6 @@ add_triton_library(TritonCPUToLLVM + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp TypeConverter.cpp TritonCPUToLLVM.cpp diff --git a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 00000000..a270c0d6 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,37 @@ +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("cpu.kernel")) { + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + llvm_unreachable("Not implemented"); + } + return success(); + } +}; + +} // namespace + +void mlir::triton::cpu::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 00000000..9ecd4703 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "mlir/Support/LogicalResult.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!LLVM::isKernel(funcOp)) { + llvm_unreachable("Not implemented"); + } + + LLVM::LLVMFuncOp newFuncOp = + *mlir::convertFuncOpToLLVMFuncOp(funcOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("cpu.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + } else { + llvm_unreachable("Not implemented"); + } + + rewriter.eraseOp(funcOp); + return success(); + } +}; + +} // namespace + +void mlir::triton::cpu::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp index 3646c92d..28d320df 100644 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp @@ -14,6 +14,7 @@ #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" #include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -71,7 +72,30 @@ struct ConvertTritonCPUToLLVM TritonCPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMConversionTarget convTarget(*context); - // TODO: + // Lower functions + { + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::cpu::populateFuncOpConversionPattern( + typeConverter, funcPatterns, + mlir::triton::cpu::patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + RewritePatternSet patterns(context); + int benefit = + mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; + mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, + patterns, benefit); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); } }; diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp index 6a5ac668..e8ca0810 100644 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp @@ -1,6 +1,7 @@ #include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/MLIRTypes.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::triton; @@ -9,8 +10,22 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + // Internally store bfloat16 as int16 addConversion([&](BFloat16Type type) -> std::optional { return IntegerType::get(type.getContext(), 16); }); } + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (pointeeType.isa()) { + llvm_unreachable("Not implemented"); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 47b44ca8..84564cab 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -137,8 +137,8 @@ def make_llir(src, metadata, options): @staticmethod def make_exe(src, metadata, options): - # Right now, src is just TTIR. Extract kernel name from tt.func. - names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) + # Just a quick hack while developing the backend. + names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) assert len(names) == 1 metadata["name"] = names[0]