Skip to content

Commit

Permalink
cleanup out field construction
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Apr 19, 2024
1 parent a07d8ea commit e7195a5
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 74 deletions.
38 changes: 7 additions & 31 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar

import numpy as np

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, errors, utils
from gt4py.next import common, errors, field_utils, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.field_utils import get_array_ns
from gt4py.next.type_system import type_translation


_P = ParamSpec("_P")
Expand Down Expand Up @@ -64,8 +63,10 @@ def __call__( # type: ignore[override]
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

xp = _get_array_ns(*all_args)
res = _construct_scan_array(out_domain, xp)(self.init)
xp = get_array_ns(*all_args)
res = field_utils.field_from_typespec(out_domain, xp)(
type_translation.from_value(self.init)
)

def scan_loop(hpos: Sequence[common.NamedIndex]) -> None:
acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init
Expand Down Expand Up @@ -163,31 +164,6 @@ def _intersect_scan_args(
)


def _get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np


def _construct_scan_array(
domain: common.Domain,
xp: ModuleType, # TODO(havogt) introduce a NDArrayNamespace protocol
) -> Callable[
[core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]],
common.MutableField | tuple[common.MutableField | tuple, ...],
]:
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.MutableField:
res = common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)
assert isinstance(res, common.MutableField)
return res

return impl


def _tuple_assign_value(
pos: Sequence[common.NamedIndex],
target: common.MutableField | tuple[common.MutableField | tuple, ...],
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from types import ModuleType
from typing import Callable

import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.next import common, utils
from gt4py.next.type_system import type_specifications as ts, type_translation


@utils.tree_map
def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
return field.asnumpy() if isinstance(field, common.Field) else field


def field_from_typespec(
domain: common.Domain, xp: ModuleType
) -> Callable[..., common.MutableField | tuple[common.MutableField | tuple, ...]]:
@utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple)
def impl(type_: ts.ScalarType) -> common.MutableField:
res = common._field(
xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)),
domain=domain,
)
assert isinstance(res, common.MutableField)
return res

return impl


def get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np
73 changes: 43 additions & 30 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@
overload,
runtime_checkable,
)
from gt4py.next import common
from gt4py.next import common, field_utils
from gt4py.next.embedded import (
context as embedded_context,
exceptions as embedded_exceptions,
operators,
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime
from gt4py.next.type_system import type_specifications as ts, type_translation


EMBEDDED = "embedded"
Expand Down Expand Up @@ -210,6 +211,10 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data)
)

@property
def dtype(self) -> np.dtype:
return self.data.dtype

def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
# numpy type
Expand Down Expand Up @@ -792,7 +797,7 @@ def _make_tuple(
raise RuntimeError(
"Found 'Undefined' value, this should not happen for a legal program."
)
dtype = _column_dtype(first)
dtype = _elem_dtype(first)
return Column(column_range.start, np.asarray(col, dtype=dtype))


Expand Down Expand Up @@ -1472,11 +1477,12 @@ def as_tuple_field(field: tuple | TupleField) -> TupleField:
return TupleOfFields(tuple(_wrap_field(f) for f in field))


def _column_dtype(elem: Any) -> np.dtype:
def _elem_dtype(elem: Any) -> np.dtype:
if hasattr(elem, "dtype"):
return elem.dtype
if isinstance(elem, tuple):
return np.dtype([(f"f{i}", _column_dtype(e)) for i, e in enumerate(elem)])
else:
return np.dtype(type(elem))
return np.dtype([(f"f{i}", _elem_dtype(e)) for i, e in enumerate(elem)])
return np.dtype(type(elem))


@builtins.scan.register(EMBEDDED)
Expand All @@ -1488,7 +1494,7 @@ def impl(*iters: ItIterator):

sorted_column_range = column_range if is_forward else reversed(column_range)
state = init
col = Column(column_range.start, np.zeros(len(column_range), dtype=_column_dtype(init)))
col = Column(column_range.start, np.zeros(len(column_range), dtype=_elem_dtype(init)))
for i in sorted_column_range:
state = scan_pass(state, *map(shifted_scan_arg(i), iters))
col[i] = state
Expand Down Expand Up @@ -1547,36 +1553,43 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType:
return eve.NOTHING


# TODO handle in clean way
def _np_void_to_tuple(a):
if isinstance(a, np.void):
return tuple(_np_void_to_tuple(elem) for elem in a)
return a
def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType:
if structured_dtype.names is None:
return type_translation.from_dtype(core_defs.dtype(structured_dtype))
return ts.TupleType(
types=[
_structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names
]
)


def _get_output_type(
fun: Callable,
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
args: tuple[Any, ...],
) -> ts.TypeSpec:
domain = _dimension_to_tag(domain_)
col_range = _extract_column_range(domain)
if isinstance(col_range, common.NamedRange):
del domain[col_range.dim.value]

pos = next(iter(_domain_iterator(domain)))
with embedded_context.new_context(closure_column_range=col_range) as ctx:
single_point_result = ctx.run(_compute_point, fun, args, pos, col_range)
dtype = _elem_dtype(single_point_result)
return _structured_dtype_to_typespec(dtype)


@builtins.as_fieldop.register(EMBEDDED)
def as_fieldop(fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain):
def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.UnstructuredDomain):
def impl(*args):
# TODO extract function, move private utils
domain = _dimension_to_tag(domain_)
col_range = _extract_column_range(domain)
if col_range is not eve.NOTHING:
del domain[col_range.dim.value]

