Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust the operand order of tt.dot to linalg.matmul converter #191

Merged
merged 1 commit into from
Dec 30, 2024

Conversation

MercuryChen
Copy link
Contributor

The tt.dot with accmulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs of add must be the accmulator.

@MercuryChen MercuryChen changed the title Adjust the operand order of tt.dot to linalg,matmul converter Adjust the operand order of tt.dot to linalg.matmul converter Nov 21, 2024
@nhat-nguyen
Copy link
Collaborator

thanks for your pr! could you help me understand why swapping the order is necessary? does having the rhs as dps init like we have currently introduce any incorrect codegen?

@MercuryChen
Copy link
Contributor Author

This is the FileCheck case of dot
The linalg.matmul output(partial sum) is %VAR_6_, the accumulator is %VAR_3_, so the accumulate operation should like %out = linalg.map {arith.addf} ins(%VAR_6_, %VAR_3_) outs(%VAR_3_) ..., instead of %VAR_6_ as the dps init.

Arith to linalg code
Above code shows why the lhs will take as the dps init.

I don't know why this bug does not cause triton-shared cpu backend got mismatch error, but in our customed SPIRV backend, it did.
Thanks!

@MercuryChen
Copy link
Contributor Author

MercuryChen commented Nov 22, 2024 via email

@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Nov 26, 2024

@MercuryChen Would you be able to share the buggy IR after the bufferization pass before and after your changes? I'm ok with the changes but just want to understand the problem a bit better. My understanding is there should be no difference in semantic between the two orderings. At the tensor level, tensors are not "rewritten" but rather created fresh every time. It is only after bufferization that buffers are assigned. In addition, the out params are only used for value initialization, but the body of this particular linalg.generic {addf} does not involve the out value in any way.

@MercuryChen
Copy link
Contributor Author

MercuryChen commented Nov 26, 2024

Thanks for your reply!
The IR after change:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %41 = arith.addi %arg18, %16 : index
      %42 = arith.remsi %41, %22 : index
      %43 = arith.subi %41, %42 : index
      %44 = arith.addi %42, %c64 : index
      %45 = arith.minsi %44, %22 : index
      %46 = arith.subi %45, %42 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %47 = arith.subi %c64, %46 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %48 = arith.remsi %arg17, %18 : index
      %49 = arith.addi %20, %48 : index
      %50 = arith.subi %49, %arg17 : index
      %51 = arith.divsi %50, %18 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %52 = arith.subi %c32, %51 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %53 = arith.muli %arg15, %c16_i32 : i32
      %54 = arith.subi %arg5, %53 : i32
      %55 = arith.index_cast %54 : i32 to index
      %56 = arith.minsi %55, %c16 : index
      %57 = arith.maxsi %56, %c0 : index
      %alloc_6 = memref.alloc() : memref<32x16xf32>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
      }
      %59 = arith.minsi %51, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %alloc_11 = memref.alloc() : memref<16x64xf32>
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
      }
      %61 = arith.minsi %46, %c64 : index
      %62 = arith.subi %c64, %61 : index
      %subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
      linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%arg16: memref<32x64xf32>) {
      ^bb0(%in: f32, %in_17: f32, %out: f32):
        %65 = arith.addf %in, %in_17 : f32
        linalg.yield %65 : f32
      }
      %63 = arith.addi %arg17, %c16 : index
      %64 = arith.addi %arg18, %26 : index
      scf.yield %alloc_16, %63, %64 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.maxsi %32, %14 : index
    %34 = arith.subi %33, %14 : index
    %35 = arith.addi %16, %c64 : index
    %36 = arith.minsi %35, %22 : index
    %37 = arith.maxsi %36, %16 : index
    %38 = arith.subi %37, %16 : index
    %39 = arith.minsi %34, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
    %subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    return
  }
}

The IR before change:

...
// just replace the `out` from %alloc_16 to %arg16
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_16, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_16: memref<32x64xf32>) {
      ^bb0(%in: f32, %in_17: f32, %out: f32):
        %65 = arith.addf %in, %in_17 : f32
        linalg.yield %65 : f32
      }
...

