diff --git a/backend/compiler.py b/backend/compiler.py index 3c51068b..ea5f2b73 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -12,6 +12,7 @@ import functools from pathlib import Path + def _get_triton_shared_opt_path() -> str: path = os.getenv("TRITON_SHARED_OPT_PATH", "") if path == "": @@ -19,6 +20,13 @@ def _get_triton_shared_opt_path() -> str: return path +# because of the way triton loads backends, this function is duplicated +# in compiler and driver +def _get_triton_shared_use_openblas() -> bool: + use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "") + return use_blas != "" + + def _get_llvm_bin_path(bin_name: str) -> str: path = os.getenv("LLVM_BINARY_DIR", "") if path == "": @@ -42,7 +50,9 @@ def _ttir_to_ttsharedir(mod): dst_path = os.path.join(tmpdir, "ttshared.mlir") Path(src_path).write_text(ttir_code) triton_shared_opt_path = _get_triton_shared_opt_path() - subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "--mlir-print-debuginfo", "-o", dst_path]) + extra_pass = ["--linalg-to-linear-algebra-subprograms"] if _get_triton_shared_use_openblas() else [] + subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental"] + \ + extra_pass + ["--mlir-print-debuginfo", "-o", dst_path]) _dump_ir_if_needed([src_path]) return Path(dst_path).read_text() diff --git a/backend/driver.py b/backend/driver.py index 41b0d26c..32862b37 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -12,6 +12,14 @@ from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget + +# because of the way triton loads backends, this function is duplicated +# in compiler and driver +def _get_triton_shared_use_openblas() -> bool: + use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "") + return use_blas != "" + + # -------------------- Launcher ---------------------------- def _ty_to_cpp(ty): if ty[0] == '*': @@ -253,11 +261,12 @@ def launch( so_path = os.path.join(tmpdir, "kernel.so") Path(asm_src_path).write_bytes(asm_src) Path(launcher_src_path).write_text(src) + extra_lib = ["-lopenblas"] if _get_triton_shared_use_openblas() else [] # Compile it together. subprocess.check_call([ "g++", "-std=c++17", launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{include_dir}", f"-L{py_lib_dir}", - "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path + "-shared", "-fPIC"] + extra_lib + ["-o", so_path ]) with open(so_path, "rb") as f: diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index 45da8aca..ed596c7a 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(LinalgToLinearAlgebraSubprograms) diff --git a/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt new file mode 100644 index 00000000..280821e7 --- /dev/null +++ b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name LinalgToLinearAlgebraSubprograms) +add_public_tablegen_target(LinalgToLinearAlgebraSubprogramsConversionPassIncGen) diff --git a/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h new file mode 100644 index 00000000..8329fa16 --- /dev/null +++ b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h @@ -0,0 +1,25 @@ +#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H +#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc" + +void populateLinalgToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs, + bool addptrToLinalg, + bool assertToCf, + RewritePatternSet &patterns); + +std::unique_ptr> createLinalgToLinearAlgebraSubprogramsPass(); + +} // namespace triton +} // namespace mlir + +#endif // LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H diff --git a/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h new file mode 100644 index 00000000..4e0b0cf1 --- /dev/null +++ b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h @@ -0,0 +1,15 @@ +#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H +#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.td b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.td new file mode 100644 index 00000000..56291706 --- /dev/null +++ b/include/triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.td @@ -0,0 +1,10 @@ +#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES +#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def LinalgToLinearAlgebraSubprograms : Pass<"linalg-to-linear-algebra-subprograms", "mlir::ModuleOp"> { + let summary = "Convert Linalg operations to library calls"; +} + +#endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 45da8aca..ed596c7a 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(LinalgToLinearAlgebraSubprograms) diff --git a/lib/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt b/lib/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt new file mode 100644 index 00000000..effdb90e --- /dev/null +++ b/lib/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_library(LinalgToLinearAlgebraSubprograms + LinalgToLinearAlgebraSubprogramsPass.cpp + + DEPENDS + LinalgToLinearAlgebraSubprogramsConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRLinalgTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/lib/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprogramsPass.cpp b/lib/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprogramsPass.cpp new file mode 100644 index 00000000..aa22cc97 --- /dev/null +++ b/lib/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprogramsPass.cpp @@ -0,0 +1,239 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-to-las" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_LINALGTOLINEARALGEBRASUBPROGRAMS +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + if (op.getInputs().size() != 2) { + LLVM_DEBUG(llvm::dbgs() << "Cannot replace, must be exactly two input matrices\n"); + return failure(); + } + + Operation *resultOp; + + Value A = op.getInputs()[0]; + Value B = op.getInputs()[1]; + Value C; + + Value matmulResult = op.getResults()[0]; + bool otherUsers = false; + bool found = false; + + for (Operation *user : matmulResult.getUsers()) { + if (auto addFOp = dyn_cast(user)) { + if (!found) { + found = true; + C = addFOp.getLhs() == matmulResult ? addFOp.getRhs() : addFOp.getLhs(); + resultOp = addFOp; + continue; + } + } + otherUsers = true; + } + + bool replacingFOp = true; + if (otherUsers || !found) { + C = op.getOutputs()[0]; + resultOp = op; + replacingFOp = false; + } + + auto tensorA = cast(A.getType()); + auto tensorB = cast(B.getType()); + auto tensorC = cast(C.getType()); + + if (tensorA.getElementType() != tensorB.getElementType() || + tensorC.getElementType() != tensorB.getElementType()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n"); + return failure(); + } + + if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n"); + return failure(); + } + + auto floatType = tensorA.getElementType(); + + // since tensors are immutable, we need to allocate a buffer for the result + Value memrefConst = rewriter.create(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C); + auto memrefType = MemRefType::get(tensorC.getShape(), floatType); + Value memrefC = rewriter.create(loc, memrefType); + auto copyOp = rewriter.create(loc, ValueRange{memrefConst}, ValueRange{memrefC}); + + ModuleOp module = op->getParentOfType(); + + auto intType = rewriter.getI32Type(); + auto int64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space + + auto funcType = FunctionType::get(op.getContext(), + {intType, intType, intType, intType, intType, intType, floatType, + ptrType, intType, ptrType, intType, floatType, + ptrType, intType}, {}); + + bool usingF64 = floatType.isF64(); + const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm"; + auto func = module.lookupSymbol(funcName); + if (!func) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + func = rewriter.create(loc, funcName, funcType); + func.setVisibility(SymbolTable::Visibility::Private); + } + + auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) { + auto indexPtr = rewriter.create(loc, memref); + auto castOp = rewriter.create(loc, int64Type, indexPtr); + return rewriter.create(loc, ptrType, castOp); + }; + + auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) { + if (auto tensorOp = V.getDefiningOp()) { + Value ref = tensorOp.getMemref(); + return memrefToPointer(ref); + } + + Value memref = rewriter.create(loc, MemRefType::get(T.getShape(), T.getElementType()), V); + return memrefToPointer(memref); + }; + + Value ptrA = tensorToPointer(A, tensorA); + Value ptrB = tensorToPointer(B, tensorB); + Value ptrC = memrefToPointer(memrefC); + + int32_t M = tensorA.getShape()[0]; + int32_t K = tensorA.getShape()[1]; + int32_t N = tensorB.getShape()[1]; + + Value alpha = rewriter.create(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0)); + Value beta = alpha; + + auto constOp = [&rewriter, &loc, &intType](int32_t V) { + return rewriter.create(loc, intType, rewriter.getI32IntegerAttr(V)); + }; + + // constants below are from OpenBLAS library, check variable names for interpretation + // for more information check: https://github.com/OpenMathLib/OpenBLAS/blob/develop/cblas.h + Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111); + Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K); + Value LDA = KVal, LDB = NVal, LDC = NVal; + + auto funcOp = rewriter.create(loc, func, ValueRange{ + CblasRowMajor, CblasNoTrans, CblasNoTrans, + MVal, NVal, KVal, + alpha, ptrA, LDA, + ptrB, LDB, beta, + ptrC, LDC + }); + + auto toTensorOp = rewriter.create(loc, + tensorC, memrefC, true /* restrict */, true /* writable */); + + if (!replacingFOp) { + rewriter.replaceOp(op, toTensorOp); + } else { + rewriter.eraseOp(op); + rewriter.replaceOp(resultOp, toTensorOp); + } + + return success(); + } +}; + +class LinalgToLinearAlgebraSubprogramsPass + : public triton::impl::LinalgToLinearAlgebraSubprogramsBase { + using LinalgToLinearAlgebraSubprogramsBase< + LinalgToLinearAlgebraSubprogramsPass>::LinalgToLinearAlgebraSubprogramsBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + patterns.add(patterns.getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>(); + + + target.addDynamicallyLegalOp([](linalg::MatmulOp op) { + Value A = op.getInputs()[0]; + Value B = op.getInputs()[1]; + + auto tensorA = cast(A.getType()); + auto tensorB = cast(B.getType()); + + if (tensorA.getElementType() != tensorB.getElementType()) { + // no need to replace if types are different + return true; + } + + if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) { + // unsupported types + return true; + } + + return false; // MatmulOp is illegal, and transformation is needed + }); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +triton::createLinalgToLinearAlgebraSubprogramsPass() { + return std::make_unique(); +} diff --git a/test/Conversion/LinalgToLinearAlgebraSubprograms/matmul.mlir b/test/Conversion/LinalgToLinearAlgebraSubprograms/matmul.mlir new file mode 100644 index 00000000..5709ce6d --- /dev/null +++ b/test/Conversion/LinalgToLinearAlgebraSubprograms/matmul.mlir @@ -0,0 +1,66 @@ +// RUN: triton-shared-opt --linalg-to-linear-algebra-subprograms %s | FileCheck %s + +module { + func.func @bare_matmul(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { + %c128_i32 = arith.constant 128 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = arith.muli %arg9, %c128_i32 : i32 + %1 = arith.index_cast %0 : i32 to index + %2 = arith.muli %arg10, %c128_i32 : i32 + %3 = arith.index_cast %2 : i32 to index + %4 = arith.index_cast %arg5 : i32 to index + %5 = arith.muli %1, %4 : index + %6 = arith.addi %5, %3 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%6], sizes: [128, 128], strides: [%4, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<128x128xf32> + memref.copy %reinterpret_cast, %alloc : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32> + %7 = bufferization.to_tensor %alloc restrict writable : memref<128x128xf32> + %8 = arith.index_cast %arg4 : i32 to index + %9 = arith.muli %1, %8 : index + %10 = arith.addi %9, %3 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%10], sizes: [128, 128], strides: [%8, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> + %alloc_1 = memref.alloc() : memref<128x128xf32> + memref.copy %reinterpret_cast_0, %alloc_1 : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32> + %11 = bufferization.to_tensor %alloc_1 restrict writable : memref<128x128xf32> + %12 = tensor.empty() : tensor<128x128xf32> + %13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<128x128xf32>) -> tensor<128x128xf32> + %14 = linalg.matmul ins(%7, %11 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%13 : tensor<128x128xf32>) -> tensor<128x128xf32> + %reinterpret_cast_2 = memref.reinterpret_cast %arg2 to offset: [%10], sizes: [128, 128], strides: [%8, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> + bufferization.materialize_in_destination %14 in writable %reinterpret_cast_2 : (tensor<128x128xf32>, memref<128x128xf32, strided<[?, 1], offset: ?>>) -> () + return + } +} + +// CHECK: module { +// CHECK: func.func private @cblas_sgemm(i32, i32, i32, i32, i32, i32, f32, !llvm.ptr, i32, !llvm.ptr, i32, f32, !llvm.ptr, i32) +// CHECK: func.func @bare_matmul(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { +// CHECK: [[C128_I32:%.+]] = arith.constant 128 : i32 +// CHECK: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: [[VAR_0:%.+]] = arith.muli %arg9, [[C128_I32]] : i32 +// CHECK: [[VAR_1:%.+]] = arith.index_cast [[VAR_0]] : i32 to index +// CHECK: [[VAR_2:%.+]] = arith.muli %arg10, [[C128_I32]] : i32 +// CHECK: [[VAR_3:%.+]] = arith.index_cast [[VAR_2]] : i32 to index +// CHECK: [[VAR_4:%.+]] = arith.index_cast %arg5 : i32 to index +// CHECK: [[VAR_5:%.+]] = arith.muli [[VAR_1]], [[VAR_4]] : index +// CHECK: [[VAR_6:%.+]] = arith.addi [[VAR_5]], [[VAR_3]] : index +// CHECK: [[REINTERPRET_CAST:%.+]] = memref.reinterpret_cast %arg0 to offset: [[VAR_6]]{{.*}} : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> +// CHECK: [[ALLOC:%.+]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy [[REINTERPRET_CAST]], [[ALLOC]]{{.*}} : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32> +// CHECK: [[VAR_7:%.+]] = bufferization.to_tensor [[ALLOC]] restrict writable{{.*}} : memref<128x128xf32> +// CHECK: [[VAR_8:%.+]] = arith.index_cast %arg4 : i32 to index +// CHECK: [[VAR_9:%.+]] = arith.muli [[VAR_1]], [[VAR_8]] : index +// CHECK: [[VAR_10:%.+]] = arith.addi [[VAR_9]], [[VAR_3]] : index +// CHECK: [[REINTERPRET_CAST_0:%.+]] = memref.reinterpret_cast %arg1 to offset: [[VAR_10]]{{.*}} : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> +// CHECK: [[ALLOC_1:%.+]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy [[REINTERPRET_CAST_0]], [[ALLOC_1]]{{.*}} : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32> +// CHECK: [[VAR_11:%.+]] = bufferization.to_tensor [[ALLOC_1]] restrict writable{{.*}} : memref<128x128xf32> +// CHECK: [[VAR_12:%.+]] = tensor.empty() : tensor<128x128xf32> +// CHECK: [[VAR_13:%.+]] = linalg.fill ins([[CST]] : f32) outs([[VAR_12]] : tensor<128x128xf32>) -> tensor<128x128xf32> +// CHECK: [[VAR_14:%.+]] = linalg.matmul ins([[VAR_7]], [[VAR_11]] : tensor<128x128xf32>, tensor<128x128xf32>) outs([[VAR_13]] : tensor<128x128xf32>) -> tensor<128x128xf32> +// CHECK: [[REINTERPRET_CAST_2:%.+]] = memref.reinterpret_cast %arg2 to offset: [[VAR_10]]{{.*}} : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_14]] in writable [[REINTERPRET_CAST_2]] : (tensor<128x128xf32>, memref<128x128xf32, strided<[?, 1], offset: ?>) -> () +// CHECK: return +// CHECK: } +// CHECK: %cst_5 = arith.constant 1.000000e+00 : f32 +// CHECK: call @cblas_sgemm{{.*}} : (i32, i32, i32, i32, i32, i32, f32, !llvm.ptr, i32, !llvm.ptr, i32, f32, !llvm.ptr, i32) -> () +// CHECK: } \ No newline at end of file diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index 82ba4f39..dc9333e6 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -18,6 +18,7 @@ #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" +#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" @@ -44,6 +45,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestMembarPass(); mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPass(); + mlir::triton::registerLinalgToLinearAlgebraSubprogramsPass(); mlir::triton::registerTritonToStructuredPass(); mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerConvertTritonToTritonGPUPass();