From d5cfa7d7b1c74056059d0e42822b4cf01a2a2a22 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 23 Jan 2024 18:13:11 +0100 Subject: [PATCH] feat[next][dace]: Add more debug info to DaCe (#1384) * Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/__init__.py | 8 +- src/gt4py/eve/traits.py | 8 + src/gt4py/next/ffront/foast_to_itir.py | 8 +- src/gt4py/next/ffront/past_to_itir.py | 20 ++- src/gt4py/next/iterator/ir.py | 3 + .../iterator/transforms/collapse_list_get.py | 2 +- .../iterator/transforms/collapse_tuple.py | 9 +- .../iterator/transforms/constant_folding.py | 4 +- src/gt4py/next/iterator/transforms/cse.py | 14 +- .../next/iterator/transforms/eta_reduction.py | 4 +- .../next/iterator/transforms/fuse_maps.py | 4 +- .../next/iterator/transforms/global_tmps.py | 10 +- .../iterator/transforms/inline_fundefs.py | 6 +- .../iterator/transforms/inline_into_scan.py | 7 +- .../iterator/transforms/inline_lambdas.py | 6 +- .../next/iterator/transforms/inline_lifts.py | 8 +- .../next/iterator/transforms/merge_let.py | 2 +- .../iterator/transforms/normalize_shifts.py | 4 +- .../iterator/transforms/propagate_deref.py | 4 +- .../transforms/prune_closure_inputs.py | 4 +- .../next/iterator/transforms/remap_symbols.py | 6 +- .../iterator/transforms/scan_eta_reduction.py | 7 +- .../iterator/transforms/symbol_ref_utils.py | 2 +- .../next/iterator/transforms/trace_shifts.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 17 +- .../runners/dace_iterator/__init__.py | 14 ++ .../runners/dace_iterator/itir_to_sdfg.py | 47 ++++-- .../runners/dace_iterator/itir_to_tasklet.py | 152 +++++++++++++----- .../runners/dace_iterator/utility.py | 24 ++- 29 files changed, 288 insertions(+), 120 deletions(-) diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 617a889e28..e726db1f1a 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -58,7 +58,12 @@ field, frozenmodel, ) -from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait +from .traits import ( + PreserveLocationVisitor, + SymbolTableTrait, + ValidatedSymbolTableTrait, + VisitorWithSymbolTableTrait, +) from .trees import ( bfs_walk_items, bfs_walk_values, @@ -113,6 +118,7 @@ "SymbolTableTrait", "ValidatedSymbolTableTrait", "VisitorWithSymbolTableTrait", + "PreserveLocationVisitor", # trees "bfs_walk_items", "bfs_walk_values", diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index df556c9d7f..aacae804d8 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result + + +class PreserveLocationVisitor(visitors.NodeVisitor): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c4d518d279..0c9ab4ab27 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,7 +15,7 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next.ffront import ( dialect_ast_enums, @@ -39,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params + >>> lowered.params # doctest: +ELLIPSIS [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ @@ -142,7 +142,7 @@ def visit_IfStmt( self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) + # from a scalar value (and not a field) assert ( isinstance(node.condition.type, ts.ScalarType) and node.condition.type.kind == ts.ScalarKind.BOOL diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 709912077b..ed239e0436 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -40,7 +40,9 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Lower Program AST (PAST) to Iterator IR (ITIR). @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: stencil=itir.SymRef(id=node.func.id), inputs=[*lowered_args, *lowered_kwargs.values()], output=output, + location=node.location, ) def _visit_slice_bound( @@ -175,17 +178,22 @@ def _visit_slice_bound( lowered_bound = self.visit(slice_bound, **kwargs) else: raise AssertionError("Expected 'None' or 'past.Constant'.") + if slice_bound: + lowered_bound.location = slice_bound.location return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): - return self._construct_itir_out_arg(node.value) + itir_node = self._construct_itir_out_arg(node.value) + itir_node.location = node.location + return itir_node elif isinstance(node, past.TupleExpr): return itir.FunCall( fun=itir.SymRef(id="make_tuple"), args=[self._construct_itir_out_arg(el) for el in node.elts], + location=node.location, ) else: raise ValueError( @@ -247,7 +255,11 @@ def _construct_itir_domain_arg( else: raise AssertionError() - return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args) + return itir.FunCall( + fun=itir.SymRef(id=domain_builtin), + args=domain_args, + location=(node_domain or out_field).location, + ) def _construct_itir_initialized_domain_arg( self, diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e6ee20e227..37abbec9e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -17,12 +17,15 @@ import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @noninstantiable class Node(eve.Node): + location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) + def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 08cbd7313e..6acb8a79c4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -16,7 +16,7 @@ from gt4py.next.iterator import ir -class CollapseListGet(eve.NodeTranslator): +class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 30457f2246..42bbf28909 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(eve.NodeTranslator): +class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -88,13 +88,6 @@ def apply( node_types, ).visit(node) - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.collapse_make_tuple_tuple_get diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index fa326760b0..696a87a197 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,12 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im -class ConstantFolding(NodeTranslator): +class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 034a39d68f..f9cf272c45 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,14 +17,20 @@ import operator import typing -from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait +from gt4py.eve import ( + NodeTranslator, + NodeVisitor, + PreserveLocationVisitor, + SymbolTableTrait, + VisitorWithSymbolTableTrait, +) from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @dataclasses.dataclass -class _NodeReplacer(NodeTranslator): +class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) expr_map: dict[int, ir.SymRef] @@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool: @dataclasses.dataclass -class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): +class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor): @dataclasses.dataclass class SubexpressionData: #: A list of node ids with equal hash and a set of collected child subexpression ids @@ -341,7 +347,7 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) -class CommonSubexpressionElimination(NodeTranslator): +class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): """ Perform common subexpression elimination. diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 55b2141499..93702a6c96 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class EtaReduction(NodeTranslator): +class EtaReduction(PreserveLocationVisitor, NodeTranslator): """Eta reduction: simplifies `λ(args...) → f(args...)` to `f`.""" def visit_Lambda(self, node: ir.Lambda) -> ir.Node: diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index e9fbb0f81d..694dcd6a61 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Fuses nested `map_`s. @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: return ir.Lambda( params=params, expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]), + location=fun.location, ) def visit_FunCall(self, node: ir.FunCall, **kwargs): @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): ir.FunCall( fun=inner_op, args=[ir.SymRef(id=param.id) for param in inner_op.params], + location=node.location, ) ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 0033f36cab..c423a3c277 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx -from gt4py.eve import Coerced, NodeTranslator +from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + location=current_closure.location, ) ) @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + location=current_closure.location, ) ) else: @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp + [ir.Sym(id=tmp.id) for tmp in tmps] + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), + location=node.location, ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari function_definitions=node.fencil.function_definitions, params=[p for p in node.fencil.params if p.id not in unused_tmps], closures=closures, + location=node.fencil.location, ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], @@ -456,6 +460,7 @@ def update_domains( stencil=closure.stencil, output=closure.output, inputs=closure.inputs, + location=closure.location, ) else: domain = closure.domain @@ -521,6 +526,7 @@ def update_domains( function_definitions=node.fencil.function_definitions, params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), + location=node.fencil.location, ), params=node.params, tmps=node.tmps, @@ -580,7 +586,7 @@ def convert_type(dtype): # TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be # tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore # and hence also not extract as a temporary. -class CreateGlobalTmps(NodeTranslator): +class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): """Main entry point for introducing global temporaries. Transforms an existing iterator IR fencil into a fencil with global temporaries. diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 6bf2b60592..a53232745f 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Set -from gt4py.eve import NOTHING, NodeTranslator +from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class InlineFundefs(NodeTranslator): +class InlineFundefs(PreserveLocationVisitor, NodeTranslator): def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): return ir.Lambda( @@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition): return self.generic_visit(node, symtable=node.annex.symtable) -class PruneUnreferencedFundefs(NodeTranslator): +class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator): def visit_FunctionDefinition( self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index fe1eae6e07..a1c9a2eb5b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Inline non-SymRef arguments into the scan. @@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) - return result + return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a56ad5cb10..0b89fe6d98 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -15,7 +15,7 @@ import dataclasses from typing import Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols @@ -104,6 +104,7 @@ def new_name(name): new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) if all(eligible_params): + new_expr.location = node.location return new_expr else: return ir.FunCall( @@ -116,11 +117,12 @@ def new_name(name): expr=new_expr, ), args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], + location=node.location, ) @dataclasses.dataclass -class InlineLambdas(NodeTranslator): +class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" PRESERVED_ANNEX_ATTRS = ("type",) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index d7d8e5e612..d6146d9fc8 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -103,14 +103,18 @@ def _transform_and_extract_lift_args( extracted_args[new_symbol] = arg new_args.append(ir.SymRef(id=new_symbol.id)) - return (im.lift(inner_stencil)(*new_args), extracted_args) + itir_node = im.lift(inner_stencil)(*new_args) + itir_node.location = node.location + return (itir_node, extracted_args) # TODO(tehrengruber): This pass has many different options that should be written as dedicated # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 7426617ac8..bcfc6b2a17 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -17,7 +17,7 @@ from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(eve.NodeTranslator): +class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Merge let-like statements. diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index efc9064612..c70dc1ccd1 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class NormalizeShifts(NodeTranslator): +class NormalizeShifts(PreserveLocationVisitor, NodeTranslator): def visit_FunCall(self, node: ir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 54bdafcda8..783e54ede0 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir @@ -22,7 +22,7 @@ # `(λ(...) → plus(multiplies(...), ...))(...)`. -class PropagateDeref(NodeTranslator): +class PropagateDeref(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node): """ diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 7fd3c50c6e..1e637a0bfb 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class PruneClosureInputs(NodeTranslator): +class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): """Removes all unused input arguments from a stencil closure.""" def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index cdf3d76173..431dd6cd7a 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Optional, Set -from gt4py.eve import NodeTranslator, SymbolTableTrait +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir -class RemapSymbolRefs(NodeTranslator): +class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): @@ -39,7 +39,7 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] return super().generic_visit(node, **kwargs) -class RenameSymbols(NodeTranslator): +class RenameSymbols(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_Sym( diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 3266c25c4b..d93b4242ab 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -24,7 +24,7 @@ def _is_scan(node: ir.Node): ) -class ScanEtaReduction(NodeTranslator): +class ScanEtaReduction(PreserveLocationVisitor, NodeTranslator): """Applies eta-reduction-like transformation involving scans. Simplifies `λ(x, y) → scan(λ(state, param_y, param_x) → ..., ...)(y, x)` to `scan(λ(state, param_x, param_y) → ..., ...)`. @@ -55,9 +55,8 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: original_scanpass.params[i + 1] for i in new_scanpass_params_idx ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) - result = ir.FunCall( + return ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]] ) - return result return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1c587fb9d6..05d137e8c4 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -21,7 +21,7 @@ @dataclasses.dataclass -class CountSymbolRefs(eve.NodeVisitor): +class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 5c607e7df1..082987ac96 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any, Final, Iterable, Literal -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -235,7 +235,7 @@ def _tuple_get(index, tuple_val): @dataclasses.dataclass(frozen=True) -class TraceShifts(NodeTranslator): +class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 861052bb25..3c878b2b00 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -16,7 +16,7 @@ from collections.abc import Iterable, Iterator from typing import TypeGuard -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -100,31 +100,36 @@ def _get_connectivity( def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), + args=[iterator], + location=iterator.location, ) def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator]) + return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="can_deref"), args=[iterator]) + return itir.FunCall( + fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location + ) def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], + location=cond.location, ) def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr]) + return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) @dataclasses.dataclass(frozen=True) -class UnrollReduce(NodeTranslator): +class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fdd8a61054..54ca08fe6e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings +from inspect import currentframe, getframeinfo from typing import Any, Mapping, Optional, Sequence import dace @@ -265,6 +266,19 @@ def build_sdfg_from_itir( if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") + for nested_sdfg in sdfg.all_sdfgs_recursive(): + if not nested_sdfg.debuginfo: + _, frameinfo = warnings.warn( + f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), getframeinfo( + currentframe() # type: ignore + ) + nested_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=frameinfo.lineno, + end_line=frameinfo.lineno, + filename=frameinfo.filename, + ) + # run DaCe transformations to simplify the SDFG sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index fb2f82fed0..dc194c0436 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, get_sorted_dims, @@ -143,6 +144,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) + program_sdfg.debuginfo = dace_debuginfo(node) last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) @@ -187,15 +189,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, + debuginfo=closure_sdfg.debuginfo, ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. @@ -213,6 +216,7 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") + closure_sdfg.debuginfo = dace_debuginfo(node) closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) @@ -239,8 +243,8 @@ def visit_StencilClosure( transient=True, ) closure_init_state.add_nedge( - closure_init_state.add_access(name), - closure_init_state.add_access(transient_name), + closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), + closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), create_memlet_full(name, closure_sdfg.arrays[name]), ) input_transients_mapping[name] = transient_name @@ -276,9 +280,15 @@ def visit_StencilClosure( out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", {}, {"__result"}, f"__result = {name}" + f"get_{name}", + {}, + {"__result"}, + f"__result = {name}", + debuginfo=closure_sdfg.debuginfo, + ) + access = closure_init_state.add_access( + out_name, debuginfo=closure_sdfg.debuginfo ) - access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) @@ -356,19 +366,20 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, + debuginfo=nsdfg.debuginfo, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data if memlet.data not in output_connectors_mapping: continue - transient_access = closure_state.add_access(memlet.data) + transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) closure_state.add_edge( nsdfg_node, edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset), + dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo), ) inner_memlet = dace.Memlet.simple( memlet.data, output_subset, other_subset_str=memlet.subset @@ -417,6 +428,7 @@ def _visit_scan_stencil_closure( # the scan operator is implemented as an SDFG to be nested in the closure SDFG scan_sdfg = dace.SDFG(name="scan") + scan_sdfg.debuginfo = dace_debuginfo(node) # create a state machine for lambda call over the scan dimension start_state = scan_sdfg.add_state("start", True) @@ -429,12 +441,16 @@ def _visit_scan_stencil_closure( # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + "get_carry_init_value", + {}, + {"__result"}, + f"__result = {init_carry_value}", + debuginfo=scan_sdfg.debuginfo, ) start_state.add_edge( carry_init_tasklet, "__result", - start_state.add_access(scan_carry_name), + start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), None, dace.Memlet.simple(scan_carry_name, "0"), ) @@ -512,11 +528,12 @@ def _visit_scan_stencil_closure( inputs=set(lambda_input_names) | set(connectivity_names), outputs=set(lambda_output_names), symbol_mapping=symbol_mapping, + debuginfo=lambda_context.body.debuginfo, ) # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name) + access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo) lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] @@ -526,7 +543,7 @@ def _visit_scan_stencil_closure( lambda_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name), + lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, dace.Memlet.simple(name, f"i_{scan_dim}"), ) @@ -534,8 +551,10 @@ def _visit_scan_stencil_closure( # add state to scan SDFG to update the carry value at each loop iteration lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name), - lambda_update_state.add_access(scan_carry_name), + lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + lambda_update_state.add_access( + scan_carry_name, debuginfo=lambda_context.body.debuginfo + ), memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 4c202b1fe8..0ace6948b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -35,6 +35,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, @@ -183,6 +184,7 @@ def __init__( def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) offset_literal, data = node_args assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value @@ -214,13 +216,14 @@ def builtin_neighbors( sdfg.add_array( result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True ) - result_access = state.add_access(result_name) + result_access = state.add_access(result_name, debuginfo=di) # generate unique map index name to avoid conflict with other maps inside same state neighbor_index = unique_name("neighbor_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, + debuginfo=di, ) table_name = connectivity_identifier(offset_dim) table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) @@ -230,17 +233,19 @@ def builtin_neighbors( code="__result = __table[__idx]", inputs={"__table", "__idx"}, outputs={"__result"}, + debuginfo=di, ) data_access_tasklet = state.add_tasklet( "data_access", code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", inputs={"__field", field_index}, outputs={"__result"}, + debuginfo=di, ) idx_name = unique_var_name() sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( - state.add_access(table_name), + state.add_access(table_name, debuginfo=di), me, shift_tasklet, memlet=create_memlet_at(table_name, table_subset), @@ -250,7 +255,7 @@ def builtin_neighbors( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di), dst_conn="__idx", ) state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) @@ -270,7 +275,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, neighbor_index), + memlet=dace.Memlet.simple(result_name, neighbor_index, debuginfo=di), src_conn="__result", ) @@ -280,6 +285,7 @@ def builtin_neighbors( def builtin_can_deref( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) # first visit shift, to get set of indices for deref can_deref_callable = node_args[0] assert isinstance(can_deref_callable, itir.FunCall) @@ -296,13 +302,15 @@ def builtin_can_deref( # Returning a SymbolExpr would be preferable, but it requires update to type-checking. result_name = unique_var_name() transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name) + result_node = transformer.context.state.add_access(result_name, debuginfo=di) transformer.context.state.add_edge( - transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + transformer.context.state.add_tasklet( + "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di + ), "_out", result_node, None, - dace.Memlet.simple(result_name, "0"), + dace.Memlet.simple(result_name, "0", debuginfo=di), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -313,13 +321,18 @@ def builtin_can_deref( # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + list(zip(args, internals)), + expr_code, + dace.dtypes.bool, + "can_deref", + dace_debuginfo=di, ) def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args) assert len(args) == 3 if_node = args[0][0] if isinstance(args[0], list) else args[0] @@ -346,7 +359,7 @@ def builtin_if( for arg in (if_node, a, b) ] expr = "({1} if {0} else {2})".format(*internals) - if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if", dace_debuginfo=di) if_expr_values.append(if_expr[0]) return if_expr_values @@ -355,6 +368,7 @@ def builtin_if( def builtin_list_get( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node @@ -369,12 +383,15 @@ def builtin_list_get( arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args ] expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] @@ -383,7 +400,13 @@ def builtin_cast( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast") + return transformer.add_expr_tasklet( + list(zip(args, internals)), + expr, + type_, + "cast", + dace_debuginfo=di, + ) def builtin_make_tuple( @@ -443,7 +466,9 @@ def _add_symbol(self, param, arg): # create storage in lambda sdfg self._sdfg.add_scalar(param, dtype=arg.dtype) # update table of lambda symbol - self._symbol_map[param] = ValueExpr(self._state.add_access(param), arg.dtype) + self._symbol_map[param] = ValueExpr( + self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype + ) elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) @@ -453,9 +478,10 @@ def _add_symbol(self, param, arg): for _, index_name in index_names.items(): self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol - field = self._state.add_access(param) + field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) indices = { - dim: self._state.add_access(index_arg) for dim, index_arg in index_names.items() + dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) + for dim, index_arg in index_names.items() } self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) else: @@ -503,7 +529,7 @@ def visit_SymRef(self, node: itir.SymRef): if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: node_type = self._node_types[id(node)] assert isinstance(node_type, Val) - access_node = self._state.add_access(param) + access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) self._symbol_map[param] = ValueExpr( access_node, dtype=itir_type_as_dace_type(node_type.dtype) ) @@ -542,6 +568,7 @@ def visit_Lambda( # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) + lambda_sdfg.debuginfo = dace_debuginfo(node) lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) lambda_symbols_pass = GatherLambdaSymbolsPass( @@ -586,11 +613,14 @@ def visit_Lambda( results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lambda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + node.expr.location = node.location for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access(result_name) + result_access = lambda_state.add_access( + result_name, debuginfo=lambda_sdfg.debuginfo + ) lambda_state.add_nedge( expr.value, result_access, @@ -599,7 +629,9 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + result = lambda_taskgen.add_expr_tasklet( + [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo + )[0] lambda_sdfg.arrays[result.value.data].transient = False results.append(result) @@ -624,6 +656,7 @@ def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))] def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: + node.fun.location = node.location if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): @@ -646,7 +679,7 @@ def _visit_call(self, node: itir.FunCall): args = self.visit(node.args) args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] args = list(itertools.chain(*args)) - + node.fun.location = node.location func_context, func_inputs, results = self.visit(node.fun, args=args) nsdfg_inputs = {} @@ -679,6 +712,7 @@ def _visit_call(self, node: itir.FunCall): inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, + debuginfo=dace_debuginfo(node, func_context.body.debuginfo), ) for name, value in func_inputs: @@ -698,14 +732,14 @@ def _visit_call(self, node: itir.FunCall): for conn, _ in neighbor_tables: var = connectivity_identifier(conn) memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var) + access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) result_exprs = [] for result in results: name = unique_var_name() self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name) + result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) result_exprs.append(ValueExpr(result_access, result.dtype)) memlet = create_memlet_full(name, self.context.body.arrays[name]) self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) @@ -713,6 +747,7 @@ def _visit_call(self, node: itir.FunCall): return result_exprs def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr @@ -727,7 +762,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + return self.add_expr_tasklet( + list(zip(args, internals)), + expr, + iterator.dtype, + "deref", + dace_debuginfo=di, + ) else: # Not all dimensions are included in the deref index list: @@ -741,7 +782,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name) + result_node = self.context.state.add_access(result_name, debuginfo=di) deref_connectors = ["_inp"] + [ f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices @@ -776,6 +817,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: output_nodes={ result_name: result_node, }, + debuginfo=di, ) return [ValueExpr(result_node, iterator.dtype)] @@ -789,10 +831,13 @@ def _split_shift_args( def _make_shift_for_rest(self, rest, iterator): return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), + args=[iterator], + location=iterator.location, ) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -815,7 +860,9 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] - connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + connectivity = self.context.state.add_access( + connectivity_identifier(offset_dim), debuginfo=di + ) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -850,7 +897,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -860,13 +907,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var) + offset_node = self.context.state.add_access(offset_var, debuginfo=di) tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}" + "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di ) self.context.state.add_edge( tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") @@ -874,6 +922,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): + di = dace_debuginfo(node, self.context.body.debuginfo) node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) reduce_dtype = itir_type_as_dace_type(node_type.dtype) @@ -930,7 +979,9 @@ def _visit_reduce(self, node: itir.FunCall): reduce_input_name, nreduce_shape, reduce_dtype, transient=True ) - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_node = itir.Lambda( + expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location + ) lambda_context, inner_inputs, inner_outputs = self.visit( lambda_node, args=args, use_neighbor_tables=False ) @@ -946,7 +997,7 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - reduce_input_node = self.context.state.add_access(reduce_input_name) + reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( self.context.state, @@ -957,6 +1008,7 @@ def _visit_reduce(self, node: itir.FunCall): symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, output_nodes={reduce_input_name: reduce_input_node}, + debuginfo=di, ) reduce_input_desc = reduce_input_node.desc(self.context.body) @@ -964,7 +1016,7 @@ def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) @@ -997,7 +1049,13 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return self.add_expr_tasklet(expr_args, expr, type_, "numeric") + return self.add_expr_tasklet( + expr_args, + expr, + type_, + "numeric", + dace_debuginfo=dace_debuginfo(node, self.context.body.debuginfo), + ) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -1005,17 +1063,24 @@ def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: return expr_func(self, node, node.args) def add_expr_tasklet( - self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str + self, + args: list[tuple[ValueExpr, str]], + expr: str, + result_type: Any, + name: str, + dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, ) -> list[ValueExpr]: + di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) expr_tasklet = self.context.state.add_tasklet( name=name, inputs={internal for _, internal in args}, outputs={"__result"}, code=f"__result = {expr}", + debuginfo=di, ) for arg, internal in args: @@ -1033,7 +1098,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0") + memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] @@ -1052,6 +1117,7 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") + body.debuginfo = dace_debuginfo(node) state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} @@ -1059,8 +1125,10 @@ def closure_to_tasklet_sdfg( for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") - access = state.add_access(name) + tasklet = state.add_tasklet( + f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo + ) + access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for name, ty in inputs: @@ -1070,14 +1138,14 @@ def closure_to_tasklet_sdfg( dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name) + field = state.add_access(name, debuginfo=body.debuginfo) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: assert isinstance(ty, ts.ScalarType) dtype = as_dace_type(ty) body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name), dtype) + symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) for arr, name in connectivities: shape, strides = new_array_symbols(name, ndim=2) body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) @@ -1089,10 +1157,12 @@ def closure_to_tasklet_sdfg( if is_scan(node.stencil): stencil = cast(FunCall, node.stencil) assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - fun_node = itir.FunCall(fun=lambda_node, args=args) + lambda_node = itir.Lambda( + expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location + ) + fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) else: - fun_node = itir.FunCall(fun=node.stencil, args=args) + fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) results = translator.visit(fun_node) for r in results: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 55717326a3..971c1bbdf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,15 +12,31 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -from typing import Any, Sequence +from typing import Any, Optional, Sequence import dace from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts +def dace_debuginfo( + node: Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return debuginfo + + def as_dace_type(type_: ts.ScalarType): if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ @@ -119,11 +135,13 @@ def add_mapped_nested_sdfg( if input_nodes is None: input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in inputs.items() } if output_nodes is None: output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in outputs.items() } if not inputs: state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet())