diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 05000da10d..623106f303 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -194,6 +194,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib GTIR_BUILTINS = { *BUILTINS, "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) + "cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py new file mode 100644 index 0000000000..54b2e0e29d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -0,0 +1,21 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir + + +__all__ = [ + "build_sdfg_from_gtir", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py new file mode 100644 index 0000000000..f401f46449 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -0,0 +1,360 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Optional, Protocol, TypeAlias + +import dace +import dace.subsets as sbs + +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_python_codegen, + gtir_to_tasklet, + utility as dace_fieldview_util, +) +from gt4py.next.type_system import type_specifications as ts + + +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg + + +IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes +TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] + + +class PrimitiveTranslator(Protocol): + @abc.abstractmethod + def __call__( + self, + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + ) -> list[TemporaryData]: + """Creates the dataflow subgraph representing a GTIR primitive function. + + This method is used by derived classes to build a specialized subgraph + for a specific GTIR primitive function. + + Arguments: + node: The GTIR node describing the primitive to be lowered + sdfg: The SDFG where the primitive subgraph should be instantiated + state: The SDFG state where the result of the primitive function should be made available + sdfg_builder: The object responsible for visiting child nodes of the primitive node. + + Returns: + A list of data access nodes and the associated GT4Py data type, which provide + access to the result of the primitive subgraph. The GT4Py data type is useful + in the case the returned data is an array, because the type provdes the domain + information (e.g. order of dimensions, dimension types). + """ + + +def _parse_arg_expr( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], +) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: + fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + assert len(fields) == 1 + data_node, arg_type = fields[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + return gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace_fieldview_util.get_map_variable(dim), + IteratorIndexDType, + ) + for dim, _, _ in domain + } + return gtir_to_tasklet.IteratorExpr( + data_node, + arg_type.dims, + indices, + ) + + +def _create_temporary_field( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], + node_type: ts.ScalarType, + output_desc: dace.data.Data, + output_field_type: ts.DataType, +) -> tuple[dace.nodes.AccessNode, ts.FieldType]: + domain_dims, domain_lbs, domain_ubs = zip(*domain) + field_dims = list(domain_dims) + field_shape = [ + # diff between upper and lower bound + (ub - lb) + for lb, ub in zip(domain_lbs, domain_ubs) + ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(domain_lbs): + field_offset = [-lb for lb in domain_lbs] + + if isinstance(output_desc, dace.data.Array): + # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) + assert isinstance(output_field_type, ts.FieldType) + # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` + node_type = output_field_type.dtype + field_dims.extend(output_field_type.dims) + field_shape.extend(output_desc.shape) + else: + assert isinstance(output_desc, dace.data.Scalar) + assert isinstance(output_field_type, ts.ScalarType) + # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` + node_type = output_field_type + + # allocate local temporary storage for the result field + temp_name, _ = sdfg.add_temp_transient( + field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + ) + field_node = state.add_access(temp_name) + field_type = ts.FieldType(field_dims, node_type) + + return field_node, field_type + + +def translate_as_field_op( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + fun_node = node.fun + assert len(fun_node.args) == 2 + stencil_expr, domain_expr = fun_node.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, gtir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, gtir.FunCall) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + domain = dace_fieldview_util.get_domain(domain_expr) + + # first visit the list of arguments and build a symbol map + stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + + # represent the field operator as a mapped tasklet graph, which will range over the field domain + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder) + input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + output_desc = output_expr.node.desc(sdfg) + + # retrieve the tasklet node which writes the result + last_node = state.in_edges(output_expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = state.in_edges(output_expr.node)[0].src_conn + state.remove_node(output_expr.node) + else: + last_node = output_expr.node + last_node_connector = None + + # allocate local temporary storage for the result field + field_node, field_type = _create_temporary_field( + sdfg, state, domain, node_type, output_desc, output_expr.field_type + ) + + # assume tasklet with single output + output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] + if isinstance(output_desc, dace.data.Array): + # additional local dimension for neighbors + assert set(output_desc.offset) == {0} + output_subset.extend(f"0:{size}" for size in output_desc.shape) + + # create map range corresponding to the field operator domain + map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} + me, mx = sdfg_builder.add_map("field_op", state, map_ranges) + + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + state.add_nedge(me, last_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) + state.add_memlet_path( + last_node, + mx, + field_node, + src_conn=last_node_connector, + memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), + ) + + return [(field_node, field_type)] + + +def translate_cond( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `cond` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "cond") + assert len(node.args) == 0 + + fun_node = node.fun + assert len(fun_node.args) == 3 + cond_expr, true_expr, false_expr = fun_node.args + + # expect condition as first argument + cond = gtir_python_codegen.get_source(cond_expr) + + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | cond | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # + cond_state = sdfg.add_state_before(state, state.label + "_cond") + sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + + true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) + false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args, strict=True): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, _ = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + desc = true_br_node.desc(sdfg) + assert false_br_node.desc(sdfg) == desc + data_name, _ = sdfg.add_temp_transient_like(desc) + output_nodes.append((state.add_access(data_name), true_br_type)) + + true_br_output_node = true_state.add_access(data_name) + true_state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) + + false_br_output_node = false_state.add_access(data_name) + false_state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) + + return output_nodes + + +def translate_symbol_ref( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for a `ir.SymRef` node.""" + assert isinstance(node, (gtir.Literal, gtir.SymRef)) + + data_type: ts.FieldType | ts.ScalarType + if isinstance(node, gtir.Literal): + sym_value = node.value + data_type = node.type + tasklet_name = "get_literal" + else: + sym_value = str(node.id) + data_type = sdfg_builder.get_symbol_type(sym_value) + tasklet_name = f"get_{sym_value}" + + if isinstance(data_type, ts.FieldType): + # add access node to current state + sym_node = state.add_access(sym_value) + + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + tasklet_node = sdfg_builder.add_tasklet( + tasklet_name, + state, + {}, + {"__out"}, + f"__out = {sym_value}", + ) + temp_name, _ = sdfg.add_temp_transient((1,), dace_fieldview_util.as_dace_type(data_type)) + sym_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) + + return [(sym_node, data_type)] + + +if TYPE_CHECKING: + # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol + __primitive_translators: list[PrimitiveTranslator] = [ + translate_as_field_op, + translate_cond, + translate_symbol_ref, + ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py new file mode 100644 index 0000000000..fcb71e4e6d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -0,0 +1,128 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm + + +MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +def format_builtin(bultin: str, *args: Any) -> str: + if bultin in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[bultin] + else: + raise NotImplementedError(f"'{bultin}' not implemented.") + return fmt.format(*args) + + +class PythonCodegen(codegen.TemplatedGenerator): + """Helper class to visit a symbolic expression and translate it to Python code. + + The generated Python code can be use either as the body of a tasklet node or, + as in the case of field domain definitions, for sybolic array shape and map range. + """ + + SymRef = as_fmt("{id}") + Literal = as_fmt("{value}") + + def _visit_deref(self, node: gtir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], gtir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: gtir.FunCall) -> str: + assert isinstance(node.fun, gtir.SymRef) + fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: gtir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, gtir.SymRef): + args = self.visit(node.args) + builtin_name = str(node.fun.id) + return format_builtin(builtin_name, *args) + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + +get_source = PythonCodegen.apply +""" +Specialized visit method for symbolic expressions. + +Returns: + A string containing the Python code corresponding to a symbolic expression +""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py new file mode 100644 index 0000000000..7468056f0c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -0,0 +1,363 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Contains visitors to lower GTIR to DaCe SDFG. + +Note: this module covers the fieldview flavour of GTIR. +""" + +from __future__ import annotations + +import abc +import dataclasses +from typing import Any, Dict, List, Protocol, Sequence, Set, Tuple, Union + +import dace + +from gt4py import eve +from gt4py.eve import concepts +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_builtin_translators, + utility as dace_fieldview_util, +) +from gt4py.next.type_system import type_specifications as ts + + +class DataflowBuilder(Protocol): + """Visitor interface to build a dataflow subgraph.""" + + @abc.abstractmethod + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + pass + + @abc.abstractmethod + def unique_map_name(self, name: str) -> str: + pass + + @abc.abstractmethod + def unique_tasklet_name(self, name: str) -> str: + pass + + def add_map( + self, + name: str, + state: dace.SDFGState, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" + unique_name = self.unique_map_name(name) + return state.add_map(unique_name, ndrange, **kwargs) + + def add_tasklet( + self, + name: str, + state: dace.SDFGState, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + + +class SDFGBuilder(DataflowBuilder, Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: + pass + + @abc.abstractmethod + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + pass + + +@dataclasses.dataclass(frozen=True) +class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): + """Provides translation capability from a GTIR program to a DaCe SDFG. + + This class is responsible for translation of `ir.Program`, that is the top level representation + of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. + Each statement is translated to a taskgraph inside a separate state. Statement states are chained + one after the other: concurrency between states should be extracted by means of SDFG analysis. + The translator will extend the SDFG while preserving the property of single exit state: + branching is allowed within the context of one statement, but in that case the statement should + terminate with a join state; the join state will represent the head state for next statement, + from where to continue building the SDFG. + """ + + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] + symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( + default_factory=lambda: {} + ) + map_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") + ) + tesklet_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") + ) + + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + assert offset in self.offset_provider + return self.offset_provider[offset] + + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: + assert symbol_name in self.symbol_types + return self.symbol_types[symbol_name] + + def unique_map_name(self, name: str) -> str: + return f"{self.map_uids.sequential_id()}_{name}" + + def unique_tasklet_name(self, name: str) -> str: + return f"{self.tesklet_uids.sequential_id()}_{name}" + + def _make_array_shape_and_strides( + self, name: str, dims: Sequence[gtx_common.Dimension] + ) -> tuple[list[dace.symbol], list[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns: + Two lists of symbols, one for the shape and the other for the strides of the array. + """ + dtype = dace.int32 + neighbor_tables = dace_fieldview_util.filter_connectivities(self.offset_provider) + shape = [ + ( + neighbor_tables[dim.value].max_neighbors + if dim.kind == gtx_common.DimensionKind.LOCAL + # we reuse the same symbol for field size passed as scalar argument to the gt4py program + else dace.symbol(f"__{name}_size_{i}", dtype) + ) + for i, dim in enumerate(dims) + ] + strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i in range(len(dims))] + return shape, strides + + def _add_storage(self, sdfg: dace.SDFG, name: str, symbol_type: ts.DataType) -> None: + """ + Add external storage (aka non-transient) for data containers passed as arguments to the SDFG. + + For fields, it allocates dace arrays, while scalars are stored as SDFG symbols. + """ + if isinstance(symbol_type, ts.FieldType): + dtype = dace_fieldview_util.as_dace_type(symbol_type.dtype) + # use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, symbol_type.dims) + sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) + elif isinstance(symbol_type, ts.ScalarType): + dtype = dace_fieldview_util.as_dace_type(symbol_type) + # scalar arguments passed to the program are represented as symbols in DaCe SDFG + sdfg.add_symbol(name, dtype) + else: + raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.") + + # TODO: unclear why mypy complains about incompatible types + assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) + self.symbol_types[name] = symbol_type + + def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: + """ + Add temporary storage (aka transient) for data containers used as GTIR temporaries. + + Assume all temporaries to be fields, therefore represented as dace arrays. + """ + raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") + + def _visit_expression( + self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState + ) -> list[dace.nodes.AccessNode]: + """ + Specialized visit method for fieldview expressions. + + This method represents the entry point to visit `ir.Stmt` expressions. + As such, it must preserve the property of single exit state in the SDFG. + + Returns a list of array nodes containing the result fields. + + TODO: Do we need to return the GT4Py `FieldType`/`ScalarType`? It is needed + in case the transient arrays containing the expression result are not guaranteed + to have the same memory layout as the target array. + """ + results: list[gtir_builtin_translators.TemporaryData] = self.visit( + node, sdfg=sdfg, head_state=head_state + ) + + field_nodes = [] + for node, _ in results: + assert isinstance(node, dace.nodes.AccessNode) + field_nodes.append(node) + + # sanity check: each statement should preserve the property of single exit state (aka head state), + # i.e. eventually only introduce internal branches, and keep the same head state + sink_states = sdfg.sink_nodes() + assert len(sink_states) == 1 + assert sink_states[0] == head_state + + return field_nodes + + def visit_Program(self, node: gtir.Program) -> dace.SDFG: + """Translates `ir.Program` to `dace.SDFG`. + + First, it will allocate field and scalar storage for global data. The storage + represents global data, available everywhere in the SDFG, either containing + external data (aka non-transient data) or temporary data (aka transient data). + The temporary data is global, therefore available everywhere in the SDFG + but not outside. Then, all statements are translated, one after the other. + """ + if node.function_definitions: + raise NotImplementedError("Functions expected to be inlined as lambda calls.") + + sdfg = dace.SDFG(node.id) + sdfg.debuginfo = dace_fieldview_util.debug_info(node, default=sdfg.debuginfo) + entry_state = sdfg.add_state("program_entry", is_start_block=True) + + # declarations of temporaries result in transient array definitions in the SDFG + if node.declarations: + temp_symbols: dict[str, str] = {} + for decl in node.declarations: + temp_symbols |= self._add_storage_for_temporary(decl) + + # define symbols for shape and offsets of temporary arrays as interstate edge symbols + head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols) + else: + head_state = entry_state + + # add non-transient arrays and/or SDFG symbols for the program arguments + for param in node.params: + assert isinstance(param.type, ts.DataType) + self._add_storage(sdfg, str(param.id), param.type) + + # visit one statement at a time and expand the SDFG from the current head state + for i, stmt in enumerate(node.body): + # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy + head_state = sdfg.add_state_after(head_state, f"stmt_{i}") + head_state._debuginfo = dace_fieldview_util.debug_info(stmt, default=sdfg.debuginfo) + self.visit(stmt, sdfg=sdfg, state=head_state) + + sdfg.validate() + return sdfg + + def visit_SetAt(self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + """Visits a `SetAt` statement expression and writes the local result to some external storage. + + Each statement expression results in some sort of dataflow gragh writing to temporary storage. + The translation of `SetAt` ensures that the result is written back to the target external storage. + """ + + expr_nodes = self._visit_expression(stmt.expr, sdfg, state) + + # the target expression could be a `SymRef` to an output node or a `make_tuple` expression + # in case the statement returns more than one field + target_nodes = self._visit_expression(stmt.target, sdfg, state) + + # convert domain expression to dictionary to ease access to dimension boundaries + domain = dace_fieldview_util.get_domain_ranges(stmt.domain) + + for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): + target_array = sdfg.arrays[target_node.data] + assert not target_array.transient + target_symbol_type = self.symbol_types[target_node.data] + + if isinstance(target_symbol_type, ts.FieldType): + subset = ",".join( + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims + ) + else: + assert len(domain) == 0 + subset = "0" + + state.add_nedge( + expr_node, + target_node, + dace.Memlet(data=target_node.data, subset=subset), + ) + + def visit_FunCall( + self, + node: gtir.FunCall, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: + # use specialized dataflow builder classes for each builtin function + if cpm.is_call_to(node.fun, "as_fieldop"): + return gtir_builtin_translators.translate_as_field_op(node, sdfg, head_state, self) + elif cpm.is_call_to(node.fun, "cond"): + return gtir_builtin_translators.translate_cond(node, sdfg, head_state, self) + else: + raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + + def visit_Lambda(self, node: gtir.Lambda) -> Any: + """ + This visitor class should never encounter `itir.Lambda` expressions + because a lambda represents a stencil, which operates from iterator to values. + In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ + raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.") + + def visit_Literal( + self, + node: gtir.Literal, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + + def visit_SymRef( + self, + node: gtir.SymRef, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + + +def build_sdfg_from_gtir( + program: gtir.Program, + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], +) -> dace.SDFG: + """ + Receives a GTIR program and lowers it to a DaCe SDFG. + + The lowering to SDFG requires that the program node is type-annotated, therefore this function + runs type ineference as first step. + As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. + + Arguments: + program: The GTIR program node to be lowered to SDFG + offset_provider: The definitions of offset providers used by the program node + + Returns: + An SDFG in the DaCe canonical form (simplified) + """ + sdfg_genenerator = GTIRToSDFG(offset_provider) + # TODO: run type inference on the `program` node before passing it to `GTIRToSDFG` + sdfg = sdfg_genenerator.visit(program) + assert isinstance(sdfg, dace.SDFG) + + sdfg.simplify() + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py new file mode 100644 index 0000000000..485db254ea --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -0,0 +1,273 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from __future__ import annotations + +import dataclasses +from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union + +import dace +import dace.subsets as sbs + +from gt4py import eve +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_python_codegen, + gtir_to_sdfg, + utility as dace_fieldview_util, +) +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass(frozen=True) +class MemletExpr: + """Scalar or array data access thorugh a memlet.""" + + node: dace.nodes.AccessNode + subset: sbs.Indices | sbs.Range + + +@dataclasses.dataclass(frozen=True) +class SymbolExpr: + """Any symbolic expression that is constant in the context of current SDFG.""" + + value: dace.symbolic.SymExpr + dtype: dace.typeclass + + +@dataclasses.dataclass(frozen=True) +class ValueExpr: + """Result of the computation implemented by a tasklet node.""" + + node: dace.nodes.AccessNode + field_type: ts.FieldType | ts.ScalarType + + +# Define alias for the elements needed to setup input connections to a map scope +InputConnection: TypeAlias = tuple[ + dace.nodes.AccessNode, + sbs.Range, + dace.nodes.Node, + Optional[str], +] + +IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr + + +@dataclasses.dataclass(frozen=True) +class IteratorExpr: + """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" + + field: dace.nodes.AccessNode + dimensions: list[gtx_common.Dimension] + indices: dict[gtx_common.Dimension, IteratorIndexExpr] + + +class LambdaToTasklet(eve.NodeVisitor): + """Translates an `ir.Lambda` expression to a dataflow graph. + + Lambda functions should only be encountered as argument to the `as_field_op` + builtin function, therefore the dataflow graph generated here typically + represents the stencil function of a field operator. + """ + + sdfg: dace.SDFG + state: dace.SDFGState + subgraph_builder: gtir_to_sdfg.DataflowBuilder + input_connections: list[InputConnection] + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + subgraph_builder: gtir_to_sdfg.DataflowBuilder, + ): + self.sdfg = sdfg + self.state = state + self.subgraph_builder = subgraph_builder + self.input_connections = [] + self.symbol_map = {} + + def _add_entry_memlet_path( + self, + src: dace.nodes.AccessNode, + src_subset: sbs.Range, + dst_node: dace.nodes.Node, + dst_conn: Optional[str] = None, + ) -> None: + self.input_connections.append((src, src_subset, dst_node, dst_conn)) + + def _add_map( + self, + name: str, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Helper method to add a map with unique ame in current state.""" + return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) + + def _add_tasklet( + self, + name: str, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Helper method to add a tasklet with unique ame in current state.""" + return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs) + + def _get_tasklet_result( + self, + dtype: dace.typeclass, + src_node: dace.nodes.Tasklet, + src_connector: str, + ) -> ValueExpr: + temp_name = self.sdfg.temp_data_name() + self.sdfg.add_scalar(temp_name, dtype, transient=True) + data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + temp_node = self.state.add_access(temp_name) + self.state.add_edge( + src_node, + src_connector, + temp_node, + None, + dace.Memlet(data=temp_name, subset="0"), + ) + return ValueExpr(temp_node, data_type) + + def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr: + assert len(node.args) == 1 + it = self.visit(node.args[0]) + + if isinstance(it, IteratorExpr): + field_desc = it.field.desc(self.sdfg) + assert len(field_desc.shape) == len(it.dimensions) + if all(isinstance(index, SymbolExpr) for index in it.indices.values()): + # when all indices are symblic expressions, we can perform direct field access through a memlet + field_subset = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] + return MemletExpr(it.field, field_subset) + + else: + raise NotImplementedError + + else: + assert isinstance(it, MemletExpr) + return it + + def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + + else: + assert isinstance(node.fun, gtir.SymRef) + + node_internals = [] + node_connections: dict[str, MemletExpr | ValueExpr] = {} + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, MemletExpr | ValueExpr): + # the argument value is the result of a tasklet node or direct field access + connector = f"__inp_{i}" + node_connections[connector] = arg_expr + node_internals.append(connector) + else: + assert isinstance(arg_expr, SymbolExpr) + # use the argument value without adding any connector + node_internals.append(arg_expr.value) + + # use tasklet connectors as expression arguments + builtin_name = str(node.fun.id) + code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) + + out_connector = "result" + tasklet_node = self._add_tasklet( + builtin_name, + set(node_connections.keys()), + {out_connector}, + "{} = {}".format(out_connector, code), + ) + + for connector, arg_expr in node_connections.items(): + if isinstance(arg_expr, ValueExpr): + self.state.add_edge( + arg_expr.node, + None, + tasklet_node, + connector, + dace.Memlet(data=arg_expr.node.data, subset="0"), + ) + else: + self._add_entry_memlet_path( + arg_expr.node, + arg_expr.subset, + tasklet_node, + connector, + ) + + # TODO: use type inference to determine the result type + if len(node_connections) == 1: + dtype = None + for conn_name in ["__inp_0", "__inp_1"]: + if conn_name in node_connections: + dtype = node_connections[conn_name].node.desc(self.sdfg).dtype + break + if dtype is None: + raise ValueError("Failed to determine the type") + else: + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + dtype = dace_fieldview_util.as_dace_type(node_type) + + return self._get_tasklet_result(dtype, tasklet_node, "result") + + def visit_Lambda( + self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] + ) -> tuple[list[InputConnection], ValueExpr]: + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg + output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) + if isinstance(output_expr, ValueExpr): + return self.input_connections, output_expr + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_entry_memlet_path( + output_expr.node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") + + def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: + dtype = dace_fieldview_util.as_dace_type(node.type) + return SymbolExpr(node.value, dtype) + + def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + param = str(node.id) + assert param in self.symbol_map + return self.symbol_map[param] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py new file mode 100644 index 0000000000..34e95506d8 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -0,0 +1,128 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from typing import Any, Mapping, Optional + +import dace + +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen +from gt4py.next.type_system import type_specifications as ts + + +def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: + """Converts GT4Py scalar type to corresponding DaCe type.""" + match type_.kind: + case ts.ScalarKind.BOOL: + return dace.bool_ + case ts.ScalarKind.INT32: + return dace.int32 + case ts.ScalarKind.INT64: + return dace.int64 + case ts.ScalarKind.FLOAT32: + return dace.float32 + case ts.ScalarKind.FLOAT64: + return dace.float64 + case _: + raise ValueError(f"Scalar type '{type_}' not supported.") + + +def as_scalar_type(typestr: str) -> ts.ScalarType: + """Obtain GT4Py scalar type from generic numpy string representation.""" + try: + kind = getattr(ts.ScalarKind, typestr.upper()) + except AttributeError as ex: + raise ValueError(f"Data type {typestr} not supported.") from ex + return ts.ScalarType(kind) + + +def debug_info( + node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return default + + +def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, gtx_common.Connectivity]: + """ + Filter offset providers of type `Connectivity`. + + In other words, filter out the cartesian offset providers. + Returns a new dictionary containing only `Connectivity` values. + """ + return { + offset: table + for offset, table in offset_provider.items() + if isinstance(table, gtx_common.Connectivity) + } + + +def get_domain( + node: gtir.Expr, +) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + """ + Specialized visit method for domain expressions. + + Returns for each domain dimension the corresponding range. + + TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea + would be to cache the results of lowering here (e.g. using `functools.lru_cache`) + """ + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + + domain = [] + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + bounds = [ + dace.symbolic.SymExpr(gtir_python_codegen.get_source(arg)) + for arg in named_range.args[1:3] + ] + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, bounds[0], bounds[1])) + + return domain + + +def get_domain_ranges( + node: gtir.Expr, +) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + """ + Returns domain represented in dictionary form. + """ + domain = get_domain(node) + + return {dim: (lb, ub) for dim, lb, ub in domain} + + +def get_map_variable(dim: gtx_common.Dimension) -> str: + """ + Format map variable name based on the naming convention for application-specific SDFG transformations. + """ + suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" + return f"i_{dim.value}_gtx_{dim.kind}{suffix}" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py new file mode 100644 index 0000000000..d674e71f0b --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -0,0 +1,384 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Test that ITIR can be lowered to SDFG. + +Note: this test module covers the fieldview flavour of ITIR. +""" + +import copy +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.runners import dace_fieldview as dace_backend +from gt4py.next.type_system import type_specifications as ts +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim + +import numpy as np +import pytest + +dace = pytest.importorskip("dace") + + +N = 10 +IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) +FSYMBOLS = dict( + __w_size_0=N, + __w_stride_0=1, + __x_size_0=N, + __x_stride_0=1, + __y_size_0=N, + __y_stride_0=1, + __z_size_0=N, + __z_stride_0=1, + size=N, +) + + +def test_gtir_copy(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="gtir_copy", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.deref("a")), + domain, + ) + )("x"), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + sdfg(x=a, y=b, **FSYMBOLS) + assert np.allclose(a, b) + + +def test_gtir_update(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + domain, + ) + )("x") + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", 1.0) + + for i, stencil in enumerate([stencil1, stencil2]): + testee = gtir.Program( + id=f"gtir_update_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=domain, + target=gtir.SymRef(id="x"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + a = np.random.rand(N) + ref = a + 1.0 + + sdfg(x=a, **FSYMBOLS) + assert np.allclose(a, ref) + + +def test_gtir_sum2(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="sum_2fields", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", "y"), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.empty_like(a) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + sdfg(x=a, y=b, z=c, **FSYMBOLS) + assert np.allclose(c, (a + b)) + + +def test_gtir_sum2_sym(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="sum_2fields_sym", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", "x"), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + sdfg(x=a, z=b, **FSYMBOLS) + assert np.allclose(b, (a + a)) + + +def test_gtir_sum3(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "x", + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("y", "w"), + ) + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b", "c")( + im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) + ), + domain, + ) + )("x", "y", "w") + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.random.rand(N) + + for i, stencil in enumerate([stencil1, stencil2]): + testee = gtir.Program( + id=f"sum_3fields_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="w", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + d = np.empty_like(a) + + sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) + assert np.allclose(d, (a + b + c)) + + +def test_gtir_cond(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="cond_2sums", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="w", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="scalar", type=ts.ScalarType(ts.ScalarKind.FLOAT64)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "x", + im.call( + im.call("cond")( + im.deref("pred"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("y", "scalar"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("w", "scalar"), + ) + )(), + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.random.rand(N) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + for s in [False, True]: + d = np.empty_like(a) + sdfg(pred=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) + assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) + + +def test_gtir_cond_nested(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="cond_nested", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="pred_1", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="pred_2", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("cond")( + im.deref("pred_1"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", 1), + im.call( + im.call("cond")( + im.deref("pred_2"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", 2), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", 3), + ) + )(), + ) + )(), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + for s1 in [False, True]: + for s2 in [False, True]: + b = np.empty_like(a) + sdfg(pred_1=np.bool_(s1), pred_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) + assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0))