-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Infer as_fieldop type without domain #1853
base: main
Are you sure you want to change the base?
Conversation
…ot yet working for tuples
@@ -492,10 +492,7 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: | |||
# probably just change the behaviour of the lowering. Until then we do this more | |||
# complicated comparison. | |||
if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): | |||
assert ( | |||
set(expr_type.dims).issubset(set(target_type.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the subset requirement is not necessary anymore, e.g. in the case
itir.SetAt(
domain=unstructured_domain,
expr=im.as_fieldop(
im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))),
unstructured_domain,
)(im.ref("inp")),
target=im.ref("out"),
)
with
im.sym("inp", float_vertex_v2e_k_field)
and im.sym("out", float_vertex_k_field)
the expr_type is
[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>), Dimension(value='V2E', kind=<DimensionKind.LOCAL: 'local'>)]
and the target_type is
[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>)]
which is correct in my opinion.
cf. test_fencil_with_nb_field_input
)(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), | ||
ts.TupleType(types=[float_i_field, float_i_field]), | ||
)(im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), | ||
ts.TupleType(types=[float_ij_field, float_ij_field]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ts.TupleType(types=[float_ij_field, float_ij_field]
correct here?
src/gt4py/next/common.py
Outdated
return ( | ||
sorted(unique_dims, key=lambda dim: (kind_order[dim.kind], dim.value)) | ||
if unique_dims | ||
else [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check this
if any(shifts_results[i]): | ||
for input_dim in input_dims: | ||
for shift_tuple in shifts_results[ | ||
i | ||
]: # Use shift tuple corresponding to the input field | ||
final_dim = ( | ||
resolve_shift(input_dim, shift_tuple) or input_dim | ||
) # If ther are no shifts, take input_dim | ||
if final_dim not in seen: | ||
seen.add(final_dim) | ||
output_dims.append(final_dim) | ||
else: | ||
output_dims.extend(input_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- No special casing for tuple of empty tuple
((),)
- Do not add input dims for empty tuple (the argument is unused in that case). Add test case
return final_target | ||
|
||
if any(shifts_results[i]): | ||
for input_dim in input_dims: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change ordering to:
for shift in ...:
for input_dim in ...:
...
for readability.
position_dims=float_vertex_k_field.dims, | ||
defined_dims=float_vertex_field.dims, | ||
element_type=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the IteratorType
like this?
Please carefully check this test as I am not 100% sure.
This extends the GTIR type inference to infer the type of as_fieldop calls without a domain.
TODOs: