Skip to content

Commit

Permalink
Some fixes and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Feb 18, 2025
1 parent e1c2bef commit 8bea1b5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 65 deletions.
101 changes: 51 additions & 50 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,80 +323,81 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
):
return ts.DeferredType(constraint=None)

input_dims = common.promote_dims(
*[field.dims for field in fields if isinstance(field, ts.FieldType)]
)
output_dims, seen = [], set()
output_dims = []
for i, field in enumerate(fields):
input_dims = common.promote_dims(
*[field.dims if isinstance(field, ts.FieldType) else []]
)
seen = set()
if isinstance(stencil.node, itir.Expr):
shifts_results = trace_shifts.trace_stencil(
stencil.node, num_args=len(fields)
) # TODO: access node differently?

if isinstance(stencil.node, itir.Expr):
shifts_results = trace_shifts.trace_stencil(
stencil.node, num_args=len(fields)
) # TODO: access node differently?
def resolve_shift(
input_dim: common.Dimension, shift_tuple: tuple[itir.OffsetLiteral, ...]
) -> common.Dimension | None:
"""
Resolves the final dimension by applying shifts from the given shift tuple.
def resolve_shift(
input_dim: common.Dimension, shift_tuple: tuple[itir.OffsetLiteral, ...]
) -> common.Dimension | None:
"""
Resolves the final dimension by applying shifts from the given shift tuple.
Iterates through the shift tuple, updating `input_dim` based on matching offsets.
Iterates through the shift tuple, updating `input_dim` based on matching offsets.
Parameters:
- input_dim (common.Dimension): The initial dimension to resolve.
- shift_tuple (tuple[itir.OffsetLiteral, ...]): A tuple of offset literals defining the shift.
Parameters:
- input_dim (common.Dimension): The initial dimension to resolve.
- shift_tuple (tuple[itir.OffsetLiteral, ...]): A tuple of offset literals defining the shift.
Returns:
- common.Dimension | None: The resolved dimension or `None` if no shift is applied.
"""
if not shift_tuple or not offset_provider_type:
return None

Returns:
- common.Dimension | None: The resolved dimension or `None` if no shift is applied.
"""
if not shift_tuple:
return None
final_target: common.Dimension | None = None

final_target: common.Dimension | None = None
if offset_provider_type:
for off_literal in shift_tuple[::2]:
offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element
if isinstance(offset_type, common.Dimension):
final_target = offset_type
elif isinstance(
if (
isinstance(offset_type, common.Dimension) and input_dim == offset_type
): # no shift applied
return offset_type
if isinstance(
offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)
):
off_source = (
offset_type.source
if isinstance(offset_type, fbuiltins.FieldOffset)
else offset_type.codomain
else (offset_type.codomain)
)
off_targets = (
offset_type.target
if isinstance(offset_type, fbuiltins.FieldOffset)
else offset_type.domain
else (offset_type.domain)
)
if input_dim == off_source: # check if input fits to offset

if input_dim == off_source: # Check if input fits to offset
for target in off_targets:
if (
target.value != off_literal.value
): # off_targets also contains a dimension with value off_literal.value which is excluded here
input_dim = final_target = (
target # setting new input_dim for next iteration
)

return final_target

if any(shifts_results):
for input_dim in input_dims:
for shifts_list in shifts_results:
for shift_tuple in shifts_list:
if (
final_dim := resolve_shift(input_dim, shift_tuple)
) and final_dim not in seen:
): # Exclude target matching off_literal.value
final_target = target
input_dim = target # Update input_dim for next iteration
return final_target

if any(shifts_results[i]):
for input_dim in input_dims:
for shift_tuple in shifts_results[
i
]: # Use shift tuple corresponding to the input field
final_dim = (
resolve_shift(input_dim, shift_tuple) or input_dim
) # If ther are no shifts, take input_dim
if final_dim not in seen:
seen.add(final_dim)
output_dims.append(final_dim)
elif input_dim not in seen:
seen.add(input_dim)
output_dims.append(input_dim)
else:
output_dims.extend(input_dims)
else:
output_dims.extend(input_dims)
else:
output_dims = domain.dims
output_dims = domain.dims

stencil_return = stencil(
*(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields),
Expand Down
55 changes: 40 additions & 15 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
V2E,
E2VDim,
Edge,
Cell,
IDim,
Ioff,
JDim,
Expand Down Expand Up @@ -518,39 +519,63 @@ def test_as_fieldop_without_domain_different_three_datatypes():
stencil = im.lambda_("it1", "it2", "it3")(
im.plus(
im.plus(
im.deref(im.shift("E2V", 1)(im.shift("V2E", 1)("it1"))),
im.deref(im.shift("C2E", 1)(im.shift("E2V", 1)("it1"))),
im.deref(im.shift("KOff", 1)("it2")),
),
im.deref(im.shift("IOff", 1)("it3")),
)
)

testee = im.as_fieldop(stencil)(
im.ref("inp1", float_edge_field),
im.ref("inp2", float_vertex_k_field),
im.ref("inp1", float_vertex_field),
im.ref("inp2", float_edge_k_field),
im.ref("inp3", float_i_field),
)
result = itir_type_inference.infer(
testee,
offset_provider_type={"V2E": V2E, "E2V": E2V, "KOff": KDim, "IOff": IDim},
offset_provider_type={"C2E": C2E, "E2V": E2V, "KOff": KDim, "IOff": IDim},
allow_undeclared_symbols=True,
)
assert result.type == ts.FieldType(dims=[Edge, IDim, Vertex, KDim], dtype=float64_type)
assert result.type == ts.FieldType(dims=[Cell, Edge, IDim, KDim], dtype=float64_type)
assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_edge_field.dims,
element_type=float_edge_field.dtype,
defined_dims=float_vertex_field.dims,
element_type=float_vertex_field.dtype,
)
assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_vertex_k_field.dims,
element_type=float_vertex_k_field.dtype,
defined_dims=float_edge_k_field.dims,
element_type=float_edge_k_field.dtype,
)
assert result.fun.args[0].type.pos_only_args[2] == it_ts.IteratorType(
position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype
)


def test_as_fieldop_without_domain_new():
stencil = im.lambda_("it1", "it2")(
im.plus(im.deref("it1"), im.deref(im.shift("KOff", 1)(im.shift("V2E", 0)("it2"))))
)

testee = im.as_fieldop(stencil)(
im.ref("inp2", float_vertex_field), im.ref("inp1", float_edge_k_field)
)
result = itir_type_inference.infer(
testee, offset_provider_type={"V2E": V2E, "KOff": KDim}, allow_undeclared_symbols=True
)
assert result.type == ts.FieldType(dims=[Vertex, KDim], dtype=float64_type)
assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_vertex_field.dims,
element_type=float_vertex_field.dtype,
)
assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_edge_k_field.dims,
element_type=float_edge_k_field.dtype,
)


def test_as_fieldop_without_domain_V2E():
stencil = im.lambda_("it")(im.deref(im.shift("V2E", 0)("it")))

Expand Down Expand Up @@ -591,17 +616,17 @@ def test_as_fieldop_without_domain_only_one_shift():


def test_as_fieldop_without_domain_new_nested_shifts():
stencil = im.lambda_("it")(im.deref(im.shift("E2V", 0)(im.shift("V2E", 0)("it"))))
stencil = im.lambda_("it")(im.deref(im.shift("C2E", 0)(im.shift("E2V", 0)("it"))))

testee = im.as_fieldop(stencil)(im.ref("inp", float_edge_field))
testee = im.as_fieldop(stencil)(im.ref("inp", float_vertex_field))
result = itir_type_inference.infer(
testee, offset_provider_type={"E2V": E2V, "V2E": V2E}, allow_undeclared_symbols=True
testee, offset_provider_type={"C2E": C2E, "E2V": E2V}, allow_undeclared_symbols=True
)
assert result.type == ts.FieldType(dims=[Edge], dtype=float64_type)
assert result.type == ts.FieldType(dims=[Cell], dtype=float64_type)
assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType(
position_dims="unknown",
defined_dims=float_edge_field.dims,
element_type=float_edge_field.dtype,
defined_dims=float_vertex_field.dims,
element_type=float_vertex_field.dtype,
)


Expand Down

0 comments on commit 8bea1b5

Please sign in to comment.