Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Infer as_fieldop type without domain #1853

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9b856cd
New global ordering relation of dimensions
SF-N Feb 5, 2025
a23c6cb
Minor text edits
SF-N Feb 5, 2025
545f98e
Merge branch 'main' into decouple_inferences
SF-N Feb 5, 2025
c851f8f
Use new promote_dims and update tests
SF-N Feb 7, 2025
74e4584
Merge branch 'main' into decouple_inferences
SF-N Feb 7, 2025
31a611d
Minor fix
SF-N Feb 7, 2025
428a979
Remove unnecesary test
SF-N Feb 7, 2025
4c2b63b
Fix Doctests
SF-N Feb 7, 2025
279fbb0
Merge branch 'main' into decouple_inferences
SF-N Feb 10, 2025
2f91dcf
Add functionality to ifer return type of as_fieldop without domain, n…
SF-N Feb 11, 2025
fbfce3e
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 11, 2025
c0849a4
Merge branch 'main' into decouple_inferences
SF-N Feb 11, 2025
98e424d
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 12, 2025
392153a
clean up type inference of as_fieldop
SF-N Feb 12, 2025
360c10f
Working on tuples
SF-N Feb 13, 2025
da549c0
Some more fixes
SF-N Feb 13, 2025
462b672
Run pre-commit
SF-N Feb 13, 2025
9050385
Merge branch 'main' into decouple_inferences
SF-N Feb 13, 2025
e0be279
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 13, 2025
5b2da27
Fix some tests
SF-N Feb 13, 2025
2a9a25a
Merge branch 'main' into decouple_inferences
SF-N Feb 14, 2025
e1c2bef
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 14, 2025
9e48581
Merge branch 'main' into decouple_inferences
SF-N Feb 17, 2025
8bea1b5
Some fixes and add more tests
SF-N Feb 18, 2025
249ba8b
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 18, 2025
2820315
Merge branch 'main' into decouple_inferences
SF-N Feb 24, 2025
c6a841a
Address review comments
SF-N Feb 25, 2025
bd66acb
Fix tests
SF-N Feb 25, 2025
e5ed262
Fix test
SF-N Feb 25, 2025
6627c8f
Merge decouple_inferences
SF-N Feb 25, 2025
392a91d
Fix tests
SF-N Feb 25, 2025
435df54
Address some review comments
SF-N Feb 25, 2025
657bd9d
Fix for no offset_provider_type
SF-N Feb 25, 2025
dd3e6db
Merge branch 'main' into infer_as_fieldop_type_without_domain
SF-N Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 36 additions & 64 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __str__(self) -> str:
return self.value


dims_kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3}


def dimension_to_implicit_offset(dim: str) -> str:
"""
Return name of offset implicitly defined by a dimension.
Expand Down Expand Up @@ -1123,84 +1126,53 @@ class GridType(StrEnum):
UNSTRUCTURED = "unstructured"


def check_dims(dims: list[Dimension]) -> None:
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")

if dims != sorted(dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)):
raise ValueError(f"Dimensions {dims} are not correctly ordered.")


def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]:
"""
Find a unique ordering of multiple (individually ordered) lists of dimensions.

The resulting list of dimensions contains all dimensions of the arguments
in the order they originally appear. If no unique order exists or a
contradicting order is found an exception is raised.
Find an ordering of multiple lists of dimensions.

A modified version (ensuring uniqueness of the order) of
`Kahn's algorithm <https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm>`_
is used to topologically sort the arguments.
The resulting list contains all unique dimensions from the input lists,
sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then
lexicographically by `Dimension.value`.

Examples:
>>> from gt4py.next.common import Dimension
>>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"])
>>> promote_dims([I, J], [I, J, K]) == [I, J, K]
>>> I = Dimension("I", DimensionKind.HORIZONTAL)
>>> J = Dimension("J", DimensionKind.HORIZONTAL)
>>> K = Dimension("K", DimensionKind.VERTICAL)
>>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL)
>>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL)
>>> promote_dims([J, K], [I, K]) == [I, J, K]
True

>>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS
>>> promote_dims([K, J], [I, K])
Traceback (most recent call last):
...
ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K.

