-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjust the operand order of tt.dot to linalg.matmul converter #191
Conversation
thanks for your pr! could you help me understand why swapping the order is necessary? does having the rhs as dps init like we have currently introduce any incorrect codegen? |
This is the FileCheck case of dot Arith to linalg code I don't know why this bug does not cause triton-shared cpu backend got mismatch error, but in our customed SPIRV backend, it did. |
@microsoft-github-policy-service agree
…---Original---
From: ***@***.***>
Date: Fri, Nov 22, 2024 09:52 AM
To: ***@***.***>;
Cc: ***@***.******@***.***>;
Subject: Re: [microsoft/triton-shared] Adjust the operand order of tt.dot tolinalg.matmul converter (PR #191)
@MercuryChen please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
@microsoft-github-policy-service agree [company="{your company}"]
Options:
(default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
(when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft" Contributor License Agreement
Contribution License Agreement
This Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
and conveys certain license rights to Microsoft Corporation and its affiliates (“Microsoft”) for Your
contributions to Microsoft open source projects. This Agreement is effective as of the latest signature
date below.
Definitions.
“Code” means the computer software code, whether in human-readable or machine-executable form,
that is delivered by You to Microsoft under this Agreement.
“Project” means any of the projects owned or managed by Microsoft and offered under a license
approved by the Open Source Initiative (www.opensource.org).
“Submit” is the act of uploading, submitting, transmitting, or distributing code or other content to any
Project, including but not limited to communication on electronic mailing lists, source code control
systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of
discussing and improving that Project, but excluding communication that is conspicuously marked or
otherwise designated in writing by You as “Not a Submission.”
“Submission” means the Code and any other copyrightable material Submitted by You, including any
associated comments and documentation.
Your Submission. You must agree to the terms of this Agreement before making a Submission to any
Project. This Agreement covers any and all Submissions that You, now or in the future (except as
described in Section 4 below), Submit to any Project.
Originality of Work. You represent that each of Your Submissions is entirely Your original work.
Should You wish to Submit materials that are not Your original work, You may Submit them separately
to the Project if You (a) retain all copyright and license information that was in the materials as You
received them, (b) in the description accompanying Your Submission, include the phrase “Submission
containing materials of a third party:” followed by the names of the third party and any licenses or other
restrictions of which You are aware, and (c) follow any other instructions in the Project’s written
guidelines concerning Submissions.
Your Employer. References to “employer” in this Agreement include Your employer or anyone else
for whom You are acting in making Your Submission, e.g. as a contractor, vendor, or agent. If Your
Submission is made in the course of Your work for an employer or Your employer has intellectual
property rights in Your Submission by contract or applicable law, You must secure permission from Your
employer to make the Submission before signing this Agreement. In that case, the term “You” in this
Agreement will refer to You and the employer collectively. If You change employers in the future and
desire to Submit additional Submissions for the new employer, then You agree to sign a new Agreement
and secure permission from the new employer before Submitting those Submissions.
Licenses.
Copyright License. You grant Microsoft, and those who receive the Submission directly or
indirectly from Microsoft, a perpetual, worldwide, non-exclusive, royalty-free, irrevocable license in the
Submission to reproduce, prepare derivative works of, publicly display, publicly perform, and distribute
the Submission and such derivative works, and to sublicense any or all of the foregoing rights to third
parties.
Patent License. You grant Microsoft, and those who receive the Submission directly or
indirectly from Microsoft, a perpetual, worldwide, non-exclusive, royalty-free, irrevocable license under
Your patent claims that are necessarily infringed by the Submission or the combination of the
Submission with the Project to which it was Submitted to make, have made, use, offer to sell, sell and
import or otherwise dispose of the Submission alone or with the Project.
Other Rights Reserved. Each party reserves all rights not expressly granted in this Agreement.
No additional licenses or rights whatsoever (including, without limitation, any implied licenses) are
granted by implication, exhaustion, estoppel or otherwise.
Representations and Warranties. You represent that You are legally entitled to grant the above
licenses. You represent that each of Your Submissions is entirely Your original work (except as You may
have disclosed under Section 3). You represent that You have secured permission from Your employer to
make the Submission in cases where Your Submission is made in the course of Your work for Your
employer or Your employer has intellectual property rights in Your Submission by contract or applicable
law. If You are signing this Agreement on behalf of Your employer, You represent and warrant that You
have the necessary authority to bind the listed employer to the obligations contained in this Agreement.
You are not expected to provide support for Your Submission, unless You choose to do so. UNLESS
REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING, AND EXCEPT FOR THE WARRANTIES
EXPRESSLY STATED IN SECTIONS 3, 4, AND 6, THE SUBMISSION PROVIDED UNDER THIS AGREEMENT IS
PROVIDED WITHOUT WARRANTY OF ANY KIND, INCLUDING, BUT NOT LIMITED TO, ANY WARRANTY OF
NONINFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
Notice to Microsoft. You agree to notify Microsoft in writing of any facts or circumstances of which
You later become aware that would make Your representations in this Agreement inaccurate in any
respect.
Information about Submissions. You agree that contributions to Projects and information about
contributions may be maintained indefinitely and disclosed publicly, including Your name and other
information that You submit with Your Submission.
Governing Law/Jurisdiction. This Agreement is governed by the laws of the State of Washington, and
the parties consent to exclusive jurisdiction and venue in the federal courts sitting in King County,
Washington, unless no federal subject matter jurisdiction exists, in which case the parties consent to
exclusive jurisdiction and venue in the Superior Court of King County, Washington. The parties waive all
defenses of lack of personal jurisdiction and forum non-conveniens.
Entire Agreement/Assignment. This Agreement is the entire agreement between the parties, and
supersedes any and all prior agreements, understandings or communications, written or oral, between
the parties relating to the subject matter hereof. This Agreement may be assigned by Microsoft.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
@MercuryChen Would you be able to share the buggy IR after the |
Thanks for your reply! #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%41 = arith.addi %arg18, %16 : index
%42 = arith.remsi %41, %22 : index
%43 = arith.subi %41, %42 : index
%44 = arith.addi %42, %c64 : index
%45 = arith.minsi %44, %22 : index
%46 = arith.subi %45, %42 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%47 = arith.subi %c64, %46 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%48 = arith.remsi %arg17, %18 : index
%49 = arith.addi %20, %48 : index
%50 = arith.subi %49, %arg17 : index
%51 = arith.divsi %50, %18 : index
%reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%52 = arith.subi %c32, %51 : index
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%53 = arith.muli %arg15, %c16_i32 : i32
%54 = arith.subi %arg5, %53 : i32
%55 = arith.index_cast %54 : i32 to index
%56 = arith.minsi %55, %c16 : index
%57 = arith.maxsi %56, %c0 : index
%alloc_6 = memref.alloc() : memref<32x16xf32>
%58 = arith.cmpi slt, %57, %c16 : index
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
}
%59 = arith.minsi %51, %c32 : index
%60 = arith.subi %c32, %59 : index
%subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x64xf32>
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
}
%61 = arith.minsi %46, %c64 : index
%62 = arith.subi %c64, %61 : index
%subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%arg16: memref<32x64xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%65 = arith.addf %in, %in_17 : f32
linalg.yield %65 : f32
}
%63 = arith.addi %arg17, %c16 : index
%64 = arith.addi %arg18, %26 : index
scf.yield %alloc_16, %63, %64 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.maxsi %32, %14 : index
%34 = arith.subi %33, %14 : index
%35 = arith.addi %16, %c64 : index
%36 = arith.minsi %35, %22 : index
%37 = arith.maxsi %36, %16 : index
%38 = arith.subi %37, %16 : index
%39 = arith.minsi %34, %c32 : index
%40 = arith.minsi %38, %c64 : index
%subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
return
}
} The IR before change: ...
// just replace the `out` from %alloc_16 to %arg16
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_16, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_16: memref<32x64xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%65 = arith.addf %in, %in_17 : f32
linalg.yield %65 : f32
}
... And the correct IR before bufferize: #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
%2 = arith.addi %arg3, %c31_i32 : i32
%3 = arith.divsi %2, %c32_i32 : i32
%4 = arith.addi %arg4, %c63_i32 : i32
%5 = arith.divsi %4, %c64_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %arg12, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %arg12, %10 : i32
%12 = arith.addi %8, %11 : i32
%13 = arith.remsi %arg12, %6 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c32_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.muli %14, %c64_i32 : i32
%18 = arith.index_cast %17 : i32 to index
%19 = arith.index_cast %arg3 : i32 to index
%20 = arith.index_cast %arg6 : i32 to index
%21 = arith.muli %16, %20 : index
%22 = arith.muli %19, %20 : index
%23 = arith.index_cast %arg7 : i32 to index
%24 = arith.index_cast %arg4 : i32 to index
%25 = arith.addi %arg5, %c15_i32 : i32
%26 = arith.divsi %25, %c16_i32 : i32
%27 = arith.muli %arg7, %c16_i32 : i32
%28 = arith.index_cast %27 : i32 to index
%29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index) : i32 {
%43 = arith.addi %arg18, %18 : index
%44 = arith.remsi %43, %24 : index
%45 = arith.subi %43, %44 : index
%46 = arith.addi %44, %c64 : index
%47 = arith.minsi %46, %24 : index
%48 = arith.subi %47, %44 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%49 = arith.subi %c64, %48 : index
%reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%50 = arith.remsi %arg17, %20 : index
%51 = arith.addi %22, %50 : index
%52 = arith.subi %51, %arg17 : index
%53 = arith.divsi %52, %20 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%54 = arith.subi %c32, %53 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%55 = arith.muli %arg15, %c16_i32 : i32
%56 = arith.subi %arg5, %55 : i32
%57 = arith.index_cast %56 : i32 to index
%58 = arith.minsi %57, %c16 : index
%59 = arith.maxsi %58, %c0 : index
%alloc = memref.alloc() : memref<32x16xf32>
%60 = arith.cmpi slt, %59, %c16 : index
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
}
%61 = arith.minsi %53, %c32 : index
%62 = arith.subi %c32, %61 : index
%subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
%alloc_8 = memref.alloc() : memref<16x64xf32>
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
}
%64 = arith.minsi %48, %c64 : index
%65 = arith.subi %c64, %64 : index
%subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
%67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
%68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%arg16 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_13: f32, %out: f32):
%71 = arith.addf %in, %in_13 : f32
linalg.yield %71 : f32
} -> tensor<32x64xf32>
%69 = arith.addi %arg17, %c16 : index
%70 = arith.addi %arg18, %28 : index
scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
}
%30 = arith.index_cast %arg8 : i32 to index
%31 = arith.muli %16, %30 : index
%32 = arith.addi %31, %18 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%33 = arith.addi %16, %c32 : index
%34 = arith.minsi %33, %19 : index
%35 = arith.maxsi %34, %16 : index
%36 = arith.subi %35, %16 : index
%37 = arith.addi %18, %c64 : index
%38 = arith.minsi %37, %24 : index
%39 = arith.maxsi %38, %18 : index
%40 = arith.subi %39, %18 : index
%41 = arith.minsi %36, %c32 : index
%42 = arith.minsi %40, %c64 : index
%extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
%subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
} As the IR shows,
In the tensor semantic, it's SSA, so maybe the proper way is create an new empty tensor for the DPS init. |
Thanks for your quick response. Would you mind also sharing the IR before bufferization too? Sorry I should have asked in the previous reply. 😄 |
This seems like a good idea to try too. I would be curious to see what the bufferization result looks like if you manually create a linalg.generic {addf} with an empty tensor as its |
Sorry for my mistake accidentally close the PR. To use the new empty tensor as |
What if you simply use an empty tensor as out without explicitly inserting linalg.copy? What does the bufferization output look like? The reason why I'm asking this question is there seems to be a fundamental issue with how we use the out params which could produce incorrect codegen in various scenarios. While swapping the order of the two operands would fix this matmul bug, it would be great if we can figure out the correct bufferization behaviour. |
Yes, do not need insert #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
%2 = arith.addi %arg3, %c31_i32 : i32
%3 = arith.divsi %2, %c32_i32 : i32
%4 = arith.addi %arg4, %c63_i32 : i32
%5 = arith.divsi %4, %c64_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %arg12, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %arg12, %10 : i32
%12 = arith.addi %8, %11 : i32
%13 = arith.remsi %arg12, %6 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c32_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.muli %14, %c64_i32 : i32
%18 = arith.index_cast %17 : i32 to index
%19 = arith.index_cast %arg3 : i32 to index
%20 = arith.index_cast %arg6 : i32 to index
%21 = arith.muli %16, %20 : index
%22 = arith.muli %19, %20 : index
%23 = arith.index_cast %arg7 : i32 to index
%24 = arith.index_cast %arg4 : i32 to index
%25 = arith.addi %arg5, %c15_i32 : i32
%26 = arith.divsi %25, %c16_i32 : i32
%27 = arith.muli %arg7, %c16_i32 : i32
%28 = arith.index_cast %27 : i32 to index
%29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index) : i32 {
%43 = arith.addi %arg18, %18 : index
%44 = arith.remsi %43, %24 : index
%45 = arith.subi %43, %44 : index
%46 = arith.addi %44, %c64 : index
%47 = arith.minsi %46, %24 : index
%48 = arith.subi %47, %44 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%49 = arith.subi %c64, %48 : index
%reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%50 = arith.remsi %arg17, %20 : index
%51 = arith.addi %22, %50 : index
%52 = arith.subi %51, %arg17 : index
%53 = arith.divsi %52, %20 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%54 = arith.subi %c32, %53 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%55 = arith.muli %arg15, %c16_i32 : i32
%56 = arith.subi %arg5, %55 : i32
%57 = arith.index_cast %56 : i32 to index
%58 = arith.minsi %57, %c16 : index
%59 = arith.maxsi %58, %c0 : index
%alloc = memref.alloc() : memref<32x16xf32>
%60 = arith.cmpi slt, %59, %c16 : index
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
}
%61 = arith.minsi %53, %c32 : index
%62 = arith.subi %c32, %61 : index
%subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
%alloc_8 = memref.alloc() : memref<16x64xf32>
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
}
%64 = arith.minsi %48, %c64 : index
%65 = arith.subi %c64, %64 : index
%subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
%67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
%manual_empty = tensor.empty() : tensor<32x64xf32>
%68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%manual_empty : tensor<32x64xf32>) {
^bb0(%in: f32, %in_13: f32, %out: f32):
%71 = arith.addf %in, %in_13 : f32
linalg.yield %71 : f32
} -> tensor<32x64xf32>
%69 = arith.addi %arg17, %c16 : index
%70 = arith.addi %arg18, %28 : index
scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
}
%30 = arith.index_cast %arg8 : i32 to index
%31 = arith.muli %16, %30 : index
%32 = arith.addi %31, %18 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%33 = arith.addi %16, %c32 : index
%34 = arith.minsi %33, %19 : index
%35 = arith.maxsi %34, %16 : index
%36 = arith.subi %35, %16 : index
%37 = arith.addi %18, %c64 : index
%38 = arith.minsi %37, %24 : index
%39 = arith.maxsi %38, %18 : index
%40 = arith.subi %39, %18 : index
%41 = arith.minsi %36, %c32 : index
%42 = arith.minsi %40, %c64 : index
%extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
%subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
}
After bufferization: #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%41 = arith.addi %arg18, %16 : index
%42 = arith.remsi %41, %22 : index
%43 = arith.subi %41, %42 : index
%44 = arith.addi %42, %c64 : index
%45 = arith.minsi %44, %22 : index
%46 = arith.subi %45, %42 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%47 = arith.subi %c64, %46 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%48 = arith.remsi %arg17, %18 : index
%49 = arith.addi %20, %48 : index
%50 = arith.subi %49, %arg17 : index
%51 = arith.divsi %50, %18 : index
%reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%52 = arith.subi %c32, %51 : index
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%53 = arith.muli %arg15, %c16_i32 : i32
%54 = arith.subi %arg5, %53 : i32
%55 = arith.index_cast %54 : i32 to index
%56 = arith.minsi %55, %c16 : index
%57 = arith.maxsi %56, %c0 : index
%alloc_6 = memref.alloc() : memref<32x16xf32>
%58 = arith.cmpi slt, %57, %c16 : index
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
}
%59 = arith.minsi %51, %c32 : index
%60 = arith.subi %c32, %59 : index
%subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x64xf32>
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
}
%61 = arith.minsi %46, %c64 : index
%62 = arith.subi %c64, %61 : index
%subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
%alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_17 : memref<32x64xf32>) {
^bb0(%in: f32, %in_18: f32, %out: f32):
%65 = arith.addf %in, %in_18 : f32
linalg.yield %65 : f32
}
%63 = arith.addi %arg17, %c16 : index
%64 = arith.addi %arg18, %26 : index
scf.yield %alloc_17, %63, %64 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.maxsi %32, %14 : index
%34 = arith.subi %33, %14 : index
%35 = arith.addi %16, %c64 : index
%36 = arith.minsi %35, %22 : index
%37 = arith.maxsi %36, %16 : index
%38 = arith.subi %37, %16 : index
%39 = arith.minsi %34, %c32 : index
%40 = arith.minsi %38, %c64 : index
%subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
return
}
} The real problem appeared on the IR after Thanks for your time on this problem. We can keep this change on our branch. If you are willing to merge it, we will be very appreciate. |
thanks for your explanation. i'm ok with merging your patch as a temporary fix, but would you mind adding a comment explaining the issue with your findings regarding using tensor.empty as the out param and linking to this issue here: #196. we will likely need to find the correct fix for all cases at some point. |
@MercuryChen Sorry about the confusion, I meant updating your code to include a comment about swapping the order of the operands with a link to the issue above. But thank you for your Github comment anyway, it helps make the issue clearer. The test also now has a merge conflict, could you take a look in addition to adding the comment like I suggested above? Thanks! |
f186c1e
to
5d8e5bf
Compare
the tt.dot with accumulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs should be the matmul accumulator. This is a temporary fix for issue microsoft#196.
@nhat-nguyen Updated. Thanks! |
The tt.dot with accmulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs of add must be the accmulator.