Skip to content

Commit

Permalink
clean up type inference of as_fieldop
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Feb 12, 2025
1 parent 98e424d commit 392153a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
44 changes: 28 additions & 16 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,29 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
stencil.node, num_args=len(fields)
) # TODO: access node differently?

def resolve_shift(input_dim, shift_tuple):
def resolve_shift(
input_dim: common.Dimension, shift_tuple: tuple[itir.OffsetLiteral, ...]
) -> common.Dimension | None:
"""
Resolves the final dimension by applying shifts from the given shift tuple.
Iterates through the shift tuple, updating `input_dim` based on matching offsets.
Parameters:
- input_dim (common.Dimension): The initial dimension to resolve.
- shift_tuple (tuple[itir.OffsetLiteral, ...]): A tuple of offset literals defining the shift.
Returns:
- common.Dimension | None: The resolved dimension or `None` if no shift is applied.
"""
if not shift_tuple:
return None

final_target = None
final_target: common.Dimension | None = None
for off_literal in shift_tuple[::2]:
off = off_literal.value
offset_type = offset_provider_type[off]

offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element
if isinstance(offset_type, common.Dimension):
if input_dim != offset_type:
pass
final_target = offset_type

elif isinstance(
offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)
):
Expand All @@ -359,28 +368,31 @@ def resolve_shift(input_dim, shift_tuple):
if isinstance(offset_type, fbuiltins.FieldOffset)
else offset_type.domain
)

if input_dim == off_source:
if input_dim == off_source: # check if input fits to offset
for target in off_targets:
if target.value != off:
input_dim = final_target = target
if (
target.value != off_literal.value
): # off_targets also contains a dimension with value off_literal.value which is excluded here
input_dim = final_target = (
target # setting new input_dim for next iteration
)

return final_target

if any(shift_tuple for shift_list in shifts_results for shift_tuple in shift_list):
if any(shifts_results):
for input_dim in input_dims:
for shifts_list in shifts_results:
for shift_tuple in shifts_list:
final_dim = resolve_shift(input_dim, shift_tuple)
if final_dim and final_dim not in seen:
if (
final_dim := resolve_shift(input_dim, shift_tuple)
) and final_dim not in seen:
seen.add(final_dim)
output_dims.append(final_dim)
elif input_dim not in seen:
seen.add(input_dim)
output_dims.append(input_dim)
else:
output_dims.extend(input_dims)

else:
output_dims = domain.dims

Expand Down
24 changes: 24 additions & 0 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,30 @@ def test_as_fieldop_without_domain_V2E():
)


def test_as_fieldop_without_domain_only_one_shift():
stencil = im.lambda_("it1", "it2")(
im.plus(im.deref(im.shift("V2E", 1)("it1")), im.deref("it2"))
)

testee = im.as_fieldop(stencil)(
im.ref("inp1", float_edge_field), im.ref("inp2", float_edge_field)
)
result = itir_type_inference.infer(
testee, offset_provider_type={"V2E": V2E, "Edge": Edge}, allow_undeclared_symbols=True
)
assert result.type == ts.FieldType(dims=[Edge, Vertex], dtype=float64_type)
assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_edge_field.dims,
element_type=float_edge_field.dtype,
)
assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_edge_field.dims,
element_type=float_edge_field.dtype,
)


def test_as_fieldop_without_domain_new_nested_shifts():
stencil = im.lambda_("it")(im.deref(im.shift("E2V", 0)(im.shift("V2E", 0)("it"))))

Expand Down

0 comments on commit 392153a

Please sign in to comment.