Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Add more debug info to DaCe #1384

Merged
merged 38 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
caebdff
Add more debug info to DaCe
kotsaloscv Nov 7, 2023
f14d091
Merge branch 'main' into more_debinfo
kotsaloscv Nov 27, 2023
bebb122
Add more debug info to DaCe
kotsaloscv Nov 27, 2023
7eeaddb
Add more debug info to DaCe : WIP
kotsaloscv Nov 27, 2023
7447807
Add more debug info to DaCe : WIP
kotsaloscv Nov 29, 2023
6d89149
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 29, 2023
16bc489
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 30, 2023
0ed8090
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
3e38756
merge main
kotsaloscv Dec 4, 2023
0b1fe1a
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
55abb29
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
54dd79d
merge main
kotsaloscv Dec 5, 2023
774b2f5
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
1622866
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
7baedb8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
223de4e
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
6c636ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
93fcb14
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
d6da4ab
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
8460c67
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
5632def
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
b59fd83
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
68dde06
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
a1a91c4
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
1ed9764
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
371dc36
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
bb880dd
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
e0a254f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
ea2f672
merge main
kotsaloscv Jan 4, 2024
50f96a8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 4, 2024
9c9c8ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
6fb28a1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
3f4e9d1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
b29bc5f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
1a2e978
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
bf33827
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
616a833
merge main
kotsaloscv Jan 23, 2024
371b1da
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/gt4py/eve/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
kwargs["symtable"] = kwargs["symtable"].parents

return result


class PreserveLocationWithSymbolTableTrait(VisitorWithSymbolTableTrait):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
result.location = node.location
return result
8 changes: 8 additions & 0 deletions src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,11 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
)

return copy.deepcopy(node, memo=memo)


class PreserveLocation(NodeVisitor):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
result.location = node.location
return result
24 changes: 18 additions & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from typing import Any, Callable, Optional

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, concepts, extended_typing
from gt4py.eve.utils import UIDGenerator
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -61,15 +61,27 @@ class FieldOperatorLowering(NodeTranslator):
<class 'gt4py.next.iterator.ir.FunctionDefinition'>
>>> lowered.id
SymbolName('fieldop')
>>> lowered.params
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
>>> lowered.params # doctest: +ELLIPSIS
[Sym(location=..., id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
"""

uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator)
preserve_location: bool = True

@classmethod
def apply(cls, node: foast.LocatedNode) -> itir.Expr:
return cls().visit(node)
def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr:
return cls(preserve_location=preserve_location).visit(node)

def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any:
result = super().visit(node, **kwargs)
if (
hasattr(node, "location")
and hasattr(result, "location")
and not isinstance(node, foast.Name)
and self.preserve_location
):
result.location = node.location
return result

def visit_FunctionDefinition(
self, node: foast.FunctionDefinition, **kwargs
Expand Down Expand Up @@ -142,7 +154,7 @@ def visit_IfStmt(
self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
# the lowered if call doesn't need to be lifted as the condition can only originate
# from a scalar value (and not a field)
# from a scalar value (and not a field)
assert (
isinstance(node.condition.type, ts.ScalarType)
and node.condition.type.kind == ts.ScalarKind.BOOL
Expand Down
23 changes: 18 additions & 5 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def _flatten_tuple_expr(
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")


class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class ProgramLowering(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Lower Program AST (PAST) to Iterator IR (ITIR).

Expand Down Expand Up @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure:
stencil=itir.SymRef(id=node.func.id),
inputs=[*lowered_args, *lowered_kwargs.values()],
output=output,
location=node.location,
)

def _visit_slice_bound(
Expand All @@ -175,17 +178,22 @@ def _visit_slice_bound(
lowered_bound = self.visit(slice_bound, **kwargs)
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
lowered_bound.location = slice_bound.location
return lowered_bound

def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr:
if isinstance(node, past.Name):
return itir.SymRef(id=node.id)
return itir.SymRef(id=node.id, location=node.location)
elif isinstance(node, past.Subscript):
return self._construct_itir_out_arg(node.value)
itir_node = self._construct_itir_out_arg(node.value)
itir_node.location = node.location
return itir_node
elif isinstance(node, past.TupleExpr):
return itir.FunCall(
fun=itir.SymRef(id="make_tuple"),
args=[self._construct_itir_out_arg(el) for el in node.elts],
location=node.location,
)
else:
raise ValueError(
Expand Down Expand Up @@ -247,7 +255,9 @@ def _construct_itir_domain_arg(
else:
raise AssertionError()

return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args)
return itir.FunCall(
fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location
)

def _construct_itir_initialized_domain_arg(
self,
Expand All @@ -263,7 +273,10 @@ def _construct_itir_initialized_domain_arg(
f"expected '{dim}', got '{keys_dims_types}'."
)

return [self.visit(bound) for bound in node_domain.values_[dim_i].elts]
itir_node = [self.visit(bound) for bound in node_domain.values_[dim_i].elts]
for i, bound in enumerate(node_domain.values_[dim_i].elts):
itir_node[i].location = bound.location
return itir_node

@staticmethod
def _compute_field_slice(node: past.Subscript):
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
from gt4py.eve.concepts import SourceLocation
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.eve.utils import noninstantiable


@noninstantiable
class Node(eve.Node):
location: Optional[SourceLocation] = None

def __str__(self) -> str:
from gt4py.next.iterator.pretty_printer import pformat

Expand Down
34 changes: 17 additions & 17 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym:
Examples
--------
>>> sym("a")
Sym(id=SymbolName('a'), kind=None, dtype=None)
Sym(location=None, id=SymbolName('a'), kind=None, dtype=None)

>>> sym(itir.Sym(id="b"))
Sym(id=SymbolName('b'), kind=None, dtype=None)
Sym(location=None, id=SymbolName('b'), kind=None, dtype=None)
"""
if isinstance(sym_or_name, itir.Sym):
return sym_or_name
Expand All @@ -43,10 +43,10 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef:
Examples
--------
>>> ref("a")
SymRef(id=SymbolRef('a'))
SymRef(location=None, id=SymbolRef('a'))

>>> ref(itir.SymRef(id="b"))
SymRef(id=SymbolRef('b'))
SymRef(location=None, id=SymbolRef('b'))
"""
if isinstance(ref_or_name, itir.SymRef):
return ref_or_name
Expand All @@ -60,13 +60,13 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti
Examples
--------
>>> ensure_expr("a")
SymRef(id=SymbolRef('a'))
SymRef(location=None, id=SymbolRef('a'))

>>> ensure_expr(3)
Literal(value='3', type='int32')
Literal(location=None, value='3', type='int32')

>>> ensure_expr(itir.OffsetLiteral(value="i"))
OffsetLiteral(value='i')
OffsetLiteral(location=None, value='i')
"""
if isinstance(literal_or_expr, str):
return ref(literal_or_expr)
Expand All @@ -83,10 +83,10 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of
Examples
--------
>>> ensure_offset("V2E")
OffsetLiteral(value='V2E')
OffsetLiteral(location=None, value='V2E')

>>> ensure_offset(itir.OffsetLiteral(value="J"))
OffsetLiteral(value='J')
OffsetLiteral(location=None, value='J')
"""
if isinstance(str_or_offset, (str, int)):
return itir.OffsetLiteral(value=str_or_offset)
Expand All @@ -100,7 +100,7 @@ class lambda_:
Examples
--------
>>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS
Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))]))
Lambda(location=None, params=[Sym(location=None, id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('deref')), args=[SymRef(location=None, id=SymbolRef('a'))]))
"""

def __init__(self, *args):
Expand All @@ -117,7 +117,7 @@ class call:
Examples
--------
>>> call("plus")(1, 1)
FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')])
FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('plus')), args=[Literal(location=None, value='1', type='int32'), Literal(location=None, value='1', type='int32')])
"""

def __init__(self, expr):
Expand Down Expand Up @@ -264,10 +264,10 @@ def shift(offset, value=None):
Examples
--------
>>> shift("i", 0)("a")
FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='i'), OffsetLiteral(value=0)]), args=[SymRef(id=SymbolRef('a'))])
FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='i'), OffsetLiteral(location=None, value=0)]), args=[SymRef(location=None, id=SymbolRef('a'))])

>>> shift("V2E")("b")
FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='V2E')]), args=[SymRef(id=SymbolRef('b'))])
FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='V2E')]), args=[SymRef(location=None, id=SymbolRef('b'))])
"""
offset = ensure_offset(offset)
args = [offset]
Expand All @@ -286,13 +286,13 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal:
Make a literal node from a value.

>>> literal_from_value(1.)
Literal(value='1.0', type='float64')
Literal(location=None, value='1.0', type='float64')
>>> literal_from_value(1)
Literal(value='1', type='int32')
Literal(location=None, value='1', type='int32')
>>> literal_from_value(2147483648)
Literal(value='2147483648', type='int64')
Literal(location=None, value='2147483648', type='int64')
>>> literal_from_value(True)
Literal(value='True', type='bool')
Literal(location=None, value='True', type='bool')
"""
if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673
raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.")
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py import eve
from gt4py.eve.visitors import PreserveLocation
from gt4py.next.iterator import ir


class CollapseListGet(eve.NodeTranslator):
class CollapseListGet(PreserveLocation, eve.NodeTranslator):
"""Simplifies expressions containing `list_get`.

Examples
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

from gt4py import eve
from gt4py.eve.visitors import PreserveLocation
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference

Expand Down Expand Up @@ -48,7 +49,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t


@dataclass(frozen=True)
class CollapseTuple(eve.NodeTranslator):
class CollapseTuple(PreserveLocation, eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.

Expand Down Expand Up @@ -88,13 +89,6 @@ def apply(
node_types,
).visit(node)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
if (
self.collapse_make_tuple_tuple_get
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im


class ConstantFolding(NodeTranslator):
class ConstantFolding(PreserveLocation, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Expand Down
17 changes: 13 additions & 4 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@
import operator
import typing

from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
SymbolTableTrait,
VisitorWithSymbolTableTrait,
traits,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.eve.visitors import PreserveLocation
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda


@dataclasses.dataclass
class _NodeReplacer(NodeTranslator):
class _NodeReplacer(PreserveLocation, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
Expand Down Expand Up @@ -72,7 +79,9 @@ def _is_collectable_expr(node: ir.Node) -> bool:


@dataclasses.dataclass
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
class CollectSubexpressions(
traits.PreserveLocationWithSymbolTableTrait, VisitorWithSymbolTableTrait, NodeVisitor
):
@dataclasses.dataclass
class SubexpressionData:
#: A list of node ids with equal hash and a set of collected child subexpression ids
Expand Down Expand Up @@ -341,7 +350,7 @@ def extract_subexpression(


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(NodeTranslator):
class CommonSubexpressionElimination(PreserveLocation, NodeTranslator):
"""
Perform common subexpression elimination.

Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/transforms/eta_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.next.iterator import ir


class EtaReduction(NodeTranslator):
class EtaReduction(PreserveLocation, NodeTranslator):
"""Eta reduction: simplifies `λ(args...) → f(args...)` to `f`."""

def visit_Lambda(self, node: ir.Lambda) -> ir.Node:
Expand Down
Loading
Loading