Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend][AIE] Support global indexing, tensor access and kernel reindex for AIE kernel mapping #300

Merged
merged 24 commits into from
Feb 20, 2025

Conversation

EthanMeng324
Copy link
Contributor

Description

The original AIE backend with dataflow interface only support simple element-wise application, like vector add. It used local indexing, which is insufficient for more complex tensor access pattern, like matrix multiplication.

Problems

In the original AIE interface, the compiler lacks sufficient information to determine which dimension (M, K, or N) should be used for tiling in this GEMM implementation.

Ty = int32
M, N, K = 16, 16, 16
P0 = 1
Mt = M // P0

@df.region()
def top():
    @df.kernel(mapping=[P0])
    def gemm(A: Ty[M, K], B: Ty[K, N], C: Ty[M, N]):
        for i, j, k in allo.grid(Mt, K, N):
            C[i, j] += A[i, k] * B[k, j]

Proposed Solutions

In the new interface, we adopt global indexing similar to the HLS backend for AIE. Additionally, we leverage tensor slicing access, inspired by Triton, to enable the compiler to determine the precise access range for each tensor, facilitating backend reindexing. Furthermore, we use primitives like allo.matmul to encapsulate fundamental operations, streamlining MLIR construction and lowering processes.

Ty = int32
M, N, K = 16, 16, 16
P0 = 2
Mt = M // P0

@df.region()
def top():
    @df.kernel(mapping=[P0])
    def gemm(A: Ty[M, K], B: Ty[K, N], C: Ty[M, N]):
        pi = df.get_pid()
        C[pi * Mt: (pi + 1) * Mt, :] = allo.matmul(
            A[pi * Mt: (pi + 1) * Mt, :], B)
module {
  aie.device(npu1_2col) {
    %tile_shim = aie.tile(0, 0)
    %tile_mem0 = aie.tile(0, 1)
    %tile_mem1 = aie.tile(1, 1)
    %tile_comp0 = aie.tile(0, 2)
    %tile_comp0_buf0 = aie.buffer(%tile_comp0) : memref<8x16xi32>
    %tile_comp1 = aie.tile(0, 3)
    %tile_comp1_buf0 = aie.buffer(%tile_comp1) : memref<8x16xi32>
    aie.objectfifo @in_sh0(%tile_shim, {%tile_mem0}, 2 : i32) : !aie.objectfifo<memref<16x16xi32>>
    aie.objectfifo @in0_p0(%tile_mem0, {%tile_comp0}, 2 : i32) : !aie.objectfifo<memref<8x16xi32>>
    aie.objectfifo @in0_p1(%tile_mem0, {%tile_comp1}, 2 : i32) : !aie.objectfifo<memref<8x16xi32>>
    aie.objectfifo.link [@in_sh0] -> [@in0_p0, @in0_p1]([] [0, 128])
    aie.objectfifo @in_sh1(%tile_shim, {%tile_mem1}, 2 : i32) : !aie.objectfifo<memref<16x16xi32>>
    aie.objectfifo @in1_p0(%tile_mem1, {%tile_comp0, %tile_comp1}, 2 : i32) : !aie.objectfifo<memref<16x16xi32>>
    aie.objectfifo.link [@in_sh1] -> [@in1_p0]([] [])
    aie.objectfifo @out_p0(%tile_comp0, {%tile_mem0}, 2 : i32) : !aie.objectfifo<memref<8x16xi32>>
    aie.objectfifo @out_p1(%tile_comp1, {%tile_mem0}, 2 : i32) : !aie.objectfifo<memref<8x16xi32>>
    aie.objectfifo @out_sh(%tile_mem0, {%tile_shim}, 2 : i32) : !aie.objectfifo<memref<16x16xi32>>
    aie.objectfifo.link [@out_p0, @out_p1] -> [@out_sh]([0, 128] [])
    %core_0_2 = aie.core(%tile_comp0) {
      %c1000 = arith.constant 0 : index
      %c1001 = arith.constant 1 : index
      %c9223372036854775807 = arith.constant 9223372036854775807 : index
      scf.for %arg0 = %c1000 to %c9223372036854775807 step %c1001 {
        %fifo0 = aie.objectfifo.acquire @in0_p0(Consume, 1) : !aie.objectfifosubview<memref<8x16xi32>>
        %local0 = aie.objectfifo.subview.access %fifo0[0] : !aie.objectfifosubview<memref<8x16xi32>> -> memref<8x16xi32>
        %fifo1 = aie.objectfifo.acquire @in1_p0(Consume, 1) : !aie.objectfifosubview<memref<16x16xi32>>
        %local1 = aie.objectfifo.subview.access %fifo1[0] : !aie.objectfifosubview<memref<16x16xi32>> -> memref<16x16xi32>
        %fifo_out = aie.objectfifo.acquire @out_p0(Produce, 1) : !aie.objectfifosubview<memref<8x16xi32>>
        %local_out = aie.objectfifo.subview.access %fifo_out[0] : !aie.objectfifosubview<memref<8x16xi32>> -> memref<8x16xi32>
      %c0_i32 = arith.constant 0 : i32
      %subview = memref.subview %local0[0, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
      %c0 = arith.constant 0 : index
      %c8 = arith.constant 8 : index
      %c1 = arith.constant 1 : index
      scf.for %arg3 = %c0 to %c8 step %c1 {
        %c0_4 = arith.constant 0 : index
        %c16 = arith.constant 16 : index
        %c1_5 = arith.constant 1 : index
        scf.for %arg4 = %c0_4 to %c16 step %c1_5 {
          memref.store %c0_i32, %tile_comp0_buf0[%arg3, %arg4] : memref<8x16xi32>
        }
      }
      %c0_0 = arith.constant 0 : index
      %c8_1 = arith.constant 8 : index
      %c1_2 = arith.constant 1 : index
      scf.for %arg3 = %c0_0 to %c8_1 step %c1_2 {
        %c0_4 = arith.constant 0 : index
        %c16 = arith.constant 16 : index
        %c1_5 = arith.constant 1 : index
        scf.for %arg4 = %c0_4 to %c16 step %c1_5 {
          %c0_6 = arith.constant 0 : index
          %c16_7 = arith.constant 16 : index
          %c1_8 = arith.constant 1 : index
          scf.for %arg5 = %c0_6 to %c16_7 step %c1_8 {
            %0 = memref.load %subview[%arg3, %arg5] : memref<8x16xi32, strided<[16, 1]>>
            %1 = memref.load %local1[%arg5, %arg4] : memref<16x16xi32>
            %2 = memref.load %tile_comp0_buf0[%arg3, %arg4] : memref<8x16xi32>
            %3 = arith.muli %0, %1 : i32
            %4 = arith.addi %2, %3 : i32
            memref.store %4, %tile_comp0_buf0[%arg3, %arg4] : memref<8x16xi32>
          }
        }
      }
      %subview_3 = memref.subview %local_out[0, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
      memref.copy %tile_comp0_buf0, %subview_3 : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
        aie.objectfifo.release @in0_p0(Consume, 1)
        aie.objectfifo.release @in1_p0(Consume, 1)
        aie.objectfifo.release @out_p0(Produce, 1)
      }
      aie.end
    }
    %core_0_3 = aie.core(%tile_comp1) {
      %c1000 = arith.constant 0 : index
      %c1001 = arith.constant 1 : index
      %c9223372036854775807 = arith.constant 9223372036854775807 : index
      scf.for %arg0 = %c1000 to %c9223372036854775807 step %c1001 {
        %fifo0 = aie.objectfifo.acquire @in0_p1(Consume, 1) : !aie.objectfifosubview<memref<8x16xi32>>
        %local0 = aie.objectfifo.subview.access %fifo0[0] : !aie.objectfifosubview<memref<8x16xi32>> -> memref<8x16xi32>
        %fifo1 = aie.objectfifo.acquire @in1_p0(Consume, 1) : !aie.objectfifosubview<memref<16x16xi32>>
        %local1 = aie.objectfifo.subview.access %fifo1[0] : !aie.objectfifosubview<memref<16x16xi32>> -> memref<16x16xi32>
        %fifo_out = aie.objectfifo.acquire @out_p1(Produce, 1) : !aie.objectfifosubview<memref<8x16xi32>>
        %local_out = aie.objectfifo.subview.access %fifo_out[0] : !aie.objectfifosubview<memref<8x16xi32>> -> memref<8x16xi32>
      %c0_i32 = arith.constant 0 : i32
      %subview = memref.subview %local0[0, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
      %c0 = arith.constant 0 : index
      %c8 = arith.constant 8 : index
      %c1 = arith.constant 1 : index
      scf.for %arg3 = %c0 to %c8 step %c1 {
        %c0_4 = arith.constant 0 : index
        %c16 = arith.constant 16 : index
        %c1_5 = arith.constant 1 : index
        scf.for %arg4 = %c0_4 to %c16 step %c1_5 {
          memref.store %c0_i32, %tile_comp1_buf0[%arg3, %arg4] : memref<8x16xi32>
        }
      }
      %c0_0 = arith.constant 0 : index
      %c8_1 = arith.constant 8 : index
      %c1_2 = arith.constant 1 : index
      scf.for %arg3 = %c0_0 to %c8_1 step %c1_2 {
        %c0_4 = arith.constant 0 : index
        %c16 = arith.constant 16 : index
        %c1_5 = arith.constant 1 : index
        scf.for %arg4 = %c0_4 to %c16 step %c1_5 {
          %c0_6 = arith.constant 0 : index
          %c16_7 = arith.constant 16 : index
          %c1_8 = arith.constant 1 : index
          scf.for %arg5 = %c0_6 to %c16_7 step %c1_8 {
            %0 = memref.load %subview[%arg3, %arg5] : memref<8x16xi32, strided<[16, 1]>>
            %1 = memref.load %local1[%arg5, %arg4] : memref<16x16xi32>
            %2 = memref.load %tile_comp1_buf0[%arg3, %arg4] : memref<8x16xi32>
            %3 = arith.muli %0, %1 : i32
            %4 = arith.addi %2, %3 : i32
            memref.store %4, %tile_comp1_buf0[%arg3, %arg4] : memref<8x16xi32>
          }
        }
      }
      %subview_3 = memref.subview %local_out[0, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
      memref.copy %tile_comp1_buf0, %subview_3 : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1]>>
        aie.objectfifo.release @in0_p1(Consume, 1)
        aie.objectfifo.release @in1_p0(Consume, 1)
        aie.objectfifo.release @out_p1(Produce, 1)
      }
      aie.end
    }
    aiex.runtime_sequence(%arg0: memref<16x16xi32>, %arg1: memref<16x16xi32>, %arg2: memref<16x16xi32>) {
      aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][1, 1, 16, 16][0, 0, 16, 1]) {id = 1 : i64, issue_token = true, metadata = @in_sh0} : memref<16x16xi32>
      aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 0][1, 1, 32, 16][0, 0, 16, 1]) {id = 2 : i64, issue_token = true, metadata = @in_sh1} : memref<16x16xi32>
      aiex.npu.dma_memcpy_nd(0, 0, %arg2[0, 0, 0, 0][1, 1, 16, 16][0, 0, 16, 1]) {id = 0 : i64, metadata = @out_sh} : memref<16x16xi32>
      aiex.npu.dma_wait {symbol = @in_sh0}
      aiex.npu.dma_wait {symbol = @in_sh1}
      aiex.npu.dma_wait {symbol = @out_sh}
    }
  }
}

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [IR], [Builder], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage (It would be better to provide ~2 different test cases to test the robustness of your code)
  • Code is well-documented

Copy link
Member

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good to me. We need more comments for your implementation so that it is easier for later development.

@@ -221,86 +237,153 @@ def codegen_aie_mlir(mod, orig_input_args, mapping):
code += format_str("%tile_shim = aie.tile(0, 0)")
for mid in range(mem_tile_size):
code += format_str(f"%tile_mem{mid} = aie.tile({mid}, 1)")
assert len(mapping) == 1, "Only support 1D mapping for now"
pe_size = mapping[0]
# assert len(mapping) == 1, "Only support 1D mapping for now"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support 2D now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. I think I will support it in the next PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you removed the mapping argument, you should also remove the assertion here

code += format_str(
f"aie.objectfifo @in_sh{i}(%tile_shim, {{%tile_mem{i}}}, 2 : i32) : !aie.objectfifo<{orig_in_type}>"
)
linkings = [False] * len(input_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment for linkings

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on your description, a clearer name might be dist_alloc, which reflects that when set to True the memory is distributed among compute tiles, and when False it is replicated to each tile.

Comment on lines 357 to 358
code += format_str("%c1000 = arith.constant 0 : index")
code += format_str("%c1001 = arith.constant 1 : index")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using 1000 and 1001?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is going to be some conflict when using c0 and c1 when c0 and c1 are also used in the actual computation. So, I chose two large numbers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a better name then. Probably %global_c0 and %global_c1

Copy link
Member

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some naming issues

Comment on lines 357 to 358
code += format_str("%c1000 = arith.constant 0 : index")
code += format_str("%c1001 = arith.constant 1 : index")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a better name then. Probably %global_c0 and %global_c1

code += format_str(
f"aie.objectfifo @in_sh{i}(%tile_shim, {{%tile_mem{i}}}, 2 : i32) : !aie.objectfifo<{orig_in_type}>"
)
linkings = [False] * len(input_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on your description, a clearer name might be dist_alloc, which reflects that when set to True the memory is distributed among compute tiles, and when False it is replicated to each tile.

@@ -221,86 +237,153 @@ def codegen_aie_mlir(mod, orig_input_args, mapping):
code += format_str("%tile_shim = aie.tile(0, 0)")
for mid in range(mem_tile_size):
code += format_str(f"%tile_mem{mid} = aie.tile({mid}, 1)")
assert len(mapping) == 1, "Only support 1D mapping for now"
pe_size = mapping[0]
# assert len(mapping) == 1, "Only support 1D mapping for now"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you removed the mapping argument, you should also remove the assertion here

Copy link
Member

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thx!

@chhzh123 chhzh123 merged commit 2f4c197 into cornell-zhang:main Feb 20, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants