From 17bae8ebabbf3bff8c862656083a369bb91cd28e Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 15 Jan 2025 12:20:05 +0100 Subject: [PATCH] feat[next][dace]: iterator-view support to DaCe backend (#1790) The lowering of scan to SDFG requires the support for iterator view. This PR introduces a subset of iterator features: - Local `if_` with exclusive branch execution - Lowering of `list_get`, `make_tuple` and `tuple_get` in iterator view - Field operators returning a tuple of fields - Tuple of fields with different size Iterator tests are enabled on dace CPU backend without SDFG transformations (`auto_optimize=False`). --------- Co-authored-by: Philip Mueller --- pyproject.toml | 8 + .../runners/dace_common/dace_backend.py | 31 +- .../runners/dace_common/workflow.py | 12 +- .../gtir_builtin_translators.py | 206 +++++-- .../runners/dace_fieldview/gtir_dataflow.py | 575 +++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 135 ++-- .../runners/dace_fieldview/utility.py | 81 ++- tests/next_tests/definitions.py | 25 +- .../ffront_tests/test_execution.py | 6 +- .../iterator_tests/test_builtins.py | 1 + .../iterator_tests/test_trivial.py | 2 + .../iterator_tests/test_tuple.py | 4 + .../iterator_tests/test_column_stencil.py | 9 +- .../test_with_toy_connectivity.py | 4 + tests/next_tests/unit_tests/conftest.py | 4 + .../dace_tests/test_gtir_to_sdfg.py | 62 +- 16 files changed, 875 insertions(+), 290 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78735116ed..88bb2feac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,16 +237,23 @@ markers = [ 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref builtin function', + 'uses_composite_shifts: tests that use composite shifts in unstructured domain', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', + 'uses_ir_if_stmts', + 'uses_lift: tests that require backend support for lift builtin function', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', + 'uses_reduce_with_lambda: tests that use lambdas as reduce functions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', + 'uses_scalar_in_domain_and_fo', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_scan_in_stencil: tests that require backend support for scan in stencil', 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', 'uses_scan_nested: tests that use nested scans', 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', @@ -254,6 +261,7 @@ markers = [ 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_iterator: tests that require backend support to deref tuple iterators', 'uses_tuple_returns: tests that require backend support for tuple results', 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', 'uses_cartesian_shift: tests that use a Cartesian connectivity', diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 90e7e07ad5..387619c667 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -7,13 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping, Sequence -from typing import Any, Iterable +from typing import Any import dace import numpy as np from gt4py._core import definitions as core_defs -from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next import common as gtx_common from . import utility as dace_utils @@ -46,10 +46,9 @@ def _convert_arg(arg: Any, sdfg_param: str) -> Any: def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names - flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { sdfg_param: _convert_arg(arg, sdfg_param) - for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) + for sdfg_param, arg in zip(sdfg_params, args, strict=True) } @@ -73,17 +72,8 @@ def _get_shape_args( for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): if isinstance(sym, dace.symbol): - if sym.name not in shape_args: - shape_args[sym.name] = size - elif shape_args[sym.name] != size: - # The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields - # in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array - # size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple. - # TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique, - # once the assumption on tuples is removed. - raise ValueError( - f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}." - ) + assert sym.name not in shape_args + shape_args[sym.name] = size elif sym != size: raise ValueError( f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." @@ -103,15 +93,8 @@ def _get_stride_args( f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) if isinstance(sym, dace.symbol): - if sym.name not in stride_args: - stride_args[str(sym)] = stride - elif stride_args[sym.name] != stride: - # See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple. - # TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique, - # once the assumption on tuples is removed. - raise ValueError( - f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}." - ) + assert sym.name not in stride_args + stride_args[sym.name] = stride elif sym != stride: raise ValueError( f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}." diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 5d9ac863c5..f0577ffaf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -10,7 +10,7 @@ import ctypes import dataclasses -from typing import Any +from typing import Any, Sequence import dace import factory @@ -112,11 +112,13 @@ def decorated_program( ) -> None: if out is not None: args = (*args, out) - if len(sdfg.arg_names) > len(args): - args = (*args, *arguments.iter_size_args(args)) + flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) + if len(sdfg.arg_names) > len(flat_args): + # The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments. + flat_args = (*flat_args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True)) + kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) use_fast_call = True @@ -151,7 +153,7 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, offset_provider, - *args, + *flat_args, check_args=False, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 354a9692d8..4cbc737312 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace from dace import subsets as dace_subsets @@ -27,6 +27,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, gtir_python_codegen, + gtir_sdfg, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -157,6 +158,33 @@ def get_local_view( """Data type used for field indexing.""" +def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. + """ + return ts.TupleType( + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + +def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ + if isinstance(arg, tuple): + tuple_type = get_tuple_type(arg) + tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) + tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) + return [ + (str(tsym.id), tfield) + for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) + ] + else: + return [(name, arg)] + + class PrimitiveTranslator(Protocol): @abc.abstractmethod def __call__( @@ -191,16 +219,20 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: +) -> ( + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr | tuple[Any, ...], ...] +): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, FieldopData): - raise ValueError(f"Received {node} as argument to field operator, expected a field.") - - return arg.get_local_view(domain) + if isinstance(arg, FieldopData): + return arg.get_local_view(domain) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda x: x.get_local_view(domain))(arg) def _get_field_layout( @@ -232,62 +264,107 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_field_operator( +def _create_field_operator_impl( + sdfg_builder: gtir_sdfg.SDFGBuilder, sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, - node_type: ts.FieldType, - sdfg_builder: gtir_sdfg.SDFGBuilder, - input_edges: Sequence[gtir_dataflow.DataflowInputEdge], output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, ) -> FieldopData: """ - Helper method to allocate a temporary field to store the output of a field operator. + Helper method to allocate a temporary array that stores one field computed by a field operator. + + This method is called by `_create_field_operator()`. Args: + sdfg_builder: The object used to build the map scope in the provided SDFG. sdfg: The SDFG that represents the scope of the field data. state: The SDFG state where to create an access node to the field data. domain: The domain of the field operator that computes the field. - node_type: The GT4Py type of the IR node that produces this field. - sdfg_builder: The object used to build the map scope in the provided SDFG. - input_edges: List of edges to pass input data into the dataflow. - output_edge: Edge representing the dataflow output data. + output_edge: The dataflow write edge representing the output data. + output_type: The GT4Py field type descriptor. + map_exit: The `MapExit` node of the field operator map scope. Returns: The field data descriptor, which includes the field access node in the given `state` and the field domain offset. """ - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_indices = _get_domain_indices(field_dims, field_offset) - dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = dace_subsets.Range.from_indices(field_indices) + domain_dims, domain_offset, domain_shape = _get_field_layout(domain) + domain_indices = _get_domain_indices(domain_dims, domain_offset) + domain_subset = dace_subsets.Range.from_indices(domain_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == node_type.dtype - assert isinstance(dataflow_output_desc, dace.data.Scalar) - assert isinstance(node_type.dtype, ts.ScalarType) - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + assert output_edge.result.gt_dtype == output_type.dtype field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_subset = domain_subset else: - assert isinstance(node_type.dtype, ts.ListType) - assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type field_dtype = output_edge.result.gt_dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None - field_dims.append(output_edge.result.gt_dtype.offset_type) - field_shape.extend(dataflow_output_desc.shape) - field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape[0]] + field_offset = [*domain_offset, dataflow_output_desc.offset[0]] + field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) field_node = state.add_access(field_name) + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(map_exit, field_node, field_subset) + + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) + + +def _create_field_operator( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: FieldopDomain, + node_type: ts.FieldType | ts.TupleType, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_edges: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], +) -> FieldopResult: + """ + Helper method to build the output of a field operator, which can consist of + a single field or a tuple of fields. + + A tuple of fields is returned when one stencil computes a grid point on multiple + fields: for each field, this method will call `_create_field_operator_impl()`. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edges: Single edge or tuple of edges representing the dataflow output data. + + Returns: + The descriptor of the field operator result, which can be either a single field + or a tuple fields. + """ + # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( + map_entry, map_exit = sdfg_builder.add_map( "fieldop", state, ndrange={ @@ -298,16 +375,21 @@ def _create_field_operator( # here we setup the edges passing through the map entry node for edge in input_edges: - edge.connect(me) - - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(mx, field_node, field_subset) + edge.connect(map_entry) - return FieldopData( - field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), - ) + if isinstance(node_type, ts.FieldType): + assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge) + return _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edges, node_type, map_exit + ) + else: + # handle tuples of fields + output_symbol_tree = dace_gtir_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edge, output_sym.type, map_exit + ) + )(output_edges, output_symbol_tree) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -366,16 +448,17 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) fun_node = node.fun assert len(fun_node.args) == 2 fieldop_expr, domain_expr = fun_node.args - assert isinstance(node.type, ts.FieldType) if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. + assert isinstance(node.type, ts.FieldType) stencil_expr = im.lambda_("a")(im.deref("a")) stencil_expr.expr.type = node.type.dtype elif isinstance(fieldop_expr, gtir.Lambda): @@ -394,12 +477,12 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - input_edges, output_edge = gtir_dataflow.visit_lambda( + input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( sdfg, state, sdfg_builder, stencil_expr, fieldop_args ) return _create_field_operator( - sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges ) @@ -458,7 +541,7 @@ def translate_if( def construct_output(inner_data: FieldopData) -> FieldopData: inner_desc = inner_data.dc_node.desc(sdfg) - outer, _ = sdfg.add_temp_transient_like(inner_desc) + outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc) outer_node = state.add_access(outer) return inner_data.make_copy(outer_node) @@ -518,8 +601,7 @@ def translate_index( dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - index_data = sdfg.temp_data_name() - sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_data, _ = sdfg_builder.add_temp_scalar(sdfg, INDEX_DTYPE) index_node = state.add_access(index_data) index_value = gtir_dataflow.ValueExpr( dc_node=index_node, @@ -570,11 +652,10 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) - return tuple( - _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) - for fname, ftype in tuple_fields - ) + symbol_tree = dace_gtir_utils.make_symbol_tree(data_name, data_type) + return gtx_utils.tree_map( + lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) + )(symbol_tree) else: raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") @@ -691,13 +772,11 @@ def translate_scalar_expr( visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: - # `gt_symbol` refers to symbols defined in the GT4Py program - gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id) - if not isinstance(gt_symbol_type, ts.ScalarType): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") + # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined + sdfg_builder.get_symbol_type(arg_expr.id) except KeyError: - # this is the case of non-variable argument, e.g. target type such as `float64`, - # used in a casting expression like `cast_(variable, float64)` + # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, + # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` visit_expr = False if visit_expr: @@ -708,7 +787,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(node.type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -738,12 +817,7 @@ def translate_scalar_expr( dace.Memlet(data=arg_node.data, subset="0"), ) # finally, create temporary for the result value - temp_name, _ = sdfg.add_scalar( - sdfg.temp_data_name(), - dace_utils.as_dace_type(node.type), - find_new_name=True, - transient=True, - ) + temp_name, _ = sdfg_builder.add_temp_scalar(sdfg, dace_utils.as_dace_type(node.type)) temp_node = state.add_access(temp_name) state.add_edge( tasklet_node, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 22d6e17cad..d086b26a2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -28,9 +28,10 @@ from dace import subsets as dace_subsets from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, @@ -115,6 +116,9 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_field_type(self) -> ts.FieldType: + return ts.FieldType([dim for dim, _ in self.field_domain], self.gt_dtype) + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -140,16 +144,19 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: class DataflowInputEdge(Protocol): """ - This protocol represents an open connection into the dataflow. + This protocol describes how to concretize a data edge to read data from a source node + into the dataflow. It provides the `connect` method to setup an input edge from an external data source. - Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope - and connect its inputs and outputs to external data nodes by means of memlets that - traverse the map entry and exit nodes. + The most common case is that the dataflow represents a stencil, which is instantied + inside a map scope and whose inputs and outputs are connected to external data nodes + by means of memlets that traverse the map entry and exit nodes. + The dataflow can also be instatiated without a map, in which case the `map_entry` + argument is set to `None`. """ @abc.abstractmethod - def connect(self, me: dace.nodes.MapEntry) -> None: ... + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... @dataclasses.dataclass(frozen=True) @@ -167,15 +174,18 @@ class MemletInputEdge(DataflowInputEdge): dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] - def connect(self, me: dace.nodes.MapEntry) -> None: + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: memlet = dace.Memlet(data=self.source.data, subset=self.subset) - self.state.add_memlet_path( - self.source, - me, - self.dest, - dst_conn=self.dest_conn, - memlet=memlet, - ) + if map_entry is None: + self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) + else: + self.state.add_memlet_path( + self.source, + map_entry, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) @dataclasses.dataclass(frozen=True) @@ -190,8 +200,12 @@ class EmptyInputEdge(DataflowInputEdge): state: dace.SDFGState node: dace.nodes.Tasklet - def connect(self, me: dace.nodes.MapEntry) -> None: - self.state.add_nedge(me, self.node, dace.Memlet()) + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + if map_entry is None: + # outside of a map scope it is possible to instantiate a tasklet node + # without input connectors + return + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -200,10 +214,12 @@ class DataflowOutputEdge: Allows to setup an output memlet through a map exit node. The result of a dataflow subgraph needs to be written to an external data node. - Since the dataflow represents a stencil and the dataflow is computed over - a field domain, the dataflow is instatiated inside a map scope. The `connect` - method creates a memlet that writes the dataflow result to the external array - passing through the map exit node. + The most common case is that the dataflow represents a stencil and the dataflow + is computed over a field domain, therefore the dataflow is instatiated inside + a map scope. The `connect` method creates a memlet that writes the dataflow + result to the external array passing through the `map_exit` node. + The dataflow can also be instatiated without a map, in which case the `map_exit` + argument is set to `None`. """ state: dace.SDFGState @@ -211,13 +227,13 @@ class DataflowOutputEdge: def connect( self, - mx: dace.nodes.MapExit, + map_exit: Optional[dace.nodes.MapExit], dest: dace.nodes.AccessNode, subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): + if isinstance(last_node, (dace.nodes.Tasklet, dace.nodes.NestedSDFG)): # the last transient node can be deleted last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn self.state.remove_node(self.result.dc_node) @@ -225,13 +241,22 @@ def connect( last_node = self.result.dc_node last_node_connector = None - self.state.add_memlet_path( - last_node, - mx, - dest, - src_conn=last_node_connector, - memlet=dace.Memlet(data=dest.data, subset=subset), - ) + if map_exit is None: + self.state.add_edge( + last_node, + last_node_connector, + dest, + None, + dace.Memlet(data=dest.data, subset=subset), + ) + else: + self.state.add_memlet_path( + last_node, + map_exit, + dest, + src_conn=last_node_connector, + memlet=dace.Memlet(data=dest.data, subset=subset), + ) DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { @@ -267,6 +292,25 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +def get_tuple_type( + data: tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...], +) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions. + """ + data_types: list[ts.DataType] = [] + for dataitem in data: + if isinstance(dataitem, tuple): + data_types.append(get_tuple_type(dataitem)) + elif isinstance(dataitem, IteratorExpr): + data_types.append(dataitem.get_field_type()) + elif isinstance(dataitem, MemletExpr): + data_types.append(dataitem.gt_dtype) + else: + data_types.append(dataitem.gt_dtype) + return ts.TupleType(data_types) + + @dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ @@ -289,9 +333,10 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( - default_factory=lambda: {} - ) + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] = dataclasses.field(default_factory=dict) def _add_input_data_edge( self, @@ -370,9 +415,9 @@ def _add_mapped_tasklet( name: str, map_ranges: Dict[str, str | dace.subsets.Subset] | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: Dict[str, dace.Memlet], code: str, - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """ @@ -427,10 +472,9 @@ def _construct_tasklet_result( # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) # in order to allow for composition of array shape in external memlets. - temp_name, _ = self.sdfg.add_temp_transient((1,), dc_dtype) + temp_name, _ = self.subgraph_builder.add_temp_array(self.sdfg, (1,), dc_dtype) else: - temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dc_dtype, transient=True) + temp_name, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, dc_dtype) temp_node = self.state.add_access(temp_name) self._add_edge( @@ -467,6 +511,9 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # format used for field index tasklet connector IndexConnectorFmt: Final = "__index_{dim}" + if isinstance(node.type, ts.TupleType): + raise NotImplementedError("Tuple deref not supported.") + assert len(node.args) == 1 arg_expr = self.visit(node.args[0]) @@ -545,6 +592,274 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + def _visit_if_branch_arg( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + param_name: str, + arg: IteratorExpr | DataExpr, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + ) -> IteratorExpr | ValueExpr: + """ + Helper method to be called by `_visit_if_branch()` to visit the input arguments. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + param_name: The parameter name of the input argument. + arg: The input argument expression. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + """ + if isinstance(arg, (MemletExpr, ValueExpr)): + arg_expr = arg + arg_node = arg.dc_node + arg_desc = arg_node.desc(self.sdfg) + if isinstance(arg, MemletExpr): + assert arg.subset.num_elements() == 1 + arg_desc = dace.data.Scalar(arg_desc.dtype) + else: + assert isinstance(arg_desc, dace.data.Scalar) + elif isinstance(arg, IteratorExpr): + arg_node = arg.field + arg_desc = arg_node.desc(self.sdfg) + arg_expr = MemletExpr(arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc)) + else: + raise TypeError(f"Unexpected {arg} as input argument.") + + if param_name in if_sdfg.arrays: + inner_desc = if_sdfg.data(param_name) + assert not inner_desc.transient + else: + inner_desc = arg_desc.clone() + inner_desc.transient = False + if_sdfg.add_datadesc(param_name, inner_desc) + if_sdfg_input_memlets[param_name] = arg_expr + + inner_node = if_branch_state.add_access(param_name) + if isinstance(arg, IteratorExpr): + return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) + else: + return ValueExpr(inner_node, arg.gt_dtype) + + def _visit_if_branch( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + expr: gtir.Expr, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + """ + Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. + + This function is called by `_visit_if()` for each if-branch. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + expr: The if branch expression to lower. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + + Returns: + A tuple containing: + - the list of input edges for the parent dataflow + - the output data, in the form of a single data edge or a tuple of data edges. + """ + assert if_branch_state in if_sdfg.states() + + lambda_args = [] + lambda_params = [] + for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): + arg = self.symbol_map[pname] + if isinstance(arg, tuple): + ptype = get_tuple_type(arg) # type: ignore[arg-type] + psymbol = im.sym(pname, ptype) + psymbol_tree = dace_gtir_utils.make_symbol_tree(pname, ptype) + inner_arg = gtx_utils.tree_map( + lambda tsym, targ: self._visit_if_branch_arg( + if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets + ) + )(psymbol_tree, arg) + else: + psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + inner_arg = self._visit_if_branch_arg( + if_sdfg, if_branch_state, pname, arg, if_sdfg_input_memlets + ) + lambda_args.append(inner_arg) + lambda_params.append(psymbol) + + # visit each branch of the if-statement as if it was a Lambda node + lambda_node = gtir.Lambda(params=lambda_params, expr=expr) + input_edges, output_edges = translate_lambda_to_dataflow( + if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, args=lambda_args + ) + + for data_node in if_branch_state.data_nodes(): + # In case tuple arguments, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + if if_branch_state.degree(data_node) == 0: + assert not data_node.desc(if_sdfg).transient + if_branch_state.remove_node(data_node) + + return input_edges, output_edges + + def _visit_if_branch_result( + self, sdfg: dace.SDFG, state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + """ + Helper function to be called by `_visit_if` to create an output connector + on the nested SDFG that will write the result to the parent SDFG. + The result data inside the nested SDFG must have the same name as the connector. + """ + output_data = str(sym.id) + if output_data in sdfg.arrays: + output_desc = sdfg.data(output_data) + assert not output_desc.transient + else: + # If the result is currently written to a transient node, inside the nested SDFG, + # we need to allocate a non-transient data node. + result_desc = edge.result.dc_node.desc(sdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = state.add_access(output_data) + state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + """ + Lowers an if-expression with exclusive branch execution into a nested SDFG, + in which each branch is lowered into a dataflow in a separate state and + the if-condition is represented as the inter-state edge condtion. + """ + + def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: + # Each output connector of the nested SDFG writes to a transient node in the parent SDFG + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.subgraph_builder.add_temp_array_like(self.sdfg, inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + + assert len(node.args) == 3 + + # TODO(edopao): enable once supported in next DaCe release + use_conditional_block: Final[bool] = False + + # evaluate the if-condition that will write to a boolean scalar node + condition_value = self.visit(node.args[0]) + assert ( + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool_) + ) + + nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) + nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + + # create states inside the nested SDFG for the if-branches + if use_conditional_block: + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) + + else: + entry_state = nsdfg.add_state("entry", is_start_block=True) + tstate = nsdfg.add_state("true_branch") + nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond")) + fstate = nsdfg.add_state("false_branch") + nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) + + input_memlets: dict[str, MemletExpr | ValueExpr] = {} + + # define scalar or symbol for the condition value inside the nested SDFG + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + input_memlets["__cond"] = condition_value + + for nstate, arg in zip([tstate, fstate], node.args[1:3]): + # visit each if-branch in the corresponding state of the nested SDFG + in_edges, out_edge = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) + for edge in in_edges: + edge.connect(map_entry=None) + + if isinstance(out_edge, tuple): + assert isinstance(node.type, ts.TupleType) + out_symbol_tree = dace_gtir_utils.make_symbol_tree("__output", node.type) + outer_value = gtx_utils.tree_map( + lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) + )(out_edge, out_symbol_tree) + else: + assert isinstance(node.type, ts.FieldType | ts.ScalarType) + outer_value = self._visit_if_branch_result( + nsdfg, nstate, out_edge, im.sym("__output", node.type) + ) + # Isolated access node will make validation fail. + # Isolated access nodes can be found in `make_tuple` expressions that + # construct tuples from input arguments. + for data_node in nstate.data_nodes(): + if nstate.degree(data_node) == 0: + assert not data_node.desc(nsdfg).transient + nsdfg.remove_node(data_node) + else: + result = outer_value + + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, + self.sdfg, + inputs=set(input_memlets.keys()), + outputs=outputs, + symbol_mapping=None, # implicitly map all free symbols to the symbols available in parent SDFG + ) + + for inner, input_expr in input_memlets.items(): + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, inner) + else: + self._add_edge( + input_expr.dc_node, + None, + nsdfg_node, + inner, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + + return ( + gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) + if isinstance(result, tuple) + else write_output_of_nested_sdfg_to_temporary(result) + ) + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ListType) assert len(node.args) == 2 @@ -605,8 +920,8 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) ) - neighbors_temp, _ = self.sdfg.add_temp_transient( - (offset_provider.max_neighbors,), field_desc.dtype + neighbors_temp, _ = self.subgraph_builder.add_temp_array( + self.sdfg, (offset_provider.max_neighbors,), field_desc.dtype ) neighbors_node = self.state.add_access(neighbors_temp) offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) @@ -652,6 +967,56 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) ) + def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: + assert len(node.args) == 2 + index_arg = self.visit(node.args[0]) + list_arg = self.visit(node.args[1]) + assert isinstance(list_arg, ValueExpr) + assert isinstance(list_arg.gt_dtype, ts.ListType) + assert isinstance(list_arg.gt_dtype.element_type, ts.ScalarType) + + list_desc = list_arg.dc_node.desc(self.sdfg) + assert len(list_desc.shape) == 1 + + result_dtype = dace_utils.as_dace_type(list_arg.gt_dtype.element_type) + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, result_dtype) + result_node = self.state.add_access(result) + + if isinstance(index_arg, SymbolExpr): + assert index_arg.dc_dtype in dace.dtypes.INTEGER_TYPES + self._add_edge( + list_arg.dc_node, + None, + result_node, + None, + dace.Memlet(data=list_arg.dc_node.data, subset=index_arg.value), + ) + elif isinstance(index_arg, ValueExpr): + tasklet_node = self._add_tasklet( + "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" + ) + self._add_edge( + index_arg.dc_node, + None, + tasklet_node, + "index", + dace.Memlet(data=index_arg.dc_node.data, subset="0"), + ) + self._add_edge( + list_arg.dc_node, + None, + tasklet_node, + "list", + self.sdfg.make_array_memlet(list_arg.dc_node.data), + ) + self._add_edge( + tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") + ) + else: + raise TypeError(f"Unexpected value {index_arg} as index argument.") + + return ValueExpr(dc_node=result_node, gt_dtype=list_arg.gt_dtype.element_type) + def _visit_map(self, node: gtir.FunCall) -> ValueExpr: """ A map node defines an operation to be mapped on all elements of input arguments. @@ -743,7 +1108,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: ) input_nodes[input_node.data] = input_node - result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) + result, _ = self.subgraph_builder.add_temp_array(self.sdfg, (local_size,), dc_dtype) result_node = self.state.add_access(result) if offset_provider_type.has_skip_values: @@ -930,8 +1295,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: op_name, reduce_init, reduce_identity = get_reduce_params(node) reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") - result = self.sdfg.temp_data_name() - self.sdfg.add_scalar(result, reduce_identity.dc_dtype, transient=True) + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, reduce_identity.dc_dtype) result_node = self.state.add_access(result) input_expr = self.visit(node.args[0]) @@ -1119,10 +1483,7 @@ def _make_unstructured_shift( """Implements shift in unstructured domain by means of a neighbor table.""" assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain - assert neighbor_dim not in it.indices - origin_dim = connectivity.source_dim - assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1253,13 +1614,45 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: + def _visit_make_tuple(self, node: gtir.FunCall) -> tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "make_tuple") + return tuple(self.visit(arg) for arg in node.args) + + def _visit_tuple_get( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + tuple_fields = self.visit(node.args[1]) + return tuple_fields[index] + + def visit_FunCall( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "if_"): + return self._visit_if(node) + elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_call_to(node, "list_get"): + return self._visit_list_get(node) + + elif cpm.is_call_to(node, "make_tuple"): + return self._visit_make_tuple(node) + + elif cpm.is_call_to(node, "tuple_get"): + return self._visit_tuple_get(node) + elif cpm.is_applied_map(node): return self._visit_map(node) @@ -1279,35 +1672,52 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: - result: DataExpr = self.visit(node.expr) + def visit_Lambda( + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + def _visit_Lambda_impl( + output_expr: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(output_expr, DataflowOutputEdge): + return output_expr + if isinstance(output_expr, ValueExpr): + return DataflowOutputEdge(self.state, output_expr) + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_data_edge( + output_expr.dc_node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dc_dtype + tasklet_node = self._add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) - if isinstance(result, ValueExpr): - return DataflowOutputEdge(self.state, result) + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return DataflowOutputEdge(self.state, output_expr) - if isinstance(result, MemletExpr): - # special case where the field operator is simply copying data from source to destination node - output_dtype = result.dc_node.desc(self.sdfg).dtype - tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_data_edge( - result.dc_node, - result.subset, - tasklet_node, - "__inp", - ) - else: - # even simpler case, where a constant value is written to destination node - output_dtype = result.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") + result = self.visit(node.expr) - output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return DataflowOutputEdge(self.state, output_expr) + return ( + gtx_utils.tree_map(_visit_Lambda_impl)(result) + if isinstance(result, tuple) + else _visit_Lambda_impl(result) + ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1318,8 +1728,13 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE def visit_let( self, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> DataflowOutputEdge: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: """ Maps lambda arguments to internal parameters. @@ -1349,13 +1764,21 @@ def visit_let( return self.visit(node) -def visit_lambda( +def translate_lambda_to_dataflow( sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], -) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], +) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], +]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, that can be instantiated inside a map scope implementing the field operator. @@ -1367,7 +1790,7 @@ def visit_lambda( Args: sdfg: The SDFG where the dataflow graph will be instantiated. state: The SDFG state where the dataflow graph will be instantiated. - sdfg_builder: Helper class to build the SDFG. + sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. @@ -1377,5 +1800,5 @@ def visit_lambda( - Output data connection. """ taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) - output_edge = taskgen.visit_let(node, args) - return taskgen.input_edges, output_edge + output_edges = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edges diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7cb1461746..23a36ba79f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -53,6 +53,25 @@ def unique_map_name(self, name: str) -> str: ... @abc.abstractmethod def unique_tasklet_name(self, name: str) -> str: ... + def add_temp_array( + self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient(shape, dtype) + + def add_temp_array_like( + self, sdfg: dace.SDFG, datadesc: dace.data.Array + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient_like(datadesc) + + def add_temp_scalar( + self, sdfg: dace.SDFG, dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary scalar to the SDFG.""" + temp_name = sdfg.temp_data_name() + return sdfg.add_scalar(temp_name, dtype, transient=True) + def add_map( self, name: str, @@ -86,9 +105,9 @@ def add_mapped_tasklet( state: dace.SDFGState, map_ranges: Dict[str, str | dace.subsets.Subset] | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: Dict[str, dace.Memlet], code: str, - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" @@ -149,15 +168,6 @@ def _collect_symbols_in_domain_expressions( ) -def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -173,9 +183,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( - default_factory=lambda: {} + default_factory=dict ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -246,7 +256,6 @@ def _add_storage( name: str, gt_type: ts.DataType, transient: bool = True, - tuple_name: Optional[str] = None, ) -> list[tuple[str, ts.DataType]]: """ Add storage in the SDFG for a given GT4Py data symbol. @@ -266,7 +275,6 @@ def _add_storage( name: Symbol Name to be allocated. gt_type: GT4Py symbol type. transient: True when the data symbol has to be allocated as internal storage. - tuple_name: Must be set for tuple fields in order to use the same array shape and strides symbols. Returns: List of tuples '(data_name, gt_type)' where 'data_name' is the name of @@ -277,11 +285,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): + for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): + assert isinstance(sym.type, ts.DataType) tuple_fields.extend( - self._add_storage( - sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name - ) + self._add_storage(sdfg, symbolic_arguments, sym.id, sym.type, transient) ) return tuple_fields @@ -293,16 +300,9 @@ def _add_storage( # ListType not supported: concept is represented as Field with local Dimension assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) - if tuple_name is None: - # Use symbolic shape, which allows to invoke the program with fields of different size; - # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) - else: - # All fields in a tuple must have the same dims and sizes, - # therefore we use the same shape and strides symbols based on 'tuple_name'. - sym_shape, sym_strides = self._make_array_shape_and_strides( - tuple_name, gt_type.dims - ) + # Use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) return [(name, gt_type)] @@ -367,7 +367,7 @@ def make_temps( if desc.transient or not use_temp: return field else: - temp, _ = sdfg.add_temp_transient_like(desc) + temp, _ = self.add_temp_array_like(sdfg, desc) temp_node = head_state.add_access(temp) head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) @@ -438,13 +438,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: assert len(self.field_offsets) == 0 sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) - - # DaCe requires C-compatible strings for the names of data containers, - # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name - # separator in the SSA pass, which generates invalid symbols for DaCe. - # Here we find new names for invalid symbols present in the IR. - node = dace_gtir_utils.replace_invalid_symbols(sdfg, node) + sdfg.debuginfo = dace_utils.debug_info(node) # start block of the stateful graph entry_state = sdfg.add_state("program_entry", is_start_block=True) @@ -633,24 +627,13 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] - def flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = _get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - lambda_arg_nodes = dict( - itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + itertools.chain( + *[ + gtir_builtin_translators.flatten_tuples(pname, arg) + for pname, arg in lambda_args_mapping + ] + ) ) # inherit symbols from parent scope but eventually override with local symbols @@ -658,7 +641,9 @@ def flatten_tuples( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: gtir_builtin_translators.get_tuple_type(arg) + if isinstance(arg, tuple) + else arg.gt_type for pname, arg in lambda_args_mapping } @@ -673,12 +658,12 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( - lambda field_offsets, field: ( - field_offsets | get_field_domain_offset(field[0], field[1]) + lambda field_offsets, sym: ( + field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] ), - p_fields, + tsyms, {}, ) return {} @@ -722,15 +707,24 @@ def get_field_domain_offset( } input_memlets = {} - nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): if nsdfg_datadesc.transient: continue - datadesc: Optional[dace.dtypes.Data] = None + if nsdfg_dataname in lambda_arg_nodes: src_node = lambda_arg_nodes[nsdfg_dataname].dc_node dataname = src_node.data datadesc = src_node.desc(sdfg) + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], + [*datadesc.shape, *datadesc.strides], + strict=True, + ) + if isinstance(nested_symbol, dace.symbol) + } else: dataname = nsdfg_dataname datadesc = sdfg.arrays[nsdfg_dataname] @@ -741,16 +735,6 @@ def get_field_domain_offset( input_memlets[nsdfg_dataname] = sdfg.make_array_memlet(dataname) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], - [*datadesc.shape, *datadesc.strides], - strict=True, - ) - if isinstance(nested_symbol, dace.symbol) - } - # Process lambda outputs # # The output arguments do not really exist, so they are not allocated before @@ -817,7 +801,7 @@ def construct_output_for_nested_sdfg( # that is externally allocated, as required by the SDFG IR. An output edge will write the result # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. inner_desc.transient = False - outer, outer_desc = sdfg.add_temp_transient_like(inner_desc) + outer, outer_desc = self.add_temp_array_like(sdfg, inner_desc) # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. dace.symbolic.safe_replace( nsdfg_symbols_mapping, @@ -884,6 +868,13 @@ def build_sdfg_from_gtir( ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) + + # DaCe requires C-compatible strings for the names of data containers, + # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name + # separator in the SSA pass, which generates invalid symbols for DaCe. + # Here we find new names for invalid symbols present in the IR. + ir = dace_gtir_utils.replace_invalid_symbols(ir) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index c46420c24b..6121529161 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -8,14 +8,14 @@ from __future__ import annotations -import itertools from typing import Dict, TypeVar import dace from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -27,43 +27,55 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: return f"i_{dim.value}_gtx_{dim.kind}{suffix}" -def get_tuple_fields( - tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> list[tuple[str, ts.DataType]]: +def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a tree representation of the symbols corresponding to the tuple fields. + The constructed tree preserves the nested nature of the tuple type, if any. Examples -------- >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert get_tuple_fields("a", t, flatten=True) == [ - ... ("a_0", sty), - ... ("a_1_0", fty), - ... ("a_1_1", sty), - ... ] + >>> assert make_symbol_tree("a", t) == ( + ... im.sym("a_0", sty), + ... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)), + ... ) """ assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] - if flatten: - expanded_fields: list[list[tuple[str, ts.DataType]]] = [ - get_tuple_fields(field_name, field_type) - if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] # type: ignore[list-item] # checked in assert - for field_name, field_type in fields - ] - return list(itertools.chain(*expanded_fields)) - else: - return fields # type: ignore[return-value] # checked in assert - - -def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: + return tuple( + make_symbol_tree(field_name, field_type) # type: ignore[misc] + if isinstance(field_type, ts.TupleType) + else im.sym(field_name, field_type) + for field_name, field_type in fields + ) + + +def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: + """ + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert flatten_tuple_fields("a", t) == [ + ... im.sym("a_0", sty), + ... im.sym("a_1_0", fty), + ... im.sym("a_1_1", sty), + ... ] + """ + symbol_tree = make_symbol_tree(tuple_name, tuple_type) + return list(gtx_utils.flatten_nested_tuple(symbol_tree)) + + +def replace_invalid_symbols(ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). - If any invalid symbol present, this funtion returns a copy of the input IR where + If any invalid symbol present, this function returns a copy of the input IR where the invalid symbols have been replaced with new names. If all symbols are valid, the input IR is returned without copying it. """ @@ -85,12 +97,17 @@ def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.S if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params): raise ValueError("Invalid symbol in program parameters.") + ir_sym_ids = {str(sym.id) for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set()} + ir_ssa_uuid = eve.utils.UIDGenerator(prefix="gtir_tmp") + invalid_symbols_mapping = { - sym_id: sdfg.temp_data_name() - for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set() - if not dace.dtypes.validate_name(sym_id := str(sym.id)) + sym_id: ir_ssa_uuid.sequential_id() + for sym_id in ir_sym_ids + if not dace.dtypes.validate_name(sym_id) } - if len(invalid_symbols_mapping) != 0: - return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) - else: + if len(invalid_symbols_mapping) == 0: return ir + + # assert that the new symbol names are not used in the IR + assert ir_sym_ids.isdisjoint(invalid_symbols_mapping.values()) + return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index bed6e89a52..e19d9e1d81 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -86,14 +86,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ALL = "all" REQUIRES_ATLAS = "requires_atlas" USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_COMPOSITE_SHIFTS = "uses_composite_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" +USES_LIFT = "uses_lift" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" +USES_REDUCE_WITH_LAMBDA = "uses_reduce_with_lambda" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" @@ -105,6 +109,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLE_ITERATOR = "uses_tuple_iterator" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" USES_CARTESIAN_SHIFT = "uses_cartesian_shift" @@ -132,11 +137,21 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): DOMAIN_INFERENCE_SKIP_LIST = [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), -] +DACE_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_COMPOSITE_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9de4449ac2..2e40cb897a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -360,6 +360,7 @@ def testee(qc: cases.IKFloatField, scalar: float): @pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator +@pytest.mark.uses_tuple_iterator def test_tuple_scalar_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( @@ -867,8 +868,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) -@pytest.mark.uses_tuple_args @pytest.mark.uses_scan +@pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -897,6 +899,7 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_different_domain_in_tuple(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] @@ -936,6 +939,7 @@ def foo( @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_tuple_field_scalar_mixed(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c0a4cd166d..885a272bfe 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -244,6 +244,7 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index fe89fe7c9d..7836b1b110 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -38,6 +38,7 @@ def baz(baz_inp): return deref(lift(bar)(baz_inp)) +@pytest.mark.uses_lift def test_trivial(program_processor): program_processor, validate = program_processor @@ -66,6 +67,7 @@ def stencil_shifted_arg_to_lift(inp): return deref(lift(deref)(shift(I, -1)(inp))) +@pytest.mark.uses_lift def test_shifted_arg_to_lift(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 39d0bd69c3..ea89bb23ba 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -219,6 +219,7 @@ def tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -272,6 +273,7 @@ def tuple_tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor @@ -319,6 +321,7 @@ def test_field_of_2_extra_dim_input(program_processor): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scalar_tuple_args(program_processor): @fundef def stencil(inp): @@ -348,6 +351,7 @@ def stencil(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_mixed_field_scalar_tuple_arg(program_processor): @fundef def stencil(inp): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index f8e9f22eff..3b4fc0a70c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -122,19 +122,19 @@ def k_level_condition_upper_tuple(k_idx, k_level): @pytest.mark.parametrize( "fun, k_level, inp_function, ref_function", [ - ( + pytest.param( k_level_condition_lower, lambda inp: 0, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([[0], inp[:-1]]), ), - ( + pytest.param( k_level_condition_upper, lambda inp: inp.shape[0] - 1, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), - ( + pytest.param( k_level_condition_upper_tuple, lambda inp: inp[0].shape[0] - 1, lambda k_size: ( @@ -142,6 +142,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + marks=pytest.mark.uses_tuple_iterator, ), ], ) @@ -184,6 +185,7 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): "kstart, reference", [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) +@pytest.mark.uses_scan def test_ksum_scan(program_processor, kstart, reference): program_processor, validate = program_processor shape = [1, 7] @@ -211,6 +213,7 @@ def ksum_back_fencil(i_size, k_size, inp, out): set_at(as_fieldop(scan(ksum, False, 0.0), domain)(inp), domain, out) +@pytest.mark.uses_scan def test_ksum_back_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ac7ce9e544..ff87de7348 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -149,6 +149,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) +@pytest.mark.uses_composite_shifts def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor): program_processor, validate = program_processor inp = vertex_index_field() @@ -174,6 +175,7 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field(program_processor): program_processor, validate = program_processor @@ -196,6 +198,7 @@ def test_sparse_input_field(program_processor): assert np.allclose(out.asnumpy(), ref) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field_v2v(program_processor): program_processor, validate = program_processor @@ -330,6 +333,7 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) +@pytest.mark.uses_lift def test_lift(program_processor): program_processor, validate = program_processor inp = vertex_index_field() diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 03662f8dcc..0bd8653a03 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,6 +60,10 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), + pytest.param( + (next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, True), + marks=pytest.mark.requires_dace, + ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 03b8e3bc15..225d22562f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -256,7 +256,16 @@ def test_gtir_tuple_args(): x_fields = (a, a, b) - sdfg(*x_fields, c, **FSYMBOLS) + tuple_symbols = { + "__x_0_size_0": N, + "__x_0_stride_0": 1, + "__x_1_0_size_0": N, + "__x_1_0_stride_0": 1, + "__x_1_1_size_0": N, + "__x_1_1_stride_0": 1, + } + + sdfg(*x_fields, c, **FSYMBOLS, **tuple_symbols) assert np.allclose(c, a * 2 + b) @@ -418,7 +427,16 @@ def test_gtir_tuple_return(): z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_0_size_0": N, + "__z_0_0_stride_0": 1, + "__z_0_1_size_0": N, + "__z_0_1_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], a) assert np.allclose(z_fields[2], b) @@ -673,9 +691,16 @@ def test_gtir_cond_with_tuple_return(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + for s in [False, True]: z_fields = (np.empty_like(a), np.empty_like(a)) - sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS) + sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a if s else b) assert np.allclose(z_fields[1], b if s else a) @@ -1846,7 +1871,14 @@ def test_gtir_let_lambda_with_tuple1(): a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a_ref) assert np.allclose(z_fields[1], b_ref) @@ -1886,7 +1918,16 @@ def test_gtir_let_lambda_with_tuple2(): z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + "__z_2_size_0": N, + "__z_2_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], val) assert np.allclose(z_fields[2], b) @@ -1938,8 +1979,17 @@ def test_gtir_if_scalars(): sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + tuple_symbols = { + "__x_0_size_0": N, + "__x_0_stride_0": 1, + "__x_1_0_size_0": N, + "__x_1_0_stride_0": 1, + "__x_1_1_size_0": N, + "__x_1_1_stride_0": 1, + } + for s in [False, True]: - sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS) + sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(b, (a + d1 if s else a + d2))