Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: embedded with itir.Program #1530

Merged
merged 73 commits into from
Jun 28, 2024

Conversation

havogt
Copy link
Contributor

@havogt havogt commented Apr 16, 2024

Updates itir.embedded to work with itir.Programs, i.e. set_at and as_fieldop.

For programs to be able to run in embedded, the domain needs to be provided as second argument to as_fieldop.

Introduces a DimensionKind to itir.AxisLiteral to be able to reconstruct the kind from the IR. This is needed now as the set_at assigns from field to field, which requires matching dimensions. However, previously the python program generated from IR would always construct horizontal dimensions (but the information would not be used).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question, not related to this PR but in general to the new IR and specifically to this type:

class SetAt(Stmt):  # from JAX array.at[...].set()
    expr: Expr  # only `as_fieldop(stencil)(inp0, ...)` in first refactoring
    domain: Expr
    target: Expr  # `make_tuple` or SymRef

Do we really need to support make_tuple as target expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the current representation for something like

@fundef
def prog(a,b,c)
    setat(as_field_op(lambda x: make_tuple(deref(x)+1, deref(x)+2))(c), domain, make_tuple(a,b))
    

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the fieldview backend support that representation as is, I mean with inlined make_tuple? or just this one?

@fundef
def prog(a,b,c)
    setat(as_field_op(lambda x, y: make_tuple(x, y))(as_field_op(lambda x: x+1)(c), as_field_op(lambda x: x+2)(c), ), domain, make_tuple(a,b))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am asking because from a lowering perspective, make_tuple and tuple_get do not implement any kind of computation on fields. Therefore, it is difficult to represent them in my map-tasklet graph. Would it be too strange to treat these builtins on the same level as as_field_op?

Copy link
Contributor

@edopao edopao Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fundef
def prog(a,b,c)
    setat(make_tuple(as_field_op(lambda x: x+1)(c), as_field_op(lambda x: x+2)(c)), domain, make_tuple(a,b))

Copy link
Contributor

@tehrengruber tehrengruber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round.

) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ...


def tree_map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is time to remove apply_to_primitive_constituents and absorb it in this function. I'll check if @SF-N has capacity to work on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sara will extract this into a new PR and also remove the apply_to_primitive_constituents cases there. I'll keep you in the loop.

@havogt havogt marked this pull request as ready for review May 16, 2024 11:27
@havogt
Copy link
Contributor Author

havogt commented May 16, 2024

cscs-ci run

@havogt havogt requested a review from tehrengruber May 16, 2024 13:09

return impl


@runtime.closure.register(EMBEDDED)
def closure(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation of ins appears to be wrong as there is the promotion of scalars below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for inp in promoted_ins
)
res = sten(*ins_iters)
res = _compute_point(sten, ins, pos, column_range)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res = _compute_point(sten, ins, pos, column_range)
res = _compute_at_point(sten, ins, pos, column_range)

Just a suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to go with _compute_at_position

@havogt
Copy link
Contributor Author

havogt commented Jun 27, 2024

cscs-ci run

@havogt havogt merged commit 3ca278d into GridTools:main Jun 28, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants