Skip to content

Commit

Permalink
Add test for neighbor / sparse input field
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Apr 18, 2024
1 parent 520c805 commit 129f6b3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
12 changes: 9 additions & 3 deletions src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class DomainType(ts.DataType):
dims: list[common.Dimension]


# TODO: how about ts.OffsetType?
@dataclasses.dataclass(frozen=True)
class OffsetLiteralType(ts.TypeSpec):
value: IntegralScalar | common.Dimension
Expand All @@ -42,7 +41,7 @@ class ListType(ts.DataType):


@dataclasses.dataclass(frozen=True)
class IteratorType(ts.DataType, ts.CallableType): # todo: rename to iterator
class IteratorType(ts.DataType, ts.CallableType):
position_dims: list[common.Dimension] | typing.Literal["unknown"]
defined_dims: list[common.Dimension]
element_type: ts.DataType
Expand All @@ -52,9 +51,16 @@ class IteratorType(ts.DataType, ts.CallableType): # todo: rename to iterator
class StencilClosureType(ts.TypeSpec):
domain: DomainType
stencil: ts.FunctionType
output: ts.FieldType | ts.TupleType # todo: validate tuple of fields
output: ts.FieldType | ts.TupleType
inputs: list[ts.FieldType]

def __post_init__(self):
# local import to avoid importing type_info from a type_specification module
from gt4py.next.type_system import type_info

for el_type in type_info.primitive_constituents(self.output):
assert isinstance(el_type, ts.FieldType), "All constituent types must be field types."


@dataclasses.dataclass(frozen=True)
class FencilType(ts.TypeSpec):
Expand Down
58 changes: 43 additions & 15 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@
float64_list_type = it_ts.ListType(element_type=float64_type)
int_list_type = it_ts.ListType(element_type=int_type)

bool_i_field = ts.FieldType(dims=[IDim], dtype=bool_type)
bool_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=bool_type)
bool_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=bool_type)
float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type)
float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type)
float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type)
float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type)

it_on_v_of_e_type = it_ts.IteratorType(
position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=int_type
Expand Down Expand Up @@ -193,7 +194,7 @@ def test_cartesian_fencil_definition():
testee = itir.FencilDefinition(
id="f",
function_definitions=[],
params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)],
params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)],
closures=[
itir.StencilClosure(
domain=cartesian_domain,
Expand All @@ -218,10 +219,10 @@ def test_cartesian_fencil_definition():
kw_only_args={},
returns=bool_type,
),
output=bool_i_field,
inputs=[bool_i_field],
output=float_i_field,
inputs=[float_i_field],
)
fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type])
fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type])
assert result.type == fencil_type
assert result.closures[0].type == closure_type

Expand All @@ -236,7 +237,7 @@ def test_unstructured_fencil_definition():
testee = itir.FencilDefinition(
id="f",
function_definitions=[],
params=[im.sym("inp", bool_edge_k_field), im.sym("out", bool_vertex_k_field)],
params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)],
closures=[
itir.StencilClosure(
domain=unstructured_domain,
Expand All @@ -261,11 +262,11 @@ def test_unstructured_fencil_definition():
kw_only_args={},
returns=bool_type,
),
output=bool_vertex_k_field,
inputs=[bool_edge_k_field],
output=float_vertex_k_field,
inputs=[float_edge_k_field],
)
fencil_type = it_ts.FencilType(
params=[bool_edge_k_field, bool_vertex_k_field], closures=[closure_type]
params=[float_edge_k_field, float_vertex_k_field], closures=[closure_type]
)
assert result.type == fencil_type
assert result.closures[0].type == closure_type
Expand All @@ -282,7 +283,7 @@ def test_function_definition():
itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")),
itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")),
],
params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)],
params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)],
closures=[
itir.StencilClosure(
domain=cartesian_domain,
Expand All @@ -307,9 +308,36 @@ def test_function_definition():
kw_only_args={},
returns=bool_type,
),
output=bool_i_field,
inputs=[bool_i_field],
output=float_i_field,
inputs=[float_i_field],
)
fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type])
fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type])
assert result.type == fencil_type
assert result.closures[0].type == closure_type


def test_fencil_with_nb_field_input():
mesh = simple_mesh()
unstructured_domain = im.call("unstructured_domain")(
im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1),
im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1),
)

testee = itir.FencilDefinition(
id="f",
function_definitions=[],
params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)],
closures=[
itir.StencilClosure(
domain=unstructured_domain,
stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))),
output=im.ref("out"),
inputs=[im.ref("inp")],
),
],
)

result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider)

assert result.closures[0].stencil.expr.args[0].type == float64_list_type
assert result.closures[0].stencil.type.returns == float64_type

0 comments on commit 129f6b3

Please sign in to comment.