Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating OpenBLAS for gemm #163

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

parsifal-47
Copy link
Contributor

Hi All,
I tried playing with different passes vectorization/parallelization in MLIR and it didn't help much. OpenBLAS helped to speed up matmult, I was initially looking at MLAS from onnxruntime, but it is not exposed externally and there are no plans at the moment as of: microsoft/onnxruntime#21644 (comment).
In this request I integrated only one function but we can bring more intrinsics if we see it as beneficial.

OpenBLAS speeds up regular triton matrix multiplication about 3x, while bare_matmult without splitting in blocks performs roughly the same with or without OpenBLAS.

I am not 100% sure this is a valuable addition, let me know what do you think. If you have any thoughts and ideas are welcome.

Below are the benchmarks:

$ cat /proc/cpuinfo | grep "model name" | head -1
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz

# No modifications:
$ python test_vec_add.py
bench_vecadd(4194304, 'torch') {}, 20 times, all results in seconds
Avg=0.065914, min=0.065114, 20pp=0.065659, 50pp=0.065868, 90pp=0.066209, max=0.067179
bench_vecadd(4194304, 'triton') {}, 20 times, all results in seconds
Avg=0.205786, min=0.130382, 20pp=0.131764, 50pp=0.134355, 90pp=0.144148, max=1.554073
bench_vecadd(8388608, 'torch') {}, 20 times, all results in seconds
Avg=0.109349, min=0.107811, 20pp=0.108896, 50pp=0.109233, 90pp=0.109751, max=0.113477
bench_vecadd(8388608, 'triton') {}, 20 times, all results in seconds
Avg=0.299948, min=0.296363, 20pp=0.296488, 50pp=0.296826, 90pp=0.305465, max=0.309868
bench_vecadd(16777216, 'torch') {}, 20 times, all results in seconds
Avg=0.253312, min=0.247950, 20pp=0.252528, 50pp=0.253736, 90pp=0.254402, max=0.257595
bench_vecadd(16777216, 'triton') {}, 20 times, all results in seconds
Avg=0.638587, min=0.624750, 20pp=0.632055, 50pp=0.641087, 90pp=0.645670, max=0.646468


# --convert-linalg-to-affine-loops -> --convert-linalg-to-parallel-loop

$ python test_vec_add.py
bench_vecadd(4194304, 'torch') {}, 20 times, all results in seconds
Avg=0.065570, min=0.064641, 20pp=0.064925, 50pp=0.065599, 90pp=0.066048, max=0.068330
bench_vecadd(4194304, 'triton') {}, 20 times, all results in seconds
Avg=0.205205, min=0.130430, 20pp=0.130840, 50pp=0.134660, 90pp=0.146575, max=1.540994
bench_vecadd(8388608, 'torch') {}, 20 times, all results in seconds
Avg=0.109661, min=0.107363, 20pp=0.109184, 50pp=0.109529, 90pp=0.110048, max=0.114105
bench_vecadd(8388608, 'triton') {}, 20 times, all results in seconds
Avg=0.305995, min=0.298448, 20pp=0.305863, 50pp=0.306140, 90pp=0.306550, max=0.310609
bench_vecadd(16777216, 'torch') {}, 20 times, all results in seconds
Avg=0.254508, min=0.250666, 20pp=0.254253, 50pp=0.254437, 90pp=0.255282, max=0.259281
bench_vecadd(16777216, 'triton') {}, 20 times, all results in seconds
Avg=0.640633, min=0.628056, 20pp=0.630829, 50pp=0.643688, 90pp=0.647847, max=0.651734

+
"--affine-loop-tile", "--affine-loop-unroll", "--affine-super-vectorize",
                               "--canonicalize",

$ python test_vec_add.py
bench_vecadd(4194304, 'torch') {}, 20 times, all results in seconds
Avg=0.065643, min=0.064924, 20pp=0.065124, 50pp=0.065437, 90pp=0.066262, max=0.067815
bench_vecadd(4194304, 'triton') {}, 20 times, all results in seconds
Avg=0.205465, min=0.130641, 20pp=0.130748, 50pp=0.134490, 90pp=0.146844, max=1.545270
bench_vecadd(8388608, 'torch') {}, 20 times, all results in seconds
Avg=0.109424, min=0.107926, 20pp=0.109237, 50pp=0.109433, 90pp=0.109909, max=0.110055
bench_vecadd(8388608, 'triton') {}, 20 times, all results in seconds
Avg=0.305716, min=0.297619, 20pp=0.305858, 50pp=0.306157, 90pp=0.307869, max=0.310986
bench_vecadd(16777216, 'torch') {}, 20 times, all results in seconds
Avg=0.254895, min=0.253817, 20pp=0.254431, 50pp=0.254680, 90pp=0.255207, max=0.259531
bench_vecadd(16777216, 'triton') {}, 20 times, all results in seconds
Avg=0.640136, min=0.628890, 20pp=0.629453, 50pp=0.643262, 90pp=0.647509, max=0.650128

+ "--linalg-generalize-named-ops", "--convert-elementwise-to-linalg",
                               "--linalg-fuse-elementwise-ops",

$ python test_vec_add.py
bench_vecadd(4194304, 'torch') {}, 20 times, all results in seconds
Avg=0.065553, min=0.064732, 20pp=0.065037, 50pp=0.065551, 90pp=0.066201, max=0.067161
bench_vecadd(4194304, 'triton') {}, 20 times, all results in seconds
Avg=0.213996, min=0.139990, 20pp=0.143892, 50pp=0.144243, 90pp=0.145541, max=1.541608
bench_vecadd(8388608, 'torch') {}, 20 times, all results in seconds
Avg=0.109342, min=0.108888, 20pp=0.109016, 50pp=0.109260, 90pp=0.109881, max=0.110115
bench_vecadd(8388608, 'triton') {}, 20 times, all results in seconds
Avg=0.321371, min=0.314948, 20pp=0.315855, 50pp=0.324178, 90pp=0.325304, max=0.328541
bench_vecadd(16777216, 'torch') {}, 20 times, all results in seconds
Avg=0.254597, min=0.253265, 20pp=0.253919, 50pp=0.254303, 90pp=0.255454, max=0.259235
bench_vecadd(16777216, 'triton') {}, 20 times, all results in seconds
Avg=0.677508, min=0.666039, 20pp=0.667093, 50pp=0.676524, 90pp=0.686813, max=0.689968

$ export TRITON_SHARED_USE_OPENBLAS=
$ python test_matmul.py
bench_matmul(256, 256, 256, 'torch') {}, 20 times, all results in seconds
Avg=0.000979, min=0.000889, 20pp=0.000893, 50pp=0.000895, 90pp=0.000936, max=0.002362
bench_matmul(256, 256, 256, 'triton') {}, 20 times, all results in seconds
Avg=0.103045, min=0.031649, 20pp=0.031727, 50pp=0.031764, 90pp=0.031896, max=1.457253
bench_matmul(384, 384, 384, 'torch') {}, 20 times, all results in seconds
Avg=0.002137, min=0.002078, 20pp=0.002078, 50pp=0.002082, 90pp=0.002093, max=0.003153
bench_matmul(384, 384, 384, 'triton') {}, 20 times, all results in seconds
Avg=0.106997, min=0.103862, 20pp=0.106970, 50pp=0.107079, 90pp=0.107910, max=0.110387
bench_matmul(512, 512, 512, 'torch') {}, 20 times, all results in seconds
Avg=0.004650, min=0.004612, 20pp=0.004620, 50pp=0.004622, 90pp=0.004644, max=0.005120
bench_matmul(512, 512, 512, 'triton') {}, 20 times, all results in seconds
Avg=0.250993, min=0.250219, 20pp=0.250416, 50pp=0.250718, 90pp=0.252334, max=0.253804
bench_matmul(640, 640, 640, 'torch') {}, 20 times, all results in seconds
Avg=0.007586, min=0.007550, 20pp=0.007560, 50pp=0.007565, 90pp=0.007603, max=0.007778
bench_matmul(640, 640, 640, 'triton') {}, 20 times, all results in seconds
Avg=0.482236, min=0.472270, 20pp=0.475301, 50pp=0.485550, 90pp=0.486853, max=0.489880
bench_matmul(768, 768, 768, 'torch') {}, 20 times, all results in seconds
Avg=0.011860, min=0.011636, 20pp=0.011804, 50pp=0.011818, 90pp=0.011949, max=0.012337
bench_matmul(768, 768, 768, 'triton') {}, 20 times, all results in seconds
Avg=0.827043, min=0.812784, 20pp=0.814049, 50pp=0.828914, 90pp=0.840402, max=0.841761

