-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating OpenBLAS for gemm #163
base: main
Are you sure you want to change the base?
Conversation
b2418d2
to
bcfe4c5
Compare
python/examples/bare_matmul.py
Outdated
|
||
offs_x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
offs_y = tl.arange(0, BLOCK_SIZE) | ||
|
||
x = tl.load(X + offs_x[:, None]) | ||
y = tl.load(Y + offs_y[None, :]) | ||
|
||
z = tl.dot(x, y) | ||
tl.store(Z + offs_x[:, None] + offs_y[None, :], z) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there's a mistake in this triton program. We're always doing matmul of size [BLOCK_SIZE, 1] @ [1, BLOCK_SIZE] which is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, thanks a lot for noticing, let me retry and get back to you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks again for noticing, I have fixed the kernel and the difference is very notable:
> python bare_matmul.py
bench_matmul(128, 'test') {}, 20 times, all results in seconds
Avg=0.035736, min=0.003327, 20pp=0.004988, 50pp=0.005001, 90pp=0.005019, max=0.621418
bench_matmul(128, 'torch') {}, 20 times, all results in seconds
Avg=0.000406, min=0.000218, 20pp=0.000219, 50pp=0.000219, 90pp=0.000403, max=0.002592
bench_matmul(128, 'triton') {}, 20 times, all results in seconds
Avg=0.001747, min=0.001088, 20pp=0.001148, 50pp=0.001160, 90pp=0.003361, max=0.003988
bench_matmul(256, 'test') {}, 20 times, all results in seconds
Avg=0.049110, min=0.005234, 20pp=0.005498, 50pp=0.006050, 90pp=0.019442, max=0.808195
bench_matmul(256, 'torch') {}, 20 times, all results in seconds
Avg=0.002304, min=0.000897, 20pp=0.000945, 50pp=0.002550, 90pp=0.003497, max=0.005217
bench_matmul(256, 'triton') {}, 20 times, all results in seconds
Avg=0.003742, min=0.003167, 20pp=0.003176, 50pp=0.003181, 90pp=0.005196, max=0.006016
bench_matmul(512, 'test') {}, 20 times, all results in seconds
Avg=0.076103, min=0.023856, 20pp=0.029637, 50pp=0.030622, 90pp=0.062977, max=0.842385
bench_matmul(512, 'torch') {}, 20 times, all results in seconds
Avg=0.005920, min=0.004630, 20pp=0.004641, 50pp=0.005528, 90pp=0.007751, max=0.011366
bench_matmul(512, 'triton') {}, 20 times, all results in seconds
Avg=0.011808, min=0.011642, 20pp=0.011654, 50pp=0.011680, 90pp=0.011719, max=0.014201
> export TRITON_SHARED_USE_OPENBLAS=
> python bare_matmul.py
bench_matmul(128, 'test') {}, 20 times, all results in seconds
Avg=0.075897, min=0.006302, 20pp=0.006325, 50pp=0.006347, 90pp=0.006507, max=1.396113
bench_matmul(128, 'torch') {}, 20 times, all results in seconds
Avg=0.000225, min=0.000218, 20pp=0.000218, 50pp=0.000219, 90pp=0.000230, max=0.000269
bench_matmul(128, 'triton') {}, 20 times, all results in seconds
Avg=0.005761, min=0.005674, 20pp=0.005698, 50pp=0.005723, 90pp=0.005954, max=0.005984
bench_matmul(256, 'test') {}, 20 times, all results in seconds
Avg=0.083720, min=0.045208, 20pp=0.045586, 50pp=0.045785, 90pp=0.046000, max=0.805519
bench_matmul(256, 'torch') {}, 20 times, all results in seconds
Avg=0.000898, min=0.000891, 20pp=0.000896, 50pp=0.000896, 90pp=0.000901, max=0.000925
bench_matmul(256, 'triton') {}, 20 times, all results in seconds
Avg=0.044448, min=0.043970, 20pp=0.044233, 50pp=0.044460, 90pp=0.044648, max=0.045219
bench_matmul(512, 'test') {}, 20 times, all results in seconds
Avg=0.395702, min=0.351327, 20pp=0.356543, 50pp=0.360086, 90pp=0.362602, max=1.096930
bench_matmul(512, 'torch') {}, 20 times, all results in seconds
Avg=0.004640, min=0.004615, 20pp=0.004621, 50pp=0.004629, 90pp=0.004654, max=0.004780
bench_matmul(512, 'triton') {}, 20 times, all results in seconds
Avg=0.348616, min=0.346798, 20pp=0.347046, 50pp=0.347344, 90pp=0.353371, max=0.354604
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111); | ||
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K); | ||
Value LDA = KVal, LDB = NVal, LDC = NVal; | ||
|
||
auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{ | ||
CblasRowMajor, CblasNoTrans, CblasNoTrans, | ||
MVal, NVal, KVal, | ||
alpha, ptrA, LDA, | ||
ptrB, LDB, beta, | ||
ptrC, LDC | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it would be useful to document the function signature here for future reference. e.g.: a link to https://github.com/OpenMathLib/OpenBLAS/blob/3ee9e9d8d050dfe6bb2733a8a87540eb7f70ff56/cblas.h#L62 would help understand where the 101 and 111 come from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
registry | ||
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect, | ||
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect, triton::TritonDialect>(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're missing the BufferizationDialect here. Without it we cannot run this pass individually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you, fixed!
python/examples/bare_matmul.py
Outdated
if provider == 'torch': | ||
torch.matmul(a, b) | ||
if provider == 'triton': | ||
bare_matmul[(1,)](a, b, c, N) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it would be useful to compare the results between torch and triton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great idea, I added "test" provider which makes comparisons
struct MatmulConverter : public OpConversionPattern<triton::DotOp> { | ||
using OpConversionPattern<triton::DotOp>::OpConversionPattern; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this internally, and we thought that mapping from linalg.matmul is a more suitable approach. The reason is we can optionally choose to do tiling of the matmul before deciding to delegate to the library call. Mapping straight from tt.dot
means we lose all the benefits of linalg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I started by making linalg:matmul -> gemm, but the issue is:
tt.dot is A*B + C, and gemm is the same, but it is getting converted into D = linalg:matmul(A, B); fadd(D, C)
at this moment is two operations need to be folded back to one, which is doable, but more complicated.
Could you please elaborate on tiling?
I was thinking this is a parameter we set externally and regardless of its value there is no difference in result between tt.dot -> gemm
vs tt.dot -> linalg:matmul + linalg:fadd -> gemm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The linalg dialect comes with a set of powerful transformation that can be used to "tile" the linalg operations. Tiling at a high-level means given a linalg operation that operates on a tensor, we will be able to transform it to instead operate on a subset of the initial tensor and generate the necessary code for it.
For instance, assume we have a highly optimized library implementation of a matmul with shape
So converting to the library call from linalg.matmul
instead of tt.dot
allows us to:
- perform any transformation necessary before deciding to delegate to the library
- adhere to the principal of gradual lowering in mlir: we gradually lower ops into lower level of representation and avoid raising abstraction.
- I think lowering to the library call first and having to map it back to linalg.matmul in case we want to do the transformation doesn't match this philosophy. We probably don't want to make another pass "BLASToLinalg" either.
In summary, I think between choosing having to fold the matmul + add to make a gemm library call vs. having to map the library call back to matmul, I'd prefer the former approach. We will be able to compose your new pass better with the rest of the project too as opposed to just a one-off code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly agree with you, I wasn't proposing to lower to gemm and then back to linalg.matmul, that would be odd. I am trying to understand what is the best point for the integration.
We have tiling on the level of triton kernel, what is good about this tiling, is that we can autotune it for the best performance. I am not sure how do we decide on the level of linalg which tiling is the best and I do not see where exactly it is happening if not in --convert-linalg-to-loops
, if that is the case, we are loosing ability to recognize the optimization after that.
And I do not propose to give up on tiling which is provided by linalg, I am just thinking if we have an optimized library call for some of the cases, we can use it. The benefit of gradual lowering in MLIR is that we can recognize optimizations and macro-operations on high level unlike pure LLVM. If it doesn't look good to have a function call on the level of linalg we can introduce additional dialect of LAS where we have a set of optimized primitives and we have "triton-to-las" which converts eligible operations to this dialect, then we have "las-to-llvm" where it generates library calls.
I think best performance and proper abstractions in this case are not mutually exclusive. I see two possible solutions:
- fold linalg:matmul + linalg:fadd -> gemm, that would be less code to change.
- introduce new dialect, lower from triton to it and lower to llvm from this dialect as a separate pass, if we assume that our library outperforms the alternative, it should be faster and we still have autotune tiling + we have tiling for the rest of the code.
Let me know what you think and thank you for your time spent reviewing the PRs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how do we decide on the level of linalg which tiling is the best and I do not see where exactly it is happening if not in --convert-linalg-to-loops, if that is the case, we are losing ability to recognize the optimization after that.
This tiling pass does not exist at the moment, but by leaving everything in linalg before converting to loops gives us the ability to do additional tiling if we ever decide to add this pass.
And I do not propose to give up on tiling which is provided by linalg, I am just thinking if we have an optimized library call for some of the cases, we can use it.
I agree with this. And so, the decision to use a highly optimized library call should not be right at the beginning of the compilation pipeline. OpenBLAS is probably always faster for matmul in isolation, but perhaps in some other contexts, the matmul may very well be fused with other ops and probably perform better. In such case, it is beneficial to delay converting to the library call as late as possible.
I think best performance and proper abstractions in this case are not mutually exclusive. I see two possible solutions
I'd prefer approach 1) (combining linalg.matmul
and linalg.add
back to a gemm) because 2) would mean we end up with two compilation pipelines. With approach 2) we lose the ability to tile the matmul unless we map it back to linalg.matmul. I'm suggesting this in the spirit of making everything as flexible as possible.
An alternative is to probably introduce an aggregate linalg.gemm
similarly to linalg.softmax
in the upstream mlir project (whether it's worth it is another question).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
bcfe4c5
to
55b4b42
Compare
@nhat-nguyen I think I resolved all your comments + added a lit test to showcase the conversion, let me know what you think, thank you! |
@nhat-nguyen could you please take a look when you have a chance |
@nhat-nguyen please take a look when you have a chance |
c7948fc
to
140acaf
Compare
Dropping benchmarks because they are covered by: |
140acaf
to
260c07f
Compare
Hi All,
I tried playing with different passes vectorization/parallelization in MLIR and it didn't help much. OpenBLAS helped to speed up matmult, I was initially looking at MLAS from onnxruntime, but it is not exposed externally and there are no plans at the moment as of: microsoft/onnxruntime#21644 (comment).
In this request I integrated only one function but we can bring more intrinsics if we see it as beneficial.
OpenBLAS speeds up regular triton matrix multiplication about 3x, while bare_matmult without splitting in blocks performs roughly the same with or without OpenBLAS.
I am not 100% sure this is a valuable addition, let me know what do you think. If you have any thoughts and ideas are welcome.
Below are the benchmarks: