diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py
index 05000da10d..623106f303 100644
--- a/src/gt4py/next/iterator/ir.py
+++ b/src/gt4py/next/iterator/ir.py
@@ -194,6 +194,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
GTIR_BUILTINS = {
*BUILTINS,
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
+ "cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other
}
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py
new file mode 100644
index 0000000000..54b2e0e29d
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py
@@ -0,0 +1,21 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir
+
+
+__all__ = [
+ "build_sdfg_from_gtir",
+]
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py
new file mode 100644
index 0000000000..f401f46449
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py
@@ -0,0 +1,360 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+from __future__ import annotations
+
+import abc
+from typing import TYPE_CHECKING, Optional, Protocol, TypeAlias
+
+import dace
+import dace.subsets as sbs
+
+from gt4py.next import common as gtx_common
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
+from gt4py.next.program_processors.runners.dace_fieldview import (
+ gtir_python_codegen,
+ gtir_to_tasklet,
+ utility as dace_fieldview_util,
+)
+from gt4py.next.type_system import type_specifications as ts
+
+
+if TYPE_CHECKING:
+ from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg
+
+
+IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
+TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]
+
+
+class PrimitiveTranslator(Protocol):
+ @abc.abstractmethod
+ def __call__(
+ self,
+ node: gtir.Node,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ sdfg_builder: gtir_to_sdfg.SDFGBuilder,
+ ) -> list[TemporaryData]:
+ """Creates the dataflow subgraph representing a GTIR primitive function.
+
+ This method is used by derived classes to build a specialized subgraph
+ for a specific GTIR primitive function.
+
+ Arguments:
+ node: The GTIR node describing the primitive to be lowered
+ sdfg: The SDFG where the primitive subgraph should be instantiated
+ state: The SDFG state where the result of the primitive function should be made available
+ sdfg_builder: The object responsible for visiting child nodes of the primitive node.
+
+ Returns:
+ A list of data access nodes and the associated GT4Py data type, which provide
+ access to the result of the primitive subgraph. The GT4Py data type is useful
+ in the case the returned data is an array, because the type provdes the domain
+ information (e.g. order of dimensions, dimension types).
+ """
+
+
+def _parse_arg_expr(
+ node: gtir.Expr,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ sdfg_builder: gtir_to_sdfg.SDFGBuilder,
+ domain: list[
+ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
+ ],
+) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr:
+ fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state)
+
+ assert len(fields) == 1
+ data_node, arg_type = fields[0]
+ # require all argument nodes to be data access nodes (no symbols)
+ assert isinstance(data_node, dace.nodes.AccessNode)
+
+ if isinstance(arg_type, ts.ScalarType):
+ return gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0]))
+ else:
+ assert isinstance(arg_type, ts.FieldType)
+ indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = {
+ dim: gtir_to_tasklet.SymbolExpr(
+ dace_fieldview_util.get_map_variable(dim),
+ IteratorIndexDType,
+ )
+ for dim, _, _ in domain
+ }
+ return gtir_to_tasklet.IteratorExpr(
+ data_node,
+ arg_type.dims,
+ indices,
+ )
+
+
+def _create_temporary_field(
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ domain: list[
+ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
+ ],
+ node_type: ts.ScalarType,
+ output_desc: dace.data.Data,
+ output_field_type: ts.DataType,
+) -> tuple[dace.nodes.AccessNode, ts.FieldType]:
+ domain_dims, domain_lbs, domain_ubs = zip(*domain)
+ field_dims = list(domain_dims)
+ field_shape = [
+ # diff between upper and lower bound
+ (ub - lb)
+ for lb, ub in zip(domain_lbs, domain_ubs)
+ ]
+ field_offset: Optional[list[dace.symbolic.SymbolicType]] = None
+ if any(domain_lbs):
+ field_offset = [-lb for lb in domain_lbs]
+
+ if isinstance(output_desc, dace.data.Array):
+ # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`)
+ assert isinstance(output_field_type, ts.FieldType)
+ # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype`
+ node_type = output_field_type.dtype
+ field_dims.extend(output_field_type.dims)
+ field_shape.extend(output_desc.shape)
+ else:
+ assert isinstance(output_desc, dace.data.Scalar)
+ assert isinstance(output_field_type, ts.ScalarType)
+ # TODO: enable `assert output_field_type == node_type`, remove variable `dtype`
+ node_type = output_field_type
+
+ # allocate local temporary storage for the result field
+ temp_name, _ = sdfg.add_temp_transient(
+ field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset
+ )
+ field_node = state.add_access(temp_name)
+ field_type = ts.FieldType(field_dims, node_type)
+
+ return field_node, field_type
+
+
+def translate_as_field_op(
+ node: gtir.Node,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ sdfg_builder: gtir_to_sdfg.SDFGBuilder,
+) -> list[TemporaryData]:
+ """Generates the dataflow subgraph for the `as_field_op` builtin function."""
+ assert isinstance(node, gtir.FunCall)
+ assert cpm.is_call_to(node.fun, "as_fieldop")
+
+ fun_node = node.fun
+ assert len(fun_node.args) == 2
+ stencil_expr, domain_expr = fun_node.args
+ # expect stencil (represented as a lambda function) as first argument
+ assert isinstance(stencil_expr, gtir.Lambda)
+ # the domain of the field operator is passed as second argument
+ assert isinstance(domain_expr, gtir.FunCall)
+
+ # add local storage to compute the field operator over the given domain
+ # TODO: use type inference to determine the result type
+ node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
+ domain = dace_fieldview_util.get_domain(domain_expr)
+
+ # first visit the list of arguments and build a symbol map
+ stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]
+
+ # represent the field operator as a mapped tasklet graph, which will range over the field domain
+ taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder)
+ input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
+ assert isinstance(output_expr, gtir_to_tasklet.ValueExpr)
+ output_desc = output_expr.node.desc(sdfg)
+
+ # retrieve the tasklet node which writes the result
+ last_node = state.in_edges(output_expr.node)[0].src
+ if isinstance(last_node, dace.nodes.Tasklet):
+ # the last transient node can be deleted
+ last_node_connector = state.in_edges(output_expr.node)[0].src_conn
+ state.remove_node(output_expr.node)
+ else:
+ last_node = output_expr.node
+ last_node_connector = None
+
+ # allocate local temporary storage for the result field
+ field_node, field_type = _create_temporary_field(
+ sdfg, state, domain, node_type, output_desc, output_expr.field_type
+ )
+
+ # assume tasklet with single output
+ output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain]
+ if isinstance(output_desc, dace.data.Array):
+ # additional local dimension for neighbors
+ assert set(output_desc.offset) == {0}
+ output_subset.extend(f"0:{size}" for size in output_desc.shape)
+
+ # create map range corresponding to the field operator domain
+ map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain}
+ me, mx = sdfg_builder.add_map("field_op", state, map_ranges)
+
+ if len(input_connections) == 0:
+ # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets
+ state.add_nedge(me, last_node, dace.Memlet())
+ else:
+ for data_node, data_subset, lambda_node, lambda_connector in input_connections:
+ memlet = dace.Memlet(data=data_node.data, subset=data_subset)
+ state.add_memlet_path(
+ data_node,
+ me,
+ lambda_node,
+ dst_conn=lambda_connector,
+ memlet=memlet,
+ )
+ state.add_memlet_path(
+ last_node,
+ mx,
+ field_node,
+ src_conn=last_node_connector,
+ memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)),
+ )
+
+ return [(field_node, field_type)]
+
+
+def translate_cond(
+ node: gtir.Node,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ sdfg_builder: gtir_to_sdfg.SDFGBuilder,
+) -> list[TemporaryData]:
+ """Generates the dataflow subgraph for the `cond` builtin function."""
+ assert isinstance(node, gtir.FunCall)
+ assert cpm.is_call_to(node.fun, "cond")
+ assert len(node.args) == 0
+
+ fun_node = node.fun
+ assert len(fun_node.args) == 3
+ cond_expr, true_expr, false_expr = fun_node.args
+
+ # expect condition as first argument
+ cond = gtir_python_codegen.get_source(cond_expr)
+
+ # use current head state to terminate the dataflow, and add a entry state
+ # to connect the true/false branch states as follows:
+ #
+ # ------------
+ # === | cond | ===
+ # || ------------ ||
+ # \/ \/
+ # ------------ -------------
+ # | true | | false |
+ # ------------ -------------
+ # || ||
+ # || ------------ ||
+ # ==> | head | <==
+ # ------------
+ #
+ cond_state = sdfg.add_state_before(state, state.label + "_cond")
+ sdfg.remove_edge(sdfg.out_edges(cond_state)[0])
+
+ # expect true branch as second argument
+ true_state = sdfg.add_state(state.label + "_true_branch")
+ sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})"))
+ sdfg.add_edge(true_state, state, dace.InterstateEdge())
+
+ # and false branch as third argument
+ false_state = sdfg.add_state(state.label + "_false_branch")
+ sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")))
+ sdfg.add_edge(false_state, state, dace.InterstateEdge())
+
+ true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state)
+ false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state)
+
+ output_nodes = []
+ for true_br, false_br in zip(true_br_args, false_br_args, strict=True):
+ true_br_node, true_br_type = true_br
+ assert isinstance(true_br_node, dace.nodes.AccessNode)
+ false_br_node, _ = false_br
+ assert isinstance(false_br_node, dace.nodes.AccessNode)
+ desc = true_br_node.desc(sdfg)
+ assert false_br_node.desc(sdfg) == desc
+ data_name, _ = sdfg.add_temp_transient_like(desc)
+ output_nodes.append((state.add_access(data_name), true_br_type))
+
+ true_br_output_node = true_state.add_access(data_name)
+ true_state.add_nedge(
+ true_br_node,
+ true_br_output_node,
+ dace.Memlet.from_array(data_name, desc),
+ )
+
+ false_br_output_node = false_state.add_access(data_name)
+ false_state.add_nedge(
+ false_br_node,
+ false_br_output_node,
+ dace.Memlet.from_array(data_name, desc),
+ )
+
+ return output_nodes
+
+
+def translate_symbol_ref(
+ node: gtir.Node,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ sdfg_builder: gtir_to_sdfg.SDFGBuilder,
+) -> list[TemporaryData]:
+ """Generates the dataflow subgraph for a `ir.SymRef` node."""
+ assert isinstance(node, (gtir.Literal, gtir.SymRef))
+
+ data_type: ts.FieldType | ts.ScalarType
+ if isinstance(node, gtir.Literal):
+ sym_value = node.value
+ data_type = node.type
+ tasklet_name = "get_literal"
+ else:
+ sym_value = str(node.id)
+ data_type = sdfg_builder.get_symbol_type(sym_value)
+ tasklet_name = f"get_{sym_value}"
+
+ if isinstance(data_type, ts.FieldType):
+ # add access node to current state
+ sym_node = state.add_access(sym_value)
+
+ else:
+ # scalar symbols are passed to the SDFG as symbols: build tasklet node
+ # to write the symbol to a scalar access node
+ tasklet_node = sdfg_builder.add_tasklet(
+ tasklet_name,
+ state,
+ {},
+ {"__out"},
+ f"__out = {sym_value}",
+ )
+ temp_name, _ = sdfg.add_temp_transient((1,), dace_fieldview_util.as_dace_type(data_type))
+ sym_node = state.add_access(temp_name)
+ state.add_edge(
+ tasklet_node,
+ "__out",
+ sym_node,
+ None,
+ dace.Memlet(data=sym_node.data, subset="0"),
+ )
+
+ return [(sym_node, data_type)]
+
+
+if TYPE_CHECKING:
+ # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
+ __primitive_translators: list[PrimitiveTranslator] = [
+ translate_as_field_op,
+ translate_cond,
+ translate_symbol_ref,
+ ]
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py
new file mode 100644
index 0000000000..fcb71e4e6d
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py
@@ -0,0 +1,128 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+
+from gt4py.eve import codegen
+from gt4py.eve.codegen import FormatTemplate as as_fmt
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
+
+
+MATH_BUILTINS_MAPPING = {
+ "abs": "abs({})",
+ "sin": "math.sin({})",
+ "cos": "math.cos({})",
+ "tan": "math.tan({})",
+ "arcsin": "asin({})",
+ "arccos": "acos({})",
+ "arctan": "atan({})",
+ "sinh": "math.sinh({})",
+ "cosh": "math.cosh({})",
+ "tanh": "math.tanh({})",
+ "arcsinh": "asinh({})",
+ "arccosh": "acosh({})",
+ "arctanh": "atanh({})",
+ "sqrt": "math.sqrt({})",
+ "exp": "math.exp({})",
+ "log": "math.log({})",
+ "gamma": "tgamma({})",
+ "cbrt": "cbrt({})",
+ "isfinite": "isfinite({})",
+ "isinf": "isinf({})",
+ "isnan": "isnan({})",
+ "floor": "math.ifloor({})",
+ "ceil": "ceil({})",
+ "trunc": "trunc({})",
+ "minimum": "min({}, {})",
+ "maximum": "max({}, {})",
+ "fmod": "fmod({}, {})",
+ "power": "math.pow({}, {})",
+ "float": "dace.float64({})",
+ "float32": "dace.float32({})",
+ "float64": "dace.float64({})",
+ "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})",
+ "int32": "dace.int32({})",
+ "int64": "dace.int64({})",
+ "bool": "dace.bool_({})",
+ "plus": "({} + {})",
+ "minus": "({} - {})",
+ "multiplies": "({} * {})",
+ "divides": "({} / {})",
+ "floordiv": "({} // {})",
+ "eq": "({} == {})",
+ "not_eq": "({} != {})",
+ "less": "({} < {})",
+ "less_equal": "({} <= {})",
+ "greater": "({} > {})",
+ "greater_equal": "({} >= {})",
+ "and_": "({} & {})",
+ "or_": "({} | {})",
+ "xor_": "({} ^ {})",
+ "mod": "({} % {})",
+ "not_": "(not {})", # ~ is not bitwise in numpy
+}
+
+
+def format_builtin(bultin: str, *args: Any) -> str:
+ if bultin in MATH_BUILTINS_MAPPING:
+ fmt = MATH_BUILTINS_MAPPING[bultin]
+ else:
+ raise NotImplementedError(f"'{bultin}' not implemented.")
+ return fmt.format(*args)
+
+
+class PythonCodegen(codegen.TemplatedGenerator):
+ """Helper class to visit a symbolic expression and translate it to Python code.
+
+ The generated Python code can be use either as the body of a tasklet node or,
+ as in the case of field domain definitions, for sybolic array shape and map range.
+ """
+
+ SymRef = as_fmt("{id}")
+ Literal = as_fmt("{value}")
+
+ def _visit_deref(self, node: gtir.FunCall) -> str:
+ assert len(node.args) == 1
+ if isinstance(node.args[0], gtir.SymRef):
+ return self.visit(node.args[0])
+ raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")
+
+ def _visit_numeric_builtin(self, node: gtir.FunCall) -> str:
+ assert isinstance(node.fun, gtir.SymRef)
+ fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)]
+ args = self.visit(node.args)
+ return fmt.format(*args)
+
+ def visit_FunCall(self, node: gtir.FunCall) -> str:
+ if cpm.is_call_to(node, "deref"):
+ return self._visit_deref(node)
+ elif isinstance(node.fun, gtir.SymRef):
+ args = self.visit(node.args)
+ builtin_name = str(node.fun.id)
+ return format_builtin(builtin_name, *args)
+ raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).")
+
+
+get_source = PythonCodegen.apply
+"""
+Specialized visit method for symbolic expressions.
+
+Returns:
+ A string containing the Python code corresponding to a symbolic expression
+"""
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py
new file mode 100644
index 0000000000..7468056f0c
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py
@@ -0,0 +1,363 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""
+Contains visitors to lower GTIR to DaCe SDFG.
+
+Note: this module covers the fieldview flavour of GTIR.
+"""
+
+from __future__ import annotations
+
+import abc
+import dataclasses
+from typing import Any, Dict, List, Protocol, Sequence, Set, Tuple, Union
+
+import dace
+
+from gt4py import eve
+from gt4py.eve import concepts
+from gt4py.next import common as gtx_common
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
+from gt4py.next.program_processors.runners.dace_fieldview import (
+ gtir_builtin_translators,
+ utility as dace_fieldview_util,
+)
+from gt4py.next.type_system import type_specifications as ts
+
+
+class DataflowBuilder(Protocol):
+ """Visitor interface to build a dataflow subgraph."""
+
+ @abc.abstractmethod
+ def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension:
+ pass
+
+ @abc.abstractmethod
+ def unique_map_name(self, name: str) -> str:
+ pass
+
+ @abc.abstractmethod
+ def unique_tasklet_name(self, name: str) -> str:
+ pass
+
+ def add_map(
+ self,
+ name: str,
+ state: dace.SDFGState,
+ ndrange: Union[
+ Dict[str, Union[str, dace.subsets.Subset]],
+ List[Tuple[str, Union[str, dace.subsets.Subset]]],
+ ],
+ **kwargs: Any,
+ ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]:
+ """Wrapper of `dace.SDFGState.add_map` that assigns unique name."""
+ unique_name = self.unique_map_name(name)
+ return state.add_map(unique_name, ndrange, **kwargs)
+
+ def add_tasklet(
+ self,
+ name: str,
+ state: dace.SDFGState,
+ inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]],
+ outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]],
+ code: str,
+ **kwargs: Any,
+ ) -> dace.nodes.Tasklet:
+ """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name."""
+ unique_name = self.unique_tasklet_name(name)
+ return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs)
+
+
+class SDFGBuilder(DataflowBuilder, Protocol):
+ """Visitor interface available to GTIR-primitive translators."""
+
+ @abc.abstractmethod
+ def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType:
+ pass
+
+ @abc.abstractmethod
+ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder):
+ """Provides translation capability from a GTIR program to a DaCe SDFG.
+
+ This class is responsible for translation of `ir.Program`, that is the top level representation
+ of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions.
+ Each statement is translated to a taskgraph inside a separate state. Statement states are chained
+ one after the other: concurrency between states should be extracted by means of SDFG analysis.
+ The translator will extend the SDFG while preserving the property of single exit state:
+ branching is allowed within the context of one statement, but in that case the statement should
+ terminate with a join state; the join state will represent the head state for next statement,
+ from where to continue building the SDFG.
+ """
+
+ offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension]
+ symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field(
+ default_factory=lambda: {}
+ )
+ map_uids: eve.utils.UIDGenerator = dataclasses.field(
+ init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map")
+ )
+ tesklet_uids: eve.utils.UIDGenerator = dataclasses.field(
+ init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet")
+ )
+
+ def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension:
+ assert offset in self.offset_provider
+ return self.offset_provider[offset]
+
+ def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType:
+ assert symbol_name in self.symbol_types
+ return self.symbol_types[symbol_name]
+
+ def unique_map_name(self, name: str) -> str:
+ return f"{self.map_uids.sequential_id()}_{name}"
+
+ def unique_tasklet_name(self, name: str) -> str:
+ return f"{self.tesklet_uids.sequential_id()}_{name}"
+
+ def _make_array_shape_and_strides(
+ self, name: str, dims: Sequence[gtx_common.Dimension]
+ ) -> tuple[list[dace.symbol], list[dace.symbol]]:
+ """
+ Parse field dimensions and allocate symbols for array shape and strides.
+
+ For local dimensions, the size is known at compile-time and therefore
+ the corresponding array shape dimension is set to an integer literal value.
+
+ Returns:
+ Two lists of symbols, one for the shape and the other for the strides of the array.
+ """
+ dtype = dace.int32
+ neighbor_tables = dace_fieldview_util.filter_connectivities(self.offset_provider)
+ shape = [
+ (
+ neighbor_tables[dim.value].max_neighbors
+ if dim.kind == gtx_common.DimensionKind.LOCAL
+ # we reuse the same symbol for field size passed as scalar argument to the gt4py program
+ else dace.symbol(f"__{name}_size_{i}", dtype)
+ )
+ for i, dim in enumerate(dims)
+ ]
+ strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i in range(len(dims))]
+ return shape, strides
+
+ def _add_storage(self, sdfg: dace.SDFG, name: str, symbol_type: ts.DataType) -> None:
+ """
+ Add external storage (aka non-transient) for data containers passed as arguments to the SDFG.
+
+ For fields, it allocates dace arrays, while scalars are stored as SDFG symbols.
+ """
+ if isinstance(symbol_type, ts.FieldType):
+ dtype = dace_fieldview_util.as_dace_type(symbol_type.dtype)
+ # use symbolic shape, which allows to invoke the program with fields of different size;
+ # and symbolic strides, which enables decoupling the memory layout from generated code.
+ sym_shape, sym_strides = self._make_array_shape_and_strides(name, symbol_type.dims)
+ sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False)
+ elif isinstance(symbol_type, ts.ScalarType):
+ dtype = dace_fieldview_util.as_dace_type(symbol_type)
+ # scalar arguments passed to the program are represented as symbols in DaCe SDFG
+ sdfg.add_symbol(name, dtype)
+ else:
+ raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.")
+
+ # TODO: unclear why mypy complains about incompatible types
+ assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType))
+ self.symbol_types[name] = symbol_type
+
+ def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]:
+ """
+ Add temporary storage (aka transient) for data containers used as GTIR temporaries.
+
+ Assume all temporaries to be fields, therefore represented as dace arrays.
+ """
+ raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.")
+
+ def _visit_expression(
+ self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState
+ ) -> list[dace.nodes.AccessNode]:
+ """
+ Specialized visit method for fieldview expressions.
+
+ This method represents the entry point to visit `ir.Stmt` expressions.
+ As such, it must preserve the property of single exit state in the SDFG.
+
+ Returns a list of array nodes containing the result fields.
+
+ TODO: Do we need to return the GT4Py `FieldType`/`ScalarType`? It is needed
+ in case the transient arrays containing the expression result are not guaranteed
+ to have the same memory layout as the target array.
+ """
+ results: list[gtir_builtin_translators.TemporaryData] = self.visit(
+ node, sdfg=sdfg, head_state=head_state
+ )
+
+ field_nodes = []
+ for node, _ in results:
+ assert isinstance(node, dace.nodes.AccessNode)
+ field_nodes.append(node)
+
+ # sanity check: each statement should preserve the property of single exit state (aka head state),
+ # i.e. eventually only introduce internal branches, and keep the same head state
+ sink_states = sdfg.sink_nodes()
+ assert len(sink_states) == 1
+ assert sink_states[0] == head_state
+
+ return field_nodes
+
+ def visit_Program(self, node: gtir.Program) -> dace.SDFG:
+ """Translates `ir.Program` to `dace.SDFG`.
+
+ First, it will allocate field and scalar storage for global data. The storage
+ represents global data, available everywhere in the SDFG, either containing
+ external data (aka non-transient data) or temporary data (aka transient data).
+ The temporary data is global, therefore available everywhere in the SDFG
+ but not outside. Then, all statements are translated, one after the other.
+ """
+ if node.function_definitions:
+ raise NotImplementedError("Functions expected to be inlined as lambda calls.")
+
+ sdfg = dace.SDFG(node.id)
+ sdfg.debuginfo = dace_fieldview_util.debug_info(node, default=sdfg.debuginfo)
+ entry_state = sdfg.add_state("program_entry", is_start_block=True)
+
+ # declarations of temporaries result in transient array definitions in the SDFG
+ if node.declarations:
+ temp_symbols: dict[str, str] = {}
+ for decl in node.declarations:
+ temp_symbols |= self._add_storage_for_temporary(decl)
+
+ # define symbols for shape and offsets of temporary arrays as interstate edge symbols
+ head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols)
+ else:
+ head_state = entry_state
+
+ # add non-transient arrays and/or SDFG symbols for the program arguments
+ for param in node.params:
+ assert isinstance(param.type, ts.DataType)
+ self._add_storage(sdfg, str(param.id), param.type)
+
+ # visit one statement at a time and expand the SDFG from the current head state
+ for i, stmt in enumerate(node.body):
+ # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy
+ head_state = sdfg.add_state_after(head_state, f"stmt_{i}")
+ head_state._debuginfo = dace_fieldview_util.debug_info(stmt, default=sdfg.debuginfo)
+ self.visit(stmt, sdfg=sdfg, state=head_state)
+
+ sdfg.validate()
+ return sdfg
+
+ def visit_SetAt(self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None:
+ """Visits a `SetAt` statement expression and writes the local result to some external storage.
+
+ Each statement expression results in some sort of dataflow gragh writing to temporary storage.
+ The translation of `SetAt` ensures that the result is written back to the target external storage.
+ """
+
+ expr_nodes = self._visit_expression(stmt.expr, sdfg, state)
+
+ # the target expression could be a `SymRef` to an output node or a `make_tuple` expression
+ # in case the statement returns more than one field
+ target_nodes = self._visit_expression(stmt.target, sdfg, state)
+
+ # convert domain expression to dictionary to ease access to dimension boundaries
+ domain = dace_fieldview_util.get_domain_ranges(stmt.domain)
+
+ for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True):
+ target_array = sdfg.arrays[target_node.data]
+ assert not target_array.transient
+ target_symbol_type = self.symbol_types[target_node.data]
+
+ if isinstance(target_symbol_type, ts.FieldType):
+ subset = ",".join(
+ f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims
+ )
+ else:
+ assert len(domain) == 0
+ subset = "0"
+
+ state.add_nedge(
+ expr_node,
+ target_node,
+ dace.Memlet(data=target_node.data, subset=subset),
+ )
+
+ def visit_FunCall(
+ self,
+ node: gtir.FunCall,
+ sdfg: dace.SDFG,
+ head_state: dace.SDFGState,
+ ) -> list[gtir_builtin_translators.TemporaryData]:
+ # use specialized dataflow builder classes for each builtin function
+ if cpm.is_call_to(node.fun, "as_fieldop"):
+ return gtir_builtin_translators.translate_as_field_op(node, sdfg, head_state, self)
+ elif cpm.is_call_to(node.fun, "cond"):
+ return gtir_builtin_translators.translate_cond(node, sdfg, head_state, self)
+ else:
+ raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).")
+
+ def visit_Lambda(self, node: gtir.Lambda) -> Any:
+ """
+ This visitor class should never encounter `itir.Lambda` expressions
+ because a lambda represents a stencil, which operates from iterator to values.
+ In fieldview, lambdas should only be arguments to field operators (`as_field_op`).
+ """
+ raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.")
+
+ def visit_Literal(
+ self,
+ node: gtir.Literal,
+ sdfg: dace.SDFG,
+ head_state: dace.SDFGState,
+ ) -> list[gtir_builtin_translators.TemporaryData]:
+ return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self)
+
+ def visit_SymRef(
+ self,
+ node: gtir.SymRef,
+ sdfg: dace.SDFG,
+ head_state: dace.SDFGState,
+ ) -> list[gtir_builtin_translators.TemporaryData]:
+ return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self)
+
+
+def build_sdfg_from_gtir(
+ program: gtir.Program,
+ offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension],
+) -> dace.SDFG:
+ """
+ Receives a GTIR program and lowers it to a DaCe SDFG.
+
+ The lowering to SDFG requires that the program node is type-annotated, therefore this function
+ runs type ineference as first step.
+ As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form.
+
+ Arguments:
+ program: The GTIR program node to be lowered to SDFG
+ offset_provider: The definitions of offset providers used by the program node
+
+ Returns:
+ An SDFG in the DaCe canonical form (simplified)
+ """
+ sdfg_genenerator = GTIRToSDFG(offset_provider)
+ # TODO: run type inference on the `program` node before passing it to `GTIRToSDFG`
+ sdfg = sdfg_genenerator.visit(program)
+ assert isinstance(sdfg, dace.SDFG)
+
+ sdfg.simplify()
+ return sdfg
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py
new file mode 100644
index 0000000000..485db254ea
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py
@@ -0,0 +1,273 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+from __future__ import annotations
+
+import dataclasses
+from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
+
+import dace
+import dace.subsets as sbs
+
+from gt4py import eve
+from gt4py.next import common as gtx_common
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
+from gt4py.next.program_processors.runners.dace_fieldview import (
+ gtir_python_codegen,
+ gtir_to_sdfg,
+ utility as dace_fieldview_util,
+)
+from gt4py.next.type_system import type_specifications as ts
+
+
+@dataclasses.dataclass(frozen=True)
+class MemletExpr:
+ """Scalar or array data access thorugh a memlet."""
+
+ node: dace.nodes.AccessNode
+ subset: sbs.Indices | sbs.Range
+
+
+@dataclasses.dataclass(frozen=True)
+class SymbolExpr:
+ """Any symbolic expression that is constant in the context of current SDFG."""
+
+ value: dace.symbolic.SymExpr
+ dtype: dace.typeclass
+
+
+@dataclasses.dataclass(frozen=True)
+class ValueExpr:
+ """Result of the computation implemented by a tasklet node."""
+
+ node: dace.nodes.AccessNode
+ field_type: ts.FieldType | ts.ScalarType
+
+
+# Define alias for the elements needed to setup input connections to a map scope
+InputConnection: TypeAlias = tuple[
+ dace.nodes.AccessNode,
+ sbs.Range,
+ dace.nodes.Node,
+ Optional[str],
+]
+
+IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr
+
+
+@dataclasses.dataclass(frozen=True)
+class IteratorExpr:
+ """Iterator for field access to be consumed by `deref` or `shift` builtin functions."""
+
+ field: dace.nodes.AccessNode
+ dimensions: list[gtx_common.Dimension]
+ indices: dict[gtx_common.Dimension, IteratorIndexExpr]
+
+
+class LambdaToTasklet(eve.NodeVisitor):
+ """Translates an `ir.Lambda` expression to a dataflow graph.
+
+ Lambda functions should only be encountered as argument to the `as_field_op`
+ builtin function, therefore the dataflow graph generated here typically
+ represents the stencil function of a field operator.
+ """
+
+ sdfg: dace.SDFG
+ state: dace.SDFGState
+ subgraph_builder: gtir_to_sdfg.DataflowBuilder
+ input_connections: list[InputConnection]
+ symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr]
+
+ def __init__(
+ self,
+ sdfg: dace.SDFG,
+ state: dace.SDFGState,
+ subgraph_builder: gtir_to_sdfg.DataflowBuilder,
+ ):
+ self.sdfg = sdfg
+ self.state = state
+ self.subgraph_builder = subgraph_builder
+ self.input_connections = []
+ self.symbol_map = {}
+
+ def _add_entry_memlet_path(
+ self,
+ src: dace.nodes.AccessNode,
+ src_subset: sbs.Range,
+ dst_node: dace.nodes.Node,
+ dst_conn: Optional[str] = None,
+ ) -> None:
+ self.input_connections.append((src, src_subset, dst_node, dst_conn))
+
+ def _add_map(
+ self,
+ name: str,
+ ndrange: Union[
+ Dict[str, Union[str, dace.subsets.Subset]],
+ List[Tuple[str, Union[str, dace.subsets.Subset]]],
+ ],
+ **kwargs: Any,
+ ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]:
+ """Helper method to add a map with unique ame in current state."""
+ return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs)
+
+ def _add_tasklet(
+ self,
+ name: str,
+ inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]],
+ outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]],
+ code: str,
+ **kwargs: Any,
+ ) -> dace.nodes.Tasklet:
+ """Helper method to add a tasklet with unique ame in current state."""
+ return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs)
+
+ def _get_tasklet_result(
+ self,
+ dtype: dace.typeclass,
+ src_node: dace.nodes.Tasklet,
+ src_connector: str,
+ ) -> ValueExpr:
+ temp_name = self.sdfg.temp_data_name()
+ self.sdfg.add_scalar(temp_name, dtype, transient=True)
+ data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype()))
+ temp_node = self.state.add_access(temp_name)
+ self.state.add_edge(
+ src_node,
+ src_connector,
+ temp_node,
+ None,
+ dace.Memlet(data=temp_name, subset="0"),
+ )
+ return ValueExpr(temp_node, data_type)
+
+ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
+ assert len(node.args) == 1
+ it = self.visit(node.args[0])
+
+ if isinstance(it, IteratorExpr):
+ field_desc = it.field.desc(self.sdfg)
+ assert len(field_desc.shape) == len(it.dimensions)
+ if all(isinstance(index, SymbolExpr) for index in it.indices.values()):
+ # when all indices are symblic expressions, we can perform direct field access through a memlet
+ field_subset = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr]
+ return MemletExpr(it.field, field_subset)
+
+ else:
+ raise NotImplementedError
+
+ else:
+ assert isinstance(it, MemletExpr)
+ return it
+
+ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr:
+ if cpm.is_call_to(node, "deref"):
+ return self._visit_deref(node)
+
+ else:
+ assert isinstance(node.fun, gtir.SymRef)
+
+ node_internals = []
+ node_connections: dict[str, MemletExpr | ValueExpr] = {}
+ for i, arg in enumerate(node.args):
+ arg_expr = self.visit(arg)
+ if isinstance(arg_expr, MemletExpr | ValueExpr):
+ # the argument value is the result of a tasklet node or direct field access
+ connector = f"__inp_{i}"
+ node_connections[connector] = arg_expr
+ node_internals.append(connector)
+ else:
+ assert isinstance(arg_expr, SymbolExpr)
+ # use the argument value without adding any connector
+ node_internals.append(arg_expr.value)
+
+ # use tasklet connectors as expression arguments
+ builtin_name = str(node.fun.id)
+ code = gtir_python_codegen.format_builtin(builtin_name, *node_internals)
+
+ out_connector = "result"
+ tasklet_node = self._add_tasklet(
+ builtin_name,
+ set(node_connections.keys()),
+ {out_connector},
+ "{} = {}".format(out_connector, code),
+ )
+
+ for connector, arg_expr in node_connections.items():
+ if isinstance(arg_expr, ValueExpr):
+ self.state.add_edge(
+ arg_expr.node,
+ None,
+ tasklet_node,
+ connector,
+ dace.Memlet(data=arg_expr.node.data, subset="0"),
+ )
+ else:
+ self._add_entry_memlet_path(
+ arg_expr.node,
+ arg_expr.subset,
+ tasklet_node,
+ connector,
+ )
+
+ # TODO: use type inference to determine the result type
+ if len(node_connections) == 1:
+ dtype = None
+ for conn_name in ["__inp_0", "__inp_1"]:
+ if conn_name in node_connections:
+ dtype = node_connections[conn_name].node.desc(self.sdfg).dtype
+ break
+ if dtype is None:
+ raise ValueError("Failed to determine the type")
+ else:
+ node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
+ dtype = dace_fieldview_util.as_dace_type(node_type)
+
+ return self._get_tasklet_result(dtype, tasklet_node, "result")
+
+ def visit_Lambda(
+ self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr]
+ ) -> tuple[list[InputConnection], ValueExpr]:
+ for p, arg in zip(node.params, args, strict=True):
+ self.symbol_map[str(p.id)] = arg
+ output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr)
+ if isinstance(output_expr, ValueExpr):
+ return self.input_connections, output_expr
+
+ if isinstance(output_expr, MemletExpr):
+ # special case where the field operator is simply copying data from source to destination node
+ output_dtype = output_expr.node.desc(self.sdfg).dtype
+ tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp")
+ self._add_entry_memlet_path(
+ output_expr.node,
+ output_expr.subset,
+ tasklet_node,
+ "__inp",
+ )
+ else:
+ # even simpler case, where a constant value is written to destination node
+ output_dtype = output_expr.dtype
+ tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}")
+ return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out")
+
+ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr:
+ dtype = dace_fieldview_util.as_dace_type(node.type)
+ return SymbolExpr(node.value, dtype)
+
+ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr:
+ param = str(node.id)
+ assert param in self.symbol_map
+ return self.symbol_map[param]
diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py
new file mode 100644
index 0000000000..34e95506d8
--- /dev/null
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py
@@ -0,0 +1,128 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from __future__ import annotations
+
+from typing import Any, Mapping, Optional
+
+import dace
+
+from gt4py.next import common as gtx_common
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
+from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen
+from gt4py.next.type_system import type_specifications as ts
+
+
+def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
+ """Converts GT4Py scalar type to corresponding DaCe type."""
+ match type_.kind:
+ case ts.ScalarKind.BOOL:
+ return dace.bool_
+ case ts.ScalarKind.INT32:
+ return dace.int32
+ case ts.ScalarKind.INT64:
+ return dace.int64
+ case ts.ScalarKind.FLOAT32:
+ return dace.float32
+ case ts.ScalarKind.FLOAT64:
+ return dace.float64
+ case _:
+ raise ValueError(f"Scalar type '{type_}' not supported.")
+
+
+def as_scalar_type(typestr: str) -> ts.ScalarType:
+ """Obtain GT4Py scalar type from generic numpy string representation."""
+ try:
+ kind = getattr(ts.ScalarKind, typestr.upper())
+ except AttributeError as ex:
+ raise ValueError(f"Data type {typestr} not supported.") from ex
+ return ts.ScalarType(kind)
+
+
+def debug_info(
+ node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None
+) -> Optional[dace.dtypes.DebugInfo]:
+ location = node.location
+ if location:
+ return dace.dtypes.DebugInfo(
+ start_line=location.line,
+ start_column=location.column if location.column else 0,
+ end_line=location.end_line if location.end_line else -1,
+ end_column=location.end_column if location.end_column else 0,
+ filename=location.filename,
+ )
+ return default
+
+
+def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, gtx_common.Connectivity]:
+ """
+ Filter offset providers of type `Connectivity`.
+
+ In other words, filter out the cartesian offset providers.
+ Returns a new dictionary containing only `Connectivity` values.
+ """
+ return {
+ offset: table
+ for offset, table in offset_provider.items()
+ if isinstance(table, gtx_common.Connectivity)
+ }
+
+
+def get_domain(
+ node: gtir.Expr,
+) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]:
+ """
+ Specialized visit method for domain expressions.
+
+ Returns for each domain dimension the corresponding range.
+
+ TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea
+ would be to cache the results of lowering here (e.g. using `functools.lru_cache`)
+ """
+ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))
+
+ domain = []
+ for named_range in node.args:
+ assert cpm.is_call_to(named_range, "named_range")
+ assert len(named_range.args) == 3
+ axis = named_range.args[0]
+ assert isinstance(axis, gtir.AxisLiteral)
+ bounds = [
+ dace.symbolic.SymExpr(gtir_python_codegen.get_source(arg))
+ for arg in named_range.args[1:3]
+ ]
+ dim = gtx_common.Dimension(axis.value, axis.kind)
+ domain.append((dim, bounds[0], bounds[1]))
+
+ return domain
+
+
+def get_domain_ranges(
+ node: gtir.Expr,
+) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]:
+ """
+ Returns domain represented in dictionary form.
+ """
+ domain = get_domain(node)
+
+ return {dim: (lb, ub) for dim, lb, ub in domain}
+
+
+def get_map_variable(dim: gtx_common.Dimension) -> str:
+ """
+ Format map variable name based on the naming convention for application-specific SDFG transformations.
+ """
+ suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else ""
+ return f"i_{dim.value}_gtx_{dim.kind}{suffix}"
diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py
new file mode 100644
index 0000000000..d674e71f0b
--- /dev/null
+++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py
@@ -0,0 +1,384 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""
+Test that ITIR can be lowered to SDFG.
+
+Note: this test module covers the fieldview flavour of ITIR.
+"""
+
+import copy
+from gt4py.next import common as gtx_common
+from gt4py.next.iterator import ir as gtir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.program_processors.runners import dace_fieldview as dace_backend
+from gt4py.next.type_system import type_specifications as ts
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim
+
+import numpy as np
+import pytest
+
+dace = pytest.importorskip("dace")
+
+
+N = 10
+IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))
+SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32)
+FSYMBOLS = dict(
+ __w_size_0=N,
+ __w_stride_0=1,
+ __x_size_0=N,
+ __x_stride_0=1,
+ __y_size_0=N,
+ __y_stride_0=1,
+ __z_size_0=N,
+ __z_stride_0=1,
+ size=N,
+)
+
+
+def test_gtir_copy():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ testee = gtir.Program(
+ id="gtir_copy",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="y", type=IFTYPE),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a")(im.deref("a")),
+ domain,
+ )
+ )("x"),
+ domain=domain,
+ target=gtir.SymRef(id="y"),
+ )
+ ],
+ )
+
+ a = np.random.rand(N)
+ b = np.empty_like(a)
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ sdfg(x=a, y=b, **FSYMBOLS)
+ assert np.allclose(a, b)
+
+
+def test_gtir_update():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ stencil1 = im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a")(im.plus(im.deref("a"), 1.0)),
+ domain,
+ )
+ )("x")
+ stencil2 = im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", 1.0)
+
+ for i, stencil in enumerate([stencil1, stencil2]):
+ testee = gtir.Program(
+ id=f"gtir_update_{i}",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=stencil,
+ domain=domain,
+ target=gtir.SymRef(id="x"),
+ )
+ ],
+ )
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ a = np.random.rand(N)
+ ref = a + 1.0
+
+ sdfg(x=a, **FSYMBOLS)
+ assert np.allclose(a, ref)
+
+
+def test_gtir_sum2():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ testee = gtir.Program(
+ id="sum_2fields",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="y", type=IFTYPE),
+ gtir.Sym(id="z", type=IFTYPE),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", "y"),
+ domain=domain,
+ target=gtir.SymRef(id="z"),
+ )
+ ],
+ )
+
+ a = np.random.rand(N)
+ b = np.random.rand(N)
+ c = np.empty_like(a)
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ sdfg(x=a, y=b, z=c, **FSYMBOLS)
+ assert np.allclose(c, (a + b))
+
+
+def test_gtir_sum2_sym():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ testee = gtir.Program(
+ id="sum_2fields_sym",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="z", type=IFTYPE),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", "x"),
+ domain=domain,
+ target=gtir.SymRef(id="z"),
+ )
+ ],
+ )
+
+ a = np.random.rand(N)
+ b = np.empty_like(a)
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ sdfg(x=a, z=b, **FSYMBOLS)
+ assert np.allclose(b, (a + a))
+
+
+def test_gtir_sum3():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ stencil1 = im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )(
+ "x",
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("y", "w"),
+ )
+ stencil2 = im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b", "c")(
+ im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c")))
+ ),
+ domain,
+ )
+ )("x", "y", "w")
+
+ a = np.random.rand(N)
+ b = np.random.rand(N)
+ c = np.random.rand(N)
+
+ for i, stencil in enumerate([stencil1, stencil2]):
+ testee = gtir.Program(
+ id=f"sum_3fields_{i}",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="y", type=IFTYPE),
+ gtir.Sym(id="w", type=IFTYPE),
+ gtir.Sym(id="z", type=IFTYPE),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=stencil,
+ domain=domain,
+ target=gtir.SymRef(id="z"),
+ )
+ ],
+ )
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ d = np.empty_like(a)
+
+ sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS)
+ assert np.allclose(d, (a + b + c))
+
+
+def test_gtir_cond():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ testee = gtir.Program(
+ id="cond_2sums",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="y", type=IFTYPE),
+ gtir.Sym(id="w", type=IFTYPE),
+ gtir.Sym(id="z", type=IFTYPE),
+ gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)),
+ gtir.Sym(id="scalar", type=ts.ScalarType(ts.ScalarKind.FLOAT64)),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )(
+ "x",
+ im.call(
+ im.call("cond")(
+ im.deref("pred"),
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("y", "scalar"),
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("w", "scalar"),
+ )
+ )(),
+ ),
+ domain=domain,
+ target=gtir.SymRef(id="z"),
+ )
+ ],
+ )
+
+ a = np.random.rand(N)
+ b = np.random.rand(N)
+ c = np.random.rand(N)
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ for s in [False, True]:
+ d = np.empty_like(a)
+ sdfg(pred=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS)
+ assert np.allclose(d, (a + b + 1) if s else (a + c + 1))
+
+
+def test_gtir_cond_nested():
+ domain = im.call("cartesian_domain")(
+ im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
+ )
+ testee = gtir.Program(
+ id="cond_nested",
+ function_definitions=[],
+ params=[
+ gtir.Sym(id="x", type=IFTYPE),
+ gtir.Sym(id="z", type=IFTYPE),
+ gtir.Sym(id="pred_1", type=ts.ScalarType(ts.ScalarKind.BOOL)),
+ gtir.Sym(id="pred_2", type=ts.ScalarType(ts.ScalarKind.BOOL)),
+ gtir.Sym(id="size", type=SIZE_TYPE),
+ ],
+ declarations=[],
+ body=[
+ gtir.SetAt(
+ expr=im.call(
+ im.call("cond")(
+ im.deref("pred_1"),
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", 1),
+ im.call(
+ im.call("cond")(
+ im.deref("pred_2"),
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", 2),
+ im.call(
+ im.call("as_fieldop")(
+ im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))),
+ domain,
+ )
+ )("x", 3),
+ )
+ )(),
+ )
+ )(),
+ domain=domain,
+ target=gtir.SymRef(id="z"),
+ )
+ ],
+ )
+
+ a = np.random.rand(N)
+
+ sdfg = dace_backend.build_sdfg_from_gtir(testee, {})
+
+ for s1 in [False, True]:
+ for s2 in [False, True]:
+ b = np.empty_like(a)
+ sdfg(pred_1=np.bool_(s1), pred_2=np.bool_(s2), x=a, z=b, **FSYMBOLS)
+ assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0))