$ python bare_matmul.py
bench_matmul(128, 128, 128, 'torch') {}, 20 times, all results in seconds
Avg=0.000282, min=0.000225, 20pp=0.000226, 50pp=0.000230, 90pp=0.000252, max=0.001159
bench_matmul(128, 128, 128, 'triton') {}, 20 times, all results in seconds
Avg=0.070500, min=0.000823, 20pp=0.000849, 50pp=0.000910, 90pp=0.001053, max=1.392515
bench_matmul(256, 256, 256, 'torch') {}, 20 times, all results in seconds
Avg=0.000948, min=0.000893, 20pp=0.000896, 50pp=0.000921, 90pp=0.000956, max=0.001581
bench_matmul(256, 256, 256, 'triton') {}, 20 times, all results in seconds
Avg=0.039535, min=0.001713, 20pp=0.001747, 50pp=0.001880, 90pp=0.002178, max=0.754220
bench_matmul(512, 512, 512, 'torch') {}, 20 times, all results in seconds
Avg=0.005258, min=0.005229, 20pp=0.005233, 50pp=0.005242, 90pp=0.005281, max=0.005486
bench_matmul(512, 512, 512, 'triton') {}, 20 times, all results in seconds
Avg=0.044105, min=0.005717, 20pp=0.005738, 50pp=0.005789, 90pp=0.007190, max=0.765825
bench_matmul(1024, 1024, 1024, 'torch') {}, 20 times, all results in seconds
Avg=0.024517, min=0.024354, 20pp=0.024435, 50pp=0.024483, 90pp=0.024669, max=0.024701
bench_matmul(1024, 1024, 1024, 'triton') {}, 20 times, all results in seconds
Avg=0.060229, min=0.021727, 20pp=0.021837, 50pp=0.021867, 90pp=0.022531, max=0.783773

$ export TRITON_SHARED_USE_OPENBLAS=1
$ python test_matmul.py
bench_matmul(256, 256, 256, 'torch') {}, 20 times, all results in seconds
Avg=0.000976, min=0.000888, 20pp=0.000893, 50pp=0.000895, 90pp=0.000955, max=0.002353
bench_matmul(256, 256, 256, 'triton') {}, 20 times, all results in seconds
Avg=0.087064, min=0.011839, 20pp=0.011897, 50pp=0.011988, 90pp=0.013468, max=1.507936
bench_matmul(384, 384, 384, 'torch') {}, 20 times, all results in seconds
Avg=0.002139, min=0.002078, 20pp=0.002080, 50pp=0.002087, 90pp=0.002114, max=0.003100
bench_matmul(384, 384, 384, 'triton') {}, 20 times, all results in seconds
Avg=0.037183, min=0.036923, 20pp=0.037050, 50pp=0.037152, 90pp=0.037378, max=0.037735
bench_matmul(512, 512, 512, 'torch') {}, 20 times, all results in seconds
Avg=0.004644, min=0.004612, 20pp=0.004623, 50pp=0.004631, 90pp=0.004643, max=0.004906
bench_matmul(512, 512, 512, 'triton') {}, 20 times, all results in seconds
Avg=0.084316, min=0.082523, 20pp=0.082774, 50pp=0.083433, 90pp=0.085976, max=0.089577
bench_matmul(640, 640, 640, 'torch') {}, 20 times, all results in seconds
Avg=0.007621, min=0.007578, 20pp=0.007598, 50pp=0.007604, 90pp=0.007632, max=0.007830
bench_matmul(640, 640, 640, 'triton') {}, 20 times, all results in seconds
Avg=0.164944, min=0.162469, 20pp=0.162578, 50pp=0.163313, 90pp=0.167781, max=0.171945
bench_matmul(768, 768, 768, 'torch') {}, 20 times, all results in seconds
Avg=0.011879, min=0.011835, 20pp=0.011855, 50pp=0.011867, 90pp=0.011912, max=0.011990
bench_matmul(768, 768, 768, 'triton') {}, 20 times, all results in seconds
Avg=0.284718, min=0.278498, 20pp=0.283348, 50pp=0.284959, 90pp=0.287150, max=0.290185

