Skip to content

Commit

Permalink
fix[next][dace]: Fixes to DaCe backend to support latest ITIR (#1499)
Browse files Browse the repository at this point in the history
Fixes in DaCe backend to support latest ITIR:

- Add support for tuple argument to lambda functions.
- Flatten list of expressions in if-statememts

Minor code cleanup: skip debug information for memlets.
  • Loading branch information
edopao authored Mar 19, 2024
1 parent b26e6a3 commit 0a9be51
Showing 1 changed file with 43 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def builtin_neighbors(
origin_index_node,
me,
shift_tasklet,
memlet=dace.Memlet(data=origin_index_node.data, subset="0", debuginfo=di),
memlet=dace.Memlet(data=origin_index_node.data, subset="0"),
dst_conn="__idx",
)
state.add_edge(
Expand Down Expand Up @@ -469,7 +469,7 @@ def builtin_neighbors(
data_access_tasklet,
mx,
neighbor_value_node,
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index),
src_conn="__data",
)

Expand All @@ -496,6 +496,7 @@ def builtin_neighbors(
{"__idx"},
{"__valid"},
f"__valid = True if __idx != {neighbor_skip_value} else False",
debuginfo=di,
)
state.add_edge(
neighbor_index_node,
Expand Down Expand Up @@ -545,7 +546,7 @@ def builtin_can_deref(
"_out",
result_node,
None,
dace.Memlet(data=result_name, subset="0", debuginfo=di),
dace.Memlet(data=result_name, subset="0"),
)
return [ValueExpr(result_node, dace.dtypes.bool)]

Expand Down Expand Up @@ -598,14 +599,14 @@ def build_if_state(arg, state):
stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True")
)
sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge())
tbr_values = build_if_state(node_args[1], tbr_state)
tbr_values = flatten_list(build_if_state(node_args[1], tbr_state))
#
fbr_state = sdfg.add_state("false_branch")
sdfg.add_edge(
stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False")
)
sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge())
fbr_values = build_if_state(node_args[2], fbr_state)
fbr_values = flatten_list(build_if_state(node_args[2], fbr_state))

assert isinstance(stmt_node, ValueExpr)
assert stmt_node.dtype == dace.dtypes.bool
Expand Down Expand Up @@ -804,7 +805,7 @@ def builtin_tuple_get(
class GatherLambdaSymbolsPass(eve.NodeVisitor):
_sdfg: dace.SDFG
_state: dace.SDFGState
_symbol_map: dict[str, TaskletExpr]
_symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]]
_parent_symbol_map: dict[str, TaskletExpr]

def __init__(
Expand All @@ -827,7 +828,7 @@ def _add_symbol(self, param, arg):
if isinstance(arg, ValueExpr):
# create storage in lambda sdfg
self._sdfg.add_scalar(param, dtype=arg.dtype)
# update table of lambda symbol
# update table of lambda symbols
self._symbol_map[param] = ValueExpr(
self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype
)
Expand All @@ -839,7 +840,7 @@ def _add_symbol(self, param, arg):
index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()}
for _, index_name in index_names.items():
self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE)
# update table of lambda symbol
# update table of lambda symbols
field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo)
indices = {
dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo)
Expand All @@ -850,6 +851,17 @@ def _add_symbol(self, param, arg):
assert isinstance(arg, SymbolExpr)
self._symbol_map[param] = arg

def _add_tuple(self, param, args):
nodes = []
# create storage in lambda sdfg for each tuple element
for arg in args:
var = unique_var_name()
self._sdfg.add_scalar(var, dtype=arg.dtype)
arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo)
nodes.append(ValueExpr(arg_node, arg.dtype))
# update table of lambda symbols
self._symbol_map[param] = tuple(nodes)

def visit_SymRef(self, node: itir.SymRef):
name = str(node.id)
if name in self._parent_symbol_map and name not in self._symbol_map:
Expand All @@ -858,9 +870,13 @@ def visit_SymRef(self, node: itir.SymRef):

def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None):
if args is not None:
assert len(node.params) == len(args)
for param, arg in zip(node.params, args):
self._add_symbol(str(param.id), arg)
if len(node.params) == len(args):
for param, arg in zip(node.params, args):
self._add_symbol(str(param.id), arg)
else:
# implicitly make tuple
assert len(node.params) == 1
self._add_tuple(str(node.params[0].id), args)
self.visit(node.expr)


Expand Down Expand Up @@ -937,7 +953,7 @@ def visit_Lambda(
# Create the SDFG for the lambda's body
lambda_sdfg = dace.SDFG(func_name)
lambda_sdfg.debuginfo = dace_debuginfo(node, self.context.body.debuginfo)
lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True)
lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True)

lambda_symbols_pass = GatherLambdaSymbolsPass(
lambda_sdfg, lambda_state, self.context.symbol_map
Expand All @@ -947,9 +963,13 @@ def visit_Lambda(
# Add for input nodes for lambda symbols
inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = []
for sym, input_node in lambda_symbols_pass.symbol_refs.items():
arg = next((arg for param, arg in zip(node.params, args) if param.id == sym), None)
if arg:
outer_node = arg
params = [str(p.id) for p in node.params]
try:
param_index = params.index(sym)
except ValueError:
param_index = -1
if param_index >= 0:
outer_node = args[param_index]
else:
# the symbol is not found among lambda arguments, then it is inherited from parent scope
outer_node = self.context.symbol_map[sym]
Expand All @@ -962,6 +982,13 @@ def visit_Lambda(
elif isinstance(input_node, ValueExpr):
assert isinstance(outer_node, ValueExpr)
inputs.append((sym, outer_node))
elif isinstance(input_node, tuple):
assert param_index >= 0
for i, input_node_i in enumerate(input_node):
arg_i = args[param_index + i]
assert isinstance(arg_i, ValueExpr)
assert isinstance(input_node_i, ValueExpr)
inputs.append((input_node_i.value.data, arg_i))

# Add connectivities as arrays
for name in connectivity_names:
Expand Down Expand Up @@ -1530,7 +1557,7 @@ def add_expr_tasklet(
)
self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet)

memlet = dace.Memlet(data=result_access.data, subset="0", debuginfo=di)
memlet = dace.Memlet(data=result_access.data, subset="0")
self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet)

return [ValueExpr(result_access, result_type)]
Expand Down

0 comments on commit 0a9be51

Please sign in to comment.