Skip to content

Commit

Permalink
Merge branch 'main' into decouple_inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Feb 17, 2025
2 parents 2a9a25a + 3f85165 commit 9e48581
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 46 deletions.
4 changes: 0 additions & 4 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, la
sdfg.add_state(gtir_pipeline.gtir.name)
return sdfg

for array in sdfg.arrays.values():
if array.transient:
array.lifetime = dace.AllocationLifetime.Persistent

sdfg.simplify(validate=False)

_set_expansion_orders(sdfg)
Expand Down
15 changes: 15 additions & 0 deletions src/gt4py/cartesian/gtc/dace/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause


from typing import Final


# StencilComputation in/out connector prefixes
CONNECTOR_PREFIX_IN: Final = "__in_"
CONNECTOR_PREFIX_OUT: Final = "__out_"
17 changes: 11 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sympy

from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder
from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder

Expand Down Expand Up @@ -77,11 +78,11 @@ def _fix_context(
"""
# change connector names
for in_edge in parent_state.in_edges(node):
assert in_edge.dst_conn.startswith("__in_")
in_edge.dst_conn = in_edge.dst_conn[len("__in_") :]
assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN)
in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)
for out_edge in parent_state.out_edges(node):
assert out_edge.src_conn.startswith("__out_")
out_edge.src_conn = out_edge.src_conn[len("__out_") :]
assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT)
out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)

# union input and output subsets
subsets = {}
Expand Down Expand Up @@ -125,9 +126,13 @@ def _get_parent_arrays(
) -> Dict[str, dace.data.Data]:
parent_arrays: Dict[str, dace.data.Data] = {}
for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None):
parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[
edge.data.data
]
for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None):
parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[
edge.data.data
]
return parent_arrays

@staticmethod
Expand Down
10 changes: 4 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset


class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait):
Expand Down Expand Up @@ -268,13 +267,13 @@ def visit_ComputationState(
for memlet in computation.read_memlets:
if memlet.field not in read_acc_and_conn:
read_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
for memlet in computation.write_memlets:
if memlet.field not in write_acc_and_conn:
write_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
node_ctx = StencilComputationSDFGBuilder.NodeContext(
Expand All @@ -298,7 +297,7 @@ def visit_FieldDecl(
dtype=data_type_to_dace_typeclass(node.dtype),
storage=node.storage.to_dace_storage(),
transient=node.name not in non_transients,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(node),
)

def visit_SymbolDecl(
Expand Down Expand Up @@ -343,7 +342,6 @@ def visit_NestedSDFG(
inputs=node.input_connectors,
outputs=node.output_connectors,
symbol_mapping=symbol_mapping,
debuginfo=dace.DebugInfo(0),
)
self.visit(
node.read_memlets,
Expand Down
14 changes: 0 additions & 14 deletions src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

from typing import TYPE_CHECKING, List

import dace
import dace.data
import dace.library
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
Expand All @@ -25,15 +20,6 @@
from gt4py.cartesian.gtc.dace.nodes import StencilComputation


def get_dace_debuginfo(node: common.LocNode):
if node.loc is not None:
return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)
else:
return dace.dtypes.DebugInfo(0)


class HorizontalIntervalRemover(eve.NodeTranslator):
def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis):
mask_attrs = dict(i=node.i, j=node.j)
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter
from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection

from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo
from .expansion_specification import ExpansionItem, make_expansion_order


def _set_expansion_order(
node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]]
Expand Down
32 changes: 19 additions & 13 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
get_dace_debuginfo,
make_dace_subset,
)
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import (
AccessCollector,
Expand Down Expand Up @@ -93,9 +98,7 @@ def _make_dace_subset(self, local_access_info, field):
global_access_info, local_access_info, self.decls[field].data_dims
)

def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs
):
def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext):
declarations = {
acc.name: ctx.decls[acc.name]
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
Expand All @@ -117,22 +120,24 @@ def visit_VerticalLoop(
access_collection = AccessCollector.apply(node)

for field in access_collection.read_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_in_connector("__in_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = f"{CONNECTOR_PREFIX_IN}{field}"
library_node.add_in_connector(connector_name)
subset = ctx.make_input_dace_subset(node, field)
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_out_connector("__out_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = f"{CONNECTOR_PREFIX_OUT}{field}"
library_node.add_out_connector(connector_name)
subset = ctx.make_output_dace_subset(node, field)
state.add_edge(
library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset)
library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset)
)

def visit_Stencil(self, node: oir.Stencil, **kwargs):
def visit_Stencil(self, node: oir.Stencil):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
for param in node.params:
if isinstance(param, oir.FieldDecl):
Expand All @@ -148,7 +153,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(param.dtype),
transient=False,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(param),
)
else:
ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
Expand All @@ -166,7 +171,8 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(decl.dtype),
transient=True,
debuginfo=dace.DebugInfo(0),
lifetime=dace.AllocationLifetime.Persistent,
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
ctx.sdfg.validate()
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo:
if node.loc is None:
return dace.dtypes.DebugInfo(0)

return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)


def array_dimensions(array: dace.data.Array):
dims = [
any(
Expand Down

0 comments on commit 9e48581

Please sign in to comment.