diff --git a/pyproject.toml b/pyproject.toml index acdaa019ee..23051eeccd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,14 +150,16 @@ warn_unused_ignores = true # GT4Py configs [[tool.mypy.overrides]] +allow_incomplete_defs = false +allow_untyped_defs = false ignore_missing_imports = false module = 'gt4py.*' [[tool.mypy.overrides]] # The following ignore_errors are only temporary. # TODO: Fix errors and enable these settings. -disallow_incomplete_defs = false -disallow_untyped_defs = false +allow_incomplete_defs = true +allow_untyped_defs = true follow_imports = 'silent' module = 'gt4py.cartesian.*' warn_unused_ignores = false @@ -186,10 +188,6 @@ module = 'gt4py.cartesian.frontend.defir_to_gtir' ignore_errors = true module = 'gt4py.cartesian.frontend.meta' -[[tool.mypy.overrides]] -disallow_untyped_defs = true -module = 'gt4py.eve.*' - [[tool.mypy.overrides]] module = 'gt4py.eve.extended_typing' warn_unused_ignores = false @@ -202,14 +200,14 @@ module = 'gt4py.storage.*' warn_unused_ignores = false [[tool.mypy.overrides]] -# # TODO: this should be changed to true after a transition period -disallow_incomplete_defs = false -module = 'gt4py.next.*' +allow_incomplete_defs = true +allow_untyped_defs = true +module = 'gt4py.next.iterator.*' [[tool.mypy.overrides]] -# TODO: temporarily to propagate it to all of next -disallow_incomplete_defs = true -module = 'gt4py.next.ffront.*' +allow_incomplete_defs = true +allow_untyped_defs = true +module = 'gt4py.next.program_processors.runners.dace_iterator.*' [[tool.mypy.overrides]] ignore_errors = true diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 801048c54e..b85f203c01 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1666,7 +1666,7 @@ def __init__(self, definition, *, options, externals=None, dtypes=None): self.block = None self.dtypes = dtypes - def __str__(self): + def __str__(self) -> str: result = " {\n" result += "\n".join("\t{}: {}".format(name, getattr(self, name)) for name in vars(self)) result += "\n}" diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index 09f02aa501..5838dcfe99 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -190,7 +190,7 @@ class LevelMarker(enum.Enum): START = 0 END = -1 - def __str__(self): + def __str__(self) -> str: return self.name @@ -251,7 +251,7 @@ def from_value(cls, value): return result - def __str__(self): + def __str__(self) -> str: return self.name @@ -268,7 +268,7 @@ class DataType(enum.Enum): FLOAT32 = 104 FLOAT64 = 108 - def __str__(self): + def __str__(self) -> str: return self.name @property @@ -665,7 +665,7 @@ def symbol(self): elif self == self.FORWARD: return "->" - def __str__(self): + def __str__(self) -> str: return self.name def __lshift__(self, steps: int): diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index fd254d46fa..0f43537758 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -140,7 +140,7 @@ def from_dace_storage(cls, schedule): class AxisBound(common.AxisBound): axis: Axis - def __str__(self): + def __str__(self) -> str: return get_axis_bound_str(self, self.axis.domain_symbol()) @classmethod diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index 1b06d70737..d760082872 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -27,7 +27,7 @@ class Axis(enum.Enum): J = 1 K = 2 - def __str__(self): + def __str__(self) -> str: return self.name names = [ax.name for ax in Axis] @@ -191,7 +191,7 @@ def __repr__(self): def __hash__(self): return tuple.__hash__(self) - def __str__(self): + def __str__(self) -> str: return tuple.__repr__(self) @property @@ -384,7 +384,7 @@ def __repr__(self): def __hash__(self): return tuple.__hash__(self) - def __str__(self): + def __str__(self) -> str: return tuple.__repr__(self) @property diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 81958f7e50..a04b3729e7 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -485,7 +485,7 @@ def __repr__(self): def __eq__(self, other): return repr(self) == repr(other) - def __str__(self): + def __str__(self) -> str: return f"{self.axis}[{self.index}] + {self.offset}" def __add__(self, offset: int): @@ -516,7 +516,7 @@ def __init__(self, axis: str, start: int, end: int): def __repr__(self): return f"AxisInterval(axis={self.axis}, start={self.start}, end={self.end})" - def __str__(self): + def __str__(self) -> str: return f"{self.axis}[{self.start}:{self.end}]" def __len__(self): @@ -532,7 +532,7 @@ def __init__(self, name: str, shift: int): def __repr__(self): return f"ShiftedAxis(name={self.name}, shift={self.shift})" - def __str__(self): + def __str__(self) -> str: return f"{self.name}+{self.shift}" def __add__(self, shift): @@ -559,7 +559,7 @@ def __gt_axis_name__(self) -> str: def __repr__(self): return f"Axis(name={self.name})" - def __str__(self): + def __str__(self) -> str: return self.name def __getitem__(self, interval): @@ -654,7 +654,7 @@ def __repr__(self): args = f"dtype={self.dtype!r}, axes={self.axes!r}, data_dims={self.data_dims!r}" return f"_FieldDescriptor({args})" - def __str__(self): + def __str__(self) -> str: return ( f"Field<[{', '.join(str(ax) for ax in self.axes)}], ({self.dtype}, {self.data_dims})>" ) diff --git a/src/gt4py/cartesian/utils/text.py b/src/gt4py/cartesian/utils/text.py index 87a44c8b32..208343aa1d 100644 --- a/src/gt4py/cartesian/utils/text.py +++ b/src/gt4py/cartesian/utils/text.py @@ -53,7 +53,7 @@ def __init__(self, joiner, index): self.joiner = joiner self.index = index - def __str__(self): + def __str__(self) -> str: return self.joiner.joiner_str if self.index < self.joiner.n_items - 1 else "" def __init__(self, joiner_str): @@ -138,5 +138,5 @@ def __iadd__(self, source_line): def __len__(self): return len(self.lines) - def __str__(self): + def __str__(self) -> str: return self.text diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index dd8225ed91..5033fae902 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -676,6 +676,23 @@ def generic_dump(cls, node: RootNode, **kwargs: Any) -> str: """ return str(node) + @overload + def generic_visit(self, node: Node, **kwargs: Any) -> str: ... + + @overload + def generic_visit( + self, + node: Union[ + list, + tuple, + collections.abc.Set, + collections.abc.Sequence, + dict, + collections.abc.Mapping, + ], + **kwargs: Any, + ) -> Collection[str]: ... + def generic_visit(self, node: RootNode, **kwargs: Any) -> Union[str, Collection[str]]: if isinstance(node, Node): template, key = self.get_template(node) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f21ee0e736..7f4d5b7b97 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -71,7 +71,7 @@ class DimensionKind(StrEnum): VERTICAL = "vertical" LOCAL = "local" - def __str__(self): + def __str__(self) -> str: return self.value @@ -80,7 +80,7 @@ class Dimension: value: str kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) - def __str__(self): + def __str__(self) -> str: return f"{self.value}[{self.kind}]" def __call__(self, val: int) -> NamedIndex: @@ -641,7 +641,7 @@ def asnumpy(self) -> np.ndarray: ... def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field: ... + def restrict(self, item: AnyIndexSpec) -> Self: ... @abc.abstractmethod def as_scalar(self) -> core_defs.ScalarT: ... @@ -651,7 +651,7 @@ def as_scalar(self) -> core_defs.ScalarT: ... def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field: ... + def __getitem__(self, item: AnyIndexSpec) -> Self: ... @abc.abstractmethod def __abs__(self) -> Field: ... @@ -867,22 +867,6 @@ def _connectivity( raise NotImplementedError -@dataclasses.dataclass(frozen=True) -class GTInfo: - definition: Any - ir: Any - - -@dataclasses.dataclass(frozen=True) -class Backend: - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - - # TODO : proper definition and implementation - def generate_operator(self, ir): - return ir - - @runtime_checkable class Connectivity(Protocol): max_neighbors: int @@ -1083,7 +1067,7 @@ class FieldBuiltinFuncRegistry: collections.ChainMap() ) - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any) -> None: cls._builtin_func_map = collections.ChainMap( {}, # New empty `dict` for new registrations on this class *[ diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index f09e1b1ac3..6642d9a055 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -17,8 +17,9 @@ import functools import itertools import operator +from collections.abc import Iterator, Sequence -from gt4py.eve.extended_typing import Any, Optional, Sequence, cast +from gt4py.eve.extended_typing import Any, Optional, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -148,9 +149,11 @@ def restrict_to_intersection( ) -def iterate_domain(domain: common.Domain): +def iterate_domain( + domain: common.Domain, +) -> Iterator[tuple[tuple[common.Dimension, int]]]: for i in itertools.product(*[list(r) for r in domain.ranges]): - yield tuple(zip(domain.dims, i)) + yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]` def _expand_ellipsis( diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 93942a5959..672dc9c620 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -16,6 +16,7 @@ import contextlib import contextvars as cvars +from collections.abc import Generator from typing import Any import gt4py.eve as eve @@ -39,7 +40,7 @@ def new_context( *, closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING, offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, -): +) -> Generator[cvars.Context, None, None]: import gt4py.next.embedded.context as this_module updates: list[tuple[cvars.ContextVar[Any], Any]] = [] @@ -51,7 +52,7 @@ def new_context( # Create new context with provided values ctx = cvars.copy_context() - def ctx_updater(*args): + def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None: for cvar, value in args: cvar.set(value) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 5a07328531..e884c61f36 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -54,7 +54,7 @@ def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArra def _make_builtin( - builtin_name: str, array_builtin_name: str, reverse=False + builtin_name: str, array_builtin_name: str, reverse: bool = False ) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: cls_ = _get_nd_array_class(*fields) @@ -228,7 +228,7 @@ def remap( __call__ = remap # type: ignore[assignment] - def restrict(self, index: common.AnyIndexSpec) -> common.Field: + def restrict(self, index: common.AnyIndexSpec) -> NdArrayField: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] new_buffer = self.__class__.array_ns.asarray(new_buffer) @@ -435,7 +435,7 @@ def inverse_image( return new_dims - def restrict(self, index: common.AnyIndexSpec) -> common.Field: + def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: cache_key = (id(self.ndarray), self.domain, index) if (restricted_connectivity := self._cache.get(cache_key, None)) is None: diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 023da5c5f8..c5f5fd0503 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -14,7 +14,7 @@ import dataclasses from types import ModuleType -from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar +from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar import numpy as np @@ -37,16 +37,19 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: @dataclasses.dataclass(frozen=True) -class ScanOperator(EmbeddedOperator[_R, _P]): +class ScanOperator(EmbeddedOperator[core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...], _P]): forward: bool - init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] + init: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] axis: common.Dimension def __call__( # type: ignore[override] self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar, # type: ignore[override] - ) -> common.Field: + ) -> ( + common.Field[Any, core_defs.ScalarT] + | tuple[common.Field[Any, core_defs.ScalarT] | tuple, ...] + ): scan_range = embedded_context.closure_column_range.get() assert self.axis == scan_range[0] scan_axis = scan_range[0] @@ -64,13 +67,13 @@ def __call__( # type: ignore[override] xp = _get_array_ns(*all_args) res = _construct_scan_array(out_domain, xp)(self.init) - def scan_loop(hpos): - acc = self.init + def scan_loop(hpos: Sequence[common.NamedIndex]) -> None: + acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init for k in scan_range[1] if self.forward else reversed(scan_range[1]): pos = (*hpos, (scan_axis, k)) new_args = [_tuple_at(pos, arg) for arg in args] new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} - acc = self.fun(acc, *new_args, **new_kwargs) + acc = self.fun(acc, *new_args, **new_kwargs) # type: ignore[arg-type] # need to express that the first argument is the same type as the return _tuple_assign_value(pos, res, acc) if len(non_scan_domain) == 0: @@ -91,7 +94,7 @@ def _get_out_domain( ]) -def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): +def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> Optional[_R]: if "out" in kwargs: # called from program or direct field_operator as program new_context_kwargs = {} @@ -118,9 +121,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): res = ctx.run(op, *args, **kwargs) _tuple_assign_field( out, - res, + res, # type: ignore[arg-type] # maybe can't be inferred properly because decorator.py is not properly typed yet domain=out_domain, ) + return None else: # called from other field_operator or missing `out` argument if "offset_provider" in kwargs: @@ -139,9 +143,9 @@ def _tuple_assign_field( target: tuple[common.MutableField | tuple, ...] | common.MutableField, source: tuple[common.Field | tuple, ...] | common.Field, domain: common.Domain, -): +) -> None: @utils.tree_map - def impl(target: common.MutableField, source: common.Field): + def impl(target: common.MutableField, source: common.Field) -> None: if common.is_field(source): target[domain] = source[domain] else: @@ -169,11 +173,17 @@ def _get_array_ns( def _construct_scan_array( - domain: common.Domain, xp: ModuleType -): # TODO(havogt) introduce a NDArrayNamespace protocol + domain: common.Domain, + xp: ModuleType, # TODO(havogt) introduce a NDArrayNamespace protocol +) -> Callable[ + [core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]], + common.MutableField | tuple[common.MutableField | tuple, ...], +]: @utils.tree_map - def impl(init: core_defs.Scalar) -> common.Field: - return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain) + def impl(init: core_defs.Scalar) -> common.MutableField: + res = common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain) + assert common.is_mutable_field(res) + return res return impl @@ -184,7 +194,7 @@ def _tuple_assign_value( source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], ) -> None: @utils.tree_map - def impl(target: common.MutableField, source: core_defs.Scalar): + def impl(target: common.MutableField, source: core_defs.Scalar) -> None: target[pos] = source impl(target, source) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 673eaca757..e0be899c00 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -21,7 +21,9 @@ """ import sys -from typing import Callable +import types +from collections.abc import Callable +from typing import Optional from gt4py.next import config @@ -41,7 +43,9 @@ def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) - return formatting.format_compilation_error(type(err), err.message, err.location) -def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None: +def compilation_error_hook( + fallback: Callable, type_: type, value: BaseException, tb: Optional[types.TracebackType] +) -> None: """ Format `CompilationError`s in a neat way. diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8d9552ab6d..5c0da54ab8 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -89,7 +89,7 @@ class BuiltInFunction(Generic[_R, _P]): # e.g. a fused multiply add could have a default implementation as a*b+c, but an optimized implementation for a specific `Field` function: Callable[_P, _R] - def __post_init__(self): + def __post_init__(self) -> None: object.__setattr__(self, "name", f"{self.function.__module__}.{self.function.__name__}") object.__setattr__(self, "__doc__", self.function.__doc__) @@ -246,7 +246,7 @@ def astype( UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES = ["isfinite", "isinf", "isnan"] -def _make_unary_math_builtin(name): +def _make_unary_math_builtin(name: str) -> None: def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: ERA001 [commented-out-code] @@ -267,7 +267,7 @@ def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs BINARY_MATH_NUMBER_BUILTIN_NAMES = ["minimum", "maximum", "fmod", "power"] -def _make_binary_math_builtin(name): +def _make_binary_math_builtin(name: str) -> None: def impl( lhs: common.Field | core_defs.ScalarT, rhs: common.Field | core_defs.ScalarT, @@ -322,11 +322,11 @@ class FieldOffset(runtime.Offset): def _cache(self) -> dict: return {} - def __post_init__(self): + def __post_init__(self) -> None: if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL: raise ValueError("Second dimension in offset must be a local dimension.") - def __gt_type__(self): + def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) def __getitem__(self, offset: int) -> common.ConnectivityField: @@ -351,7 +351,7 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: return connectivity - def as_connectivity_field(self): + def as_connectivity_field(self) -> common.ConnectivityField: """Convert to connectivity field using the offset providers in current embedded execution context.""" assert isinstance(self.value, str) current_offset_provider = embedded.context.offset_provider.get(None) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 88564900ac..2a9cb47950 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -27,7 +27,7 @@ class LocatedNode(Node): location: SourceLocation - def __str__(self): + def __str__(self) -> str: from gt4py.next.ffront.foast_pretty_printer import pretty_format try: diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 982655e079..6044b41421 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -871,7 +871,6 @@ def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: - return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) @@ -891,6 +890,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) ), ) + assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType)) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 7cdece6c59..0e39853a3c 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -444,7 +444,7 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op, *args, **kwargs): + def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: lowered_args = [self.visit(arg, **kwargs) for arg in args] if any(type_info.contains_local_field(arg.type) for arg in args): lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index f90fd8efcd..5b7e5f80d3 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable, TypeVar +from typing import Any, Callable, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next.ffront import type_info as ti_ffront @@ -38,7 +38,7 @@ def to_tuples_of_iterator(expr: itir.Expr | str, arg_type: ts.TypeSpec) -> itir. """ param = f"__toi_{eve_utils.content_hash(expr)}" - def fun(primitive_type, path): + def fun(primitive_type: ts.TypeSpec, path: tuple[int, ...]) -> itir.Expr: inner_expr = im.deref("it") for path_part in path: inner_expr = im.tuple_get(path_part, inner_expr) @@ -79,7 +79,7 @@ def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec) -> itir. for type_ in type_constituents ) - def fun(_, path): + def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall: param_name = "__iot_el" for path_part in path: param_name = f"{param_name}_{path_part}" diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index b08d355990..eb4d07c8fc 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -168,7 +168,7 @@ def apply( ) -> itir.FencilDefinition: return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) - def __init__(self, grid_type): + def __init__(self, grid_type: common.GridType): self.grid_type = grid_type def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index e8095b2777..83b86ac656 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -197,7 +197,9 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: # TODO: we want some generic field type here, but our type system does not support it yet. return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype) - return type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True) + res = type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True) + assert isinstance(res, (ts.FieldType, ts.TupleType)) + return res @type_info.function_signature_incompatibilities.register diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f9f1ba47e0..b7cf36187d 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -24,7 +24,13 @@ import math import sys import warnings -from typing import ( + +import numpy as np +import numpy.typing as npt + +from gt4py._core import definitions as core_defs +from gt4py.eve import extended_typing as xtyping +from gt4py.eve.extended_typing import ( Any, Callable, Generic, @@ -34,6 +40,7 @@ NoReturn, Optional, Protocol, + Self, Sequence, SupportsFloat, SupportsInt, @@ -45,12 +52,6 @@ overload, runtime_checkable, ) - -import numpy as np -import numpy.typing as npt - -from gt4py._core import definitions as core_defs -from gt4py.eve import extended_typing as xtyping from gt4py.next import common, embedded as next_embedded from gt4py.next.embedded import exceptions as embedded_exceptions from gt4py.next.ffront import fbuiltins @@ -86,7 +87,7 @@ class SparseTag(Tag): ... class NeighborTableOffsetProvider: def __init__( self, - table: npt.NDArray, + table: core_defs.NDArrayObject, origin_axis: common.Dimension, neighbor_axis: common.Dimension, max_neighbors: int, @@ -103,7 +104,9 @@ def __init__( def mapped_index( self, primary: common.IntIndex, neighbor_idx: common.IntIndex ) -> common.IntIndex: - return self.table[(primary, neighbor_idx)] + res = self.table[(primary, neighbor_idx)] + assert common.is_int_index(res) + return res class StridedNeighborOffsetProvider: @@ -1088,7 +1091,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field: + def restrict(self, item: common.AnyIndexSpec) -> Self: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off d, r = item[0] assert d == self._dimension @@ -1207,7 +1210,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field: + def restrict(self, item: common.AnyIndexSpec) -> Self: # TODO set a domain... return self diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 017068bb19..e64220844d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -355,7 +355,7 @@ def promote_to_const_iterator(expr: str | itir.Expr) -> itir.Expr: def promote_to_lifted_stencil( op: str | itir.SymRef | Callable, -) -> Callable[..., itir.Expr]: +) -> Callable[..., itir.FunCall]: """ Promotes a function `op` from values to iterators. @@ -372,7 +372,7 @@ def promote_to_lifted_stencil( if isinstance(op, (str, itir.SymRef, itir.Lambda)): op = call(op) - def _impl(*its: itir.Expr) -> itir.Expr: + def _impl(*its: itir.Expr) -> itir.FunCall: args = [ f"__arg{i}" for i in range(len(its)) ] # TODO: `op` must not contain `SymRef(id="__argX")` diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 4feeca43e3..1d675beea2 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -415,7 +415,7 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di ) sizes[provider.neighbor_axis.value] = max( sizes.get(provider.neighbor_axis.value, 0), - provider.table.max(), + provider.table.max(), # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes diff --git a/src/gt4py/next/otf/binding/interface.py b/src/gt4py/next/otf/binding/interface.py index 2ea13edd58..44b5ef3e9e 100644 --- a/src/gt4py/next/otf/binding/interface.py +++ b/src/gt4py/next/otf/binding/interface.py @@ -21,7 +21,7 @@ from gt4py.next.otf import languages -def format_source(settings: languages.LanguageSettings, source): +def format_source(settings: languages.LanguageSettings, source: str) -> str: return codegen.format_source(settings.formatter_key, source, style=settings.formatter_style) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 05e072b0d6..e308883af6 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -144,11 +144,11 @@ class BindingCodeGenerator(TemplatedGenerator): BindingFunction = as_jinja("""module.def("{{exported_name}}", &{{wrapper_name}}, "{{doc}}");""") - def visit_FunctionCall(self, call: FunctionCall): + def visit_FunctionCall(self, call: FunctionCall) -> str: args = [self.visit(arg) for arg in call.args] return cpp_interface.render_function_call(call.target, args) - def visit_BufferSID(self, sid: BufferSID, **kwargs): + def visit_BufferSID(self, sid: BufferSID, **kwargs: Any) -> str: pybuffer = f"{sid.source_buffer}.first" dims = [self.visit(dim) for dim in sid.dimensions] origin = f"{sid.source_buffer}.second" @@ -158,7 +158,7 @@ def visit_BufferSID(self, sid: BufferSID, **kwargs): renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})" return renamed - def visit_CompositeSID(self, node: CompositeSID, **kwargs): + def visit_CompositeSID(self, node: CompositeSID, **kwargs: Any) -> str: kwargs["composite_ids"] = ( f"gridtools::integral_constant" for i in range(len(node.elems)) ) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 2aadd4e21f..694a99e54e 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -98,12 +98,12 @@ class CMakeProject( build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG extra_cmake_flags: list[str] = dataclasses.field(default_factory=list) - def build(self): + def build(self) -> None: self._write_files() self._run_config() self._run_build() - def _write_files(self): + def _write_files(self) -> None: for name, content in self.source_files.items(): (self.root_path / name).write_text(content, encoding="utf-8") @@ -118,7 +118,7 @@ def _write_files(self): self.root_path, ) - def _run_config(self): + def _run_config(self) -> None: logfile = self.root_path / "log_config.txt" with logfile.open(mode="w") as log_file_pointer: subprocess.check_call( @@ -139,7 +139,7 @@ def _run_config(self): build_data.update_status(new_status=build_data.BuildStatus.CONFIGURED, path=self.root_path) - def _run_build(self): + def _run_build(self) -> None: logfile = self.root_path / "log_build.txt" with logfile.open(mode="w") as log_file_pointer: subprocess.check_call( diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 2c60b32ee3..2fcb7ad0d9 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -85,7 +85,7 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator): """ ) - def visit_FindDependency(self, dep: FindDependency): + def visit_FindDependency(self, dep: FindDependency) -> str: # TODO(ricoh): do not add more libraries here # and do not use this design in a new build system. # Instead, design this to be extensible (refer to ADR-0016). @@ -103,7 +103,7 @@ def visit_FindDependency(self, dep: FindDependency): case _: raise ValueError(f"Library '{dep.name}' is not supported") - def visit_LinkDependency(self, dep: LinkDependency): + def visit_LinkDependency(self, dep: LinkDependency) -> str: # TODO(ricoh): do not add more libraries here # and do not use this design in a new build system. # Instead, design this to be extensible (refer to ADR-0016). diff --git a/src/gt4py/next/otf/compilation/common.py b/src/gt4py/next/otf/compilation/common.py index 784295a55e..9212608558 100644 --- a/src/gt4py/next/otf/compilation/common.py +++ b/src/gt4py/next/otf/compilation/common.py @@ -19,5 +19,5 @@ import importlib -def python_module_suffix(): +def python_module_suffix() -> str: return importlib.machinery.EXTENSION_SUFFIXES[0][1:] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 88c1b44792..6609ea5229 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -67,7 +67,7 @@ class ProgramSource(Generic[SrcL, SettingT]): language: type[SrcL] language_settings: SettingT - def __post_init__(self): + def __post_init__(self) -> None: if not isinstance(self.language_settings, self.language.settings_class): raise TypeError( f"Wrong language settings type for '{self.language}', must be subclass of '{self.language.settings_class}'." @@ -124,7 +124,7 @@ def build(self) -> None: ... class CompiledProgram(Protocol): """Executable python representation of a program.""" - def __call__(self, *args, **kwargs) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> None: ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 6a020f1102..fde41182cc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -115,7 +115,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: TernaryExpr = as_fmt("({cond}?{true_expr}:{false_expr})") CastExpr = as_fmt("static_cast<{new_dtype}>({obj_expr})") - def visit_TaggedValues(self, node: gtfn_ir.TaggedValues, **kwargs): + def visit_TaggedValues(self, node: gtfn_ir.TaggedValues, **kwargs: Any) -> str: tags = self.visit(node.tags) values = self.visit(node.values) if self.is_cartesian: @@ -135,7 +135,7 @@ def visit_OffsetLiteral(self, node: gtfn_ir.OffsetLiteral, **kwargs: Any) -> str "::gridtools::sid::composite::keys<${','.join(f'::gridtools::integral_constant' for i in range(len(values)))}>::make_values(${','.join(values)})" ) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): + def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> str: if ( isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id in self.user_defined_function_ids @@ -179,7 +179,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) - def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): + def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs: Any) -> str: expr_ = "return " + self.visit(node.expr) return self.generic_visit(node, expr_=expr_) @@ -210,11 +210,12 @@ def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): """ ) - def visit_TemporaryAllocation(self, node, **kwargs): + def visit_TemporaryAllocation(self, node: gtfn_ir.TemporaryAllocation, **kwargs: Any) -> str: # TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with # start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with # size and offset values, just to here convert back in order to obtain stop values again. # TODO(tehrengruber): Fix memory alignment. + assert isinstance(node.domain, (gtfn_ir.CartesianDomain, gtfn_ir.UnstructuredDomain)) assert node.domain.tagged_offsets.tags == node.domain.tagged_sizes.tags tags = node.domain.tagged_offsets.tags new_sizes = [] @@ -330,14 +331,14 @@ class GTFNIMCodegen(GTFNCodegen): ReturnStmt = as_fmt("return {ret};") - def visit_Conditional(self, node: gtfn_im_ir.Conditional, **kwargs): + def visit_Conditional(self, node: gtfn_im_ir.Conditional, **kwargs: Any) -> str: if_rhs_ = self.visit(node.if_stmt.rhs) else_rhs_ = self.visit(node.else_stmt.rhs) return self.generic_visit(node, if_rhs_=if_rhs_, else_rhs_=else_rhs_) def visit_ImperativeFunctionDefinition( - self, node: gtfn_im_ir.ImperativeFunctionDefinition, **kwargs - ): + self, node: gtfn_im_ir.ImperativeFunctionDefinition, **kwargs: Any + ) -> str: expr_ = "".join(self.visit(stmt) for stmt in node.fun) return self.generic_visit(node, expr_=expr_) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index 49cf16bf4c..c92b269a3a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -from typing import Any, Dict, Iterable, Iterator, List, TypeGuard, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, TypeGuard, Union import gt4py.eve as eve from gt4py.eve import NodeTranslator @@ -145,14 +145,18 @@ def _make_sparse_acess( class PlugInCurrentIdx(NodeTranslator): - def visit_SymRef(self, node): + def visit_SymRef( + self, node: gtfn_ir_common.SymRef + ) -> gtfn_ir.OffsetLiteral | gtfn_ir_common.SymRef: if node.id == "nbh_iter": return self.cur_idx if self.acc is not None and node.id == self.acc.id: return gtfn_ir_common.SymRef(id=self.red_idx) return self.generic_visit(node) - def __init__(self, cur_idx, acc, red_idx): + def __init__( + self, cur_idx: gtfn_ir.OffsetLiteral, acc: Optional[gtfn_ir_common.Sym], red_idx: str + ): self.cur_idx = cur_idx self.acc = acc self.red_idx = red_idx @@ -164,12 +168,14 @@ class GTFN_IM_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) - def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs): + def visit_SymRef(self, node: gtfn_ir_common.SymRef, **kwargs: Any) -> gtfn_ir_common.SymRef: if "localized_symbols" in kwargs and node.id in kwargs["localized_symbols"]: return gtfn_ir_common.SymRef(id=kwargs["localized_symbols"][node.id]) return node - def commit_args(self, node: gtfn_ir.FunCall, tmp_id: str, fun_id: str, **kwargs): + def commit_args( + self, node: gtfn_ir.FunCall, tmp_id: str, fun_id: str, **kwargs: Any + ) -> gtfn_ir.FunCall: for i, arg in enumerate(node.args): expr = self.visit(arg, **kwargs) self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{tmp_id}_{i}"), rhs=expr)) @@ -182,14 +188,14 @@ def _expand_lambda( new_args: List[gtfn_ir.FunCall], red_idx: str, max_neighbors: int, - **kwargs, - ): + **kwargs: Any, + ) -> None: fun, init = node.fun.args # type: ignore param_to_args = dict(zip([param.id for param in fun.params[1:]], new_args)) acc = fun.params[0] class InlineArgs(NodeTranslator): - def visit_Expr(self, node): + def visit_Expr(self, node: gtfn_ir_common.Expr) -> gtfn_ir_common.Expr: if hasattr(node, "id") and node.id in param_to_args: return param_to_args[node.id] return self.generic_visit(node) @@ -212,8 +218,8 @@ def _expand_symref( new_args: List[gtfn_ir.FunCall], red_idx: str, max_neighbors: int, - **kwargs, - ): + **kwargs: Any, + ) -> None: fun, init = node.fun.args # type: ignore red_lit = gtfn_ir_common.Sym(id=f"{red_idx}") @@ -232,7 +238,7 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs): + def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: offset_provider = kwargs["offset_provider"] assert offset_provider is not None @@ -258,7 +264,7 @@ def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs): return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): + def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any( isinstance( arg, @@ -309,7 +315,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): args=[self.visit(arg, **kwargs) for arg in node.args], ) - def visit_TernaryExpr(self, node: gtfn_ir.TernaryExpr, **kwargs): + def visit_TernaryExpr(self, node: gtfn_ir.TernaryExpr, **kwargs: Any) -> gtfn_ir_common.SymRef: cond = self.visit(node.cond, **kwargs) if_ = self.visit(node.true_expr, **kwargs) else_ = self.visit(node.false_expr, **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 7fa43e78c9..2207b1c1d5 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -13,14 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -from typing import Any, ClassVar, Iterable, Optional, Type, Union +from typing import Any, ClassVar, Iterable, Optional, Type, TypeGuard, Union import gt4py.eve as eve from gt4py.eve.concepts import SymbolName from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries +from gt4py.next.iterator.transforms import global_tmps from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -48,7 +48,7 @@ from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef -def pytype_to_cpptype(t: str): +def pytype_to_cpptype(t: str) -> Optional[str]: try: return { "float32": "float", @@ -135,7 +135,7 @@ def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, offset_provider: dict[str, common.Dimension | common.Connectivity], -): +) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() .if_isinstance(itir.OffsetLiteral) @@ -191,6 +191,16 @@ def _literal_as_integral_constant(node: itir.Literal) -> IntegralConstant: return IntegralConstant(value=int(node.value)) +def _is_scan(node: itir.Node) -> TypeGuard[itir.FunCall]: + return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") + + +def _bool_from_literal(node: itir.Node) -> bool: + assert isinstance(node, itir.Literal) + assert node.type == "bool" and node.value in ("True", "False") + return node.value == "True" + + @dataclasses.dataclass(frozen=True) class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): _binary_op_map: ClassVar[dict[str, str]] = { @@ -222,12 +232,12 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @classmethod def apply( cls, - node: itir.FencilDefinition | FencilWithTemporaries, + node: itir.FencilDefinition | global_tmps.FencilWithTemporaries, *, offset_provider: dict, column_axis: Optional[common.Dimension], - ): - if isinstance(node, FencilWithTemporaries): + ) -> FencilDefinition: + if isinstance(node, global_tmps.FencilWithTemporaries): fencil_definition = node.fencil elif isinstance(node, itir.FencilDefinition): fencil_definition = node @@ -294,7 +304,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> Offset def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs: Any) -> Literal: return Literal(value=node.value, type="axis_literal") - def _make_domain(self, node: itir.FunCall): + def _make_domain(self, node: itir.FunCall) -> tuple[TaggedValues, TaggedValues]: tags = [] sizes = [] offsets = [] @@ -431,32 +441,21 @@ def visit_FunctionDefinition( expr=self.visit(node.expr, **kwargs), ) - @staticmethod - def _is_scan(node: itir.Node): - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - def _visit_output_argument(self, node: itir.Expr): + def _visit_output_argument(self, node: itir.Expr) -> SidComposite | SymRef: if isinstance(node, itir.SymRef): return self.visit(node) elif isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="make_tuple"): return SidComposite(values=[self._visit_output_argument(v) for v in node.args]) raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") - @staticmethod - def _bool_from_literal(node: itir.Node): - assert isinstance(node, itir.Literal) - assert node.type == "bool" and node.value in ("True", "False") - return node.value == "True" - def visit_StencilClosure( self, node: itir.StencilClosure, extracted_functions: list, **kwargs: Any ) -> Union[ScanExecution, StencilExecution]: backend = Backend(domain=self.visit(node.domain, stencil=node.stencil, **kwargs)) - if self._is_scan(node.stencil): + if _is_scan(node.stencil): scan_id = self.uids.sequential_id(prefix="_scan") - assert isinstance(node.stencil, itir.FunCall) scan_lambda = self.visit(node.stencil.args[0], **kwargs) - forward = self._bool_from_literal(node.stencil.args[1]) + forward = _bool_from_literal(node.stencil.args[1]) scan_def = ScanPassDefinition( id=scan_id, params=scan_lambda.params, expr=scan_lambda.expr, forward=forward ) @@ -556,20 +555,29 @@ def visit_FencilDefinition( temporaries=[], ) - def visit_Temporary(self, node, *, params: list, **kwargs) -> TemporaryAllocation: - def dtype_to_cpp(x): + def visit_Temporary( + self, node: global_tmps.Temporary, *, params: list, **kwargs: Any + ) -> TemporaryAllocation: + def dtype_to_cpp(x: int | tuple | str) -> str: if isinstance(x, int): return f"std::remove_const_t<::gridtools::sid::element_type>" if isinstance(x, tuple): return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x) + ">" assert isinstance(x, str) - return pytype_to_cpptype(x) + res = pytype_to_cpptype(x) + assert isinstance(res, str) + return res + assert isinstance( + node.dtype, (int, tuple, str) + ) # TODO(havogt): this looks weird, consider refactoring return TemporaryAllocation( id=node.id, dtype=dtype_to_cpp(node.dtype), domain=self.visit(node.domain, **kwargs) ) - def visit_FencilWithTemporaries(self, node, **kwargs) -> FencilDefinition: + def visit_FencilWithTemporaries( + self, node: global_tmps.FencilWithTemporaries, **kwargs: Any + ) -> FencilDefinition: fencil = self.visit(node.fencil, **kwargs) return FencilDefinition( id=fencil.id, diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index fd80f9aacf..a65676af76 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -56,7 +56,7 @@ class ToLispLike(TemplatedGenerator): ) @classmethod - def apply(cls, root, **kwargs: Any) -> str: + def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override] transformed = apply_common_transforms( root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] ) diff --git a/src/gt4py/next/program_processors/formatters/type_check.py b/src/gt4py/next/program_processors/formatters/type_check.py index 8f17b8cf98..03aeef1264 100644 --- a/src/gt4py/next/program_processors/formatters/type_check.py +++ b/src/gt4py/next/program_processors/formatters/type_check.py @@ -12,13 +12,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Any + from gt4py.next.iterator import ir as itir, type_inference from gt4py.next.iterator.transforms import apply_common_transforms, global_tmps from gt4py.next.program_processors.processor_interface import program_formatter @program_formatter -def check_type_inference(program: itir.FencilDefinition, *args, **kwargs) -> str: +def check_type_inference(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: type_inference.pprint(type_inference.infer(program, offset_provider=kwargs["offset_provider"])) transformed = apply_common_transforms( program, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index b8032c17b8..91f35178b3 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -27,7 +27,7 @@ class ModularExecutor(ppi.ProgramExecutor): otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram] name: Optional[str] = None - def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: + def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( *args, offset_provider=kwargs["offset_provider"] ) diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 36870252c5..fcde1cc2b6 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -41,7 +41,7 @@ class ProgramProcessorCallable(Protocol[OutputT]): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: ... + def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> OutputT: ... class ProgramProcessor(ProgramProcessorCallable[OutputT], Protocol[OutputT, ProcessorKindT]): @@ -140,7 +140,7 @@ def make_program_processor( filtered_kwargs = _make_kwarg_filter(accept_kwargs) @functools.wraps(func) - def _wrapper(program: itir.FencilDefinition, *args, **kwargs) -> OutputT: + def _wrapper(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> OutputT: return func(program, *args_filter(args), **filtered_kwargs(kwargs)) if name is not None: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 77f86b4084..72874147cb 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -37,7 +37,7 @@ def dace_debuginfo( return debuginfo -def as_dace_type(type_: ts.ScalarType): +def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ elif type_.kind == ts.ScalarKind.INT32: @@ -59,7 +59,7 @@ def as_scalar_type(typestr: str) -> ts.ScalarType: return ts.ScalarType(kind) -def filter_neighbor_tables(offset_provider: dict[str, Any]): +def filter_neighbor_tables(offset_provider: dict[str, Any]) -> dict[str, NeighborTable]: return { offset: table for offset, table in offset_provider.items() @@ -67,7 +67,7 @@ def filter_neighbor_tables(offset_provider: dict[str, Any]): } -def connectivity_identifier(name: str): +def connectivity_identifier(name: str) -> str: return f"__connectivity_{name}" diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7bd65d791e..75944e1dd4 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -50,8 +50,8 @@ def convert_args( inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args, offset_provider: dict[str, common.Connectivity | common.Dimension] - ): + *args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension] + ) -> None: converted_args = [convert_arg(arg) for arg in args] conn_args = extract_connectivity_args(offset_provider, device) return inp( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index eb6c4e9d9e..7ca88eab06 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -33,7 +33,7 @@ from gt4py.next.program_processors import modular_executor, processor_interface as ppi -def _create_tmp(axes, origin, shape, dtype): +def _create_tmp(axes: str, origin: str, shape: str, dtype: Any) -> str: if isinstance(dtype, tuple): return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype)},)" else: @@ -69,7 +69,9 @@ def ${id}(${','.join(params)}): ) # extension required by global_tmps - def visit_FencilWithTemporaries(self, node, **kwargs): + def visit_FencilWithTemporaries( + self, node: gtmps_transform.FencilWithTemporaries, **kwargs: Any + ) -> str: params = self.visit(node.params) tmps = "\n ".join(self.visit(node.tmps)) @@ -84,16 +86,21 @@ def visit_FencilWithTemporaries(self, node, **kwargs): + f"\n {node.fencil.id}({args}, **kwargs)\n" ) - def visit_Temporary(self, node, **kwargs): - assert isinstance(node.domain, itir.FunCall) and node.domain.fun.id in ( - "cartesian_domain", - "unstructured_domain", + def visit_Temporary(self, node: gtmps_transform.Temporary, **kwargs: Any) -> str: + assert ( + isinstance(node.domain, itir.FunCall) + and isinstance(node.domain.fun, itir.SymRef) + and node.domain.fun.id + in ( + "cartesian_domain", + "unstructured_domain", + ) ) assert all( isinstance(r, itir.FunCall) and r.fun == itir.SymRef(id="named_range") for r in node.domain.args ) - domain_ranges = [self.visit(r.args) for r in node.domain.args] + domain_ranges = [self.visit(r.args) for r in node.domain.args] # type: ignore[attr-defined] # `node.domain` is `FunCall` checked in previous assert axes = ", ".join(label for label, _, _ in domain_ranges) origin = "{" + ", ".join(f"{label}: -{start}" for label, start, _ in domain_ranges) + "}" shape = "(" + ", ".join(f"{stop}-{start}" for _, start, stop in domain_ranges) + ")" @@ -203,7 +210,7 @@ def fencil_generator( @ppi.program_executor # type: ignore[arg-type] def execute_roundtrip( ir: itir.Node, - *args, + *args: Any, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], debug: bool = False, @@ -253,7 +260,7 @@ class Meta: class RoundtripExecutor(modular_executor.ModularExecutor): dispatch_backend: Optional[ppi.ProgramExecutor] = None - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: + def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: kwargs["backend"] = self.dispatch_backend self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( *args, **kwargs diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 88eb9b4273..a0ec811900 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -15,7 +15,8 @@ import functools import types import typing -from typing import Any, Callable, Iterator, Type, TypeGuard, cast +from collections.abc import Callable, Iterator +from typing import Any, Generic, Protocol, Type, TypeGuard, TypeVar, cast import numpy as np @@ -98,13 +99,13 @@ def primitive_constituents( def primitive_constituents( symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[True], -) -> XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: ... +) -> XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: ... def primitive_constituents( symbol_type: ts.TypeSpec, with_path_arg: bool = False, -) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: +) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: """ Return the primitive types contained in a composite type. @@ -122,7 +123,9 @@ def primitive_constituents( [FieldType(...), ScalarType(...), FieldType(...)] """ - def constituents_yielder(symbol_type: ts.TypeSpec, path: tuple[int, ...]): + def constituents_yielder( + symbol_type: ts.TypeSpec, path: tuple[int, ...] + ) -> Iterator[ts.TypeSpec] | Iterator[tuple[ts.TypeSpec, tuple[int, ...]]]: if isinstance(symbol_type, ts.TupleType): for i, el_type in enumerate(symbol_type.types): yield from constituents_yielder(el_type, (*path, i)) @@ -132,19 +135,26 @@ def constituents_yielder(symbol_type: ts.TypeSpec, path: tuple[int, ...]): else: yield symbol_type - return xiter(constituents_yielder(symbol_type, ())) + return xiter(constituents_yielder(symbol_type, ())) # type: ignore[return-value] # why resolved to XIterable[object]? +_R = TypeVar("_R", covariant=True) +_T = TypeVar("_T") + + +class TupleConstructorType(Protocol, Generic[_R]): + def __call__(self, *args: Any) -> _R: ... + + +# TODO(havogt): the complicated typing is a hint that this function needs refactoring def apply_to_primitive_constituents( symbol_type: ts.TypeSpec, - fun: ( - Callable[[ts.TypeSpec], ts.TypeSpec] | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec] - ), - _path=(), + fun: (Callable[[ts.TypeSpec], _T] | Callable[[ts.TypeSpec, tuple[int, ...]], _T]), + _path: tuple[int, ...] = (), *, - with_path_arg=False, - tuple_constructor=lambda *elements: ts.TupleType(types=[*elements]), -): + with_path_arg: bool = False, + tuple_constructor: TupleConstructorType[_R] = lambda *elements: ts.TupleType(types=[*elements]), # type: ignore[assignment] # probably related to https://github.com/python/mypy/issues/10854 +) -> _T | _R: """ Apply function to all primitive constituents of a type. @@ -280,9 +290,9 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: return is_floating_point(symbol_type) or is_integral(symbol_type) -def arithmetic_bounds(arithmetic_type: ts.ScalarType): +def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: assert is_arithmetic(arithmetic_type) - return { + return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), ts.ScalarKind.INT32: (np.iinfo(np.int32).min, np.iinfo(np.int32).max), @@ -532,8 +542,8 @@ def canonicalize_arguments( args: tuple | list, kwargs: dict, *, - ignore_errors=False, - use_signature_ordering=False, + ignore_errors: bool = False, + use_signature_ordering: bool = False, ) -> tuple[list, dict]: raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @@ -544,8 +554,8 @@ def canonicalize_function_arguments( args: tuple | list, kwargs: dict, *, - ignore_errors=False, - use_signature_ordering=False, + ignore_errors: bool = False, + use_signature_ordering: bool = False, ) -> tuple[list, dict]: num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) cargs = [UNDEFINED_ARG] * max(num_pos_params, len(args)) @@ -649,8 +659,8 @@ def function_signature_incompatibilities_func( args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec], *, - skip_canonicalization=False, - skip_structural_checks=False, + skip_canonicalization: bool = False, + skip_structural_checks: bool = False, ) -> Iterator[str]: if not skip_canonicalization: args, kwargs = canonicalize_arguments(func_type, args, kwargs, ignore_errors=True) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index f178a5752f..9487d2f12b 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -67,7 +67,7 @@ class OffsetType(TypeSpec): source: func_common.Dimension target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] - def __str__(self): + def __str__(self) -> str: return f"Offset[{self.source}, {self.target}]" @@ -85,7 +85,7 @@ class ScalarType(DataType): kind: ScalarKind shape: Optional[list[int]] = None - def __str__(self): + def __str__(self) -> str: kind_str = self.kind.name.lower() if self.shape is None: return kind_str @@ -96,7 +96,7 @@ def __str__(self): class TupleType(DataType): types: list[DataType] - def __str__(self): + def __str__(self) -> str: return f"tuple[{', '.join(map(str, self.types))}]" def __iter__(self) -> Iterator[DataType]: @@ -108,7 +108,7 @@ class FieldType(DataType, CallableType): dims: list[func_common.Dimension] dtype: ScalarType - def __str__(self): + def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" @@ -120,7 +120,7 @@ class FunctionType(TypeSpec, CallableType): kw_only_args: dict[str, DataType | DeferredType] returns: DataType | DeferredType | VoidType - def __str__(self): + def __str__(self) -> str: arg_strs = [str(arg) for arg in self.pos_only_args] kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 21932afd70..14caa1ae3e 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -43,12 +43,12 @@ class RecursionDetected(Exception): def __init__(self, obj: Any): self.obj = obj - def __enter__(self): + def __enter__(self) -> None: if id(self.obj) in self.guarded_objects: raise self.RecursionDetected() self.guarded_objects.add(id(self.obj)) - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: self.guarded_objects.remove(id(self.obj))