Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tree_map and replace apply_to_primitive_constituents #1570

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8add029
refactor[next]: itir embedded: cleaner closure run
havogt Apr 4, 2024
853d3e1
cleanup
havogt Apr 4, 2024
f661cd3
fix test
havogt Apr 4, 2024
09e568d
without temporaries
havogt Apr 8, 2024
12b8696
temporaries
havogt Apr 8, 2024
540a2d8
cleanup
havogt Apr 9, 2024
23ddef1
move to SetAt
havogt Apr 10, 2024
e64b986
Merge branch 'refactor_itir_embedded' into itir_program_embedded2
havogt Apr 10, 2024
c99f44d
embedded
havogt Apr 10, 2024
1a6f885
roundtrip+double_roundtrip with shortcuts
havogt Apr 11, 2024
39d6d7c
changes
havogt Apr 11, 2024
ab44009
fencil2program only for gtfn
havogt Apr 11, 2024
12f1663
fix import
havogt Apr 11, 2024
aa80949
Merge remote-tracking branch 'upstream/main' into itir_program
havogt Apr 11, 2024
5037493
fix builtins list
havogt Apr 11, 2024
751581e
add comment
havogt Apr 11, 2024
3d2f33e
fix type checker
havogt Apr 11, 2024
53bad75
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 11, 2024
4cbce7e
Apply suggestions from code review
havogt Apr 12, 2024
c955645
format
havogt Apr 12, 2024
6effe10
pretty printing/parsing
havogt Apr 12, 2024
66de3ec
Apply suggestions from code review
havogt Apr 15, 2024
e63da77
address more review comments
havogt Apr 15, 2024
45fba85
move tmp to pretty_printer
havogt Apr 15, 2024
1a70218
pparse for temporaries
havogt Apr 15, 2024
c39c603
rename gtfn.FencilDefinition -> Program
havogt Apr 15, 2024
705cfcf
remove TODO
havogt Apr 15, 2024
c5e78c4
Apply suggestions from code review
havogt Apr 15, 2024
a336bf5
rename as_field_operator -> as_fieldop
havogt Apr 15, 2024
af16f40
missed a file
havogt Apr 15, 2024
2d6bfbf
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 15, 2024
df1146b
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 15, 2024
bc2c2d3
add fencil2program to roundtrip
havogt Apr 16, 2024
c7ccd6a
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 16, 2024
f45b460
pre-allocate result buffer
havogt Apr 17, 2024
5d4fc3d
fix tracer context
havogt Apr 17, 2024
1d93192
first (almost) complete embedded version
havogt Apr 17, 2024
5882c28
add dim kind to print/parse
havogt Apr 17, 2024
b7cbf16
fix tests
havogt Apr 17, 2024
e97ca25
cleanup test_program
havogt Apr 17, 2024
8c2bd8f
re-enable lift mode in roundtrip
havogt Apr 18, 2024
21b230b
replace lift_mode fixture by backend in program_processor
havogt Apr 18, 2024
0946783
fix doctests
havogt Apr 18, 2024
f93da09
fix tests
havogt Apr 18, 2024
97663bd
undo quickstart changes
havogt Apr 18, 2024
3297f7b
undo delete cpp_backend_tests
havogt Apr 18, 2024
f37b372
fix quickstart guide again
havogt Apr 18, 2024
e242ab6
remove runtime lift
havogt Apr 19, 2024
a07d8ea
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 19, 2024
3c2b9a5
Merge remote-tracking branch 'upstream/main' into test_lift_mode_to_p…
havogt Apr 19, 2024
a9f1043
Update docs/user/next/QuickstartGuide.md
havogt Apr 19, 2024
e7195a5
cleanup out field construction
havogt Apr 19, 2024
3f67746
Update src/gt4py/next/program_processors/runners/double_roundtrip.py
havogt Apr 22, 2024
369eae7
read config.DEBUG at execution
havogt Apr 22, 2024
35f2132
remove LiftMode.SIMPLE_HEURISTIC
havogt Apr 22, 2024
4a4f9b1
fix formatting
havogt Apr 23, 2024
2825588
Merge branch 'test_lift_mode_to_processor' into itir_program_embedded2
havogt Apr 23, 2024
65abde7
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 23, 2024
7373eb6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 6, 2024
bfcc118
move ordering of unstructured domain to gtfn
havogt May 15, 2024
54f44cc
fix problem in column dtype if contains None
havogt May 16, 2024
b8b26e6
address more review comments
havogt May 16, 2024
b7b489e
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 16, 2024
9ee02e4
fix tuples in columns
havogt May 16, 2024
d104633
fix preserve axis kind in global tmps
havogt May 17, 2024
866c3d6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 17, 2024
66e2464
fix follow up issue
havogt May 17, 2024
56e0086
Start using tree_map instead of apply_to_primitive_constituents
SF-N Jun 10, 2024
8ab97c4
Add functionality to call also tree_map(lambda x: x + 1, ((1, 2), 3))…
SF-N Jul 3, 2024
7fbd083
Merge main
SF-N Dec 31, 2024
827a4d3
Run pre-commit
SF-N Dec 31, 2024
0a5ac37
Minor
SF-N Dec 31, 2024
70283e6
Replace more apply_to_primitive_constituents by tree_map
SF-N Dec 31, 2024
14619d2
Minor fix
SF-N Dec 31, 2024
fac65de
Revert replacing when tuple_constructor is present
SF-N Dec 31, 2024
de34ef6
Try to use result_collection_constructor
SF-N Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,12 +848,12 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.",
)

