diff --git a/loopy/check.py b/loopy/check.py index ee24d6e4b..1c552eae1 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -68,7 +68,7 @@ check_each_kernel, ) from loopy.type_inference import TypeReader -from loopy.typing import ExpressionT, not_none +from loopy.typing import Expression, not_none logger = logging.getLogger(__name__) @@ -221,7 +221,7 @@ def check_offsets_and_dim_tags(kernel: LoopKernel) -> None: dep_mapper = DependencyMapper() def ensure_depends_only_on_arguments( - what: str, expr: Union[str, ExpressionT]) -> None: + what: str, expr: Union[str, Expression]) -> None: if isinstance(expr, str): expr = Variable(expr) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 2e39d89bd..389ecd042 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -57,7 +57,7 @@ from loopy.target import TargetBase from loopy.tools import LoopyKeyBuilder, caches from loopy.types import LoopyType -from loopy.typing import ExpressionT +from loopy.typing import Expression from loopy.version import DATA_MODEL_VERSION @@ -200,14 +200,14 @@ class CodeGenerationState: kernel: LoopKernel target: TargetBase implemented_domain: isl.Set - implemented_predicates: FrozenSet[Union[str, ExpressionT]] + implemented_predicates: FrozenSet[Union[str, Expression]] # /!\ mutable seen_dtypes: Set[LoopyType] seen_functions: Set[SeenFunction] seen_atomic_dtypes: Set[LoopyType] - var_subst_map: Map[str, ExpressionT] + var_subst_map: Map[str, Expression] allow_complex: bool callables_table: CallablesTable is_entrypoint: bool @@ -231,7 +231,7 @@ def copy(self, **kwargs: Any) -> "CodeGenerationState": return replace(self, **kwargs) def copy_and_assign( - self, name: str, value: ExpressionT) -> "CodeGenerationState": + self, name: str, value: Expression) -> "CodeGenerationState": """Make a copy of self with variable *name* fixed to *value*.""" return self.copy(var_subst_map=self.var_subst_map.set(name, value)) diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 28aa3be30..9fbb3c9d7 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -93,12 +93,12 @@ def make_slab(space, iname, start, stop, iname_multiplier=1): space = zero.get_domain_space() - from pymbolic.primitives import Expression + from pymbolic.primitives import ExpressionNode from loopy.symbolic import aff_from_expr - if isinstance(start, Expression): + if isinstance(start, ExpressionNode): start = aff_from_expr(space, start) - if isinstance(stop, Expression): + if isinstance(stop, ExpressionNode): stop = aff_from_expr(space, stop) if isinstance(start, int): diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 967640260..4f392edd6 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -49,7 +49,7 @@ import islpy as isl from islpy import dim_type -from pymbolic import ArithmeticExpressionT +from pymbolic import ArithmeticExpression from pytools import ( UniqueNameGenerator, generate_unique_names, @@ -75,7 +75,7 @@ from loopy.target import TargetBase from loopy.tools import update_persistent_hash from loopy.types import LoopyType, NumpyType -from loopy.typing import ExpressionT, InameStr +from loopy.typing import Expression, InameStr if TYPE_CHECKING: @@ -193,7 +193,7 @@ class LoopKernel(Taggable): with non-parallel implementation tags. """ - applied_iname_rewrites: Tuple[Dict[InameStr, ExpressionT], ...] = () + applied_iname_rewrites: Tuple[Dict[InameStr, Expression], ...] = () """ A list of past substitution dictionaries that were applied to the kernel. These are stored so that they may be repeated @@ -1036,8 +1036,8 @@ def get_grid_size_upper_bounds_as_exprs( self, callables_table, ignore_auto=False, return_dict=False ) -> Tuple[ - Tuple[ArithmeticExpressionT, ...], - Tuple[ArithmeticExpressionT, ...]]: + Tuple[ArithmeticExpression, ...], + Tuple[ArithmeticExpression, ...]]: """Return a tuple (global_size, local_size) containing a grid that could accommodate execution of *all* instructions in the kernel. diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 8cabbec23..220477c30 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -45,7 +45,7 @@ import numpy as np # noqa from typing_extensions import TypeAlias -from pymbolic import ArithmeticExpressionT +from pymbolic import ArithmeticExpression from pymbolic.primitives import is_arithmetic_expression from pytools import ImmutableRecord from pytools.tag import Tag, Taggable @@ -53,7 +53,7 @@ from loopy.diagnostic import LoopyError from loopy.symbolic import flatten from loopy.types import LoopyType -from loopy.typing import ExpressionT, ShapeType, auto, is_integer +from loopy.typing import Expression, ShapeType, auto, is_integer if TYPE_CHECKING: @@ -609,8 +609,8 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes, # {{{ array base class (for arguments and temporary arrays) -ToShapeLikeConvertible: TypeAlias = (Tuple[ExpressionT | str, ...] - | ExpressionT | type[auto] | str | tuple[str, ...]) +ToShapeLikeConvertible: TypeAlias = (Tuple[Expression | str, ...] + | Expression | type[auto] | str | tuple[str, ...]) def _parse_shape_or_strides( @@ -634,12 +634,12 @@ def _parse_shape_or_strides( raise ValueError("shape can't be a list") if isinstance(x_parsed, tuple): - x_tup: tuple[ExpressionT | str, ...] = x_parsed + x_tup: tuple[Expression | str, ...] = x_parsed else: assert x_parsed is not auto - x_tup = (cast(ExpressionT, x_parsed),) + x_tup = (cast(Expression, x_parsed),) - def parse_arith(x: ExpressionT | str) -> ArithmeticExpressionT: + def parse_arith(x: Expression | str) -> ArithmeticExpression: if isinstance(x, str): res = parse(x) else: @@ -714,7 +714,7 @@ class ArrayBase(ImmutableRecord, Taggable): """See :ref:`data-dim-tags`. """ - offset: Union[ExpressionT, str, None] + offset: Union[Expression, str, None] """Offset from the beginning of the buffer to the point from which the strides are counted, in units of the :attr:`dtype`. May be one of @@ -1158,9 +1158,9 @@ def drop_vec_dims( if not isinstance(dim_tag, VectorArrayDimTag)) -def get_strides(array: ArrayBase) -> Tuple[ExpressionT, ...]: +def get_strides(array: ArrayBase) -> Tuple[Expression, ...]: from pymbolic import var - result: List[ExpressionT] = [] + result: List[Expression] = [] if array.dim_tags is None: return () @@ -1188,10 +1188,10 @@ def get_strides(array: ArrayBase) -> Tuple[ExpressionT, ...]: class AccessInfo(ImmutableRecord): array_name: str vector_index: Optional[int] - subscripts: Tuple[ExpressionT, ...] + subscripts: Tuple[Expression, ...] -def _apply_offset(sub: ExpressionT, ary: ArrayBase) -> ExpressionT: +def _apply_offset(sub: Expression, ary: ArrayBase) -> Expression: """ Helper for :func:`get_access_info`. Augments *ary*'s subscript index expression (*sub*) with its offset info. @@ -1228,8 +1228,8 @@ def _apply_offset(sub: ExpressionT, ary: ArrayBase) -> ExpressionT: def get_access_info(kernel: "LoopKernel", ary: Union["ArrayArg", "TemporaryVariable"], - index: Union[ExpressionT, Tuple[ExpressionT, ...]], - eval_expr: Callable[[ExpressionT], int], + index: Union[Expression, Tuple[Expression, ...]], + eval_expr: Callable[[Expression], int], vectorization_info: "VectorizationInfo") -> AccessInfo: """ :arg ary: an object of type :class:`ArrayBase` @@ -1283,7 +1283,7 @@ def eval_expr_assert_integer_constant(i, expr) -> int: num_target_axes = ary.num_target_axes() vector_index = None - subscripts: List[ExpressionT] = [0] * num_target_axes + subscripts: List[Expression] = [0] * num_target_axes vector_size = ary.vector_size(kernel.target) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 5d1de0e5d..8ca5aa87a 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -45,7 +45,7 @@ import numpy as np from immutables import Map -from pymbolic import ArithmeticExpressionT, Variable +from pymbolic import ArithmeticExpression, Variable from pytools import ImmutableRecord from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase @@ -65,7 +65,7 @@ make_assignment, ) from loopy.types import LoopyType, ToLoopyTypeConvertible -from loopy.typing import ExpressionT, ShapeType, auto +from loopy.typing import Expression, ShapeType, auto __doc__ = """ @@ -103,7 +103,7 @@ # {{{ utilities -def _names_from_expr(expr: Union[None, ExpressionT, str]) -> FrozenSet[str]: +def _names_from_expr(expr: Union[None, Expression, str]) -> FrozenSet[str]: from numbers import Number from loopy.symbolic import DependencyMapper @@ -651,7 +651,7 @@ class TemporaryVariable(ArrayBase): """ storage_shape: Optional[ShapeType] - base_indices: Optional[Tuple[ExpressionT, ...]] + base_indices: Optional[Tuple[Expression, ...]] address_space: Union[AddressSpace, Type[auto]] base_storage: Optional[str] """The name of a storage array that is to be used to actually @@ -698,12 +698,12 @@ def __init__( shape: Union[ShapeType, Type["auto"], None] = auto, address_space: Union[AddressSpace, Type[auto], None] = None, dim_tags: Optional[Sequence[ArrayDimImplementationTag]] = None, - offset: Union[ExpressionT, str, None] = 0, + offset: Union[Expression, str, None] = 0, dim_names: Optional[Tuple[str, ...]] = None, - strides: Optional[Tuple[ExpressionT, ...]] = None, + strides: Optional[Tuple[Expression, ...]] = None, order: str | None = None, - base_indices: Optional[Tuple[ExpressionT, ...]] = None, + base_indices: Optional[Tuple[Expression, ...]] = None, storage_shape: ShapeType | None = None, base_storage: Optional[str] = None, @@ -809,7 +809,7 @@ def copy(self, **kwargs: Any) -> TemporaryVariable: return super().copy(**kwargs) @property - def nbytes(self) -> ExpressionT: + def nbytes(self) -> Expression: if self.storage_shape is not None: shape = self.storage_shape else: @@ -817,7 +817,7 @@ def nbytes(self) -> ExpressionT: raise ValueError("shape is None") if self.shape is auto: raise ValueError("shape is auto") - shape = cast(Tuple[ArithmeticExpressionT], self.shape) + shape = cast(Tuple[ArithmeticExpression], self.shape) if self.dtype is None: raise ValueError("data type is indeterminate") @@ -898,7 +898,7 @@ class SubstitutionRule: name: str arguments: Sequence[str] - expression: ExpressionT + expression: Expression def copy(self, **kwargs: Any) -> SubstitutionRule: return replace(self, **kwargs) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 51d4856da..5902e2579 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -37,7 +37,7 @@ from loopy.diagnostic import LoopyError from loopy.tools import Optional as LoopyOptional from loopy.types import LoopyType -from loopy.typing import ExpressionT, InameStr +from loopy.typing import Expression, InameStr # {{{ instruction tags @@ -250,7 +250,7 @@ class InstructionBase(ImmutableRecord, Taggable): groups: FrozenSet[str] conflicts_with_groups: FrozenSet[str] no_sync_with: FrozenSet[Tuple[str, str]] - predicates: FrozenSet[ExpressionT] + predicates: FrozenSet[Expression] within_inames: FrozenSet[InameStr] within_inames_is_final: bool priority: int @@ -901,8 +901,8 @@ class Assignment(MultiAssignmentBase): .. automethod:: __init__ """ - assignee: ExpressionT - expression: ExpressionT + assignee: Expression + expression: Expression temp_var_type: LoopyOptional atomicity: Tuple[VarAtomicity, ...] @@ -910,8 +910,8 @@ class Assignment(MultiAssignmentBase): set("assignee temp_var_type atomicity".split()) def __init__(self, - assignee: Union[str, ExpressionT], - expression: Union[str, ExpressionT], + assignee: Union[str, Expression], + expression: Union[str, Expression], id: Optional[str] = None, happens_after: Union[ Mapping[str, HappensAfter], FrozenSet[str], str, None] = None, @@ -1271,8 +1271,8 @@ def modify_assignee_for_array_call(assignee): "SubArrayRef as its inputs") -def make_assignment(assignees: tuple[ExpressionT, ...], - expression: ExpressionT, +def make_assignment(assignees: tuple[Expression, ...], + expression: Expression, temp_var_types: ( Sequence[LoopyType | None] | None) = None, **kwargs: Any) -> Assignment | CallInstruction: diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 3293e9a1e..4dc824ead 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -68,7 +68,7 @@ # for the benefit of loopy.statistics, for now from loopy.type_inference import infer_unknown_types -from loopy.typing import ExpressionT +from loopy.typing import Expression # {{{ check for writes to predicates @@ -174,14 +174,14 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel: sep_axis_indices_set = frozenset(sep_axis_indices) assert isinstance(arg.shape, tuple) - new_shape: Optional[Tuple[ExpressionT, ...]] = \ + new_shape: Optional[Tuple[Expression, ...]] = \ _remove_at_indices(sep_axis_indices_set, arg.shape) new_dim_tags: Optional[Tuple[ArrayDimImplementationTag, ...]] = \ _remove_at_indices(sep_axis_indices_set, arg.dim_tags) new_dim_names: Optional[Tuple[Optional[str], ...]] = \ _remove_at_indices(sep_axis_indices_set, arg.dim_names) - sep_shape: List[ExpressionT] = [arg.shape[i] for i in sep_axis_indices] + sep_shape: List[Expression] = [arg.shape[i] for i in sep_axis_indices] for i, sep_shape_i in enumerate(sep_shape): if not isinstance(sep_shape_i, (int, np.integer)): raise LoopyError( diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ad502e1a5..f7d10b9d5 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -47,7 +47,7 @@ import pymbolic.primitives as p import pytools.lex from islpy import dim_type -from pymbolic import ArithmeticExpressionT, Variable +from pymbolic import ArithmeticExpression, Variable from pymbolic.mapper import ( CachedCombineMapper as CombineMapperBase, CachedIdentityMapper as IdentityMapperBase, @@ -81,7 +81,7 @@ UnableToDetermineAccessRangeError, ) from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible -from loopy.typing import ExpressionT, auto +from loopy.typing import Expression, auto if TYPE_CHECKING: @@ -129,7 +129,7 @@ # {{{ mappers with support for loopy-specific primitives -class IdentityMapperMixin(Mapper[ExpressionT, P]): +class IdentityMapperMixin(Mapper[Expression, P]): def map_literal(self, expr: Literal, *args, **kwargs): return expr @@ -206,7 +206,7 @@ def map_resolved_function(self, expr, *args, **kwargs): class FlattenMapper(FlattenMapperBase, IdentityMapperMixin): # FIXME: Lies! This needs to be made precise. - def is_expr_integer_valued(self, expr: ExpressionT) -> bool: + def is_expr_integer_valued(self, expr: Expression) -> bool: return True @@ -505,7 +505,7 @@ def map_substitution(self, name, rule, arguments): # {{{ loopy-specific primitives -class LoopyExpressionBase(p.Expression): +class LoopyExpressionBase(p.ExpressionNode): def stringifier(self): from loopy.diagnostic import LoopyError raise LoopyError("pymbolic < 2019.1 is in use. Please upgrade.") @@ -539,7 +539,7 @@ class ArrayLiteral(LoopyExpressionBase): similar mappers). Not for use in Loopy source representation. """ - children: tuple[ExpressionT, ...] + children: tuple[Expression, ...] @p.expr_dataclass() @@ -602,7 +602,7 @@ class TypeAnnotation(LoopyExpressionBase): """ type: LoopyType - child: ExpressionT + child: Expression @p.expr_dataclass(init=False) @@ -618,10 +618,10 @@ class TypeCast(LoopyExpressionBase): # numpy pickling bug madness. (see loopy.types) _type_name: str - child: ExpressionT + child: Expression """The expression to be cast.""" - def __init__(self, type: ToLoopyTypeConvertible, child: ExpressionT): + def __init__(self, type: ToLoopyTypeConvertible, child: Expression): super().__init__() from loopy.types import NumpyType, to_loopy_type @@ -700,7 +700,7 @@ class Reduction(LoopyExpressionBase): carried out. """ - expr: ExpressionT + expr: Expression """An expression which may have tuple type. If the expression has tuple type, it must be one of the following: @@ -718,7 +718,7 @@ def __init__(self, operation: ReductionOperation | str, inames: (tuple[str | pymbolic.primitives.Variable, ...] | pymbolic.primitives.Variable | str), - expr: ExpressionT, + expr: Expression, allow_simultaneous: bool = False ) -> None: if isinstance(inames, str): @@ -780,8 +780,8 @@ class LinearSubscript(LoopyExpressionBase): """Represents a linear index into a multi-dimensional array, completely ignoring any multi-dimensional layout. """ - aggregate: ExpressionT - index: ExpressionT + aggregate: Expression + index: Expression @p.expr_dataclass() @@ -966,11 +966,11 @@ def _get_dependencies_and_reduction_inames(expr): return deps, reduction_inames -def get_dependencies(expr: ExpressionT | type[auto]) -> AbstractSet[str]: +def get_dependencies(expr: Expression | type[auto]) -> AbstractSet[str]: return _get_dependencies_and_reduction_inames(expr)[0] -def get_reduction_inames(expr: ExpressionT) -> AbstractSet[str]: +def get_reduction_inames(expr: Expression) -> AbstractSet[str]: return _get_dependencies_and_reduction_inames(expr)[1] @@ -1255,9 +1255,9 @@ def map_call(self, expr, expn_state, *args, **kwargs): def make_new_arg_context( rule_name: str, arg_names: Sequence[str], - arguments: Sequence[ExpressionT], - arg_context: Mapping[str, ExpressionT] - ) -> Mapping[str, ExpressionT]: + arguments: Sequence[Expression], + arg_context: Mapping[str, Expression] + ) -> Mapping[str, Expression]: if len(arg_names) != len(arguments): raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" % (rule_name, len(arguments), len(arg_names), )) @@ -1709,7 +1709,7 @@ def map_subscript(self, expr): # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpressionT: +def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: from pymbolic import var denom = aff.get_denominator_val().to_python() @@ -1730,7 +1730,7 @@ def aff_to_expr(aff: isl.Aff) -> ArithmeticExpressionT: return flatten(result // denom) -def pw_aff_to_expr(pw_aff: isl.PwAff, int_ok: bool = False) -> ExpressionT: +def pw_aff_to_expr(pw_aff: isl.PwAff, int_ok: bool = False) -> Expression: if isinstance(pw_aff, int): if not int_ok: warn("expected PwAff, got int", stacklevel=2) @@ -1844,7 +1844,7 @@ def map_call(self, expr): "for as-pwaff evaluation") -def aff_from_expr(space: isl.Space, expr: ExpressionT, vars_to_zero=None) -> isl.Aff: +def aff_from_expr(space: isl.Space, expr: Expression, vars_to_zero=None) -> isl.Aff: if vars_to_zero is None: vars_to_zero = frozenset() diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 9f227bd37..a2961eee9 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -64,7 +64,7 @@ from loopy.tools import remove_common_indentation from loopy.translation_unit import FunctionIdT, TranslationUnit from loopy.types import LoopyType, NumpyType, to_loopy_type -from loopy.typing import ExpressionT, auto +from loopy.typing import Expression, auto __doc__ = """ @@ -880,8 +880,8 @@ def get_function_declaration( def get_kernel_call(self, codegen_state: CodeGenerationState, subkernel_name: str, - gsize: Tuple[ExpressionT, ...], - lsize: Tuple[ExpressionT, ...]) -> Optional[Generable]: + gsize: Tuple[Expression, ...], + lsize: Tuple[Expression, ...]) -> Optional[Generable]: return None def emit_temp_var_decl_for_tv_with_base_storage(self, diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 0c15faa58..e201326a5 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -48,7 +48,7 @@ from loopy.target.c import CExpression from loopy.type_inference import TypeReader from loopy.types import LoopyType -from loopy.typing import ExpressionT, is_integer +from loopy.typing import Expression, is_integer __doc__ = """ @@ -92,7 +92,7 @@ def with_assignments(self, names_to_vars): type_inf_mapper = self.type_inf_mapper.with_assignments(names_to_vars) return type(self)(self.codegen_state, self.fortran_abi, type_inf_mapper) - def infer_type(self, expr: ExpressionT) -> LoopyType: + def infer_type(self, expr: Expression) -> LoopyType: result = self.type_inf_mapper(expr) assert isinstance(result, LoopyType) diff --git a/loopy/target/execution.py b/loopy/target/execution.py index 2443a1420..eaeb76b4a 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -58,7 +58,7 @@ from loopy.tools import LoopyKeyBuilder, caches from loopy.translation_unit import TranslationUnit from loopy.types import LoopyType, NumpyType -from loopy.typing import ExpressionT, integer_expr_or_err +from loopy.typing import Expression, integer_expr_or_err from loopy.version import DATA_MODEL_VERSION @@ -109,7 +109,7 @@ def __call__(self, kernel_kwargs: Dict[str, Any]) -> Dict[str, Any]: # {{{ ExecutionWrapperGeneratorBase -def _str_to_expr(name_or_expr: Union[str, ExpressionT]) -> ExpressionT: +def _str_to_expr(name_or_expr: Union[str, Expression]) -> Expression: if isinstance(name_or_expr, str): return var(name_or_expr) else: @@ -118,8 +118,8 @@ def _str_to_expr(name_or_expr: Union[str, ExpressionT]) -> ExpressionT: @dataclass(frozen=True) class _ArgFindingEquation: - lhs: ExpressionT - rhs: ExpressionT + lhs: Expression + rhs: Expression # Arg finding code is sorted by priority, all equations (across all unknowns) # of lowest priority first. @@ -389,7 +389,7 @@ def handle_non_numpy_arg(self, gen: CodeGenerator, arg): def handle_alloc( self, gen: CodeGenerator, arg: ArrayArg, - strify: Callable[[Union[ExpressionT, Tuple[ExpressionT]]], str], + strify: Callable[[Union[Expression, Tuple[Expression]]], str], skip_arg_checks: bool) -> None: """ Handle allocation of non-specified arguments for C-execution @@ -534,7 +534,7 @@ def strify_allowing_none(shape_axis): else: return strify(shape_axis) - def strify_tuple(t: Optional[Tuple[ExpressionT, ...]]) -> str: + def strify_tuple(t: Optional[Tuple[Expression, ...]]) -> str: if t is None: return "None" if len(t) == 0: diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py index 1cd7a5bd2..4200a4b24 100644 --- a/loopy/target/ispc.py +++ b/loopy/target/ispc.py @@ -43,7 +43,7 @@ from loopy.target.c import CFamilyASTBuilder, CFamilyTarget from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper from loopy.types import LoopyType -from loopy.typing import ExpressionT +from loopy.typing import Expression # {{{ expression mapper @@ -252,8 +252,8 @@ def get_function_declaration( def get_kernel_call(self, codegen_state: CodeGenerationState, subkernel_name: str, - gsize: Tuple[ExpressionT, ...], - lsize: Tuple[ExpressionT, ...]) -> Generable: + gsize: Tuple[Expression, ...], + lsize: Tuple[Expression, ...]) -> Generable: kernel = codegen_state.kernel ecm = self.get_expression_to_code_mapper(codegen_state) diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index e4da6cd8b..fa7fd20e8 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -67,7 +67,7 @@ from loopy.target.python import PythonASTBuilderBase from loopy.translation_unit import FunctionIdT, TranslationUnit from loopy.types import NumpyType -from loopy.typing import ExpressionT +from loopy.typing import Expression logger = logging.getLogger(__name__) @@ -855,7 +855,7 @@ def get_temporary_decls(self, codegen_state, schedule_index): def get_kernel_call( self, codegen_state: CodeGenerationState, subkernel_name: str, - gsize: Tuple[ExpressionT, ...], lsize: Tuple[ExpressionT, ...] + gsize: Tuple[Expression, ...], lsize: Tuple[Expression, ...] ) -> genpy.Suite: from genpy import Assert, Assign, Comment, Line, Suite diff --git a/loopy/target/pyopencl_execution.py b/loopy/target/pyopencl_execution.py index be859ab70..248f5f2eb 100644 --- a/loopy/target/pyopencl_execution.py +++ b/loopy/target/pyopencl_execution.py @@ -37,7 +37,7 @@ from loopy.schedule.tools import KernelArgInfo from loopy.target.execution import ExecutionWrapperGeneratorBase, ExecutorBase from loopy.types import LoopyType -from loopy.typing import ExpressionT, integer_expr_or_err +from loopy.typing import Expression, integer_expr_or_err logger = logging.getLogger(__name__) @@ -109,7 +109,7 @@ def handle_non_numpy_arg(self, gen: CodeGenerator, arg: ArrayArg) -> None: def handle_alloc( self, gen: CodeGenerator, arg: ArrayArg, - strify: Callable[[ExpressionT], str], + strify: Callable[[Expression], str], skip_arg_checks: bool) -> None: """ Handle allocation of non-specified arguments for pyopencl execution diff --git a/loopy/transform/array_buffer_map.py b/loopy/transform/array_buffer_map.py index 81b5c933f..5e8e56234 100644 --- a/loopy/transform/array_buffer_map.py +++ b/loopy/transform/array_buffer_map.py @@ -29,12 +29,12 @@ import islpy as isl from islpy import dim_type -from pymbolic import ArithmeticExpressionT, var +from pymbolic import ArithmeticExpression, var from pymbolic.mapper.substitutor import make_subst_func from pytools import memoize_method from loopy.symbolic import SubstitutionMapper, get_dependencies -from loopy.typing import ExpressionT +from loopy.typing import Expression @dataclass(frozen=True) @@ -47,7 +47,7 @@ class AccessDescriptor: """ identifier: Any = None - storage_axis_exprs: Optional[Sequence[ArithmeticExpressionT]] = None + storage_axis_exprs: Optional[Sequence[ArithmeticExpression]] = None def copy(self, **kwargs) -> Self: return replace(self, **kwargs) @@ -72,10 +72,10 @@ def to_parameters_or_project_out(param_inames, set_inames, set): # {{{ construct storage->sweep map def build_per_access_storage_to_domain_map( - storage_axis_exprs: Sequence[ExpressionT], + storage_axis_exprs: Sequence[Expression], domain: isl.BasicSet, storage_axis_names: Sequence[str], - prime_sweep_inames: Callable[[ExpressionT], ExpressionT] + prime_sweep_inames: Callable[[Expression], Expression] ) -> isl.BasicMap: map_space = domain.space @@ -204,9 +204,9 @@ def compute_bounds(kernel, domain, stor2sweep, class ArrayToBufferMapBase(ABC): non1_storage_axis_names: Tuple[str, ...] - storage_base_indices: Tuple[ArithmeticExpressionT, ...] - non1_storage_shape: Tuple[ArithmeticExpressionT, ...] - non1_storage_axis_flags: Tuple[ArithmeticExpressionT, ...] + storage_base_indices: Tuple[ArithmeticExpression, ...] + non1_storage_shape: Tuple[ArithmeticExpression, ...] + non1_storage_axis_flags: Tuple[ArithmeticExpression, ...] @abstractmethod def is_access_descriptor_in_footprint(self, accdesc: AccessDescriptor) -> bool: diff --git a/loopy/transform/data.py b/loopy/transform/data.py index c63604f8c..739717860 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -36,7 +36,7 @@ from loopy.kernel.function_interface import CallableKernel, ScalarCallable from loopy.translation_unit import TranslationUnit, for_each_kernel from loopy.types import LoopyType -from loopy.typing import ExpressionT +from loopy.typing import Expression # {{{ convenience: add_prefetch @@ -984,11 +984,11 @@ def add_padding_to_avoid_bank_conflicts(kernel, device): @dataclass(frozen=True) class _BaseStorageInfo: name: str - next_offset: ExpressionT + next_offset: Expression approx_nbytes: Optional[int] = None -def _sym_max(a: ExpressionT, b: ExpressionT) -> ExpressionT: +def _sym_max(a: Expression, b: Expression) -> Expression: from numbers import Number if isinstance(a, Number) and isinstance(b, Number): return max(a, b) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index b0fbb5468..0982c43fd 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -28,7 +28,7 @@ from immutables import Map import islpy as isl -from pymbolic import ArithmeticExpressionT, var +from pymbolic import ArithmeticExpression, var from pymbolic.mapper.substitutor import make_subst_func from pytools import memoize_on_first_arg from pytools.tag import Tag @@ -60,7 +60,7 @@ from loopy.translation_unit import CallablesTable, TranslationUnit from loopy.types import LoopyType, ToLoopyTypeConvertible, to_loopy_type from loopy.typing import ( - ExpressionT, + Expression, auto, integer_expr_or_err, integer_or_err, @@ -133,14 +133,14 @@ def contains_a_subst_rule_invocation(kernel, insn): @dataclass(frozen=True) class RuleAccessDescriptor(AccessDescriptor): - args: Optional[Sequence[ArithmeticExpressionT]] = None + args: Optional[Sequence[ArithmeticExpression]] = None def access_descriptor_id(args, expansion_stack): return (args, expansion_stack) -def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]: +def storage_axis_exprs(storage_axis_sources, args) -> Sequence[Expression]: result = [] for saxis_source in storage_axis_sources: @@ -577,9 +577,9 @@ def precompute_for_single_kernel( for fpg in footprint_generators: if isinstance(fpg, Variable): - args: tuple[ArithmeticExpressionT, ...] = () + args: tuple[ArithmeticExpression, ...] = () elif isinstance(fpg, Call): - args = cast(tuple[ArithmeticExpressionT, ...], fpg.parameters) + args = cast(tuple[ArithmeticExpression, ...], fpg.parameters) else: raise ValueError("footprint generator must " "be substitution rule invocation") diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py index 7d1f3c870..e981ad4be 100644 --- a/loopy/transform/realize_reduction.py +++ b/loopy/transform/realize_reduction.py @@ -34,7 +34,7 @@ from immutables import Map import islpy as isl -from pymbolic.primitives import Expression +from pymbolic.primitives import ExpressionNode from pytools import memoize_on_first_arg from pytools.tag import Tag @@ -103,7 +103,7 @@ class _ReductionRealizationContext: surrounding_within_inames: FrozenSet[str] surrounding_depends_on: FrozenSet[str] surrounding_no_sync_with: FrozenSet[Tuple[str, str]] - surrounding_predicates: FrozenSet[Expression] + surrounding_predicates: FrozenSet[ExpressionNode] # }}} diff --git a/loopy/typing.py b/loopy/typing.py index 7cc7209b9..bcd4afc4f 100644 --- a/loopy/typing.py +++ b/loopy/typing.py @@ -36,13 +36,13 @@ import numpy as np from typing_extensions import TypeAlias, TypeIs -from pymbolic.primitives import Expression -from pymbolic.typing import ArithmeticExpressionT, ExpressionT, IntegerT +from pymbolic.primitives import ExpressionNode +from pymbolic.typing import ArithmeticExpression, Expression, Integer # The Fortran parser may insert dimensions of 'None', but I'd like to phase # that out, so we're not encoding that in the type. -ShapeType: TypeAlias = Tuple[ArithmeticExpressionT, ...] +ShapeType: TypeAlias = Tuple[ArithmeticExpression, ...] StridesType: TypeAlias = ShapeType InameStr: TypeAlias = str @@ -67,15 +67,15 @@ def is_integer(obj: object) -> TypeIs[int | np.integer]: return isinstance(obj, (int, np.integer)) -def integer_or_err(expr: ExpressionT) -> IntegerT: +def integer_or_err(expr: Expression) -> Integer: if isinstance(expr, (int, np.integer)): return expr else: raise ValueError(f"expected integer, got {type(expr)}") -def integer_expr_or_err(expr: ExpressionT) -> IntegerT | Expression: - if isinstance(expr, (int, np.integer, Expression)): +def integer_expr_or_err(expr: Expression) -> Integer | ExpressionNode: + if isinstance(expr, (int, np.integer, ExpressionNode)): return expr else: raise ValueError(f"expected integer or expression, got {type(expr)}") diff --git a/pyproject.toml b/pyproject.toml index 3204163f0..57b6ba44e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "pytools>=2024.1.5", - "pymbolic>=2024.1", + "pymbolic>=2024.2", "genpy>=2016.1.2", # https://github.com/inducer/loopy/pull/419