Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Mar 19, 2024
1 parent f350973 commit 3f6ed5e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ class FieldBuiltinFuncRegistry:
collections.ChainMap()
)

def __init_subclass__(cls, **kwargs: Any):
def __init_subclass__(cls, **kwargs: Any) -> None:
cls._builtin_func_map = collections.ChainMap(
{}, # New empty `dict` for new registrations on this class
*[
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def iterate_domain(
domain: common.Domain,
) -> Iterator[tuple[tuple[common.Dimension, int]]]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int]`
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]`


def _expand_ellipsis(
Expand Down
13 changes: 8 additions & 5 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:


@dataclasses.dataclass(frozen=True)
class ScanOperator(EmbeddedOperator[_R, _P]):
class ScanOperator(EmbeddedOperator[core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...], _P]):
forward: bool
init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]
init: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...]
axis: common.Dimension

def __call__( # type: ignore[override]
self,
*args: common.Field | core_defs.Scalar,
**kwargs: common.Field | core_defs.Scalar, # type: ignore[override]
) -> common.Field | tuple[common.Field | tuple, ...]:
) -> (
common.Field[Any, core_defs.ScalarT]
| tuple[common.Field[Any, core_defs.ScalarT] | tuple, ...]
):
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
Expand All @@ -65,13 +68,13 @@ def __call__( # type: ignore[override]
res = _construct_scan_array(out_domain, xp)(self.init)

def scan_loop(hpos: Sequence[common.NamedIndex]) -> None:
acc: _R = self.init # type: ignore[assignment] # `_R` not resolved?
acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init
for k in scan_range[1] if self.forward else reversed(scan_range[1]):
pos = (*hpos, (scan_axis, k))
new_args = [_tuple_at(pos, arg) for arg in args]
new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()}
acc = self.fun(acc, *new_args, **new_kwargs) # type: ignore[arg-type] # need to express that the first argument is the same type as the return
_tuple_assign_value(pos, res, acc) # type: ignore[arg-type] # requires more precise typing for `_R`
_tuple_assign_value(pos, res, acc)

if len(non_scan_domain) == 0:
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
Expand Down
19 changes: 10 additions & 9 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
import math
import sys
import warnings
from typing import (

import numpy as np
import numpy.typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import (
Any,
Callable,
Generic,
Expand All @@ -34,6 +40,7 @@
NoReturn,
Optional,
Protocol,
Self,
Sequence,
SupportsFloat,
SupportsInt,
Expand All @@ -45,12 +52,6 @@
overload,
runtime_checkable,
)

import numpy as np
import numpy.typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.next import common, embedded as next_embedded
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next.ffront import fbuiltins
Expand Down Expand Up @@ -1090,7 +1091,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> xtyping.Self:
def restrict(self, item: common.AnyIndexSpec) -> Self:
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
d, r = item[0]
assert d == self._dimension
Expand Down Expand Up @@ -1209,7 +1210,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> xtyping.Self:
def restrict(self, item: common.AnyIndexSpec) -> Self:
# TODO set a domain...
return self

Expand Down

0 comments on commit 3f6ed5e

Please sign in to comment.