Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating OpenBLAS for gemm #163

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
import functools
from pathlib import Path


def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
if path == "":
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


def _get_llvm_bin_path(bin_name: str) -> str:
path = os.getenv("LLVM_BINARY_DIR", "")
if path == "":
Expand All @@ -42,7 +50,9 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "--mlir-print-debuginfo", "-o", dst_path])
extra_pass = ["--linalg-to-linear-algebra-subprograms"] if _get_triton_shared_use_openblas() else []
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental"] + \
extra_pass + ["--mlir-print-debuginfo", "-o", dst_path])
_dump_ir_if_needed([src_path])
return Path(dst_path).read_text()

Expand Down
11 changes: 10 additions & 1 deletion backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


# -------------------- Launcher ----------------------------
def _ty_to_cpp(ty):
if ty[0] == '*':
Expand Down Expand Up @@ -253,11 +261,12 @@ def launch(
so_path = os.path.join(tmpdir, "kernel.so")
Path(asm_src_path).write_bytes(asm_src)
Path(launcher_src_path).write_text(src)
extra_lib = ["-lopenblas"] if _get_triton_shared_use_openblas() else []
# Compile it together.
subprocess.check_call([
"g++", "-std=c++17", launcher_src_path, asm_src_path,
f"-I{py_include_dir}", f"-I{include_dir}", f"-L{py_lib_dir}",
"-shared", f"-l{py_lib}", "-fPIC", "-o", so_path
"-shared", "-fPIC"] + extra_lib + ["-o", so_path
])

with open(so_path, "rb") as f:
Expand Down
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(LinalgToLinearAlgebraSubprograms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name LinalgToLinearAlgebraSubprograms)
add_public_tablegen_target(LinalgToLinearAlgebraSubprogramsConversionPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

#define GEN_PASS_DECL
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"

void populateLinalgToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs,
bool addptrToLinalg,
bool assertToCf,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<ModuleOp>> createLinalgToLinearAlgebraSubprogramsPass();

} // namespace triton
} // namespace mlir

#endif // LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H

#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES
#define LINALG_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def LinalgToLinearAlgebraSubprograms : Pass<"linalg-to-linear-algebra-subprograms", "mlir::ModuleOp"> {
let summary = "Convert Linalg operations to library calls";
}

#endif
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(LinalgToLinearAlgebraSubprograms)
21 changes: 21 additions & 0 deletions lib/Conversion/LinalgToLinearAlgebraSubprograms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
add_triton_library(LinalgToLinearAlgebraSubprograms
LinalgToLinearAlgebraSubprogramsPass.cpp

DEPENDS
LinalgToLinearAlgebraSubprogramsConversionPassIncGen

LINK_LIBS PUBLIC
MLIRLinalgTransforms
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
TritonIR
TritonTransforms
TritonTilingExtIR
TritonStructuredIR
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/LinalgToLinearAlgebraSubprograms.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-to-las"

using namespace mlir;
using namespace triton;

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_LINALGTOLINEARALGEBRASUBPROGRAMS
#include "triton-shared/Conversion/LinalgToLinearAlgebraSubprograms/Passes.h.inc"
} // namespace triton
} // namespace mlir

namespace {

struct MatmulConverter : public OpConversionPattern<linalg::MatmulOp> {
using OpConversionPattern<linalg::MatmulOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

if (op.getInputs().size() != 2) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, must be exactly two input matrices\n");
return failure();
}

Operation *resultOp;

Value A = op.getInputs()[0];
Value B = op.getInputs()[1];
Value C;

Value matmulResult = op.getResults()[0];
bool otherUsers = false;
bool found = false;

for (Operation *user : matmulResult.getUsers()) {
if (auto addFOp = dyn_cast<arith::AddFOp>(user)) {
if (!found) {
found = true;
C = addFOp.getLhs() == matmulResult ? addFOp.getRhs() : addFOp.getLhs();
resultOp = addFOp;
continue;
}
}
otherUsers = true;
}

bool replacingFOp = true;
if (otherUsers || !found) {
C = op.getOutputs()[0];
resultOp = op;
replacingFOp = false;
}

auto tensorA = cast<RankedTensorType>(A.getType());
auto tensorB = cast<RankedTensorType>(B.getType());
auto tensorC = cast<RankedTensorType>(C.getType());

if (tensorA.getElementType() != tensorB.getElementType() ||
tensorC.getElementType() != tensorB.getElementType()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n");
return failure();
}

if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n");
return failure();
}

auto floatType = tensorA.getElementType();

// since tensors are immutable, we need to allocate a buffer for the result
Value memrefConst = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C);
auto memrefType = MemRefType::get(tensorC.getShape(), floatType);
Value memrefC = rewriter.create<memref::AllocOp>(loc, memrefType);
auto copyOp = rewriter.create<linalg::CopyOp>(loc, ValueRange{memrefConst}, ValueRange{memrefC});

ModuleOp module = op->getParentOfType<ModuleOp>();

auto intType = rewriter.getI32Type();
auto int64Type = rewriter.getI64Type();
auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space

auto funcType = FunctionType::get(op.getContext(),
{intType, intType, intType, intType, intType, intType, floatType,
ptrType, intType, ptrType, intType, floatType,
ptrType, intType}, {});

bool usingF64 = floatType.isF64();
const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm";
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (!func) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
func = rewriter.create<func::FuncOp>(loc, funcName, funcType);
func.setVisibility(SymbolTable::Visibility::Private);
}

auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) {
auto indexPtr = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(loc, memref);
auto castOp = rewriter.create<arith::IndexCastOp>(loc, int64Type, indexPtr);
return rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, castOp);
};

auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) {
if (auto tensorOp = V.getDefiningOp<bufferization::ToTensorOp>()) {
Value ref = tensorOp.getMemref();
return memrefToPointer(ref);
}

Value memref = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(T.getShape(), T.getElementType()), V);
return memrefToPointer(memref);
};

Value ptrA = tensorToPointer(A, tensorA);
Value ptrB = tensorToPointer(B, tensorB);
Value ptrC = memrefToPointer(memrefC);

int32_t M = tensorA.getShape()[0];
int32_t K = tensorA.getShape()[1];
int32_t N = tensorB.getShape()[1];

Value alpha = rewriter.create<arith::ConstantOp>(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0));
Value beta = alpha;

auto constOp = [&rewriter, &loc, &intType](int32_t V) {
return rewriter.create<arith::ConstantOp>(loc, intType, rewriter.getI32IntegerAttr(V));
};

// constants below are from OpenBLAS library, check variable names for interpretation
// for more information check: https://github.com/OpenMathLib/OpenBLAS/blob/develop/cblas.h
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111);
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K);
Value LDA = KVal, LDB = NVal, LDC = NVal;

auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{
CblasRowMajor, CblasNoTrans, CblasNoTrans,
MVal, NVal, KVal,
alpha, ptrA, LDA,
ptrB, LDB, beta,
ptrC, LDC
});

auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc,
tensorC, memrefC, true /* restrict */, true /* writable */);

if (!replacingFOp) {
rewriter.replaceOp(op, toTensorOp);
} else {
rewriter.eraseOp(op);
rewriter.replaceOp(resultOp, toTensorOp);
}

return success();
}
};

class LinalgToLinearAlgebraSubprogramsPass
: public triton::impl::LinalgToLinearAlgebraSubprogramsBase<LinalgToLinearAlgebraSubprogramsPass> {
using LinalgToLinearAlgebraSubprogramsBase<
LinalgToLinearAlgebraSubprogramsPass>::LinalgToLinearAlgebraSubprogramsBase;

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect, bufferization::BufferizationDialect,
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());

patterns.add<MatmulConverter>(patterns.getContext());

target.addLegalDialect<
func::FuncDialect, arith::ArithDialect, math::MathDialect,
affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect,
cf::ControlFlowDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>();


target.addDynamicallyLegalOp<linalg::MatmulOp>([](linalg::MatmulOp op) {
Value A = op.getInputs()[0];
Value B = op.getInputs()[1];

auto tensorA = cast<RankedTensorType>(A.getType());
auto tensorB = cast<RankedTensorType>(B.getType());

if (tensorA.getElementType() != tensorB.getElementType()) {
// no need to replace if types are different
return true;
}

if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) {
// unsupported types
return true;
}

return false; // MatmulOp is illegal, and transformation is needed
});

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
triton::createLinalgToLinearAlgebraSubprogramsPass() {
return std::make_unique<LinalgToLinearAlgebraSubprogramsPass>();
}
Loading