return_type = type_info.apply_to_primitive_constituents(
return_type = type_info.type_tree_map(
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
value.type,
)
)
)(value.type)

assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

return foast.Call(
Expand Down
9 changes: 5 additions & 4 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec:
return ts.FieldType(dims=[], dtype=type_el)
return type_el

return type_info.apply_to_primitive_constituents(promote_el, type_)
return type_info.type_tree_map(promote_el)(type_)


def promote_zero_dims(
Expand Down Expand Up @@ -306,6 +306,7 @@ def return_type_scanop(
# field
[callable_type.axis],
)
return type_info.apply_to_primitive_constituents(
lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)), carry_dtype
)

return type_info.type_tree_map(
lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
)(carry_dtype)
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def fuse_as_fieldop(
# just a safety check if typing information is available
if arg.type and not isinstance(arg.type, ts.DeferredType):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
dtype = type_info.type_tree_map(type_info.extract_dtype)(arg.type)
assert not isinstance(dtype, it_ts.ListType)
new_param: str
if isinstance(
Expand Down Expand Up @@ -233,7 +233,7 @@ def visit_FunCall(self, node: itir.FunCall):
eligible_args = []
for arg, arg_shifts in zip(args, shifts, strict=True):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
dtype = type_info.type_tree_map(type_info.extract_dtype)(arg.type)
# TODO(tehrengruber): make this configurable
eligible_args.append(
_is_tuple_expr_of_literals(arg)
Expand Down
19 changes: 7 additions & 12 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,13 @@ def _transform_by_pattern(
domain_expr = domain.as_expr()

assert isinstance(tmp_expr.type, ts.TypeSpec)
tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents(
lambda x: uids.sequential_id(),
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = (
type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
)
tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map(
result_collection_constructor=lambda elements: tuple(elements)
)(lambda x: uids.sequential_id())(tmp_expr.type)

tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map(
result_collection_constructor=lambda elements: tuple(elements)
)(type_info.extract_dtype)(tmp_expr.type)

# allocate temporary for all tuple elements
def allocate_temporary(tmp_name: str, dtype: ts.ScalarType):
Expand Down
5 changes: 2 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,8 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup
assert isinstance(domain, it_ts.DomainType)
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above
node.dtype,
return type_info.type_tree_map(lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype))(
node.dtype
)

def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None:
Expand Down
9 changes: 4 additions & 5 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _convert_as_fieldop_input_to_iterator(
input_dims = []

element_type: ts.DataType
element_type = type_info.apply_to_primitive_constituents(type_info.extract_dtype, input_)
element_type = type_info.type_tree_map(type_info.extract_dtype)(input_)

# handle neighbor / sparse input fields
defined_dims = []
Expand Down Expand Up @@ -311,12 +311,11 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
offset_provider_type=offset_provider_type,
)
assert isinstance(stencil_return, ts.DataType)
return type_info.apply_to_primitive_constituents(
return type_info.type_tree_map(
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type)
if domain.dims != "unknown"
else ts.DeferredType(constraint=ts.FieldType),
stencil_return,
)
else ts.DeferredType(constraint=ts.FieldType)
)(stencil_return)

return applied_as_fieldop

Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

from gt4py.eve.utils import XIterable, xiter
from gt4py.next import common
from gt4py.next import common, utils
from gt4py.next.type_system import type_specifications as ts


Expand Down Expand Up @@ -197,6 +197,11 @@ def apply_to_primitive_constituents(
return fun(*symbol_types)


type_tree_map = utils.tree_map(
collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x])
)


def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType:
"""
Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`.
Expand Down
60 changes: 40 additions & 20 deletions src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,36 @@ def tree_map(
@overload
def tree_map(
*,
collection_type: type | tuple[type, ...] = tuple,
result_collection_constructor: Optional[type | Callable] = None,
) -> Callable[
[Callable[_P, _R]], Callable[..., Any]
]: ... # TODO(havogt): if result_collection_constructor is Callable, improve typing
collection_type: type | tuple[type, ...],
result_collection_constructor: Optional[type] = None,
) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ...


def tree_map(
fun: Optional[Callable[_P, _R]] = None,
*,
*args: Callable[_P, _R],
collection_type: type | tuple[type, ...] = tuple,
result_collection_constructor: Optional[type | Callable] = None,
) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]:
) -> (
Callable[..., _R | tuple[_R | tuple, ...]]
| Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]
| _R
| tuple[_R | tuple, ...]
):
"""
Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s).
Apply `args[0]` to each entry of (possibly nested) collections (by default `tuple`s).

Args:
fun: Function to apply to each entry of the collection.
args[0]: Function to apply to each entry of the collection.
collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types.
result_collection_constructor: Type of the collection to be returned. If `None` the same type as `collection_type` is used.
result_collection_constructor: Constructor of the collection to be returned. If `None` the same type as `collection_type` is used.

Examples:
>>> tree_map(lambda x: x + 1)(((1, 2), 3))
((2, 3), 4)

>>> tree_map(lambda x: x + 1, ((1, 2), 3))
((2, 3), 4)

>>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6))
((5, 7), 9)

Expand All @@ -129,7 +134,24 @@ def tree_map(
)
result_collection_constructor = collection_type

if fun:
if len(args) == 0:
return functools.partial(
tree_map,
collection_type=collection_type,
result_collection_constructor=result_collection_constructor,
)

if callable(args[0]):
fun = args[0]
colls = args[1:]

if len(colls) == 0:
return functools.partial(
tree_map,
fun,
collection_type=collection_type,
result_collection_constructor=result_collection_constructor,
)

@functools.wraps(fun)
def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
Expand All @@ -140,14 +162,12 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
assert result_collection_constructor is not None
return result_collection_constructor(impl(*arg) for arg in zip(*args))

return fun( # type: ignore[call-arg, misc] # mypy not smart enough
return fun( # type: ignore[call-arg] # mypy not smart enough
*cast(_P.args, args)
) # mypy doesn't understand that `args` at this point is of type `_P.args`

return impl
else:
return functools.partial(
tree_map,
collection_type=collection_type,
result_collection_constructor=result_collection_constructor,
)
return impl(*colls)

raise TypeError(
"tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`, or with a function and collection."
)
33 changes: 32 additions & 1 deletion tests/next_tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,45 @@
import pytest

from gt4py.next import utils
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info
from gt4py.next.common import Field
from numpy import int64


def test_tree_map_scalar():
@utils.tree_map(collection_type=ts.ScalarType, result_collection_constructor=tuple)
def testee(x):
return x + 1

assert testee(1) == (2)


def test_apply_to_primitive_constituents():
int_type = ts.ScalarType(kind=ts.ScalarKind.INT64)
tuple_type = ts.TupleType(types=[ts.TupleType(types=[int_type, int_type]), int_type])

tree = type_info.type_tree_map(
lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type)
)(tuple_type)

prim = type_info.apply_to_primitive_constituents(
lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), tuple_type
)

assert tree == prim


def test_tree_map_default():
expected_result = ((2, 3), 4)

@utils.tree_map
def testee(x):
return x + 1

assert testee(((1, 2), 3)) == ((2, 3), 4)
assert testee(((1, 2), 3)) == expected_result
assert utils.tree_map(lambda x: x + 1)(((1, 2), 3)) == expected_result
assert utils.tree_map(lambda x: x + 1, ((1, 2), 3)) == expected_result


def test_tree_map_multi_arg():
Expand Down
Loading