Skip to content

Commit

Permalink
feat[next][dace]: iterator-view support to DaCe backend (#1790)
Browse files Browse the repository at this point in the history
The lowering of scan to SDFG requires the support for iterator view.
This PR introduces a subset of iterator features:
- Local `if_` with exclusive branch execution
- Lowering of `list_get`, `make_tuple` and `tuple_get` in iterator view
- Field operators returning a tuple of fields
- Tuple of fields with different size

Iterator tests are enabled on dace CPU backend without SDFG
transformations (`auto_optimize=False`).

---------

Co-authored-by: Philip Mueller <philip.mueller@cscs.ch>
  • Loading branch information
edopao and philip-paul-mueller authored Jan 15, 2025
1 parent 33bb68b commit 17bae8e
Show file tree
Hide file tree
Showing 16 changed files with 875 additions and 290 deletions.
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,31 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref builtin function',
'uses_composite_shifts: tests that use composite shifts in unstructured domain',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_floordiv: tests that require backend support for floor division',
'uses_if_stmts: tests that require backend support for if-statements',
'uses_index_fields: tests that require backend support for index fields',
'uses_ir_if_stmts',
'uses_lift: tests that require backend support for lift builtin function',
'uses_negative_modulo: tests that require backend support for modulo on negative numbers',
'uses_origin: tests that require backend support for domain origin',
'uses_reduce_with_lambda: tests that use lambdas as reduce functions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scalar_in_domain_and_fo',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_scan_in_stencil: tests that require backend support for scan in stencil',
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_scan_nested: tests that use nested scans',
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_iterator: tests that require backend support to deref tuple iterators',
'uses_tuple_returns: tests that require backend support for tuple results',
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, Iterable
from typing import Any

import dace
import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next import common as gtx_common

from . import utility as dace_utils

Expand Down Expand Up @@ -46,10 +46,9 @@ def _convert_arg(arg: Any, sdfg_param: str) -> Any:

def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
return {
sdfg_param: _convert_arg(arg, sdfg_param)
for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True)
for sdfg_param, arg in zip(sdfg_params, args, strict=True)
}


Expand All @@ -73,17 +72,8 @@ def _get_shape_args(
for name, value in args.items():
for sym, size in zip(arrays[name].shape, value.shape, strict=True):
if isinstance(sym, dace.symbol):
if sym.name not in shape_args:
shape_args[sym.name] = size
elif shape_args[sym.name] != size:
# The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields
# in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array
# size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple.
# TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique,
# once the assumption on tuples is removed.
raise ValueError(
f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}."
)
assert sym.name not in shape_args
shape_args[sym.name] = size
elif sym != size:
raise ValueError(
f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}."
Expand All @@ -103,15 +93,8 @@ def _get_stride_args(
f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)."
)
if isinstance(sym, dace.symbol):
if sym.name not in stride_args:
stride_args[str(sym)] = stride
elif stride_args[sym.name] != stride:
# See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple.
# TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique,
# once the assumption on tuples is removed.
raise ValueError(
f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}."
)
assert sym.name not in stride_args
stride_args[sym.name] = stride
elif sym != stride:
raise ValueError(
f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ctypes
import dataclasses
from typing import Any
from typing import Any, Sequence

import dace
import factory
Expand Down Expand Up @@ -112,11 +112,13 @@ def decorated_program(
) -> None:
if out is not None:
args = (*args, out)
if len(sdfg.arg_names) > len(args):
args = (*args, *arguments.iter_size_args(args))
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True))
kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True))
kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu))

use_fast_call = True
Expand Down Expand Up @@ -151,7 +153,7 @@ def decorated_program(
sdfg_args = dace_backend.get_sdfg_args(
sdfg,
offset_provider,
*args,
*flat_args,
check_args=False,
on_gpu=on_gpu,
use_field_canonical_representation=use_field_canonical_representation,
Expand Down
Loading

0 comments on commit 17bae8e

Please sign in to comment.