Skip to content

Commit

Permalink
Track renames in pymbolic 2024.2
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 21, 2024
1 parent c80684d commit c840231
Show file tree
Hide file tree
Showing 21 changed files with 115 additions and 115 deletions.
4 changes: 2 additions & 2 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
check_each_kernel,
)
from loopy.type_inference import TypeReader
from loopy.typing import ExpressionT, not_none
from loopy.typing import Expression, not_none


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -221,7 +221,7 @@ def check_offsets_and_dim_tags(kernel: LoopKernel) -> None:
dep_mapper = DependencyMapper()

def ensure_depends_only_on_arguments(
what: str, expr: Union[str, ExpressionT]) -> None:
what: str, expr: Union[str, Expression]) -> None:
if isinstance(expr, str):
expr = Variable(expr)

Expand Down
8 changes: 4 additions & 4 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from loopy.target import TargetBase
from loopy.tools import LoopyKeyBuilder, caches
from loopy.types import LoopyType
from loopy.typing import ExpressionT
from loopy.typing import Expression
from loopy.version import DATA_MODEL_VERSION


Expand Down Expand Up @@ -200,14 +200,14 @@ class CodeGenerationState:
kernel: LoopKernel
target: TargetBase
implemented_domain: isl.Set
implemented_predicates: FrozenSet[Union[str, ExpressionT]]
implemented_predicates: FrozenSet[Union[str, Expression]]

# /!\ mutable
seen_dtypes: Set[LoopyType]
seen_functions: Set[SeenFunction]
seen_atomic_dtypes: Set[LoopyType]

var_subst_map: Map[str, ExpressionT]
var_subst_map: Map[str, Expression]
allow_complex: bool
callables_table: CallablesTable
is_entrypoint: bool
Expand All @@ -231,7 +231,7 @@ def copy(self, **kwargs: Any) -> "CodeGenerationState":
return replace(self, **kwargs)

def copy_and_assign(
self, name: str, value: ExpressionT) -> "CodeGenerationState":
self, name: str, value: Expression) -> "CodeGenerationState":
"""Make a copy of self with variable *name* fixed to *value*."""
return self.copy(var_subst_map=self.var_subst_map.set(name, value))

Expand Down
6 changes: 3 additions & 3 deletions loopy/isl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def make_slab(space, iname, start, stop, iname_multiplier=1):

space = zero.get_domain_space()

from pymbolic.primitives import Expression
from pymbolic.primitives import ExpressionNode

from loopy.symbolic import aff_from_expr
if isinstance(start, Expression):
if isinstance(start, ExpressionNode):
start = aff_from_expr(space, start)
if isinstance(stop, Expression):
if isinstance(stop, ExpressionNode):
stop = aff_from_expr(space, stop)

if isinstance(start, int):
Expand Down
10 changes: 5 additions & 5 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic import ArithmeticExpressionT
from pymbolic import ArithmeticExpression
from pytools import (
UniqueNameGenerator,
generate_unique_names,
Expand All @@ -75,7 +75,7 @@
from loopy.target import TargetBase
from loopy.tools import update_persistent_hash
from loopy.types import LoopyType, NumpyType
from loopy.typing import ExpressionT, InameStr
from loopy.typing import Expression, InameStr


if TYPE_CHECKING:
Expand Down Expand Up @@ -193,7 +193,7 @@ class LoopKernel(Taggable):
with non-parallel implementation tags.
"""

applied_iname_rewrites: Tuple[Dict[InameStr, ExpressionT], ...] = ()
applied_iname_rewrites: Tuple[Dict[InameStr, Expression], ...] = ()
"""
A list of past substitution dictionaries that
were applied to the kernel. These are stored so that they may be repeated
Expand Down Expand Up @@ -1036,8 +1036,8 @@ def get_grid_size_upper_bounds_as_exprs(
self, callables_table,
ignore_auto=False, return_dict=False
) -> Tuple[
Tuple[ArithmeticExpressionT, ...],
Tuple[ArithmeticExpressionT, ...]]:
Tuple[ArithmeticExpression, ...],
Tuple[ArithmeticExpression, ...]]:
"""Return a tuple (global_size, local_size) containing a grid that
could accommodate execution of *all* instructions in the kernel.
Expand Down
30 changes: 15 additions & 15 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@
import numpy as np # noqa
from typing_extensions import TypeAlias

from pymbolic import ArithmeticExpressionT
from pymbolic import ArithmeticExpression
from pymbolic.primitives import is_arithmetic_expression
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable

from loopy.diagnostic import LoopyError
from loopy.symbolic import flatten
from loopy.types import LoopyType
from loopy.typing import ExpressionT, ShapeType, auto, is_integer
from loopy.typing import Expression, ShapeType, auto, is_integer


if TYPE_CHECKING:
Expand Down Expand Up @@ -609,8 +609,8 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,

# {{{ array base class (for arguments and temporary arrays)

ToShapeLikeConvertible: TypeAlias = (Tuple[ExpressionT | str, ...]
| ExpressionT | type[auto] | str | tuple[str, ...])
ToShapeLikeConvertible: TypeAlias = (Tuple[Expression | str, ...]
| Expression | type[auto] | str | tuple[str, ...])


def _parse_shape_or_strides(
Expand All @@ -634,12 +634,12 @@ def _parse_shape_or_strides(
raise ValueError("shape can't be a list")

if isinstance(x_parsed, tuple):
x_tup: tuple[ExpressionT | str, ...] = x_parsed
x_tup: tuple[Expression | str, ...] = x_parsed
else:
assert x_parsed is not auto
x_tup = (cast(ExpressionT, x_parsed),)
x_tup = (cast(Expression, x_parsed),)

def parse_arith(x: ExpressionT | str) -> ArithmeticExpressionT:
def parse_arith(x: Expression | str) -> ArithmeticExpression:
if isinstance(x, str):
res = parse(x)
else:
Expand Down Expand Up @@ -714,7 +714,7 @@ class ArrayBase(ImmutableRecord, Taggable):
"""See :ref:`data-dim-tags`.
"""

offset: Union[ExpressionT, str, None]
offset: Union[Expression, str, None]
"""Offset from the beginning of the buffer to the point from
which the strides are counted, in units of the :attr:`dtype`.
May be one of
Expand Down Expand Up @@ -1158,9 +1158,9 @@ def drop_vec_dims(
if not isinstance(dim_tag, VectorArrayDimTag))


def get_strides(array: ArrayBase) -> Tuple[ExpressionT, ...]:
def get_strides(array: ArrayBase) -> Tuple[Expression, ...]:
from pymbolic import var
result: List[ExpressionT] = []
result: List[Expression] = []

if array.dim_tags is None:
return ()
Expand Down Expand Up @@ -1188,10 +1188,10 @@ def get_strides(array: ArrayBase) -> Tuple[ExpressionT, ...]:
class AccessInfo(ImmutableRecord):
array_name: str
vector_index: Optional[int]
subscripts: Tuple[ExpressionT, ...]
subscripts: Tuple[Expression, ...]


def _apply_offset(sub: ExpressionT, ary: ArrayBase) -> ExpressionT:
def _apply_offset(sub: Expression, ary: ArrayBase) -> Expression:
"""
Helper for :func:`get_access_info`.
Augments *ary*'s subscript index expression (*sub*) with its offset info.
Expand Down Expand Up @@ -1228,8 +1228,8 @@ def _apply_offset(sub: ExpressionT, ary: ArrayBase) -> ExpressionT:

def get_access_info(kernel: "LoopKernel",
ary: Union["ArrayArg", "TemporaryVariable"],
index: Union[ExpressionT, Tuple[ExpressionT, ...]],
eval_expr: Callable[[ExpressionT], int],
index: Union[Expression, Tuple[Expression, ...]],
eval_expr: Callable[[Expression], int],
vectorization_info: "VectorizationInfo") -> AccessInfo:
"""
:arg ary: an object of type :class:`ArrayBase`
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def eval_expr_assert_integer_constant(i, expr) -> int:
num_target_axes = ary.num_target_axes()

vector_index = None
subscripts: List[ExpressionT] = [0] * num_target_axes
subscripts: List[Expression] = [0] * num_target_axes

vector_size = ary.vector_size(kernel.target)

Expand Down
20 changes: 10 additions & 10 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpressionT, Variable
from pymbolic import ArithmeticExpression, Variable
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase

Expand All @@ -65,7 +65,7 @@
make_assignment,
)
from loopy.types import LoopyType, ToLoopyTypeConvertible
from loopy.typing import ExpressionT, ShapeType, auto
from loopy.typing import Expression, ShapeType, auto


__doc__ = """
Expand Down Expand Up @@ -103,7 +103,7 @@

# {{{ utilities

def _names_from_expr(expr: Union[None, ExpressionT, str]) -> FrozenSet[str]:
def _names_from_expr(expr: Union[None, Expression, str]) -> FrozenSet[str]:
from numbers import Number

from loopy.symbolic import DependencyMapper
Expand Down Expand Up @@ -651,7 +651,7 @@ class TemporaryVariable(ArrayBase):
"""