>>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS
raise ValueError(f"Dimensions {dims} are not correctly ordered.")
ValueError: Dimensions [Dimension(value='K', kind=<DimensionKind.VERTICAL: 'vertical'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)] are not correctly ordered.
>>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V]
True
>>> promote_dims([I, E2C], [K, E2V])
Traceback (most recent call last):
...
ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J.
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
ValueError: There are more than one dimension with DimensionKind 'LOCAL'.
"""
# build a graph with the vertices being dimensions and edges representing
# the order between two dimensions. The graph is encoded as a dictionary
# mapping dimensions to their predecessors, i.e. a dictionary containing
# adjacency lists. Since graphlib.TopologicalSorter uses predecessors
# (contrary to successors) we also use this directionality here.
graph: dict[Dimension, set[Dimension]] = {}

for dims in dims_list:
if len(dims) == 0:
continue
# create a vertex for each dimension
for dim in dims:
graph.setdefault(dim, set())
# add edges
predecessor = dims[0]
for dim in dims[1:]:
graph[dim].add(predecessor)
predecessor = dim

# modified version of Kahn's algorithm
topologically_sorted_list: list[Dimension] = []

# compute in-degree for each vertex
in_degree = {v: 0 for v in graph.keys()}
for v1 in graph:
for v2 in graph[v1]:
in_degree[v2] += 1

# process vertices with in-degree == 0
# TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list
while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]:
if len(zero_in_degree_vertex_list) != 1:
raise ValueError(
f"Dimensions can not be promoted. Could not determine "
f"order of the following dimensions: "
f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}."
)
v = zero_in_degree_vertex_list[0]
del in_degree[v]
topologically_sorted_list.insert(0, v)
# update in-degree
for predecessor in graph[v]:
in_degree[predecessor] -= 1

if len(in_degree.items()) > 0:
raise ValueError(
f"Dimensions can not be promoted. The following dimensions "
f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}."
)
check_dims(list(dims))
unique_dims = {dim for dims in dims_list for dim in dims}


return topologically_sorted_list
promoted_dims = sorted(unique_dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value))
check_dims(promoted_dims)
return promoted_dims if unique_dims else []


class FieldBuiltinFuncRegistry:
Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,7 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None:
# probably just change the behaviour of the lowering. Until then we do this more
# complicated comparison.
if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType):
assert (
set(expr_type.dims).issubset(set(target_type.dims))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the subset requirement is not necessary anymore, e.g. in the case

itir.SetAt(
        domain=unstructured_domain,
        expr=im.as_fieldop(
            im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))),
            unstructured_domain,
        )(im.ref("inp")),
        target=im.ref("out"),
)

with
im.sym("inp", float_vertex_v2e_k_field) and im.sym("out", float_vertex_k_field)

the expr_type is

[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>), Dimension(value='V2E', kind=<DimensionKind.LOCAL: 'local'>)]

and the target_type is

[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>)]

which is correct in my opinion.

cf. test_fencil_with_nb_field_input

and target_type.dtype == expr_type.dtype
)
assert target_type.dtype == expr_type.dtype

def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType:
return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind))
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class NamedRangeType(ts.TypeSpec):


class DomainType(ts.DataType):
dims: list[common.Dimension] | Literal["unknown"]
dims: list[common.Dimension] | Literal["unknown"] # TODO: remove unknown


class OffsetLiteralType(ts.TypeSpec):
value: ts.ScalarType | common.Dimension


class IteratorType(ts.DataType, ts.CallableType):
position_dims: list[common.Dimension] | Literal["unknown"]
position_dims: list[common.Dimension] | Literal["unknown"] # TODO: remove unknown: 80%?
defined_dims: list[common.Dimension]
element_type: ts.DataType

Expand Down
89 changes: 85 additions & 4 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union
from gt4py.next import common
from gt4py.next.iterator import builtins
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.type_system import type_specifications as it_ts
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.utils import tree_map
Expand Down Expand Up @@ -302,6 +304,7 @@ def as_fieldop(
# information on the ordering of dimensions. In this example
# `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)`
# it is unclear if the result has dimension I, J or J, I.
# Todo: update comment
if domain is None:
domain = it_ts.DomainType(dims="unknown")

Expand All @@ -314,15 +317,93 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
):
return ts.DeferredType(constraint=None)

output_dims = []
for i, field in enumerate(fields):
input_dims = common.promote_dims(
*[field.dims if isinstance(field, ts.FieldType) else []]
)
seen = set()
if isinstance(stencil.node, itir.Expr):
shifts_results = trace_shifts.trace_stencil(
stencil.node, num_args=len(fields)
) # TODO: access node differently?

def resolve_shift(
input_dim: common.Dimension, shift_tuple: tuple[itir.OffsetLiteral, ...]
) -> common.Dimension | None:
"""
Resolves the final dimension by applying shifts from the given shift tuple.

Iterates through the shift tuple, updating `input_dim` based on matching offsets.

Parameters:
- input_dim (common.Dimension): The initial dimension to resolve.
- shift_tuple (tuple[itir.OffsetLiteral, ...]): A tuple of offset literals defining the shift.

Returns:
- common.Dimension | None: The resolved dimension or `None` if no shift is applied.
"""
if not shift_tuple or not offset_provider_type:
return None

final_target: common.Dimension | None = None

for off_literal in shift_tuple[::2]:
offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element
if (
isinstance(offset_type, common.Dimension) and input_dim == offset_type
): # no shift applied
return offset_type
if isinstance(
offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)
):
off_source = (
offset_type.source
if isinstance(offset_type, fbuiltins.FieldOffset)
else (offset_type.codomain)
)
off_targets = (
offset_type.target
if isinstance(offset_type, fbuiltins.FieldOffset)
else (offset_type.domain)
)

if input_dim == off_source: # Check if input fits to offset
for target in off_targets:
if (
target.value != off_literal.value
): # Exclude target matching off_literal.value
final_target = target
input_dim = target # Update input_dim for next iteration
return final_target

if any(shifts_results[i]):
for input_dim in input_dims:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change ordering to:

for shift in ...:
  for input_dim in ...:
    ...

for readability.

for shift_tuple in shifts_results[
i
]: # Use shift tuple corresponding to the input field
final_dim = (
resolve_shift(input_dim, shift_tuple) or input_dim
) # If ther are no shifts, take input_dim
if final_dim not in seen:
seen.add(final_dim)
output_dims.append(final_dim)
else:
output_dims.extend(input_dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • No special casing for tuple of empty tuple ((),)
  • Do not add input dims for empty tuple (the argument is unused in that case). Add test case

else:
output_dims = domain.dims

stencil_return = stencil(
*(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields),
offset_provider_type=offset_provider_type,
)

assert isinstance(stencil_return, ts.DataType)
return type_info.apply_to_primitive_constituents(
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type)
if domain.dims != "unknown"
else ts.DeferredType(constraint=ts.FieldType),
lambda el_type: ts.FieldType(
dims=sorted({dim for dim in output_dims}, key=lambda dim: (common.dims_kind_order[dim.kind], dim.value)) if output_dims != "unknown" else [],
dtype=el_type,
),
stencil_return,
)

Expand Down
14 changes: 8 additions & 6 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,11 @@ def promote(
>>> promoted.dims == [I, J, K] and promoted.dtype == dtype
True

>>> promote(
>>> promoted: ts.FieldType = promote(
... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype)
... ) # doctest: +ELLIPSIS
Traceback (most recent call last):
...
ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K.
... )
>>> promoted.dims == [I, J, K] and promoted.dtype == dtype
True
"""
if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types):
if not all(type_ == types[0] for type_ in types):
Expand Down Expand Up @@ -642,7 +641,10 @@ def return_type_field(
new_dims.append(d)
else:
new_dims.extend(target_dims)
return ts.FieldType(dims=new_dims, dtype=field_type.dtype)
return ts.FieldType(
dims=sorted(new_dims, key=lambda dim: (common.dims_kind_order[dim.kind], dim.value)),
dtype=field_type.dtype,
)


UNDEFINED_ARG = types.new_class("UNDEFINED_ARG")
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def __str__(self) -> str:
dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]"
return f"Field[{dims}, {self.dtype}]"

@eve_datamodels.validator("dims")
def _dims_validator(
self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension]
) -> None:
common.check_dims(dims)


class TupleType(DataType):
# TODO(tehrengruber): Remove `DeferredType` again. This was erroneously
Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
# mypy does not accept [IDim, ...] as a type

IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type]
JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type]
IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type]
IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type]
KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,9 @@ def reduction_ek_field(
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@gtx.field_operator
def reduction_ke_field(
edge_f: common.Field[[KDim, Edge], np.int32],
) -> common.Field[[KDim, Vertex], np.int32]:
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@pytest.mark.uses_unstructured_shift
@pytest.mark.parametrize(
"fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__
"fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__
)
def test_neighbor_sum(unstructured_case_3d, fop):
v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]:

def test_broadcast():
def foo(inp: gtx.Field[[TDim], float64]):
return broadcast(inp, (UDim, TDim))
return broadcast(inp, (TDim, UDim))

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
Expand All @@ -912,7 +912,7 @@ def foo(inp: gtx.Field[[TDim], float64]):

def test_scalar_broadcast():
def foo():
return broadcast(1, (UDim, TDim))
return broadcast(1, (TDim, UDim))

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
Expand Down
12 changes: 5 additions & 7 deletions tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,11 @@ def test_binop_nonmatching_dims():
def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]):
return a + b

with pytest.raises(
errors.DSLError,
match=(
r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +."
),
):
_ = FieldOperatorParser.apply_to_function(nonmatching)
parsed = FieldOperatorParser.apply_to_function(nonmatching)

assert parsed.body.stmts[0].value.type == ts.FieldType(
dims=[X, Y], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
)


def test_bitopping_float():
Expand Down
Loading
Loading