diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index db12ba13ed..6d39fc376d 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -93,19 +93,13 @@ def _transform_by_pattern( domain_expr = domain.as_expr() assert isinstance(tmp_expr.type, ts.TypeSpec) - tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( - lambda x: uids.sequential_id(), - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) # TODO: how should tuple_constructorb e handled? - - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( - type_info.apply_to_primitive_constituents( - type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) - ) # TODO: how should tuple_constructorb e handled? + tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map( + result_collection_constructor=lambda elements: tuple(elements) + )(lambda x: uids.sequential_id())(tmp_expr.type) + + tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map( + result_collection_constructor=lambda elements: tuple(elements) + )(type_info.extract_dtype)(tmp_expr.type) # allocate temporary for all tuple elements def allocate_temporary(tmp_name: str, dtype: ts.ScalarType):