From 5b2da27166520a0ee502eabeae68640ead279b4a Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 13 Feb 2025 12:53:42 +0100 Subject: [PATCH] Fix some tests --- src/gt4py/next/common.py | 6 +- .../iterator/type_system/type_synthesizer.py | 56 ++++++++++--------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b13741233a..3f5ab56cad 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -1142,7 +1142,11 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} - return sorted(unique_dims, key=lambda dim: (kind_order[dim.kind], dim.value)) + return ( + sorted(unique_dims, key=lambda dim: (kind_order[dim.kind], dim.value)) + if unique_dims + else [] + ) class FieldBuiltinFuncRegistry: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5cb9df24e1..45f1e9d597 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -352,31 +352,32 @@ def resolve_shift( return None final_target: common.Dimension | None = None - for off_literal in shift_tuple[::2]: - offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element - if isinstance(offset_type, common.Dimension): - final_target = offset_type - elif isinstance( - offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType) - ): - off_source = ( - offset_type.source - if isinstance(offset_type, fbuiltins.FieldOffset) - else offset_type.codomain - ) - off_targets = ( - offset_type.target - if isinstance(offset_type, fbuiltins.FieldOffset) - else offset_type.domain - ) - if input_dim == off_source: # check if input fits to offset - for target in off_targets: - if ( - target.value != off_literal.value - ): # off_targets also contains a dimension with value off_literal.value which is excluded here - input_dim = final_target = ( - target # setting new input_dim for next iteration - ) + if offset_provider_type: + for off_literal in shift_tuple[::2]: + offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element + if isinstance(offset_type, common.Dimension): + final_target = offset_type + elif isinstance( + offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType) + ): + off_source = ( + offset_type.source + if isinstance(offset_type, fbuiltins.FieldOffset) + else offset_type.codomain + ) + off_targets = ( + offset_type.target + if isinstance(offset_type, fbuiltins.FieldOffset) + else offset_type.domain + ) + if input_dim == off_source: # check if input fits to offset + for target in off_targets: + if ( + target.value != off_literal.value + ): # off_targets also contains a dimension with value off_literal.value which is excluded here + input_dim = final_target = ( + target # setting new input_dim for next iteration + ) return final_target @@ -404,7 +405,10 @@ def resolve_shift( assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=common.promote_dims(output_dims), dtype=el_type), + lambda el_type: ts.FieldType( + dims=common.promote_dims(output_dims) if output_dims != "unknown" else [], + dtype=el_type, + ), stencil_return, )