storage_shape: Optional[ShapeType]
base_indices: Optional[Tuple[ExpressionT, ...]]
base_indices: Optional[Tuple[Expression, ...]]
address_space: Union[AddressSpace, Type[auto]]
base_storage: Optional[str]
"""The name of a storage array that is to be used to actually
Expand Down Expand Up @@ -698,12 +698,12 @@ def __init__(
shape: Union[ShapeType, Type["auto"], None] = auto,
address_space: Union[AddressSpace, Type[auto], None] = None,
dim_tags: Optional[Sequence[ArrayDimImplementationTag]] = None,
offset: Union[ExpressionT, str, None] = 0,
offset: Union[Expression, str, None] = 0,
dim_names: Optional[Tuple[str, ...]] = None,
strides: Optional[Tuple[ExpressionT, ...]] = None,
strides: Optional[Tuple[Expression, ...]] = None,
order: str | None = None,

base_indices: Optional[Tuple[ExpressionT, ...]] = None,
base_indices: Optional[Tuple[Expression, ...]] = None,
storage_shape: ShapeType | None = None,

base_storage: Optional[str] = None,
Expand Down Expand Up @@ -809,15 +809,15 @@ def copy(self, **kwargs: Any) -> TemporaryVariable:
return super().copy(**kwargs)

@property
def nbytes(self) -> ExpressionT:
def nbytes(self) -> Expression:
if self.storage_shape is not None:
shape = self.storage_shape
else:
if self.shape is None:
raise ValueError("shape is None")
if self.shape is auto:
raise ValueError("shape is auto")
shape = cast(Tuple[ArithmeticExpressionT], self.shape)
shape = cast(Tuple[ArithmeticExpression], self.shape)

if self.dtype is None:
raise ValueError("data type is indeterminate")
Expand Down Expand Up @@ -898,7 +898,7 @@ class SubstitutionRule:

name: str
arguments: Sequence[str]
expression: ExpressionT
expression: Expression

def copy(self, **kwargs: Any) -> SubstitutionRule:
return replace(self, **kwargs)
Expand Down
16 changes: 8 additions & 8 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from loopy.diagnostic import LoopyError
from loopy.tools import Optional as LoopyOptional
from loopy.types import LoopyType
from loopy.typing import ExpressionT, InameStr
from loopy.typing import Expression, InameStr


# {{{ instruction tags
Expand Down Expand Up @@ -250,7 +250,7 @@ class InstructionBase(ImmutableRecord, Taggable):
groups: FrozenSet[str]
conflicts_with_groups: FrozenSet[str]
no_sync_with: FrozenSet[Tuple[str, str]]
predicates: FrozenSet[ExpressionT]
predicates: FrozenSet[Expression]
within_inames: FrozenSet[InameStr]
within_inames_is_final: bool
priority: int
Expand Down Expand Up @@ -901,17 +901,17 @@ class Assignment(MultiAssignmentBase):
.. automethod:: __init__
"""

assignee: ExpressionT
expression: ExpressionT
assignee: Expression
expression: Expression
temp_var_type: LoopyOptional
atomicity: Tuple[VarAtomicity, ...]

fields = MultiAssignmentBase.fields | \
set("assignee temp_var_type atomicity".split())

def __init__(self,
assignee: Union[str, ExpressionT],
expression: Union[str, ExpressionT],
assignee: Union[str, Expression],
expression: Union[str, Expression],
id: Optional[str] = None,
happens_after: Union[
Mapping[str, HappensAfter], FrozenSet[str], str, None] = None,
Expand Down Expand Up @@ -1271,8 +1271,8 @@ def modify_assignee_for_array_call(assignee):
"SubArrayRef as its inputs")


def make_assignment(assignees: tuple[ExpressionT, ...],
expression: ExpressionT,
def make_assignment(assignees: tuple[Expression, ...],
expression: Expression,
temp_var_types: (
Sequence[LoopyType | None] | None) = None,
**kwargs: Any) -> Assignment | CallInstruction:
Expand Down
6 changes: 3 additions & 3 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

# for the benefit of loopy.statistics, for now
from loopy.type_inference import infer_unknown_types
from loopy.typing import ExpressionT
from loopy.typing import Expression


# {{{ check for writes to predicates
Expand Down Expand Up @@ -174,14 +174,14 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel:
sep_axis_indices_set = frozenset(sep_axis_indices)

assert isinstance(arg.shape, tuple)
new_shape: Optional[Tuple[ExpressionT, ...]] = \
new_shape: Optional[Tuple[Expression, ...]] = \
_remove_at_indices(sep_axis_indices_set, arg.shape)
new_dim_tags: Optional[Tuple[ArrayDimImplementationTag, ...]] = \
_remove_at_indices(sep_axis_indices_set, arg.dim_tags)
new_dim_names: Optional[Tuple[Optional[str], ...]] = \
_remove_at_indices(sep_axis_indices_set, arg.dim_names)

sep_shape: List[ExpressionT] = [arg.shape[i] for i in sep_axis_indices]
sep_shape: List[Expression] = [arg.shape[i] for i in sep_axis_indices]
for i, sep_shape_i in enumerate(sep_shape):
if not isinstance(sep_shape_i, (int, np.integer)):
raise LoopyError(
Expand Down
Loading

0 comments on commit c840231

Please sign in to comment.