And the correct IR before bufferize:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %0 = tensor.empty() : tensor<32x64xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
    %2 = arith.addi %arg3, %c31_i32 : i32
    %3 = arith.divsi %2, %c32_i32 : i32
    %4 = arith.addi %arg4, %c63_i32 : i32
    %5 = arith.divsi %4, %c64_i32 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.divsi %arg12, %6 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = arith.subi %3, %8 : i32
    %10 = arith.minsi %9, %c8_i32 : i32
    %11 = arith.remsi %arg12, %10 : i32
    %12 = arith.addi %8, %11 : i32
    %13 = arith.remsi %arg12, %6 : i32
    %14 = arith.divsi %13, %10 : i32
    %15 = arith.muli %12, %c32_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.muli %14, %c64_i32 : i32
    %18 = arith.index_cast %17 : i32 to index
    %19 = arith.index_cast %arg3 : i32 to index
    %20 = arith.index_cast %arg6 : i32 to index
    %21 = arith.muli %16, %20 : index
    %22 = arith.muli %19, %20 : index
    %23 = arith.index_cast %arg7 : i32 to index
    %24 = arith.index_cast %arg4 : i32 to index
    %25 = arith.addi %arg5, %c15_i32 : i32
    %26 = arith.divsi %25, %c16_i32 : i32
    %27 = arith.muli %arg7, %c16_i32 : i32
    %28 = arith.index_cast %27 : i32 to index
    %29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %43 = arith.addi %arg18, %18 : index
      %44 = arith.remsi %43, %24 : index
      %45 = arith.subi %43, %44 : index
      %46 = arith.addi %44, %c64 : index
      %47 = arith.minsi %46, %24 : index
      %48 = arith.subi %47, %44 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %49 = arith.subi %c64, %48 : index
      %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %50 = arith.remsi %arg17, %20 : index
      %51 = arith.addi %22, %50 : index
      %52 = arith.subi %51, %arg17 : index
      %53 = arith.divsi %52, %20 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %54 = arith.subi %c32, %53 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %55 = arith.muli %arg15, %c16_i32 : i32
      %56 = arith.subi %arg5, %55 : i32
      %57 = arith.index_cast %56 : i32 to index
      %58 = arith.minsi %57, %c16 : index
      %59 = arith.maxsi %58, %c0 : index
      %alloc = memref.alloc() : memref<32x16xf32>
      %60 = arith.cmpi slt, %59, %c16 : index
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
      }
      %61 = arith.minsi %53, %c32 : index
      %62 = arith.subi %c32, %61 : index
      %subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
      %alloc_8 = memref.alloc() : memref<16x64xf32>
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
      }
      %64 = arith.minsi %48, %c64 : index
      %65 = arith.subi %c64, %64 : index
      %subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
      %67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
      %68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%arg16 : tensor<32x64xf32>) {
      ^bb0(%in: f32, %in_13: f32, %out: f32):
        %71 = arith.addf %in, %in_13 : f32
        linalg.yield %71 : f32
      } -> tensor<32x64xf32>
      %69 = arith.addi %arg17, %c16 : index
      %70 = arith.addi %arg18, %28 : index
      scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
    }
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %16, %30 : index
    %32 = arith.addi %31, %18 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %33 = arith.addi %16, %c32 : index
    %34 = arith.minsi %33, %19 : index
    %35 = arith.maxsi %34, %16 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.addi %18, %c64 : index
    %38 = arith.minsi %37, %24 : index
    %39 = arith.maxsi %38, %18 : index
    %40 = arith.subi %39, %18 : index
    %41 = arith.minsi %36, %c32 : index
    %42 = arith.minsi %40, %c64 : index
    %extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
    %subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
    return
  }
}

As the IR shows, addf will involve the out of linalg.generic. I agree with the

tensors are not "rewritten" but rather created fresh every time

In the tensor semantic, it's SSA, so maybe the proper way is create an new empty tensor for the DPS init.

@nhat-nguyen
Copy link
Collaborator

Thanks for your quick response. Would you mind also sharing the IR before bufferization too? Sorry I should have asked in the previous reply. 😄

@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Nov 26, 2024

In the tensor semantic, it's SSA, so maybe the proper way is create an new empty tensor for the DPS init.

This seems like a good idea to try too. I would be curious to see what the bufferization result looks like if you manually create a linalg.generic {addf} with an empty tensor as its out.