$ python bare_matmul.py
bench_matmul(128, 128, 128, 'torch') {}, 20 times, all results in seconds
Avg=0.000281, min=0.000224, 20pp=0.000225, 50pp=0.000231, 90pp=0.000253, max=0.001170
bench_matmul(128, 128, 128, 'triton') {}, 20 times, all results in seconds
Avg=0.073045, min=0.000952, 20pp=0.000991, 50pp=0.001004, 90pp=0.001045, max=1.441713
bench_matmul(256, 256, 256, 'torch') {}, 20 times, all results in seconds
Avg=0.003961, min=0.000894, 20pp=0.000896, 50pp=0.000898, 90pp=0.013107, max=0.023987
bench_matmul(256, 256, 256, 'triton') {}, 20 times, all results in seconds
Avg=0.042639, min=0.001894, 20pp=0.001989, 50pp=0.002191, 90pp=0.002414, max=0.811260
bench_matmul(512, 512, 512, 'torch') {}, 20 times, all results in seconds
Avg=0.005257, min=0.005227, 20pp=0.005233, 50pp=0.005241, 90pp=0.005255, max=0.005526
bench_matmul(512, 512, 512, 'triton') {}, 20 times, all results in seconds
Avg=0.047079, min=0.006016, 20pp=0.006043, 50pp=0.006064, 90pp=0.007346, max=0.819966
bench_matmul(1024, 1024, 1024, 'torch') {}, 20 times, all results in seconds
Avg=0.024563, min=0.024479, 20pp=0.024518, 50pp=0.024554, 90pp=0.024602, max=0.024773
bench_matmul(1024, 1024, 1024, 'triton') {}, 20 times, all results in seconds
Avg=0.065703, min=0.022247, 20pp=0.024235, 50pp=0.024251, 90pp=0.025087, max=0.850108

Comment on lines 11 to 24

offs_x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_y = tl.arange(0, BLOCK_SIZE)

x = tl.load(X + offs_x[:, None])
y = tl.load(Y + offs_y[None, :])

z = tl.dot(x, y)
tl.store(Z + offs_x[:, None] + offs_y[None, :], z)

Copy link
Collaborator

@nhat-nguyen nhat-nguyen Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there's a mistake in this triton program. We're always doing matmul of size [BLOCK_SIZE, 1] @ [1, BLOCK_SIZE] which is $[N, 1] @ [1, N]$ in the triton version. But the torch version is doing $[N, N] @ [N, N]$. Perhaps this is why bare_matmul doesn't show any speed up at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, thanks a lot for noticing, let me retry and get back to you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks again for noticing, I have fixed the kernel and the difference is very notable:

> python bare_matmul.py
bench_matmul(128, 'test') {}, 20 times, all results in seconds
Avg=0.035736, min=0.003327, 20pp=0.004988, 50pp=0.005001, 90pp=0.005019, max=0.621418
bench_matmul(128, 'torch') {}, 20 times, all results in seconds
Avg=0.000406, min=0.000218, 20pp=0.000219, 50pp=0.000219, 90pp=0.000403, max=0.002592
bench_matmul(128, 'triton') {}, 20 times, all results in seconds
Avg=0.001747, min=0.001088, 20pp=0.001148, 50pp=0.001160, 90pp=0.003361, max=0.003988
bench_matmul(256, 'test') {}, 20 times, all results in seconds
Avg=0.049110, min=0.005234, 20pp=0.005498, 50pp=0.006050, 90pp=0.019442, max=0.808195
bench_matmul(256, 'torch') {}, 20 times, all results in seconds
Avg=0.002304, min=0.000897, 20pp=0.000945, 50pp=0.002550, 90pp=0.003497, max=0.005217
bench_matmul(256, 'triton') {}, 20 times, all results in seconds
Avg=0.003742, min=0.003167, 20pp=0.003176, 50pp=0.003181, 90pp=0.005196, max=0.006016
bench_matmul(512, 'test') {}, 20 times, all results in seconds
Avg=0.076103, min=0.023856, 20pp=0.029637, 50pp=0.030622, 90pp=0.062977, max=0.842385
bench_matmul(512, 'torch') {}, 20 times, all results in seconds
Avg=0.005920, min=0.004630, 20pp=0.004641, 50pp=0.005528, 90pp=0.007751, max=0.011366
bench_matmul(512, 'triton') {}, 20 times, all results in seconds
Avg=0.011808, min=0.011642, 20pp=0.011654, 50pp=0.011680, 90pp=0.011719, max=0.014201
> export TRITON_SHARED_USE_OPENBLAS=
> python bare_matmul.py
bench_matmul(128, 'test') {}, 20 times, all results in seconds
Avg=0.075897, min=0.006302, 20pp=0.006325, 50pp=0.006347, 90pp=0.006507, max=1.396113
bench_matmul(128, 'torch') {}, 20 times, all results in seconds
Avg=0.000225, min=0.000218, 20pp=0.000218, 50pp=0.000219, 90pp=0.000230, max=0.000269
bench_matmul(128, 'triton') {}, 20 times, all results in seconds
Avg=0.005761, min=0.005674, 20pp=0.005698, 50pp=0.005723, 90pp=0.005954, max=0.005984
bench_matmul(256, 'test') {}, 20 times, all results in seconds
Avg=0.083720, min=0.045208, 20pp=0.045586, 50pp=0.045785, 90pp=0.046000, max=0.805519
bench_matmul(256, 'torch') {}, 20 times, all results in seconds
Avg=0.000898, min=0.000891, 20pp=0.000896, 50pp=0.000896, 90pp=0.000901, max=0.000925
bench_matmul(256, 'triton') {}, 20 times, all results in seconds
Avg=0.044448, min=0.043970, 20pp=0.044233, 50pp=0.044460, 90pp=0.044648, max=0.045219
bench_matmul(512, 'test') {}, 20 times, all results in seconds
Avg=0.395702, min=0.351327, 20pp=0.356543, 50pp=0.360086, 90pp=0.362602, max=1.096930
bench_matmul(512, 'torch') {}, 20 times, all results in seconds
Avg=0.004640, min=0.004615, 20pp=0.004621, 50pp=0.004629, 90pp=0.004654, max=0.004780
bench_matmul(512, 'triton') {}, 20 times, all results in seconds
Avg=0.348616, min=0.346798, 20pp=0.347046, 50pp=0.347344, 90pp=0.353371, max=0.354604

Comment on lines 118 to 129
Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111);
Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K);
Value LDA = KVal, LDB = NVal, LDC = NVal;

auto funcOp = rewriter.create<func::CallOp>(loc, func, ValueRange{
CblasRowMajor, CblasNoTrans, CblasNoTrans,
MVal, NVal, KVal,
alpha, ptrA, LDA,
ptrB, LDB, beta,
ptrC, LDC
});

Copy link
Collaborator

@nhat-nguyen nhat-nguyen Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be useful to document the function signature here for future reference. e.g.: a link to https://github.com/OpenMathLib/OpenBLAS/blob/3ee9e9d8d050dfe6bb2733a8a87540eb7f70ff56/cblas.h#L62 would help understand where the 101 and 111 come from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect, math::MathDialect,
affine::AffineDialect, scf::SCFDialect, tensor::TensorDialect, LLVM::LLVMDialect, triton::TritonDialect>();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing the BufferizationDialect here. Without it we cannot run this pass individually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you, fixed!

Comment on lines 29 to 39
if provider == 'torch':
torch.matmul(a, b)
if provider == 'triton':
bare_matmul[(1,)](a, b, c, N)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be useful to compare the results between torch and triton

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea, I added "test" provider which makes comparisons

Comment on lines 37 to 38
struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this internally, and we thought that mapping from linalg.matmul is a more suitable approach. The reason is we can optionally choose to do tiling of the matmul before deciding to delegate to the library call. Mapping straight from tt.dot means we lose all the benefits of linalg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I started by making linalg:matmul -> gemm, but the issue is:
tt.dot is A*B + C, and gemm is the same, but it is getting converted into D = linalg:matmul(A, B); fadd(D, C)
at this moment is two operations need to be folded back to one, which is doable, but more complicated.

Could you please elaborate on tiling?
I was thinking this is a parameter we set externally and regardless of its value there is no difference in result between tt.dot -> gemm vs tt.dot -> linalg:matmul + linalg:fadd -> gemm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The linalg dialect comes with a set of powerful transformation that can be used to "tile" the linalg operations. Tiling at a high-level means given a linalg operation that operates on a tensor, we will be able to transform it to instead operate on a subset of the initial tensor and generate the necessary code for it.

For instance, assume we have a highly optimized library implementation of a matmul with shape $[512, 512] @ [512, 512]$, and our input matmul is $[1024, 512] @ [512, 512]$. We can decide to split the left-hand side matrix into two halves by row and perform two separate library calls. The linalg transformation allows us to do this fairly easily.

So converting to the library call from linalg.matmul instead of tt.dot allows us to:

  • perform any transformation necessary before deciding to delegate to the library
  • adhere to the principal of gradual lowering in mlir: we gradually lower ops into lower level of representation and avoid raising abstraction.
    • I think lowering to the library call first and having to map it back to linalg.matmul in case we want to do the transformation doesn't match this philosophy. We probably don't want to make another pass "BLASToLinalg" either.

In summary, I think between choosing having to fold the matmul + add to make a gemm library call vs. having to map the library call back to matmul, I'd prefer the former approach. We will be able to compose your new pass better with the rest of the project too as opposed to just a one-off code path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly agree with you, I wasn't proposing to lower to gemm and then back to linalg.matmul, that would be odd. I am trying to understand what is the best point for the integration.

We have tiling on the level of triton kernel, what is good about this tiling, is that we can autotune it for the best performance. I am not sure how do we decide on the level of linalg which tiling is the best and I do not see where exactly it is happening if not in --convert-linalg-to-loops, if that is the case, we are loosing ability to recognize the optimization after that.

And I do not propose to give up on tiling which is provided by linalg, I am just thinking if we have an optimized library call for some of the cases, we can use it. The benefit of gradual lowering in MLIR is that we can recognize optimizations and macro-operations on high level unlike pure LLVM. If it doesn't look good to have a function call on the level of linalg we can introduce additional dialect of LAS where we have a set of optimized primitives and we have "triton-to-las" which converts eligible operations to this dialect, then we have "las-to-llvm" where it generates library calls.

I think best performance and proper abstractions in this case are not mutually exclusive. I see two possible solutions:

  1. fold linalg:matmul + linalg:fadd -> gemm, that would be less code to change.
  2. introduce new dialect, lower from triton to it and lower to llvm from this dialect as a separate pass, if we assume that our library outperforms the alternative, it should be faster and we still have autotune tiling + we have tiling for the rest of the code.

Let me know what you think and thank you for your time spent reviewing the PRs!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how do we decide on the level of linalg which tiling is the best and I do not see where exactly it is happening if not in --convert-linalg-to-loops, if that is the case, we are losing ability to recognize the optimization after that.

This tiling pass does not exist at the moment, but by leaving everything in linalg before converting to loops gives us the ability to do additional tiling if we ever decide to add this pass.

And I do not propose to give up on tiling which is provided by linalg, I am just thinking if we have an optimized library call for some of the cases, we can use it.

I agree with this. And so, the decision to use a highly optimized library call should not be right at the beginning of the compilation pipeline. OpenBLAS is probably always faster for matmul in isolation, but perhaps in some other contexts, the matmul may very well be fused with other ops and probably perform better. In such case, it is beneficial to delay converting to the library call as late as possible.

I think best performance and proper abstractions in this case are not mutually exclusive. I see two possible solutions

I'd prefer approach 1) (combining linalg.matmul and linalg.add back to a gemm) because 2) would mean we end up with two compilation pipelines. With approach 2) we lose the ability to tile the matmul unless we map it back to linalg.matmul. I'm suggesting this in the spirit of making everything as flexible as possible.

An alternative is to probably introduce an aggregate linalg.gemm similarly to linalg.softmax in the upstream mlir project (whether it's worth it is another question).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@parsifal-47
Copy link
Contributor Author

@nhat-nguyen I think I resolved all your comments + added a lit test to showcase the conversion, let me know what you think, thank you!

@parsifal-47
Copy link
Contributor Author

@nhat-nguyen could you please take a look when you have a chance

@parsifal-47
Copy link
Contributor Author

@nhat-nguyen please take a look when you have a chance

@parsifal-47 parsifal-47 changed the title Adding benchmarks, integrating OpenBLAS Integrating OpenBLAS for gemm Jan 4, 2025
@parsifal-47
Copy link
Contributor Author

Dropping benchmarks because they are covered by:
#209

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants