Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend][AIE] Support global indexing, tensor access and kernel reindex for AIE kernel mapping #300

Merged
merged 24 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 314 additions & 58 deletions allo/backend/aie.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion allo/backend/llvm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-name-in-module, inconsistent-return-statements
# pylint: disable=no-name-in-module, inconsistent-return-statements, too-many-function-args

import os
import ctypes
Expand Down
37 changes: 20 additions & 17 deletions allo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,22 @@ def remove_unused_func_ops(s, func_names):
func_op.erase()


def _build_top(s, stream_info):
def _build_top(s, stream_info, enable_tensor):
"""
s: top-level schedule
stream_info: {func_name: [(stream_names, direction)]}
"""
# remove unused kernel
passes = ["canonicalize"]
pipeline = f'builtin.module(func.func({",".join(passes)}))'
try:
with s.module.context:
mlir_pass_manager.parse(pipeline).run(s.module.operation)
except Exception as e:
print("Error: failed to run MLIR lower pipeline, printing module...")
print(s.module)
raise e
if not enable_tensor:
passes = ["canonicalize"]
pipeline = f'builtin.module(func.func({",".join(passes)}))'
try:
with s.module.context:
mlir_pass_manager.parse(pipeline).run(s.module.operation)
except Exception as e:
print("Error: failed to run MLIR lower pipeline, printing module...")
print(s.module)
raise e
remove_unused_func_ops(s, stream_info.keys())

# create argument mapping
Expand Down Expand Up @@ -237,11 +238,11 @@ def df_primitive_default(s):
df_pipeline(s.module, rewind=True)


def customize(func, opt_default=True):
def customize(func, opt_default=True, enable_tensor=False):
global_vars = get_global_vars(func)
s = _customize(func, global_vars=global_vars)
s = _customize(func, global_vars=global_vars, enable_tensor=enable_tensor)
stream_info = move_stream_to_interface(s)
s = _build_top(s, stream_info)
s = _build_top(s, stream_info, enable_tensor)

if opt_default:
df_primitive_default(s)
Expand All @@ -257,16 +258,18 @@ def build(
configs=None,
wrap_io=True,
opt_default=True,
enable_tensor=False,
):
if target == "aie":
global_vars = get_global_vars(func)
s = _customize(func, global_vars=global_vars)
mapping = func.mapping
mod = AIEModule(s.module, s.top_func_name, project, mapping)
s = _customize(func, global_vars=global_vars, enable_tensor=True)
# stream_info = move_stream_to_interface(s)
# s = _build_top(s, stream_info, enable_tensor)
mod = AIEModule(s.module, s.top_func_name, project)
mod.build()
return mod
# FPGA backend
s = customize(func, opt_default)
s = customize(func, opt_default, enable_tensor=enable_tensor)
hls_mod = s.build(
target=target,
mode=mode,
Expand Down
4 changes: 4 additions & 0 deletions allo/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def sub(lhs, rhs, name=None):
return lhs - rhs


def mul(lhs, rhs, name=None):
return lhs * rhs


def div(lhs, rhs, name=None):
return lhs / rhs

Expand Down
68 changes: 64 additions & 4 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,22 +980,64 @@ def build_affine_expr(ctx, node):
@staticmethod
def build_slices(ctx, node, in_shape):
# caculate the static offsets, sizes, strides for ExtractSlice and InsertSlice
slices = node.slice.dims
slices = node.slice.dims if len(node.shape) > 1 else [node.slice]
static_offsets = []
static_sizes = []
static_strides = []
offsets = []
sizes = []
# Not support dynamic strides?
for index, size in zip(slices, in_shape):
if isinstance(index, ast.Slice):
lower = 0 if index.lower is None else build_stmt(ctx, index.lower).val
lower = (
0
if index.lower is None
else ASTResolver.resolve_constant(index.lower, ctx)
)
upper = (
size if index.upper is None else build_stmt(ctx, index.upper).val
size
if index.upper is None
else ASTResolver.resolve_constant(index.upper, ctx)
)
if index.step is None:
step = 1
elif isinstance(index.step, ast.Constant):
step = index.step.value
else:
raise RuntimeError("Unsupported step type")
if lower is None:
static_offsets.append(ShapedType.get_dynamic_size())
offset_expr = build_stmt(ctx, index.lower)
offset = ASTTransformer.build_cast_op(
ctx, offset_expr, index.dtype, Index()
).result
offsets.append(offset)
static_sizes.append(ShapedType.get_dynamic_size())
if upper is None:
upper_expr = build_stmt(ctx, index.upper)
size_expr = tensor_d.FloorDivSOp(
tensor_d.SubOp(upper_expr, offset_expr).result, step
)
else:
size_expr = tensor_d.FloorDivSOp(
tensor_d.SubOp(upper, offset_expr).result, step
)
size = ASTTransformer.build_cast_op(
ctx, size_expr, index.dtype, Index()
).result
sizes.append(size)
continue
if upper is None:
static_sizes.append(ShapedType.get_dynamic_size())
upper_expr = build_stmt(ctx, index.upper)
size_expr = tensor_d.FloorDivSOp(
tensor_d.SubOp(upper_expr, lower).result, step
)
size = ASTTransformer.build_cast_op(
ctx, size_expr, index.dtype, Index()
).result
sizes.append(size)
continue
elif isinstance(index, (ast.Index, ast.Constant)):
lower = (
index.value.value if isinstance(index, ast.Index) else index.value
Expand All @@ -1019,7 +1061,7 @@ def build_slices(ctx, node, in_shape):
def build_tensor_access(ctx, node, val=None, idx=0):
# TODO: Fix tuple idx
value = build_stmt(ctx, node.value)
if len(node.shape) > 1:
if len(node.shape) >= 1:
dtype = RankedTensorType(value.result.type).element_type
in_shape = RankedTensorType(value.result.type).shape
(
Expand Down Expand Up @@ -1933,6 +1975,7 @@ def build_Call(ctx, node):
"log",
"add",
"sub",
"mul",
"div",
"relu",
"conv2d",
Expand All @@ -1944,6 +1987,23 @@ def build_Call(ctx, node):
"view",
"concat",
}:
if fn_name in {"add", "sub", "mul", "div"}:
new_args[0] = ASTTransformer.build_broadcast_op(
ctx,
new_args[0],
node.dtype,
node.args[0].shape,
node.shape,
node.dims[0],
)
new_args[1] = ASTTransformer.build_broadcast_op(
ctx,
new_args[1],
node.dtype,
node.args[1].shape,
node.shape,
node.dims[1],
)
return ASTTransformer.build_library_op(
ctx, node=node, attr=fn_name, new_args=new_args
)
Expand Down
14 changes: 9 additions & 5 deletions allo/ir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,16 +892,20 @@ def visit_library_op(ctx, node, op_name, new_args):
"log",
"add",
"sub",
"mul",
"div",
"relu",
"copy",
}:
# Element-wise operation
if op_name in {"add", "sub", "div"}:
assert (
new_args[0].shape == new_args[1].shape
), f"Only support element-wise {op_name} of two inputs with the same shape, got {new_args[0].shape} and {new_args[1].shape}"
node.shape = new_args[0].shape
if op_name in {"add", "sub", "mul", "div"}:
final_shape, lhs_dims, rhs_dims = TypeInferer.visit_broadcast(
ctx, new_args[0], new_args[1]
)
node.dims = (lhs_dims, rhs_dims)
node.shape = final_shape
else:
node.shape = new_args[0].shape
node.dtype = new_args[0].dtype
return node
if op_name in {"matmul", "bmm", "linear", "conv2d", "sumpool", "maxpool"}:
Expand Down
10 changes: 10 additions & 0 deletions allo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,13 @@ def extract_out_np_arrays_from_out_struct(out_struct_ptr_ptr, num_output):
ranked_memref_to_numpy(getattr(out_struct_ptr_ptr[0][0], f"memref{i}"))
)
return out_np_arrays


def get_element_type_from_str(element_type_str, context):
if element_type_str.startswith("f"):
bits = int(element_type_str[1:])
return F32Type.get(context) if bits == 32 else F64Type.get(context)
if element_type_str.startswith("i"):
bits = int(element_type_str[1:])
return IntegerType.get_signless(bits, context)
raise ValueError(f"unknown element_type_str: {element_type_str}")
35 changes: 35 additions & 0 deletions tests/dataflow/aie/test_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import allo.dataflow as df
from allo.ir.types import int32
import numpy as np
import allo


def _test_gemm():
Ty = int32
M, N, K = 16, 16, 16
P0 = 2
Mt = M // P0

@df.region()
def top():
@df.kernel(mapping=[P0])
def gemm(A: Ty[M, K], B: Ty[K, N], C: Ty[M, N]):
pi = df.get_pid()
C[pi * Mt : (pi + 1) * Mt, :] = allo.matmul(
A[pi * Mt : (pi + 1) * Mt, :], B
)

mod = df.build(top, target="aie")
A = np.random.randint(0, 64, (M, K)).astype(np.int32)
B = np.random.randint(0, 64, (K, N)).astype(np.int32)
C = np.zeros((M, N)).astype(np.int32)
mod(A, B, C)
np.testing.assert_allclose(C, A @ B, atol=1e-5)
print("PASSED!")


