From 0a9be51dc0177158aad8bd6182cddcefd6b10ca4 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Mar 2024 18:28:29 +0100 Subject: [PATCH] fix[next][dace]: Fixes to DaCe backend to support latest ITIR (#1499) Fixes in DaCe backend to support latest ITIR: - Add support for tuple argument to lambda functions. - Flatten list of expressions in if-statememts Minor code cleanup: skip debug information for memlets. --- .../runners/dace_iterator/itir_to_tasklet.py | 59 ++++++++++++++----- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6f9c5733ac..f671f4be3f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -401,7 +401,7 @@ def builtin_neighbors( origin_index_node, me, shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0", debuginfo=di), + memlet=dace.Memlet(data=origin_index_node.data, subset="0"), dst_conn="__idx", ) state.add_edge( @@ -469,7 +469,7 @@ def builtin_neighbors( data_access_tasklet, mx, neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), + memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), src_conn="__data", ) @@ -496,6 +496,7 @@ def builtin_neighbors( {"__idx"}, {"__valid"}, f"__valid = True if __idx != {neighbor_skip_value} else False", + debuginfo=di, ) state.add_edge( neighbor_index_node, @@ -545,7 +546,7 @@ def builtin_can_deref( "_out", result_node, None, - dace.Memlet(data=result_name, subset="0", debuginfo=di), + dace.Memlet(data=result_name, subset="0"), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -598,14 +599,14 @@ def build_if_state(arg, state): stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") ) sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = build_if_state(node_args[1], tbr_state) + tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) # fbr_state = sdfg.add_state("false_branch") sdfg.add_edge( stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") ) sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = build_if_state(node_args[2], fbr_state) + fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) assert isinstance(stmt_node, ValueExpr) assert stmt_node.dtype == dace.dtypes.bool @@ -804,7 +805,7 @@ def builtin_tuple_get( class GatherLambdaSymbolsPass(eve.NodeVisitor): _sdfg: dace.SDFG _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] + _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] _parent_symbol_map: dict[str, TaskletExpr] def __init__( @@ -827,7 +828,7 @@ def _add_symbol(self, param, arg): if isinstance(arg, ValueExpr): # create storage in lambda sdfg self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbol + # update table of lambda symbols self._symbol_map[param] = ValueExpr( self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype ) @@ -839,7 +840,7 @@ def _add_symbol(self, param, arg): index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} for _, index_name in index_names.items(): self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbol + # update table of lambda symbols field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) indices = { dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) @@ -850,6 +851,17 @@ def _add_symbol(self, param, arg): assert isinstance(arg, SymbolExpr) self._symbol_map[param] = arg + def _add_tuple(self, param, args): + nodes = [] + # create storage in lambda sdfg for each tuple element + for arg in args: + var = unique_var_name() + self._sdfg.add_scalar(var, dtype=arg.dtype) + arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) + nodes.append(ValueExpr(arg_node, arg.dtype)) + # update table of lambda symbols + self._symbol_map[param] = tuple(nodes) + def visit_SymRef(self, node: itir.SymRef): name = str(node.id) if name in self._parent_symbol_map and name not in self._symbol_map: @@ -858,9 +870,13 @@ def visit_SymRef(self, node: itir.SymRef): def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): if args is not None: - assert len(node.params) == len(args) - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) + if len(node.params) == len(args): + for param, arg in zip(node.params, args): + self._add_symbol(str(param.id), arg) + else: + # implicitly make tuple + assert len(node.params) == 1 + self._add_tuple(str(node.params[0].id), args) self.visit(node.expr) @@ -937,7 +953,7 @@ def visit_Lambda( # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) lambda_sdfg.debuginfo = dace_debuginfo(node, self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) + lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) lambda_symbols_pass = GatherLambdaSymbolsPass( lambda_sdfg, lambda_state, self.context.symbol_map @@ -947,9 +963,13 @@ def visit_Lambda( # Add for input nodes for lambda symbols inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - arg = next((arg for param, arg in zip(node.params, args) if param.id == sym), None) - if arg: - outer_node = arg + params = [str(p.id) for p in node.params] + try: + param_index = params.index(sym) + except ValueError: + param_index = -1 + if param_index >= 0: + outer_node = args[param_index] else: # the symbol is not found among lambda arguments, then it is inherited from parent scope outer_node = self.context.symbol_map[sym] @@ -962,6 +982,13 @@ def visit_Lambda( elif isinstance(input_node, ValueExpr): assert isinstance(outer_node, ValueExpr) inputs.append((sym, outer_node)) + elif isinstance(input_node, tuple): + assert param_index >= 0 + for i, input_node_i in enumerate(input_node): + arg_i = args[param_index + i] + assert isinstance(arg_i, ValueExpr) + assert isinstance(input_node_i, ValueExpr) + inputs.append((input_node_i.value.data, arg_i)) # Add connectivities as arrays for name in connectivity_names: @@ -1530,7 +1557,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet(data=result_access.data, subset="0", debuginfo=di) + memlet = dace.Memlet(data=result_access.data, subset="0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)]