Skip to content

Commit

Permalink
Adding benchmarks, integrating OpenBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
parsifal-47 committed Aug 15, 2024
1 parent 9a47497 commit bcfe4c5
Show file tree
Hide file tree
Showing 15 changed files with 397 additions and 3 deletions.
13 changes: 12 additions & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@
import functools
from pathlib import Path


def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
if path == "":
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


def _get_llvm_bin_path(bin_name: str) -> str:
path = os.getenv("LLVM_BINARY_DIR", "")
if path == "":
Expand All @@ -32,7 +40,10 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "-o", dst_path])
extra_pass = ["--triton-to-linear-algebra-subprograms"] if _get_triton_shared_use_openblas() else []
subprocess.check_call([triton_shared_opt_path, src_path] + extra_pass + \
["--triton-to-linalg-experimental",
"-o", dst_path])
return Path(dst_path).read_text()


Expand Down
11 changes: 10 additions & 1 deletion backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget


# because of the way triton loads backends, this function is duplicated
# in compiler and driver
def _get_triton_shared_use_openblas() -> bool:
use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "")
return use_blas != ""


# -------------------- Launcher ----------------------------
def _ty_to_cpp(ty):
if ty[0] == '*':
Expand Down Expand Up @@ -250,11 +258,12 @@ def launch(
so_path = os.path.join(tmpdir, "kernel.so")
Path(asm_src_path).write_bytes(asm_src)
Path(launcher_src_path).write_text(src)
extra_lib = ["-lopenblas"] if _get_triton_shared_use_openblas() else []
# Compile it together.
subprocess.check_call([
"g++", launcher_src_path, asm_src_path,
f"-I{py_include_dir}", f"-I{include_dir}",
"-shared", "-fPIC", "-o", so_path
"-shared", "-fPIC"] + extra_lib + ["-o", so_path
])

with open(so_path, "rb") as f:
Expand Down
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(TritonToLinearAlgebraSubprograms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinearAlgebraSubprograms)
add_public_tablegen_target(TritonToLinearAlgebraSubprogramsConversionPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonToLinearAlgebraSubprograms : Pass<"triton-to-linear-algebra-subprograms", "mlir::ModuleOp"> {
let summary = "Convert Linalg operations to library calls";
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

#define GEN_PASS_DECL
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc"

void populateTritonToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs,
bool addptrToLinalg,
bool assertToCf,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinearAlgebraSubprogramsPass();

} // namespace triton
} // namespace mlir

#endif // TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(TritonToLinearAlgebraSubprograms)
21 changes: 21 additions & 0 deletions lib/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
add_triton_library(TritonToLinearAlgebraSubprograms
TritonToLinearAlgebraSubprogramsPass.cpp

DEPENDS
TritonToLinearAlgebraSubprogramsConversionPassIncGen

LINK_LIBS PUBLIC
MLIRLinalgTransforms
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
TritonIR
TritonTransforms
TritonTilingExtIR
TritonStructuredIR
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-to-las"

using namespace mlir;
using namespace triton;

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_TRITONTOLINEARALGEBRASUBPROGRAMS
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc"
} // namespace triton
} // namespace mlir

namespace {

struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value A = op.getA();
Value B = op.getB();
Value C = op.getC();

auto tensorA = cast<RankedTensorType>(A.getType());
auto tensorB = cast<RankedTensorType>(B.getType());
auto tensorC = cast<RankedTensorType>(C.getType());

if (tensorA.getElementType() != tensorB.getElementType() ||
tensorC.getElementType() != tensorB.getElementType()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n");
return failure();
}

if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n");
return failure();
}

auto floatType = tensorA.getElementType();

// since tensors are immutable, we need to allocate a buffer for the result
Value memrefConst = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C);
auto memrefType = MemRefType::get(tensorC.getShape(), floatType);
Value memrefC = rewriter.create<memref::AllocOp>(loc, memrefType);
auto copyOp = rewriter.create<linalg::CopyOp>(loc, ValueRange{memrefConst}, ValueRange{memrefC});

ModuleOp module = op->getParentOfType<ModuleOp>();

auto intType = rewriter.getI32Type();
auto int64Type = rewriter.getI64Type();
auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space

auto funcType = FunctionType::get(op.getContext(),
{intType, intType, intType, intType, intType, intType, floatType,
ptrType, intType, ptrType, intType, floatType,
ptrType, intType}, {});

bool usingF64 = floatType.isF64();
const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm";
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (!func) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
func = rewriter.create<func::FuncOp>(loc, funcName, funcType);
func.setVisibility(SymbolTable::Visibility::Private);
}

auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) {
auto indexPtr = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(loc, memref);
auto castOp = rewriter.create<arith::IndexCastOp>(loc, int64Type, indexPtr);
return rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, castOp);
};

auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) {
Value memref = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(T.getShape(), T.getElementType()), V);
return memrefToPointer(memref);
};

Value ptrA = tensorToPointer(A, tensorA);
Value ptrB = tensorToPointer(B, tensorB);
Value ptrC = memrefToPointer(memrefC);

int32_t M = tensorA.getShape()[0];
int32_t K = tensorA.getShape()[1];
int32_t N = tensorB.getShape()[1];

Value alpha = rewriter.create<arith::ConstantOp>(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0));
Value beta = alpha;

auto constOp = [&rewriter, &loc, &intType](int32_t V) {
return rewriter.create<arith::ConstantOp>(loc, intType, rewriter.getI32IntegerAttr(V));
};
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111);
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K);
Value LDA = KVal, LDB = NVal, LDC = NVal;

auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{
CblasRowMajor, CblasNoTrans, CblasNoTrans,
MVal, NVal, KVal,
alpha, ptrA, LDA,
ptrB, LDB, beta,
ptrC, LDC
});

auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc,
tensorC, memrefC, true /* restrict */, true /* writable */);
rewriter.replaceOp(op, toTensorOp);
return success();
}
};

class TritonToLinearAlgebraSubprogramsPass
: public triton::impl::TritonToLinearAlgebraSubprogramsBase<TritonToLinearAlgebraSubprogramsPass> {
using TritonToLinearAlgebraSubprogramsBase<
TritonToLinearAlgebraSubprogramsPass>::TritonToLinearAlgebraSubprogramsBase;

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect,
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect, triton::TritonDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());

patterns.add<MatmulConverter>(patterns.getContext());

target.addLegalDialect<
func::FuncDialect, arith::ArithDialect, math::MathDialect,
affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect,
cf::ControlFlowDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>();

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
triton::createTritonToLinearAlgebraSubprogramsPass() {
return std::make_unique<TritonToLinearAlgebraSubprogramsPass>();
}
39 changes: 39 additions & 0 deletions python/examples/bare_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

import triton
import triton.language as tl
import benchmark


@triton.jit
def bare_matmul(X, Y, Z, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)

offs_x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_y = tl.arange(0, BLOCK_SIZE)

x = tl.load(X + offs_x[:, None])
y = tl.load(Y + offs_y[None, :])

z = tl.dot(x, y)
tl.store(Z + offs_x[:, None] + offs_y[None, :], z)


@benchmark.measure()
def bench_matmul(M, N, K, provider):
device = 'cpu'
dtype = torch.float32
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((K, N), device=device, dtype=dtype)
c = torch.empty((K, N), device=device, dtype=dtype)
if provider == 'torch':
torch.matmul(a, b)
if provider == 'triton':
bare_matmul[(1,)](a, b, c, N)


if __name__ == "__main__":
benchmark.select_cpu_backend()
for X in [2**i for i in range(7, 11, 1)]:
for provider in ['torch', 'triton']:
bench_matmul(X, X, X, provider)
Loading

0 comments on commit bcfe4c5

Please sign in to comment.