if __name__ == "__main__":
_test_gemm()
36 changes: 22 additions & 14 deletions tests/dataflow/aie/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@ def _test_matrix_scalar_add():
Ty = int32
M, N = 64, 64
P0 = 4
Mt = M // P0

@df.kernel(mapping=[P0])
def core(A: Ty[M, N], B: Ty[M, N]):
for i, j in allo.grid(M // P0, N):
B[i, j] = A[i, j] + 1
@df.region()
def top():
@df.kernel(mapping=[P0])
def core(A: Ty[M, N], B: Ty[M, N]):
pi = df.get_pid()
B[pi * Mt : (pi + 1) * Mt, :] = allo.add(A[pi * Mt : (pi + 1) * Mt, :], 1)

top = df.build(core, target="aie")
mod = df.build(top, target="aie")
A = np.random.randint(0, 100, (M, N)).astype(np.int32)
B = np.zeros((M, N)).astype(np.int32)
top(A, B)
mod(A, B)
np.testing.assert_allclose(B, A + 1)
print("PASSED!")

Expand All @@ -29,17 +32,22 @@ def _test_matrix_matrix_add():
Ty = int32
M, N = 64, 64
P0 = 4

@df.kernel(mapping=[P0])
def core(A: Ty[M, N], B: Ty[M, N], C: Ty[M, N]):
for i, j in allo.grid(M // P0, N):
C[i, j] = A[i, j] + B[i, j]

top = df.build(core, target="aie")
Mt = M // P0

@df.region()
def top():
@df.kernel(mapping=[P0])
def core(A: Ty[M, N], B: Ty[M, N], C: Ty[M, N]):
pi = df.get_pid()
C[pi * Mt : (pi + 1) * Mt, :] = allo.add(
A[pi * Mt : (pi + 1) * Mt, :], B[pi * Mt : (pi + 1) * Mt, :]
)

mod = df.build(top, target="aie")
A = np.random.randint(0, 100, (M, N)).astype(np.int32)
B = np.random.randint(0, 100, (M, N)).astype(np.int32)
C = np.zeros((M, N)).astype(np.int32)
top(A, B, C)
mod(A, B, C)
np.testing.assert_allclose(C, A + B)
print("PASSED!")

Expand Down
38 changes: 23 additions & 15 deletions tests/dataflow/aie/test_multi_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import allo
from allo.ir.types import int32, float32
from allo.ir.types import int32
import allo.dataflow as df
import numpy as np

Expand All @@ -13,18 +13,21 @@ def _test_vector_scalar_add():
# v v-------------------------v v
# shim tile <-> mem tile <-> comp tile0 comp tile1 comp tile2
Ty = int32
M = 48
P0 = 3
M = 1024
P0 = 4
Mt = M // P0

@df.kernel(mapping=[P0])
def core(A: Ty[M], B: Ty[M]):
for i in range(M // P0):
B[i] = A[i] + 1
@df.region()
def top():
@df.kernel(mapping=[P0])
def core(A: Ty[M], B: Ty[M]):
pi = df.get_pid()
B[pi * Mt : (pi + 1) * Mt] = allo.add(A[pi * Mt : (pi + 1) * Mt], 1)

top = df.build(core, target="aie")
mod = df.build(top, target="aie")
A = np.random.randint(0, 100, M).astype(np.int32)
B = np.zeros(M).astype(np.int32)
top(A, B)
mod(A, B)
np.testing.assert_allclose(B, A + 1)
print("PASSED!")

Expand All @@ -37,17 +40,22 @@ def _test_vector_vector_add():
Ty = int32
M = 1024
P0 = 4
Mt = M // P0

@df.kernel(mapping=[P0])
def core(A: Ty[M], B: Ty[M], C: Ty[M]):
for i in range(M // P0):
C[i] = A[i] + B[i]
@df.region()
def top():
@df.kernel(mapping=[P0])
def core(A: Ty[M], B: Ty[M], C: Ty[M]):
pi = df.get_pid()
C[pi * Mt : (pi + 1) * Mt] = allo.add(
A[pi * Mt : (pi + 1) * Mt], B[pi * Mt : (pi + 1) * Mt]
)

top = df.build(core, target="aie")
mod = df.build(top, target="aie")
A = np.random.randint(0, 100, M).astype(np.int32)
B = np.random.randint(0, 100, M).astype(np.int32)
C = np.zeros(M).astype(np.int32)
top(A, B, C)
mod(A, B, C)
np.testing.assert_allclose(C, A + B)
print("PASSED!")

Expand Down
Loading
Loading