Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataflow] [Pass] Enable multi-cache design for larger-scale gemm #280

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions allo/backend/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
ext_libs=None,
configs=None,
func_args=None,
wrapping=True,
):
self.top_func_name = top_func_name
self.mode = mode
Expand All @@ -186,12 +187,14 @@ def __init__(
self.func = find_func_in_module(self.module, top_func_name)
if platform == "vitis_hls":
assert func_args is not None, "Need to specify func_args"
generate_input_output_buffers(
self.module,
top_func_name,
flatten=True,
mappings=configs.get("mappings", None),
)

if wrapping:
generate_input_output_buffers(
self.module,
top_func_name,
flatten=True,
mappings=configs.get("mappings", None),
)

# TODO: Fix dataflow!
# if "dataflow" in self.func.attributes:
Expand Down
3 changes: 2 additions & 1 deletion allo/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def get_equivalent_variables(self, name):
return ele
return []

def build(self, target=None, mode=None, project=None, configs=None):
def build(self, target=None, mode=None, project=None, configs=None, wrapping=True):
if target is None or target == "llvm":
target = "llvm"
return LLVMModule(
Expand All @@ -896,6 +896,7 @@ def build(self, target=None, mode=None, project=None, configs=None):
ext_libs=self.ext_libs,
configs=configs,
func_args=self.func_args,
wrapping=wrapping,
)
raise NotImplementedError(f"Target {target} is not supported")

Expand Down
13 changes: 12 additions & 1 deletion allo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .ir.utils import get_global_vars, get_all_funcs_except_top
from .backend.aie import AIEModule
from .ir.types import Stream
from .passes import df_pipeline


def get_pid():
Expand Down Expand Up @@ -235,10 +236,19 @@ def customize(func):
s = _customize(func, global_vars=global_vars)
stream_info = move_stream_to_interface(s)
s = _build_top(s, stream_info)

df_pipeline(s.module, rewind=True)
return s


def build(func, target="vitis_hls", mode="csim", project="top.prj", configs=None):
def build(
func,
target="vitis_hls",
mode="csim",
project="top.prj",
configs=None,
wrapping=True,
):
if target == "aie":
global_vars = get_global_vars(func)
s = _customize(func, global_vars=global_vars)
Expand All @@ -253,5 +263,6 @@ def build(func, target="vitis_hls", mode="csim", project="top.prj", configs=None
mode=mode,
project=project,
configs=configs,
wrapping=wrapping,
)
return hls_mod
29 changes: 29 additions & 0 deletions allo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
UnrankedMemRefType,
FunctionType,
TypeAttr,
UnitAttr,
FlatSymbolRefAttr,
ArrayAttr,
Attribute,
Expand All @@ -26,6 +27,7 @@
func as func_d,
affine as affine_d,
memref as memref_d,
scf as scf_d,
linalg as linalg_d,
arith as arith_d,
)
Expand Down Expand Up @@ -705,3 +707,30 @@ def add_use(val, val_name):
# recover final sets
res = recover_sets()
return res


def df_pipeline(module, initiation_interval=1, rewind=False):

def pipe_loop_innermost(forop, ii, rewind):
inner_forops = []
for op in forop.body.operations:
if isinstance(op, (scf_d.ForOp, affine_d.AffineForOp)):
inner_forops.append(op)
if inner_forops:
for inner_forop in inner_forops:
pipe_loop_innermost(inner_forop, ii, rewind)
else:
forop.attributes["pipeline_ii"] = ii
if rewind:
forop.attributes["rewind"] = UnitAttr.get()
# print('Pipeline Once.')

with module.context:
i32 = IntegerType.get_unsigned(32)
ii = IntegerAttr.get(i32, initiation_interval)
for op in module.body.operations:
if isinstance(op, func_d.FuncOp):
func = op
for op_ in func.entry_block.operations:
if isinstance(op_, (scf_d.ForOp, affine_d.AffineForOp)):
pipe_loop_innermost(op_, ii, rewind)
9 changes: 1 addition & 8 deletions tests/dataflow/test_daisy_chain_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def top():
@df.kernel(mapping=[P0, P1])
def gemm(A: int16[M, K], B: int16[K, N], C: int16[M, N]):
i, j = df.get_pid()
# periperals kernels
# peripheral kernels
with allo.meta_if(i == 0 and j == 0):
for k in range(K):
# pack data A
Expand Down Expand Up @@ -56,9 +56,6 @@ def gemm(A: int16[M, K], B: int16[K, N], C: int16[M, N]):
fifo_A[i - 1, 0].put(a[16 * (i - 1) : 16 * i])
with allo.meta_if(i < M):
L2_A[i + 1].put(a)
# TODO: Fix meta matching
with allo.meta_else():
pass

with allo.meta_elif(i == 0):
# j > 0, the first row
Expand All @@ -67,17 +64,13 @@ def gemm(A: int16[M, K], B: int16[K, N], C: int16[M, N]):
fifo_B[0, j - 1].put(b[16 * (j - 1) : 16 * j])
with allo.meta_if(j < N):
L2_B[j + 1].put(b)
with allo.meta_else():
pass

with allo.meta_elif(i == P0 - 1):
c_C = L1_C[i - 2, N - j].get()
L2_C[j - 1].put(c_C)
with allo.meta_if(j != 1):
for ind in range(j - 1):
L2_C[j - 1].put(L2_C[j - 2].get())
with allo.meta_else():
pass

with allo.meta_elif(j == P1 - 1):
pass
Expand Down
208 changes: 208 additions & 0 deletions tests/dataflow/test_multi_cache_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import allo
from allo.ir.types import int8, int16, UInt
from allo.utils import get_np_struct_type
import allo.dataflow as df
import allo.backend.hls as hls
import allo.dsl as dsl
import numpy as np

M, N, K = 128, 128, 128
Rt, Ct = 8, 8

# M, N, K = 16, 16, 16
# Rt, Ct = 4, 4

# M, N, K = 4, 4, 4
# Rt, Ct = 2, 2
P0, P1 = Rt + 2, Ct + 2


@df.region()
def top():

L3_A = df.pipe(dtype=UInt(Rt * 8), shape=(), depth=4)
L3_B = df.pipe(dtype=UInt(Ct * 8), shape=(), depth=4)
L3_C = df.pipe(dtype=UInt(Rt * 8), shape=(), depth=4)

L2_A = df.array(df.pipe(dtype=UInt(Rt * 8), shape=(), depth=4), shape=(P0 - 1,))
L2_B = df.array(df.pipe(dtype=UInt(Ct * 8), shape=(), depth=4), shape=(P1 - 1,))

L1_C = df.array(df.pipe(dtype=UInt(Rt * 8), shape=(), depth=4), shape=(Rt, Ct))
L2_C = df.array(df.pipe(dtype=UInt(Rt * 8), shape=(), depth=4), shape=(Ct,))

fifo_A = df.array(df.pipe(dtype=int8, shape=(), depth=4), shape=(Rt, Ct))
fifo_B = df.array(df.pipe(dtype=int8, shape=(), depth=4), shape=(Rt, Ct))

@df.kernel(mapping=[1])
def offchip_loadA(A_Packed: UInt(Rt * 8)[M * K // Rt]):
for mt, nt in dsl.grid(M // Rt, N // Ct):
for k in range(K):
L3_A.put(A_Packed[mt * K + k])

@df.kernel(mapping=[1])
def offchip_loadB(B_Packed: UInt(Ct * 8)[K * N // Ct]):
for mt, nt in dsl.grid(M // Rt, N // Ct):
for k in range(K):
L3_B.put(B_Packed[nt * K + k])

@df.kernel(mapping=[P0, P1])
def gemm():
i, j = df.get_pid()
# peripheral kernels
with allo.meta_if(i == 0 and j == 0):
for mt, nt in dsl.grid(M // Rt, N // Ct):
for k in range(K):
L2_A[1].put(L3_A.get())
L2_B[1].put(L3_B.get())

with allo.meta_elif(i == P0 - 1 and j == P1 - 1):
for mt, nt in dsl.grid(M // Rt, N // Ct):
for n in range(Ct):
L3_C.put(L2_C[Ct - 1].get())

with allo.meta_elif(i in {0, P0 - 1} and j in {0, P1 - 1}):
pass

with allo.meta_elif(j == 0):
# i > 0, the first column
for mt, nt in dsl.grid(M // Rt, N // Ct):
for k in range(K):
a = L2_A[i].get()
# unpack data
fifo_A[i - 1, 0].put(a[8 * (i - 1) : 8 * i])
with allo.meta_if(i < Rt):
L2_A[i + 1].put(a)

with allo.meta_elif(i == 0):
# j > 0, the first row
for mt, nt in dsl.grid(M // Rt, N // Ct):
for k in range(K):
b = L2_B[j].get()
fifo_B[0, j - 1].put(b[8 * (j - 1) : 8 * j])
with allo.meta_if(j < Ct):
L2_B[j + 1].put(b)

with allo.meta_elif(i == P0 - 1):
for mt, nt in dsl.grid(M // Rt, N // Ct):
c_C = L1_C[i - 2, Ct - j].get()
L2_C[j - 1].put(c_C)
with allo.meta_if(j != 1):
for ind in range(j - 1):
L2_C[j - 1].put(L2_C[j - 2].get())

with allo.meta_elif(j == P1 - 1):
pass

# main body
with allo.meta_else():
for mt, nt in dsl.grid(M // Rt, N // Ct):
c: int8 = 0
for k in range(K):
a: int8 = fifo_A[i - 1, j - 1].get()
b: int8 = fifo_B[i - 1, j - 1].get()
c += a * b
with allo.meta_if(j < Ct):
fifo_A[i - 1, j].put(a)
with allo.meta_if(i < Rt):
fifo_B[i, j - 1].put(b)

with allo.meta_if(i == 1):
packed_tmp: UInt(M * 8) = 0
with allo.meta_else():
packed_tmp: UInt(M * 8) = L1_C[i - 2, j - 1].get()

packed_c: UInt(M * 8) = 0
for m in range(Rt):
if m == i - 1:
packed_c[m * 8 : (m + 1) * 8] = c
else:
packed_c[m * 16 : (m + 1) * 16] = packed_tmp[
m * 16 : (m + 1) * 16
]
L1_C[i - 1, j - 1].put(packed_c)

@df.kernel(mapping=[1])
def offchip_store(C_Packed: UInt(Rt * 8)[M * N // Rt]):
for mt, nt in dsl.grid(M // Rt, N // Ct):
for n in range(Ct):
C_Packed[mt * N + nt * Ct + n] = L3_C.get()


def test_large_scale_gemm():
def serialize_A(matrix_A):
A_ser = np.zeros((M * K), dtype=np.int8)
for mt in range(M // Rt):
for k in range(K):
for m in range(Rt):
A_ser[mt * (K * Rt) + k * Rt + m] = matrix_A[mt * Rt + m, k]
return A_ser

def serialize_B(matrix_B):
B_ser = np.zeros((K * N), dtype=np.int8)
for nt in range(N // Ct):
for k in range(K):
for n in range(Ct):
B_ser[nt * (K * Ct) + k * Ct + n] = matrix_B[k, nt * Ct + n]
return B_ser

def deserialize_C(C_ser):
matrix_C = np.zeros((M, N), dtype=np.int8)
for mt in range(M // Rt):
for n in range(N):
for m in range(Rt):
matrix_C[mt * Rt + m, n] = C_ser[mt * (N * Rt) + n * Rt + m]
return matrix_C

# # TODO: Fix the packing-related issue!
# np_type_A = get_np_struct_type(Rt * 8)
# np_type_B = get_np_struct_type(Ct * 8)
# np_type_C = get_np_struct_type(Rt * 8)

np_type_A = np.int64
np_type_B = np.int64
np_type_C = np.int64

A = np.random.randint(-2, 2, (M, K), dtype=np.int8)
B = np.random.randint(-2, 2, (K, N), dtype=np.int8)
C = np.zeros((M, N), dtype=np.int8)

A_packed = serialize_A(A).view(np_type_A)
B_packed = serialize_B(B).view(np_type_B)
# print(A_packed)
# print(B_packed)

if hls.is_available("vitis_hls"):
print(A)
print(B)
C_golden = np.dot(A, B)

C_packed = np.zeros((M * N // Rt), dtype=np_type_C)
mod1 = df.build(top, wrapping=False)
mod1(A_packed, B_packed, C_packed)
C = deserialize_C(C_packed.view(np.int8))
print(C)
print(C_golden)
np.testing.assert_allclose(C, C_golden, atol=1e-5)
print("Passed csim Test!")

C_packed = np.zeros((M * N // Rt), dtype=np_type_C)
mod2 = df.build(
top,
target="vitis_hls",
mode="hw_emu",
project=f"df-packed-{Rt}x{Ct}.prj",
wrapping=False,
)
mod2(A_packed, B_packed, C_packed)
C = deserialize_C(C_packed.view(np.int8))
print(C)
print(C_golden)
np.testing.assert_allclose(C, C_golden, atol=1e-5)
print("Passed hw_emu Test!")


if __name__ == "__main__":
test_large_scale_gemm()
2 changes: 1 addition & 1 deletion tests/dataflow/test_tiled_systolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def gemm(A: int32[M, K], B: int32[K, N], C: int32[M, N]):
i, j = df.get_pid()
for m in range(M // Mt):
for n in range(N // Nt):
# periperals kernels
# peripheral kernels
with allo.meta_if(i in {0, Mt + 1} and j in {0, Nt + 1}):
pass
with allo.meta_elif(j == 0):
Expand Down
Loading
Loading