-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding benchmarks, integrating OpenBLAS
- Loading branch information
1 parent
9a47497
commit bcfe4c5
Showing
15 changed files
with
397 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinearAlgebraSubprograms) | ||
add_public_tablegen_target(TritonToLinearAlgebraSubprogramsConversionPassIncGen) |
15 changes: 15 additions & 0 deletions
15
include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H | ||
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H | ||
|
||
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h" | ||
|
||
namespace mlir { | ||
namespace triton { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" | ||
|
||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif |
10 changes: 10 additions & 0 deletions
10
include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES | ||
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def TritonToLinearAlgebraSubprograms : Pass<"triton-to-linear-algebra-subprograms", "mlir::ModuleOp"> { | ||
let summary = "Convert Linalg operations to library calls"; | ||
} | ||
|
||
#endif |
25 changes: 25 additions & 0 deletions
25
...ton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H | ||
#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
namespace mlir { | ||
namespace triton { | ||
|
||
#define GEN_PASS_DECL | ||
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" | ||
|
||
void populateTritonToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs, | ||
bool addptrToLinalg, | ||
bool assertToCf, | ||
RewritePatternSet &patterns); | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinearAlgebraSubprogramsPass(); | ||
|
||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif // TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
lib/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
add_triton_library(TritonToLinearAlgebraSubprograms | ||
TritonToLinearAlgebraSubprogramsPass.cpp | ||
|
||
DEPENDS | ||
TritonToLinearAlgebraSubprogramsConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRLinalgTransforms | ||
MLIRArithDialect | ||
MLIRDialectUtils | ||
MLIRIR | ||
MLIRMathDialect | ||
MLIRPass | ||
MLIRTensorDialect | ||
MLIRTransforms | ||
MLIRSupport | ||
TritonIR | ||
TritonTransforms | ||
TritonTilingExtIR | ||
TritonStructuredIR | ||
) |
173 changes: 173 additions & 0 deletions
173
lib/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprogramsPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h" | ||
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" | ||
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#include "llvm/Support/Debug.h" | ||
|
||
#define DEBUG_TYPE "linalg-to-las" | ||
|
||
using namespace mlir; | ||
using namespace triton; | ||
|
||
namespace mlir { | ||
namespace triton { | ||
#define GEN_PASS_DEF_TRITONTOLINEARALGEBRASUBPROGRAMS | ||
#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" | ||
} // namespace triton | ||
} // namespace mlir | ||
|
||
namespace { | ||
|
||
struct MatmulConverter : public OpConversionPattern<triton::DotOp> { | ||
using OpConversionPattern<triton::DotOp>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
|
||
Value A = op.getA(); | ||
Value B = op.getB(); | ||
Value C = op.getC(); | ||
|
||
auto tensorA = cast<RankedTensorType>(A.getType()); | ||
auto tensorB = cast<RankedTensorType>(B.getType()); | ||
auto tensorC = cast<RankedTensorType>(C.getType()); | ||
|
||
if (tensorA.getElementType() != tensorB.getElementType() || | ||
tensorC.getElementType() != tensorB.getElementType()) { | ||
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n"); | ||
return failure(); | ||
} | ||
|
||
if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) { | ||
LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n"); | ||
return failure(); | ||
} | ||
|
||
auto floatType = tensorA.getElementType(); | ||
|
||
// since tensors are immutable, we need to allocate a buffer for the result | ||
Value memrefConst = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C); | ||
auto memrefType = MemRefType::get(tensorC.getShape(), floatType); | ||
Value memrefC = rewriter.create<memref::AllocOp>(loc, memrefType); | ||
auto copyOp = rewriter.create<linalg::CopyOp>(loc, ValueRange{memrefConst}, ValueRange{memrefC}); | ||
|
||
ModuleOp module = op->getParentOfType<ModuleOp>(); | ||
|
||
auto intType = rewriter.getI32Type(); | ||
auto int64Type = rewriter.getI64Type(); | ||
auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space | ||
|
||
auto funcType = FunctionType::get(op.getContext(), | ||
{intType, intType, intType, intType, intType, intType, floatType, | ||
ptrType, intType, ptrType, intType, floatType, | ||
ptrType, intType}, {}); | ||
|
||
bool usingF64 = floatType.isF64(); | ||
const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm"; | ||
auto func = module.lookupSymbol<func::FuncOp>(funcName); | ||
if (!func) { | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(module.getBody()); | ||
func = rewriter.create<func::FuncOp>(loc, funcName, funcType); | ||
func.setVisibility(SymbolTable::Visibility::Private); | ||
} | ||
|
||
auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) { | ||
auto indexPtr = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(loc, memref); | ||
auto castOp = rewriter.create<arith::IndexCastOp>(loc, int64Type, indexPtr); | ||
return rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, castOp); | ||
}; | ||
|
||
auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) { | ||
Value memref = rewriter.create<bufferization::ToMemrefOp>(loc, MemRefType::get(T.getShape(), T.getElementType()), V); | ||
return memrefToPointer(memref); | ||
}; | ||
|
||
Value ptrA = tensorToPointer(A, tensorA); | ||
Value ptrB = tensorToPointer(B, tensorB); | ||
Value ptrC = memrefToPointer(memrefC); | ||
|
||
int32_t M = tensorA.getShape()[0]; | ||
int32_t K = tensorA.getShape()[1]; | ||
int32_t N = tensorB.getShape()[1]; | ||
|
||
Value alpha = rewriter.create<arith::ConstantOp>(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0)); | ||
Value beta = alpha; | ||
|
||
auto constOp = [&rewriter, &loc, &intType](int32_t V) { | ||
return rewriter.create<arith::ConstantOp>(loc, intType, rewriter.getI32IntegerAttr(V)); | ||
}; | ||
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111); | ||
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K); | ||
Value LDA = KVal, LDB = NVal, LDC = NVal; | ||
|
||
auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{ | ||
CblasRowMajor, CblasNoTrans, CblasNoTrans, | ||
MVal, NVal, KVal, | ||
alpha, ptrA, LDA, | ||
ptrB, LDB, beta, | ||
ptrC, LDC | ||
}); | ||
|
||
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, | ||
tensorC, memrefC, true /* restrict */, true /* writable */); | ||
rewriter.replaceOp(op, toTensorOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
class TritonToLinearAlgebraSubprogramsPass | ||
: public triton::impl::TritonToLinearAlgebraSubprogramsBase<TritonToLinearAlgebraSubprogramsPass> { | ||
using TritonToLinearAlgebraSubprogramsBase< | ||
TritonToLinearAlgebraSubprogramsPass>::TritonToLinearAlgebraSubprogramsBase; | ||
|
||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry | ||
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect, | ||
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect, triton::TritonDialect>(); | ||
} | ||
|
||
void runOnOperation() override { | ||
auto moduleOp = getOperation(); | ||
RewritePatternSet patterns(&getContext()); | ||
ConversionTarget target(getContext()); | ||
|
||
patterns.add<MatmulConverter>(patterns.getContext()); | ||
|
||
target.addLegalDialect< | ||
func::FuncDialect, arith::ArithDialect, math::MathDialect, | ||
affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect, | ||
cf::ControlFlowDialect, tensor::TensorDialect, | ||
bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>(); | ||
|
||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> | ||
triton::createTritonToLinearAlgebraSubprogramsPass() { | ||
return std::make_unique<TritonToLinearAlgebraSubprogramsPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
import benchmark | ||
|
||
|
||
@triton.jit | ||
def bare_matmul(X, Y, Z, BLOCK_SIZE: tl.constexpr): | ||
pid = tl.program_id(0) | ||
|
||
offs_x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
offs_y = tl.arange(0, BLOCK_SIZE) | ||
|
||
x = tl.load(X + offs_x[:, None]) | ||
y = tl.load(Y + offs_y[None, :]) | ||
|
||
z = tl.dot(x, y) | ||
tl.store(Z + offs_x[:, None] + offs_y[None, :], z) | ||
|
||
|
||
@benchmark.measure() | ||
def bench_matmul(M, N, K, provider): | ||
device = 'cpu' | ||
dtype = torch.float32 | ||
a = torch.randn((M, K), device=device, dtype=dtype) | ||
b = torch.randn((K, N), device=device, dtype=dtype) | ||
c = torch.empty((K, N), device=device, dtype=dtype) | ||
if provider == 'torch': | ||
torch.matmul(a, b) | ||
if provider == 'triton': | ||
bare_matmul[(1,)](a, b, c, N) | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark.select_cpu_backend() | ||
for X in [2**i for i in range(7, 11, 1)]: | ||
for provider in ['torch', 'triton']: | ||
bench_matmul(X, X, X, provider) |
Oops, something went wrong.