pos = next(_domain_iterator(domain))
with embedded_context.new_context(closure_column_range=col_range) as ctx:
single_point_result = ctx.run(_compute_point, fun, args, pos, col_range)
if isinstance(single_point_result, Column):
single_point_result = single_point_result.data[0]
single_point_result = _np_void_to_tuple(single_point_result)

xp = operators._get_array_ns(*args)
out = operators._construct_scan_array(common.domain(domain_), xp)(single_point_result)
xp = field_utils.get_array_ns(*args)
type_ = _get_output_type(fun, domain, args)
out = field_utils.field_from_typespec(common.domain(domain), xp)(type_)

# TODO `out` gets allocated in the order of domain_, but might not match the order of `target` in set_at

closure(
_dimension_to_tag(domain_),
_dimension_to_tag(domain),
fun,
out,
list(args),
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __str__(self) -> str:
def __iter__(self) -> Iterator[DataType]:
yield from self.types

def __len__(self) -> int:
return len(self.types)


@dataclass(frozen=True)
class FieldType(DataType, CallableType):
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import numpy.typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.next import common
from gt4py.next.type_system import type_info, type_specifications as ts
Expand Down Expand Up @@ -219,3 +220,31 @@ def from_value(value: Any) -> ts.TypeSpec:
return symbol_type
else:
raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.")


def as_dtype(type_: ts.ScalarType) -> core_defs.DType:
if type_.kind == ts.ScalarKind.BOOL:
return core_defs.BoolDType()
elif type_.kind == ts.ScalarKind.INT32:
return core_defs.Int32DType()
elif type_.kind == ts.ScalarKind.INT64:
return core_defs.Int64DType()
elif type_.kind == ts.ScalarKind.FLOAT32:
return core_defs.Float32DType()
elif type_.kind == ts.ScalarKind.FLOAT64:
return core_defs.Float64DType()
raise ValueError(f"Scalar type '{type_}' not supported.")


def from_dtype(dtype: core_defs.DType) -> ts.ScalarType:
if dtype == core_defs.BoolDType():
return ts.ScalarType(kind=ts.ScalarKind.BOOL)
elif dtype == core_defs.Int32DType():
return ts.ScalarType(kind=ts.ScalarKind.INT32)
elif dtype == core_defs.Int64DType():
return ts.ScalarType(kind=ts.ScalarKind.INT64)
elif dtype == core_defs.Float32DType():
return ts.ScalarType(kind=ts.ScalarKind.FLOAT32)
elif dtype == core_defs.Float64DType():
return ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
raise ValueError(f"DType '{dtype}' not supported.")
76 changes: 63 additions & 13 deletions src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import functools
from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast
from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload


class RecursionGuard:
Expand Down Expand Up @@ -69,26 +69,76 @@ def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]:
return (value,)


def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]:
@overload
def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, ...]]: ...


@overload
def tree_map(
*, collection_type: type | tuple[type, ...], result_collection_type: Optional[type] = None
) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ...


def tree_map(
*args: Callable[_P, _R],
collection_type: type | tuple[type, ...] = tuple,
result_collection_type: Optional[type] = None,
) -> (
Callable[..., _R | tuple[_R | tuple, ...]]
| Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]
):
"""
Apply `fun` to each entry of (possibly nested) tuples.
Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s).
Args:
fun: Function to apply to each entry of the collection.
collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types.
result_collection_type: Type of the collection to be returned. If `None` the same type as `collection_type` is used.
Examples:
>>> tree_map(lambda x: x + 1)(((1, 2), 3))
((2, 3), 4)
>>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6))
((5, 7), 9)
"""
@functools.wraps(fun)
def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
if isinstance(args[0], tuple):
assert all(isinstance(arg, tuple) and len(args[0]) == len(arg) for arg in args)
return tuple(impl(*arg) for arg in zip(*args))
>>> tree_map(collection_type=list)(lambda x: x + 1)([[1, 2], 3])
[[2, 3], 4]
return fun(
*cast(_P.args, args)
) # mypy doesn't understand that `args` at this point is of type `_P.args`
>>> tree_map(collection_type=list, result_collection_type=tuple)(lambda x: x + 1)([[1, 2], 3])
((2, 3), 4)
"""

return impl
if result_collection_type is None:
if isinstance(collection_type, tuple):
raise TypeError(
"tree_map() requires `result_collection_type` when `collection_type` is a tuple."
)
result_collection_type = collection_type

if len(args) == 1:
fun = args[0]

@functools.wraps(fun)
def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
if isinstance(args[0], collection_type):
assert all(
isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args
)
assert result_collection_type is not None
return result_collection_type(impl(*arg) for arg in zip(*args))

return fun(
*cast(_P.args, args)
) # mypy doesn't understand that `args` at this point is of type `_P.args`

return impl
if len(args) == 0:
return functools.partial(
tree_map,
collection_type=collection_type,
result_collection_type=result_collection_type,
)
raise TypeError(
"tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_type`."
)
Loading

0 comments on commit e7195a5

Please sign in to comment.