diff --git a/allo/backend/aie.py b/allo/backend/aie.py index 80d62244..423216b1 100644 --- a/allo/backend/aie.py +++ b/allo/backend/aie.py @@ -1,15 +1,31 @@ # Copyright Allo authors. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # mlir-aie commit: 8329b6 -# pylint: disable=consider-using-with, bad-builtin +# pylint: disable=consider-using-with, bad-builtin, no-name-in-module, too-many-branches, pointless-string-statement import os import subprocess +import re import numpy as np +from .._mlir.ir import ( + IntegerAttr, + IntegerType, + DenseI64ArrayAttr, + Context, + RankedTensorType, + FunctionType, + TypeAttr, + Location, +) +from .._mlir.dialects import func as func_d +from .._mlir.passmanager import PassManager as mlir_pass_manager from .vitis import read_tensor_from_file -from ..ir.transform import find_func_in_module -from ..utils import get_func_inputs_outputs +from ..utils import ( + get_func_inputs_outputs, + get_dtype_and_shape_from_type, + get_element_type_from_str, +) from .utils import format_str, format_code from .vitis import ctype_map @@ -211,7 +227,29 @@ def codegen_host(input_args): return code -def codegen_aie_mlir(mod, orig_input_args, mapping): +def codegen_aie_mlir(mod, orig_input_args, func_arg_sizes, func_buf_dicts): + """ + Generates MLIR-AIE code with MLIR module and extra information + + Parameters + ---------- + mod: allo._mlir.ir.Module + The MLIR module built by allo. + + orig_input_args: List[Tuple[str, List[int]]] + The original types of the argument of the function. + Each element in the list stands for the type of each argument. e.g. ('i32', [16, 16]). + For current version, we assume all function in the module have the same input arguments. + + func_arg_sizes: List[List[List[int]]] + The actual size of each argument that each function will be using. + The first dim stands for each function, the second dim stands for each argument, the last dim stands for shape. + + func_buf_dicts: List[Dict[str, Tuple[str, List[int]]]] + The local buffer each function creates. + Each function in the list is a dictionary, where the key is the name of the buffer, and the value is the type of + the element. e.g. ('i32', [16, 16]). + """ input_args = orig_input_args.copy() code = format_str("module {", indent=0) mem_tile_size = 2 if len(input_args) > 2 else 1 @@ -221,86 +259,162 @@ def codegen_aie_mlir(mod, orig_input_args, mapping): code += format_str("%tile_shim = aie.tile(0, 0)") for mid in range(mem_tile_size): code += format_str(f"%tile_mem{mid} = aie.tile({mid}, 1)") - assert len(mapping) == 1, "Only support 1D mapping for now" - pe_size = mapping[0] + # TODO: maybe use name of the function to support 2D? + # number of function declaration except top + funcs = list(mod.body.operations)[:-1] + pe_size = len(funcs) + buf_name_dicts = [] for pid in range(pe_size): code += format_str(f"%tile_comp{pid} = aie.tile(0, {pid + 2})") + buf_dict = func_buf_dicts[pid] + buf_name_dict = {} + for i, name in enumerate(buf_dict.keys()): + tile_name = f"%tile_comp{pid}" + new_name = f"{tile_name}_buf{i}" + buf_name_dict[name] = new_name + ele_type, shape = buf_dict[name] + str_list = list(map(str, shape)) + str_list.append(ele_type) + buf_type = f"memref<{'x'.join(map(str, str_list))}>" + code += format_str(f"{new_name} = aie.buffer({tile_name}) : {buf_type}") + buf_name_dicts.append(buf_name_dict) # update module and args - mod_str = str(mod) - for i, (ele_type, shape) in enumerate(input_args): - orig_ele_type = f"memref<{'x'.join(map(str, shape))}x{ele_type}>" - shape = (shape[0] // pe_size, *shape[1:]) + for j, (ele_type, orig_shape) in enumerate(input_args): + orig_ele_type = f"memref<{'x'.join(map(str, orig_shape))}x{ele_type}>" + # TODO: need to deal with different sizes for different funcs + shape = func_arg_sizes[0][j] ele_type = f"memref<{'x'.join(map(str, shape))}x{ele_type}>" - input_args[i] = (ele_type, orig_ele_type, shape) - mod_str = mod_str.replace(orig_ele_type, ele_type) + input_args[j] = (ele_type, orig_ele_type, shape, orig_shape) + func_strs = list(map(str, funcs)) + # update buffers + for pid in range(pe_size): + func_str = func_strs[pid] + buf_name_dict = buf_name_dicts[pid] + # remove memref.alloc + pattern_alloc = re.compile(r"^.*memref\.alloc.*\n?", re.MULTILINE) + func_str = re.sub(pattern_alloc, "", func_str) + # replace new buffer name + pattern_boundary = r"(?" - ) + """ + dist_allocs define the allocation strategy for each argument. Currently, there are only two options: True or False. + + - If True: The memory tile divides the memory of the argument and distributes it among all compute tiles using + `aie.objectfifo.link`. This is only feasible if the total memory consumed by all compute tiles does not exceed + the original memory allocated to the argument. + Example: If there are 3 compute tiles, each consuming 1/3 of matrix A, this strategy can be applied. + + - If False: The memory tile assigns the entire memory of the argument to each compute tile. + """ + dist_allocs = [False] * len(input_args) + for i, (in_type, orig_in_type, shape, orig_shape) in enumerate(input_args[:-1]): + total_sizes = [0] * len(orig_shape) + for sizes in func_arg_sizes: + for dim in range(len(orig_shape)): + total_sizes[dim] += sizes[i][dim] + for dim, orig_len in enumerate(orig_shape): + if total_sizes[dim] <= orig_len: + dist_allocs[i] = True + break + if dist_allocs[i]: + # depth=2 means double buffer + code += format_str( + f"aie.objectfifo @in_sh{i}(%tile_shim, {{%tile_mem{i}}}, 2 : i32) : !aie.objectfifo<{orig_in_type}>" + ) + for pid in range(pe_size): + code += format_str( + f"aie.objectfifo @in{i}_p{pid}(%tile_mem{i}, {{%tile_comp{pid}}}, 2 : i32) : !aie.objectfifo<{in_type}>" + ) + in_mem_str = ", ".join([f"@in{i}_p{pid}" for pid in range(pe_size)]) + shape_prod = np.prod(shape) + in_mem_stride = list(range(0, shape_prod * pe_size, shape_prod)) + # (src_offsets, dst_offsets) + code += format_str( + f"aie.objectfifo.link [@in_sh{i}] -> [{in_mem_str}]([] {in_mem_stride})" + ) + else: + code += format_str( + f"aie.objectfifo @in_sh{i}(%tile_shim, {{%tile_mem{i}}}, 2 : i32) : !aie.objectfifo<{orig_in_type}>" + ) + in_tile_str = ", ".join([f"%tile_comp{pid}" for pid in range(pe_size)]) + code += format_str( + f"aie.objectfifo @in{i}_p0(%tile_mem{i}, {{{in_tile_str}}}, 2 : i32) : !aie.objectfifo<{in_type}>" + ) + code += format_str(f"aie.objectfifo.link [@in_sh{i}] -> [@in{i}_p0]([] [])") + out_id = len(input_args) - 1 + out_type, orig_out_type, out_shape, orig_out_shape = input_args[-1] + total_sizes = [0] * len(orig_out_shape) + for sizes in func_arg_sizes: + for dim in range(len(orig_out_shape)): + total_sizes[dim] += sizes[-1][dim] + for dim, orig_out_len in enumerate(orig_out_shape): + if total_sizes[dim] <= orig_out_len: + dist_allocs[-1] = True + break + if dist_allocs[-1]: + # output uses tile_mem0 for pid in range(pe_size): code += format_str( - f"aie.objectfifo @in{i}_p{pid}(%tile_mem{i}, {{%tile_comp{pid}}}, 2 : i32) : !aie.objectfifo<{in_type}>" + f"aie.objectfifo @out_p{pid}(%tile_comp{pid}, {{%tile_mem0}}, 2 : i32) : !aie.objectfifo<{out_type}>" ) - in_mem_str = ", ".join([f"@in{i}_p{pid}" for pid in range(pe_size)]) - shape_prod = np.prod(shape) - in_mem_stride = list(range(0, shape_prod * pe_size, shape_prod)) - # (src_offsets, dst_offsets) code += format_str( - f"aie.objectfifo.link [@in_sh{i}] -> [{in_mem_str}]([] {in_mem_stride})" + f"aie.objectfifo @out_sh(%tile_mem0, {{%tile_shim}}, 2 : i32) : !aie.objectfifo<{orig_out_type}>" ) - out_id = len(input_args) - 1 - out_type, orig_out_type, out_shape = input_args[-1] - # output uses tile_mem0 - for pid in range(pe_size): + out_mem_str = ", ".join([f"@out_p{pid}" for pid in range(pe_size)]) + shape_prod = np.prod(out_shape) + out_mem_stride = list(range(0, shape_prod * pe_size, shape_prod)) code += format_str( - f"aie.objectfifo @out_p{pid}(%tile_comp{pid}, {{%tile_mem0}}, 2 : i32) : !aie.objectfifo<{out_type}>" + f"aie.objectfifo.link [{out_mem_str}] -> [@out_sh]({out_mem_stride} [])" ) - code += format_str( - f"aie.objectfifo @out_sh(%tile_mem0, {{%tile_shim}}, 2 : i32) : !aie.objectfifo<{orig_out_type}>" - ) - out_mem_str = ", ".join([f"@out_p{pid}" for pid in range(pe_size)]) - shape_prod = np.prod(out_shape) - out_mem_stride = list(range(0, shape_prod * pe_size, shape_prod)) - code += format_str( - f"aie.objectfifo.link [{out_mem_str}] -> [@out_sh]({out_mem_stride} [])" - ) + else: + out_tile_str = ", ".join([f"%tile_comp{pid}" for pid in range(pe_size)]) + code += format_str( + f"aie.objectfifo @out_sh(%tile_mem0, {{%tile_shim}}, 2 : i32) : !aie.objectfifo<{orig_out_type}>" + ) + code += format_str( + f"aie.objectfifo @out_p0({{{out_tile_str}}}, {{%tile_mem0}}, 2 : i32) : !aie.objectfifo<{out_type}>" + ) + code += format_str("aie.objectfifo.link [@out_p0] -> [@out_sh]([] [])") # create core computation - for pid in range(pe_size): + for pid, func_str in enumerate(func_strs): code += format_str(f"%core_0_{pid + 2} = aie.core(%tile_comp{pid}) {{") with format_code(indent=6): - code += format_str("%c0 = arith.constant 0 : index") - code += format_str("%c1 = arith.constant 1 : index") + code += format_str("%global_c0 = arith.constant 0 : index") + code += format_str("%global_c1 = arith.constant 1 : index") code += format_str( "%c9223372036854775807 = arith.constant 9223372036854775807 : index" ) code += format_str( - "scf.for %arg0 = %c0 to %c9223372036854775807 step %c1 {" + "scf.for %arg0 = %global_c0 to %c9223372036854775807 step %global_c1 {" ) with format_code(indent=8): - for i, (in_type, _, shape) in enumerate(input_args[:-1]): + for i, (in_type, _, shape, _) in enumerate(input_args[:-1]): code += format_str( - f"%fifo{i} = aie.objectfifo.acquire @in{i}_p{pid}(Consume, 1) : !aie.objectfifosubview<{in_type}>" + f"%fifo{i} = aie.objectfifo.acquire @in{i}_p{pid if dist_allocs[i] else 0}(Consume, 1) : !aie.objectfifosubview<{in_type}>" ) code += format_str( f"%local{i} = aie.objectfifo.subview.access %fifo{i}[0] : !aie.objectfifosubview<{in_type}> -> {in_type}" ) - mod_str = mod_str.replace(f"%arg{i}", f"%local{i}") + func_str = func_str.replace(f"%arg{i}", f"%local{i}") code += format_str( f"%fifo_out = aie.objectfifo.acquire @out_p{pid}(Produce, 1) : !aie.objectfifosubview<{out_type}>" ) code += format_str( f"%local_out = aie.objectfifo.subview.access %fifo_out[0] : !aie.objectfifosubview<{out_type}> -> {out_type}" ) - mod_str = mod_str.replace(f"%arg{out_id}", "%local_out") + func_str = func_str.replace(f"%arg{out_id}", "%local_out") with format_code(indent=4): - for line in mod_str.splitlines()[2:-3]: + for line in func_str.splitlines()[1:-2]: code += format_str(line, strip=False) for i in range(len(input_args[:-1])): code += format_str( - f"aie.objectfifo.release @in{i}_p{pid}(Consume, 1)" + f"aie.objectfifo.release @in{i}_p{pid if dist_allocs[i] else 0}(Consume, 1)" ) code += format_str(f"aie.objectfifo.release @out_p{pid}(Produce, 1)") code += format_str("}") @@ -309,18 +423,18 @@ def codegen_aie_mlir(mod, orig_input_args, mapping): in_args = ", ".join( [ f"%arg{i}: {orig_in_type}" - for i, (_, orig_in_type, _) in enumerate(input_args[:-1]) + for i, (_, orig_in_type, _, _) in enumerate(input_args[:-1]) ] ) code += format_str( f"aiex.runtime_sequence({in_args}, %arg{out_id}: {orig_out_type}) {{" ) with format_code(indent=6): - for i, (_, orig_in_type, shape) in enumerate(input_args[:-1]): + for i, (_, orig_in_type, shape, _) in enumerate(input_args[:-1]): # (x, y, memref[offset][size][stride]) # issue_token: MM2S-false, S2MM-true if len(shape) == 1: - size_n_stride = f"[1, 1, 1, {shape[0] * pe_size}][0, 0, 0, 1]" + size_n_stride = f"[1, 1, 1, {shape[0] * (pe_size if dist_allocs[i] else 1)}][0, 0, 0, 1]" else: size_n_stride = ( f"[1, 1, {shape[0] * pe_size}, {shape[1]}][0, 0, {shape[1]}, 1]" @@ -344,21 +458,164 @@ def codegen_aie_mlir(mod, orig_input_args, mapping): return code +def reindex_tensor_access(mod): + ctx = mod.context + funcs = list(mod.body.operations)[:-1] + # func -> arg -> dim + func_arg_lower_bounds = [] + func_arg_sizes = [] + for pi, func in enumerate(funcs): + entry_block = func.regions[0].blocks[0] + args = entry_block.arguments + arg_types = args.types + # TODO: might need some specialization for scalar input arg + lower_bounds = [ + [float("inf") for _ in range(len(arg_type.shape))] for arg_type in arg_types + ] + sizes = [[0 for _ in range(len(arg_type.shape))] for arg_type in arg_types] + for block in func.regions[0].blocks: + for op in block.operations: + if op.operation.name in {"tensor.extract_slice", "tensor.insert_slice"}: + operand_idx = ( + 0 if op.operation.name == "tensor.extract_slice" else 1 + ) + if op.operands[operand_idx] not in args: + continue + index = list(args).index(op.operands[operand_idx]) + static_offsets = op.attributes["static_offsets"] + static_sizes = op.attributes["static_sizes"] + for i, (offset, size) in enumerate( + zip(static_offsets, static_sizes) + ): + lower_bounds[index][i] = min(lower_bounds[index][i], offset) + sizes[index][i] = max(sizes[index][i], size) + for i, lower_bound in enumerate(lower_bounds): + # Arguments never used with slice + if lower_bound[0] == float("inf"): + # If ever used, assume using entire tensor + if len(list(args[i].uses)) > 0: + lower_bounds[i] = [0] * len(lower_bound) + sizes[i] = args[i].type.shape + else: + lower_bounds[i] = [0] * len(lower_bound) + sizes[i] = [0] * len(lower_bound) + + func_arg_lower_bounds.append(lower_bounds) + func_arg_sizes.append(sizes) + + for pi, func in enumerate(funcs): + entry_block = func.regions[0].blocks[0] + args = entry_block.arguments + lower_bounds = func_arg_lower_bounds[pi] + sizes = func_arg_sizes[pi] + for block in func.regions[0].blocks: + for op in block.operations: + if op.operation.name == "tensor.extract_slice": + if op.operands[0] not in args: + continue + index = list(args).index(op.operands[0]) + static_offsets = op.attributes["static_offsets"] + new_offsets = [] + for i, offset in enumerate(static_offsets): + # TODO: need to support multi-dim mappings + # diff = pi * (op.operands[0].type.shape[0] // pe_size) + new_offset = offset - lower_bounds[index][i] + new_offset_attr = IntegerAttr.get( + IntegerType.get_signless(64, ctx), new_offset + ) + new_offsets.append(new_offset_attr) + op.attributes["static_offsets"] = DenseI64ArrayAttr.get( + new_offsets, ctx + ) + elif op.operation.name == "tensor.insert_slice": + if op.operands[1] not in args: + continue + index = list(args).index(op.operands[1]) + static_offsets = op.attributes["static_offsets"] + new_offsets = [] + for i, offset in enumerate(static_offsets): + # TODO: need to support multi-dim mappings + # diff = pi * (op.operands[1].type.shape[0] // pe_size) + new_offset = offset - lower_bounds[index][i] + new_offset_attr = IntegerAttr.get( + IntegerType.get_signless(64, ctx), new_offset + ) + new_offsets.append(new_offset_attr) + op.attributes["static_offsets"] = DenseI64ArrayAttr.get( + new_offsets, ctx + ) + return func_arg_lower_bounds, func_arg_sizes + + +def update_func_op_arg_types( + func_op: func_d.FuncOp, input_args, new_shapes, context: Context +): + old_func_type = func_op.function_type + old_result_types = old_func_type.value.results + new_input_types = [] + for (ele_type_str, _), shape in zip(input_args, new_shapes): + elem_ty = get_element_type_from_str(ele_type_str, context) + memref_ty = RankedTensorType.get(shape, elem_ty) + new_input_types.append(memref_ty) + new_func_type = FunctionType.get(new_input_types, old_result_types, context) + new_type = TypeAttr.get(new_func_type, context) + func_op.operation.attributes["function_type"] = new_type + entry_block = func_op.entry_block + for i, block_arg in enumerate(entry_block.arguments): + block_arg.set_type(new_input_types[i]) + + +def lower_tensor_to_memref(mod): + passes = [ + # "linalg-generalize-named-ops", + # "linalg-fuse-elementwise-ops", + "one-shot-bufferize{bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}", + "func.func(convert-linalg-to-affine-loops),lower-affine", + ] + pipeline = f'builtin.module({",".join(passes)})' + with mod.context: + mlir_pass_manager.parse(pipeline).run(mod.operation) + + +def record_local_buffer(mod): + func_buf_dicts = [] + funcs = list(mod.body.operations)[:-1] + for func in funcs: + buf_dict = {} + for block in func.regions[0].blocks: + for op in block.operations: + if op.operation.name == "memref.alloc": + name = op.result.get_name() + dtype, shape = get_dtype_and_shape_from_type(op.result.type) + buf_dict[name] = (dtype, shape) + func_buf_dicts.append(buf_dict) + return func_buf_dicts + + class AIEModule: - def __init__(self, module, top_func_name, project, mapping): + def __init__(self, module, top_func_name, project): self.module = module self.top_func_name = top_func_name - self.top_func = find_func_in_module(self.module, self.top_func_name) + # TODO: need to support multiple kernels + for op in module.body.operations: + if isinstance(op, func_d.FuncOp) and op.name.value != top_func_name: + self.kernel_func = op self.project = project self.module = module - self.mapping = mapping def build(self): assert "MLIR_AIE_INSTALL_DIR" in os.environ, "Please set MLIR_AIE_INSTALL_DIR" assert "PEANO_INSTALL_DIR" in os.environ, "Please set PEANO_INSTALL_DIR" - inputs, outputs = get_func_inputs_outputs(self.top_func) - input_args = inputs + outputs - code = codegen_aie_mlir(self.module, input_args, self.mapping) + self.inputs, self.outputs = get_func_inputs_outputs(self.kernel_func) + input_args = self.inputs + self.outputs + _, func_arg_sizes = reindex_tensor_access(self.module) + with self.module.context as ctx, Location.unknown(): + for i, func_op in enumerate(list(self.module.body.operations)[:-1]): + shapes = func_arg_sizes[i] + update_func_op_arg_types(func_op, input_args, shapes, ctx) + lower_tensor_to_memref(self.module) + func_buf_dicts = record_local_buffer(self.module) + code = codegen_aie_mlir(self.module, input_args, func_arg_sizes, func_buf_dicts) os.makedirs(os.path.join(self.project, "build"), exist_ok=True) with open(os.path.join(self.project, "top.mlir"), "w", encoding="utf-8") as f: f.write(code) @@ -393,8 +650,7 @@ def __call__(self, *args): process.wait() if process.returncode != 0: raise RuntimeError("Failed to execute AIE code.") - inputs, _ = get_func_inputs_outputs(self.top_func) result = read_tensor_from_file( - inputs[-1][0], args[-1].shape, f"{self.project}/output.data" + self.inputs[-1][0], args[-1].shape, f"{self.project}/output.data" ) args[-1][:] = result diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py index 89f1b767..54c9b721 100644 --- a/allo/backend/llvm.py +++ b/allo/backend/llvm.py @@ -1,6 +1,6 @@ # Copyright Allo authors. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-name-in-module, inconsistent-return-statements +# pylint: disable=no-name-in-module, inconsistent-return-statements, too-many-function-args import os import ctypes diff --git a/allo/dataflow.py b/allo/dataflow.py index 646bfac1..18ee695c 100644 --- a/allo/dataflow.py +++ b/allo/dataflow.py @@ -121,21 +121,22 @@ def remove_unused_func_ops(s, func_names): func_op.erase() -def _build_top(s, stream_info): +def _build_top(s, stream_info, enable_tensor): """ s: top-level schedule stream_info: {func_name: [(stream_names, direction)]} """ # remove unused kernel - passes = ["canonicalize"] - pipeline = f'builtin.module(func.func({",".join(passes)}))' - try: - with s.module.context: - mlir_pass_manager.parse(pipeline).run(s.module.operation) - except Exception as e: - print("Error: failed to run MLIR lower pipeline, printing module...") - print(s.module) - raise e + if not enable_tensor: + passes = ["canonicalize"] + pipeline = f'builtin.module(func.func({",".join(passes)}))' + try: + with s.module.context: + mlir_pass_manager.parse(pipeline).run(s.module.operation) + except Exception as e: + print("Error: failed to run MLIR lower pipeline, printing module...") + print(s.module) + raise e remove_unused_func_ops(s, stream_info.keys()) # create argument mapping @@ -237,11 +238,11 @@ def df_primitive_default(s): df_pipeline(s.module, rewind=True) -def customize(func, opt_default=True): +def customize(func, opt_default=True, enable_tensor=False): global_vars = get_global_vars(func) - s = _customize(func, global_vars=global_vars) + s = _customize(func, global_vars=global_vars, enable_tensor=enable_tensor) stream_info = move_stream_to_interface(s) - s = _build_top(s, stream_info) + s = _build_top(s, stream_info, enable_tensor) if opt_default: df_primitive_default(s) @@ -257,16 +258,18 @@ def build( configs=None, wrap_io=True, opt_default=True, + enable_tensor=False, ): if target == "aie": global_vars = get_global_vars(func) - s = _customize(func, global_vars=global_vars) - mapping = func.mapping - mod = AIEModule(s.module, s.top_func_name, project, mapping) + s = _customize(func, global_vars=global_vars, enable_tensor=True) + # stream_info = move_stream_to_interface(s) + # s = _build_top(s, stream_info, enable_tensor) + mod = AIEModule(s.module, s.top_func_name, project) mod.build() return mod # FPGA backend - s = customize(func, opt_default) + s = customize(func, opt_default, enable_tensor=enable_tensor) hls_mod = s.build( target=target, mode=mode, diff --git a/allo/dsl.py b/allo/dsl.py index 5fd6e5d0..0254244a 100644 --- a/allo/dsl.py +++ b/allo/dsl.py @@ -29,6 +29,10 @@ def sub(lhs, rhs, name=None): return lhs - rhs +def mul(lhs, rhs, name=None): + return lhs * rhs + + def div(lhs, rhs, name=None): return lhs / rhs diff --git a/allo/ir/builder.py b/allo/ir/builder.py index f2fffb14..438cbbff 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -980,15 +980,24 @@ def build_affine_expr(ctx, node): @staticmethod def build_slices(ctx, node, in_shape): # caculate the static offsets, sizes, strides for ExtractSlice and InsertSlice - slices = node.slice.dims + slices = node.slice.dims if len(node.shape) > 1 else [node.slice] static_offsets = [] static_sizes = [] static_strides = [] + offsets = [] + sizes = [] + # Not support dynamic strides? for index, size in zip(slices, in_shape): if isinstance(index, ast.Slice): - lower = 0 if index.lower is None else build_stmt(ctx, index.lower).val + lower = ( + 0 + if index.lower is None + else ASTResolver.resolve_constant(index.lower, ctx) + ) upper = ( - size if index.upper is None else build_stmt(ctx, index.upper).val + size + if index.upper is None + else ASTResolver.resolve_constant(index.upper, ctx) ) if index.step is None: step = 1 @@ -996,6 +1005,39 @@ def build_slices(ctx, node, in_shape): step = index.step.value else: raise RuntimeError("Unsupported step type") + if lower is None: + static_offsets.append(ShapedType.get_dynamic_size()) + offset_expr = build_stmt(ctx, index.lower) + offset = ASTTransformer.build_cast_op( + ctx, offset_expr, index.dtype, Index() + ).result + offsets.append(offset) + static_sizes.append(ShapedType.get_dynamic_size()) + if upper is None: + upper_expr = build_stmt(ctx, index.upper) + size_expr = tensor_d.FloorDivSOp( + tensor_d.SubOp(upper_expr, offset_expr).result, step + ) + else: + size_expr = tensor_d.FloorDivSOp( + tensor_d.SubOp(upper, offset_expr).result, step + ) + size = ASTTransformer.build_cast_op( + ctx, size_expr, index.dtype, Index() + ).result + sizes.append(size) + continue + if upper is None: + static_sizes.append(ShapedType.get_dynamic_size()) + upper_expr = build_stmt(ctx, index.upper) + size_expr = tensor_d.FloorDivSOp( + tensor_d.SubOp(upper_expr, lower).result, step + ) + size = ASTTransformer.build_cast_op( + ctx, size_expr, index.dtype, Index() + ).result + sizes.append(size) + continue elif isinstance(index, (ast.Index, ast.Constant)): lower = ( index.value.value if isinstance(index, ast.Index) else index.value @@ -1019,7 +1061,7 @@ def build_slices(ctx, node, in_shape): def build_tensor_access(ctx, node, val=None, idx=0): # TODO: Fix tuple idx value = build_stmt(ctx, node.value) - if len(node.shape) > 1: + if len(node.shape) >= 1: dtype = RankedTensorType(value.result.type).element_type in_shape = RankedTensorType(value.result.type).shape ( @@ -1933,6 +1975,7 @@ def build_Call(ctx, node): "log", "add", "sub", + "mul", "div", "relu", "conv2d", @@ -1944,6 +1987,23 @@ def build_Call(ctx, node): "view", "concat", }: + if fn_name in {"add", "sub", "mul", "div"}: + new_args[0] = ASTTransformer.build_broadcast_op( + ctx, + new_args[0], + node.dtype, + node.args[0].shape, + node.shape, + node.dims[0], + ) + new_args[1] = ASTTransformer.build_broadcast_op( + ctx, + new_args[1], + node.dtype, + node.args[1].shape, + node.shape, + node.dims[1], + ) return ASTTransformer.build_library_op( ctx, node=node, attr=fn_name, new_args=new_args ) diff --git a/allo/ir/infer.py b/allo/ir/infer.py index 4076d142..fdbecce1 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -892,16 +892,20 @@ def visit_library_op(ctx, node, op_name, new_args): "log", "add", "sub", + "mul", "div", "relu", "copy", }: # Element-wise operation - if op_name in {"add", "sub", "div"}: - assert ( - new_args[0].shape == new_args[1].shape - ), f"Only support element-wise {op_name} of two inputs with the same shape, got {new_args[0].shape} and {new_args[1].shape}" - node.shape = new_args[0].shape + if op_name in {"add", "sub", "mul", "div"}: + final_shape, lhs_dims, rhs_dims = TypeInferer.visit_broadcast( + ctx, new_args[0], new_args[1] + ) + node.dims = (lhs_dims, rhs_dims) + node.shape = final_shape + else: + node.shape = new_args[0].shape node.dtype = new_args[0].dtype return node if op_name in {"matmul", "bmm", "linear", "conv2d", "sumpool", "maxpool"}: diff --git a/allo/utils.py b/allo/utils.py index 91ad7f56..7717cba5 100644 --- a/allo/utils.py +++ b/allo/utils.py @@ -421,3 +421,13 @@ def extract_out_np_arrays_from_out_struct(out_struct_ptr_ptr, num_output): ranked_memref_to_numpy(getattr(out_struct_ptr_ptr[0][0], f"memref{i}")) ) return out_np_arrays + + +def get_element_type_from_str(element_type_str, context): + if element_type_str.startswith("f"): + bits = int(element_type_str[1:]) + return F32Type.get(context) if bits == 32 else F64Type.get(context) + if element_type_str.startswith("i"): + bits = int(element_type_str[1:]) + return IntegerType.get_signless(bits, context) + raise ValueError(f"unknown element_type_str: {element_type_str}") diff --git a/tests/dataflow/aie/test_gemm.py b/tests/dataflow/aie/test_gemm.py new file mode 100644 index 00000000..1d2b1f96 --- /dev/null +++ b/tests/dataflow/aie/test_gemm.py @@ -0,0 +1,35 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import allo.dataflow as df +from allo.ir.types import int32 +import numpy as np +import allo + + +def _test_gemm(): + Ty = int32 + M, N, K = 16, 16, 16 + P0 = 2 + Mt = M // P0 + + @df.region() + def top(): + @df.kernel(mapping=[P0]) + def gemm(A: Ty[M, K], B: Ty[K, N], C: Ty[M, N]): + pi = df.get_pid() + C[pi * Mt : (pi + 1) * Mt, :] = allo.matmul( + A[pi * Mt : (pi + 1) * Mt, :], B + ) + + mod = df.build(top, target="aie") + A = np.random.randint(0, 64, (M, K)).astype(np.int32) + B = np.random.randint(0, 64, (K, N)).astype(np.int32) + C = np.zeros((M, N)).astype(np.int32) + mod(A, B, C) + np.testing.assert_allclose(C, A @ B, atol=1e-5) + print("PASSED!") + + +if __name__ == "__main__": + _test_gemm() diff --git a/tests/dataflow/aie/test_matrix.py b/tests/dataflow/aie/test_matrix.py index bf810d63..f7b22b18 100644 --- a/tests/dataflow/aie/test_matrix.py +++ b/tests/dataflow/aie/test_matrix.py @@ -11,16 +11,19 @@ def _test_matrix_scalar_add(): Ty = int32 M, N = 64, 64 P0 = 4 + Mt = M // P0 - @df.kernel(mapping=[P0]) - def core(A: Ty[M, N], B: Ty[M, N]): - for i, j in allo.grid(M // P0, N): - B[i, j] = A[i, j] + 1 + @df.region() + def top(): + @df.kernel(mapping=[P0]) + def core(A: Ty[M, N], B: Ty[M, N]): + pi = df.get_pid() + B[pi * Mt : (pi + 1) * Mt, :] = allo.add(A[pi * Mt : (pi + 1) * Mt, :], 1) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.randint(0, 100, (M, N)).astype(np.int32) B = np.zeros((M, N)).astype(np.int32) - top(A, B) + mod(A, B) np.testing.assert_allclose(B, A + 1) print("PASSED!") @@ -29,17 +32,22 @@ def _test_matrix_matrix_add(): Ty = int32 M, N = 64, 64 P0 = 4 - - @df.kernel(mapping=[P0]) - def core(A: Ty[M, N], B: Ty[M, N], C: Ty[M, N]): - for i, j in allo.grid(M // P0, N): - C[i, j] = A[i, j] + B[i, j] - - top = df.build(core, target="aie") + Mt = M // P0 + + @df.region() + def top(): + @df.kernel(mapping=[P0]) + def core(A: Ty[M, N], B: Ty[M, N], C: Ty[M, N]): + pi = df.get_pid() + C[pi * Mt : (pi + 1) * Mt, :] = allo.add( + A[pi * Mt : (pi + 1) * Mt, :], B[pi * Mt : (pi + 1) * Mt, :] + ) + + mod = df.build(top, target="aie") A = np.random.randint(0, 100, (M, N)).astype(np.int32) B = np.random.randint(0, 100, (M, N)).astype(np.int32) C = np.zeros((M, N)).astype(np.int32) - top(A, B, C) + mod(A, B, C) np.testing.assert_allclose(C, A + B) print("PASSED!") diff --git a/tests/dataflow/aie/test_multi_core.py b/tests/dataflow/aie/test_multi_core.py index e7832bdb..8cb28a50 100644 --- a/tests/dataflow/aie/test_multi_core.py +++ b/tests/dataflow/aie/test_multi_core.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import allo -from allo.ir.types import int32, float32 +from allo.ir.types import int32 import allo.dataflow as df import numpy as np @@ -13,18 +13,21 @@ def _test_vector_scalar_add(): # v v-------------------------v v # shim tile <-> mem tile <-> comp tile0 comp tile1 comp tile2 Ty = int32 - M = 48 - P0 = 3 + M = 1024 + P0 = 4 + Mt = M // P0 - @df.kernel(mapping=[P0]) - def core(A: Ty[M], B: Ty[M]): - for i in range(M // P0): - B[i] = A[i] + 1 + @df.region() + def top(): + @df.kernel(mapping=[P0]) + def core(A: Ty[M], B: Ty[M]): + pi = df.get_pid() + B[pi * Mt : (pi + 1) * Mt] = allo.add(A[pi * Mt : (pi + 1) * Mt], 1) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.randint(0, 100, M).astype(np.int32) B = np.zeros(M).astype(np.int32) - top(A, B) + mod(A, B) np.testing.assert_allclose(B, A + 1) print("PASSED!") @@ -37,17 +40,22 @@ def _test_vector_vector_add(): Ty = int32 M = 1024 P0 = 4 + Mt = M // P0 - @df.kernel(mapping=[P0]) - def core(A: Ty[M], B: Ty[M], C: Ty[M]): - for i in range(M // P0): - C[i] = A[i] + B[i] + @df.region() + def top(): + @df.kernel(mapping=[P0]) + def core(A: Ty[M], B: Ty[M], C: Ty[M]): + pi = df.get_pid() + C[pi * Mt : (pi + 1) * Mt] = allo.add( + A[pi * Mt : (pi + 1) * Mt], B[pi * Mt : (pi + 1) * Mt] + ) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.randint(0, 100, M).astype(np.int32) B = np.random.randint(0, 100, M).astype(np.int32) C = np.zeros(M).astype(np.int32) - top(A, B, C) + mod(A, B, C) np.testing.assert_allclose(C, A + B) print("PASSED!") diff --git a/tests/dataflow/aie/test_vector.py b/tests/dataflow/aie/test_vector.py index 97927fd8..d088785a 100644 --- a/tests/dataflow/aie/test_vector.py +++ b/tests/dataflow/aie/test_vector.py @@ -12,15 +12,16 @@ def _test_vector_scalar_add(): Ty = int32 M = 1024 - @df.kernel(mapping=[1]) - def core(A: Ty[M], B: Ty[M]): - for i in range(M): - B[i] = A[i] + 1 + @df.region() + def top(): + @df.kernel(mapping=[1]) + def core(A: Ty[M], B: Ty[M]): + B[:] = allo.add(A, 1) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.randint(0, 100, M).astype(np.int32) B = np.zeros(M).astype(np.int32) - top(A, B) + mod(A, B) np.testing.assert_allclose(B, A + 1) print("PASSED!") @@ -30,15 +31,16 @@ def _test_vector_scalar_mul(): Ty = float32 M = 512 - @df.kernel(mapping=[1]) - def core(A: Ty[M], B: Ty[M]): - for i in range(M): - B[i] = A[i] * 2 + @df.region() + def top(): + @df.kernel(mapping=[1]) + def core(A: Ty[M], B: Ty[M]): + B[:] = allo.mul(A, 2) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.random(M).astype(np.float32) B = np.zeros(M).astype(np.float32) - top(A, B) + mod(A, B) np.testing.assert_allclose(B, A * 2, rtol=1e-5) print("PASSED!") @@ -48,16 +50,17 @@ def _test_vector_vector_add(): Ty = int32 M = 1024 - @df.kernel(mapping=[1]) - def core(A: Ty[M], B: Ty[M], C: Ty[M]): - for i in range(M): - C[i] = A[i] + B[i] + @df.region() + def top(): + @df.kernel(mapping=[1]) + def core(A: Ty[M], B: Ty[M], C: Ty[M]): + C[:] = allo.add(A, B) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.randint(0, 100, M).astype(np.int32) B = np.random.randint(0, 100, M).astype(np.int32) C = np.zeros(M).astype(np.int32) - top(A, B, C) + mod(A, B, C) np.testing.assert_allclose(C, A + B) print("PASSED!") @@ -67,16 +70,17 @@ def _test_vector_vector_mul(): Ty = float32 M = 1024 - @df.kernel(mapping=[1]) - def core(A: Ty[M], B: Ty[M], C: Ty[M]): - for i in range(M): - C[i] = A[i] * B[i] + @df.region() + def top(): + @df.kernel(mapping=[1]) + def core(A: Ty[M], B: Ty[M], C: Ty[M]): + C[:] = allo.mul(A, B) - top = df.build(core, target="aie") + mod = df.build(top, target="aie") A = np.random.random(M).astype(np.float32) B = np.random.random(M).astype(np.float32) C = np.zeros(M).astype(np.float32) - top(A, B, C) + mod(A, B, C) np.testing.assert_allclose(C, A * B, rtol=1e-5) print("PASSED!")