@MercuryChen
Copy link
Contributor Author

Sorry for my mistake accidentally close the PR.
I update the related test cases, I think the CI could pass now.

To use the new empty tensor as out of linalg.generic, we need insert another linalg.copy op to copy the out memref in to the "accumulator" memref. Which will increase unnecessary local memory cost and copy, so current implementation of arith to linalg is better, I'd prefer current solution.

@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Nov 27, 2024

Sorry for my mistake accidentally close the PR. I update the related test cases, I think the CI could pass now.

To use the new empty tensor as out of linalg.generic, we need insert another linalg.copy op to copy the out memref in to the "accumulator" memref. Which will increase unnecessary local memory cost and copy, so current implementation of arith to linalg is better, I'd prefer current solution.

What if you simply use an empty tensor as out without explicitly inserting linalg.copy? What does the bufferization output look like? The reason why I'm asking this question is there seems to be a fundamental issue with how we use the out params which could produce incorrect codegen in various scenarios. While swapping the order of the two operands would fix this matmul bug, it would be great if we can figure out the correct bufferization behaviour.

@MercuryChen
Copy link
Contributor Author

Yes, do not need insert linalg.copy, because the result of linalg.generic alias with dps init, I forget this rule. So just insert a tensor.empty works will.
init with empty tensor, before bufferization:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %0 = tensor.empty() : tensor<32x64xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
    %2 = arith.addi %arg3, %c31_i32 : i32
    %3 = arith.divsi %2, %c32_i32 : i32
    %4 = arith.addi %arg4, %c63_i32 : i32
    %5 = arith.divsi %4, %c64_i32 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.divsi %arg12, %6 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = arith.subi %3, %8 : i32
    %10 = arith.minsi %9, %c8_i32 : i32
    %11 = arith.remsi %arg12, %10 : i32
    %12 = arith.addi %8, %11 : i32
    %13 = arith.remsi %arg12, %6 : i32
    %14 = arith.divsi %13, %10 : i32
    %15 = arith.muli %12, %c32_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.muli %14, %c64_i32 : i32
    %18 = arith.index_cast %17 : i32 to index
    %19 = arith.index_cast %arg3 : i32 to index
    %20 = arith.index_cast %arg6 : i32 to index
    %21 = arith.muli %16, %20 : index
    %22 = arith.muli %19, %20 : index
    %23 = arith.index_cast %arg7 : i32 to index
    %24 = arith.index_cast %arg4 : i32 to index
    %25 = arith.addi %arg5, %c15_i32 : i32
    %26 = arith.divsi %25, %c16_i32 : i32
    %27 = arith.muli %arg7, %c16_i32 : i32
    %28 = arith.index_cast %27 : i32 to index
    %29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %43 = arith.addi %arg18, %18 : index
      %44 = arith.remsi %43, %24 : index
      %45 = arith.subi %43, %44 : index
      %46 = arith.addi %44, %c64 : index
      %47 = arith.minsi %46, %24 : index
      %48 = arith.subi %47, %44 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %49 = arith.subi %c64, %48 : index
      %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %50 = arith.remsi %arg17, %20 : index
      %51 = arith.addi %22, %50 : index
      %52 = arith.subi %51, %arg17 : index
      %53 = arith.divsi %52, %20 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %54 = arith.subi %c32, %53 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %55 = arith.muli %arg15, %c16_i32 : i32
      %56 = arith.subi %arg5, %55 : i32
      %57 = arith.index_cast %56 : i32 to index
      %58 = arith.minsi %57, %c16 : index
      %59 = arith.maxsi %58, %c0 : index
      %alloc = memref.alloc() : memref<32x16xf32>
      %60 = arith.cmpi slt, %59, %c16 : index
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
      }
      %61 = arith.minsi %53, %c32 : index
      %62 = arith.subi %c32, %61 : index
      %subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
      %alloc_8 = memref.alloc() : memref<16x64xf32>
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
      }
      %64 = arith.minsi %48, %c64 : index
      %65 = arith.subi %c64, %64 : index
      %subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
      %67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
      %manual_empty = tensor.empty() : tensor<32x64xf32>
      %68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%manual_empty : tensor<32x64xf32>) {
      ^bb0(%in: f32, %in_13: f32, %out: f32):
        %71 = arith.addf %in, %in_13 : f32
        linalg.yield %71 : f32
      } -> tensor<32x64xf32>
      %69 = arith.addi %arg17, %c16 : index
      %70 = arith.addi %arg18, %28 : index
      scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
    }
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %16, %30 : index
    %32 = arith.addi %31, %18 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %33 = arith.addi %16, %c32 : index
    %34 = arith.minsi %33, %19 : index
    %35 = arith.maxsi %34, %16 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.addi %18, %c64 : index
    %38 = arith.minsi %37, %24 : index
    %39 = arith.maxsi %38, %18 : index
    %40 = arith.subi %39, %18 : index
    %41 = arith.minsi %36, %c32 : index
    %42 = arith.minsi %40, %c64 : index
    %extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
    %subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
    return
  }
}

After bufferization:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %41 = arith.addi %arg18, %16 : index
      %42 = arith.remsi %41, %22 : index
      %43 = arith.subi %41, %42 : index
      %44 = arith.addi %42, %c64 : index
      %45 = arith.minsi %44, %22 : index
      %46 = arith.subi %45, %42 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %47 = arith.subi %c64, %46 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %48 = arith.remsi %arg17, %18 : index
      %49 = arith.addi %20, %48 : index
      %50 = arith.subi %49, %arg17 : index
      %51 = arith.divsi %50, %18 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %52 = arith.subi %c32, %51 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %53 = arith.muli %arg15, %c16_i32 : i32
      %54 = arith.subi %arg5, %53 : i32
      %55 = arith.index_cast %54 : i32 to index
      %56 = arith.minsi %55, %c16 : index
      %57 = arith.maxsi %56, %c0 : index
      %alloc_6 = memref.alloc() : memref<32x16xf32>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
      }
      %59 = arith.minsi %51, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %alloc_11 = memref.alloc() : memref<16x64xf32>
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
      }
      %61 = arith.minsi %46, %c64 : index
      %62 = arith.subi %c64, %61 : index
      %subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
      linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
      %alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_17 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_18: f32, %out: f32):
        %65 = arith.addf %in, %in_18 : f32
        linalg.yield %65 : f32
      }
      %63 = arith.addi %arg17, %c16 : index
      %64 = arith.addi %arg18, %26 : index
      scf.yield %alloc_17, %63, %64 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.maxsi %32, %14 : index
    %34 = arith.subi %33, %14 : index
    %35 = arith.addi %16, %c64 : index
    %36 = arith.minsi %35, %22 : index
    %37 = arith.maxsi %36, %16 : index
    %38 = arith.subi %37, %16 : index
    %39 = arith.minsi %34, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
    %subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    return
  }
}

The real problem appeared on the IR after convert-scf-to-cf and lower cf to spirv. The origin mlir-opt not support lower cf with dynamic memref argument, and our lowering implementation not robust enough, so we need this PR's change.

Thanks for your time on this problem. We can keep this change on our branch. If you are willing to merge it, we will be very appreciate.

@nhat-nguyen
Copy link
Collaborator

thanks for your explanation. i'm ok with merging your patch as a temporary fix, but would you mind adding a comment explaining the issue with your findings regarding using tensor.empty as the out param and linking to this issue here: #196. we will likely need to find the correct fix for all cases at some point.

@nhat-nguyen
Copy link
Collaborator

@MercuryChen Sorry about the confusion, I meant updating your code to include a comment about swapping the order of the operands with a link to the issue above. But thank you for your Github comment anyway, it helps make the issue clearer. The test also now has a merge conflict, could you take a look in addition to adding the comment like I suggested above? Thanks!

@MercuryChen MercuryChen force-pushed the main branch 2 times, most recently from f186c1e to 5d8e5bf Compare December 5, 2024 02:19
the tt.dot with accumulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic,
generic will take the lhs of add as the DPS init, so the lhs should be
the matmul accumulator. This is a temporary fix for issue microsoft#196.
@MercuryChen
Copy link
Contributor Author

@nhat-nguyen Updated. Thanks!

@nhat-nguyen nhat-nguyen merged commit d9933bb into microsoft:main Dec 30, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants