Skip to content

Commit

Permalink
Merge decouple_inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Feb 25, 2025
2 parents 249ba8b + e5ed262 commit 6627c8f
Show file tree
Hide file tree
Showing 17 changed files with 239 additions and 72 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _local
/src/__init__.py
/tests/__init__.py
.gt_cache/
.gt4py_cache/
.gt_cache_pytest*/

# DaCe
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def isbool(self):
return self == self.BOOL

def isinteger(self):
return self in (self.INT8, self.INT32, self.INT64)
return self in (self.INT8, self.INT16, self.INT32, self.INT64)

def isfloat(self):
return self in (self.FLOAT32, self.FLOAT64)
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,13 @@ class ScalarAccess(common.ScalarAccess, Expr):


class VariableKOffset(common.VariableKOffset[Expr]):
pass
@datamodels.validator("k")
def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None:
for part in expression.walk_values():
if isinstance(part, Cast):
raise ValueError(
"DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881."
)


class IndexAccess(common.FieldAccess, Expr):
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _visit_offset(
else:
int_sizes.append(None)
sym_offsets = [
dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs))
dace.symbolic.pystr_to_symbolic(
self.visit(off, access_info=access_info, decl=decl, **kwargs)
)
for off in (node.to_dict()["i"], node.to_dict()["j"], node.k)
]
for axis in access_info.variable_offset_axes:
Expand Down
40 changes: 30 additions & 10 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __str__(self) -> str:
return self.value


dims_kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3}


def dimension_to_implicit_offset(dim: str) -> str:
"""
Return name of offset implicitly defined by a dimension.
Expand Down Expand Up @@ -1123,12 +1126,20 @@ class GridType(StrEnum):
UNSTRUCTURED = "unstructured"


def check_dims(dims: list[Dimension]) -> None:
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")

if dims != sorted(dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)):
raise ValueError(f"Dimensions {dims} are not correctly ordered.")


def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]:
"""
Find a sorted ordering of multiple lists of dimensions.
Find an ordering of multiple lists of dimensions.
The resulting list contains all unique dimensions from the input lists,
sorted first by `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then
sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then
lexicographically by `Dimension.value`.
Examples:
Expand All @@ -1138,21 +1149,30 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]:
>>> K = Dimension("K", DimensionKind.VERTICAL)
>>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL)
>>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL)
>>> promote_dims([K, J], [I, K]) == [I, J, K]
>>> promote_dims([J, K], [I, K]) == [I, J, K]
True
>>> promote_dims([K, I], [E2C, E2V]) == [I, K, E2C, E2V]
>>> promote_dims([K, J], [I, K])
Traceback (most recent call last):
...
raise ValueError(f"Dimensions {dims} are not correctly ordered.")
ValueError: Dimensions [Dimension(value='K', kind=<DimensionKind.VERTICAL: 'vertical'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)] are not correctly ordered.
>>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V]
True
>>> promote_dims([I, E2C], [K, E2V])
Traceback (most recent call last):
...
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
ValueError: There are more than one dimension with DimensionKind 'LOCAL'.
"""

for dims in dims_list:
check_dims(list(dims))
unique_dims = {dim for dims in dims_list for dim in dims}

kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3}

return (
sorted(unique_dims, key=lambda dim: (kind_order[dim.kind], dim.value))
if unique_dims
else []
)
promoted_dims = sorted(unique_dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value))
check_dims(promoted_dims)
return promoted_dims if unique_dims else []


class FieldBuiltinFuncRegistry:
Expand Down
12 changes: 3 additions & 9 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,13 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType:


@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_BUILTINS)
def _(
lhs: ts.ScalarType | ts.FieldType, rhs: ts.ScalarType | ts.FieldType
) -> ts.ScalarType | ts.FieldType:
def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType:
if isinstance(lhs, ts.DeferredType):
return rhs
if isinstance(rhs, ts.DeferredType):
return lhs
if lhs == rhs:
return lhs
else:
assert isinstance(lhs, ts.FieldType) and isinstance(rhs, ts.FieldType)
assert lhs.dtype == rhs.dtype
return ts.FieldType(dims=common.promote_dims(*[lhs.dims, rhs.dims]), dtype=lhs.dtype)
assert lhs == rhs
return lhs


@_register_builtin_type_synthesizer(
Expand Down
39 changes: 30 additions & 9 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,36 @@ def connect(
dest: dace.nodes.AccessNode,
subset: dace_subsets.Range,
) -> None:
# retrieve the node which writes the result
last_node = self.state.in_edges(self.result.dc_node)[0].src
if isinstance(last_node, dace.nodes.Tasklet):
# the last transient node can be deleted
# Note that it could also be applied when `last_node` is a NestedSDFG,
# but an exception would be when the inner write to global data is a
# WCR memlet, because that prevents fusion of the outer map. This case
# happens for the reduce with skip values, which uses a map with WCR.
last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn
write_edge = self.state.in_edges(self.result.dc_node)[0]
write_size = write_edge.data.dst_subset.num_elements()
# check the kind of node which writes the result
if isinstance(write_edge.src, dace.nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted
assert write_size.is_constant()
remove_last_node = True
elif isinstance(write_edge.src, dace.nodes.NestedSDFG):
if write_size.is_constant():
# Temporary data with compile-time size is allocated on the stack
# and therefore is safe to keep. We decide to keep it as a workaround
# for a dace issue with memlet propagation in combination with
# nested SDFGs containing conditional blocks. The output memlet
# of such blocks will be marked as dynamic because dace is not able
# to detect the exact size of a conditional branch dataflow, even
# in case of if-else expressions with exact same output data.
remove_last_node = False
else:
# In case the output data has runtime size it is necessary to remove
# it in order to avoid dynamic memory allocation inside a parallel
# map scope. Otherwise, the memory allocation will for sure lead
# to performance degradation, and eventually illegal memory issues
# when the gpu runs out of local memory.
remove_last_node = True
else:
remove_last_node = False

if remove_last_node:
last_node = write_edge.src
last_node_connector = write_edge.src_conn
self.state.remove_node(self.result.dc_node)
else:
last_node = self.result.dc_node
Expand Down
7 changes: 5 additions & 2 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def promote(
True
>>> promoted: ts.FieldType = promote(
... ts.FieldType(dims=[J, I], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype)
... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype)
... )
>>> promoted.dims == [I, J, K] and promoted.dtype == dtype
True
Expand Down Expand Up @@ -641,7 +641,10 @@ def return_type_field(
new_dims.append(d)
else:
new_dims.extend(target_dims)
return ts.FieldType(dims=new_dims, dtype=field_type.dtype)
return ts.FieldType(
dims=sorted(new_dims, key=lambda dim: (common.dims_kind_order[dim.kind], dim.value)),
dtype=field_type.dtype,
)


UNDEFINED_ARG = types.new_class("UNDEFINED_ARG")
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def __str__(self) -> str:
dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]"
return f"Field[{dims}, {self.dtype}]"

@eve_datamodels.validator("dims")
def _dims_validator(
self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension]
) -> None:
common.check_dims(dims)


class TupleType(DataType):
# TODO(tehrengruber): Remove `DeferredType` again. This was erroneously
Expand Down
6 changes: 6 additions & 0 deletions tests/cartesian_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _get_backends_with_storage_info(storage_info_kind: str):
_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")]
PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES]

DACE_BACKENDS = [
_backend_name_as_param(name)
for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES)
]
NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS]


@pytest.fixture()
def id_version():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
)
from gt4py.storage.cartesian import utils as storage_utils

from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library
from cartesian_tests.definitions import (
ALL_BACKENDS,
CPU_BACKENDS,
DACE_BACKENDS,
NON_DACE_BACKENDS,
get_array_library,
)
from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import (
EXTERNALS_REGISTRY as externals_registry,
REGISTRY as stencil_definitions,
Expand Down Expand Up @@ -762,3 +768,89 @@ def test(
out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64)
test(in_arr, out_arr)
assert (out_arr[:, :, :] == 388.0).all()


@pytest.mark.parametrize("backend", NON_DACE_BACKENDS)
def test_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", DACE_BACKENDS)
@pytest.mark.xfail(raises=ValueError)
def test_dace_no_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_read_after_write_stencil(backend):
"""Stencil with multiple read after write access patterns."""

@gtscript.stencil(backend=backend)
def lagrangian_contributions(
q: Field[np.float64],
pe1: Field[np.float64],
pe2: Field[np.float64],
q4_1: Field[np.float64],
q4_2: Field[np.float64],
q4_3: Field[np.float64],
q4_4: Field[np.float64],
dp1: Field[np.float64],
lev: gtscript.Field[gtscript.IJ, np.int64],
):
"""
Args:
q (out):
pe1 (in):
pe2 (in):
q4_1 (in):
q4_2 (in):
q4_3 (in):
q4_4 (in):
dp1 (in):
lev (inout):
"""
with computation(FORWARD), interval(...):
pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev]
if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]:
pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev]
q = (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl)
)
else:
qsum = (pe1[0, 0, lev + 1] - pe2) * (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl))
)
lev = lev + 1
while pe1[0, 0, lev + 1] < pe2[0, 0, 1]:
qsum += dp1[0, 0, lev] * q4_1[0, 0, lev]
lev = lev + 1
dp = pe2[0, 0, 1] - pe1[0, 0, lev]
esl = dp / dp1[0, 0, lev]
qsum += dp * (
q4_2[0, 0, lev]
+ 0.5
* esl
* (
q4_3[0, 0, lev]
- q4_2[0, 0, lev]
+ q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl)
)
)
q = qsum / (pe2[0, 0, 1] - pe2)
lev = lev - 1
18 changes: 18 additions & 0 deletions tests/cartesian_tests/unit_tests/test_gtc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@
# - For testing non-leave nodes, introduce builders with defaults (for leave nodes as well)


def test_data_type_methods():
for type in DataType:
if type == DataType.BOOL:
assert type.isbool()
else:
assert not type.isbool()

if type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64):
assert type.isinteger()
else:
assert not type.isinteger()

if type in (DataType.FLOAT32, DataType.FLOAT64):
assert type.isfloat()
else:
assert not type.isfloat()


class DummyExpr(Expr):
"""Fake expression for cases where a concrete expression is not needed."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,9 @@ def reduction_ek_field(
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@gtx.field_operator
def reduction_ke_field(
edge_f: common.Field[[KDim, Edge], np.int32],
) -> common.Field[[KDim, Vertex], np.int32]:
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@pytest.mark.uses_unstructured_shift
@pytest.mark.parametrize(
"fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__
"fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__
)
def test_neighbor_sum(unstructured_case_3d, fop):
v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]:

def test_broadcast():
def foo(inp: gtx.Field[[TDim], float64]):
return broadcast(inp, (UDim, TDim))
return broadcast(inp, (TDim, UDim))

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
Expand All @@ -912,7 +912,7 @@ def foo(inp: gtx.Field[[TDim], float64]):

def test_scalar_broadcast():
def foo():
return broadcast(1, (UDim, TDim))
return broadcast(1, (TDim, UDim))

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
Expand Down
Loading

0 comments on commit 6627c8f

Please sign in to comment.