From a81889b4c7ab1f0681bda6da726f12f15a2cabcd Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 12 Nov 2024 12:50:19 -0600 Subject: [PATCH 1/7] Add a new failing test case for tag_inames. --- test/test_transform.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_transform.py b/test/test_transform.py index 98398fefd..11ec37159 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -377,6 +377,35 @@ def test_set_arg_order(): knl = lp.set_argument_order(knl, "out,a,n,b") +def test_tag_inames_keeps_all_tags_if_able(): + t_unit = lp.make_kernel( + "{ [i,j]: 0<=i,j Date: Tue, 12 Nov 2024 15:38:33 -0600 Subject: [PATCH 2/7] Fix tag_inames to apply multiple tags, type it --- loopy/transform/iname.py | 91 +++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 97257745c..1f318313c 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -21,10 +21,14 @@ """ +from collections.abc import Iterable, Mapping, Sequence from typing import Any, FrozenSet, Optional +from typing_extensions import TypeAlias + import islpy as isl from islpy import dim_type +from pytools.tag import Tag from loopy.diagnostic import LoopyError from loopy.kernel import LoopKernel @@ -675,9 +679,18 @@ def untag_inames(kernel, iname_to_untag, tag_type): # {{{ tag inames +_Tags_ish: TypeAlias = Tag | Sequence[Tag] | str | Sequence[str] + + @for_each_kernel -def tag_inames(kernel, iname_to_tag, force=False, - ignore_nonexistent=False): +def tag_inames( + kernel: LoopKernel, + iname_to_tag: (Mapping[str, _Tags_ish] + | Sequence[tuple[str, _Tags_ish]] + | str), + force: bool = False, + ignore_nonexistent: bool = False + ) -> LoopKernel: """Tag an iname :arg iname_to_tag: a list of tuples ``(iname, new_tag)``. *new_tag* is given @@ -697,74 +710,67 @@ def tag_inames(kernel, iname_to_tag, force=False, """ if isinstance(iname_to_tag, str): - def parse_kv(s): + def parse_kv(s: str) -> tuple[str, str]: colon_index = s.find(":") if colon_index == -1: raise ValueError("tag decl '%s' has no colon" % s) return (s[:colon_index].strip(), s[colon_index+1:].strip()) - iname_to_tag = [ + iname_to_tags_seq: Sequence[tuple[str, _Tags_ish]] = [ parse_kv(s) for s in iname_to_tag.split(",") if s.strip()] + elif isinstance(iname_to_tag, Mapping): + iname_to_tags_seq = list(iname_to_tag.items()) + else: + iname_to_tags_seq = iname_to_tag if not iname_to_tag: return kernel - # convert dict to list of tuples - if isinstance(iname_to_tag, dict): - iname_to_tag = list(iname_to_tag.items()) - # flatten iterables of tags for each iname - try: - from collections.abc import Iterable - except ImportError: - from collections import Iterable # pylint:disable=no-name-in-module - - unpack_iname_to_tag = [] - for iname, tags in iname_to_tag: + unpack_iname_to_tag: list[tuple[str, Tag | str]] = [] + for iname, tags in iname_to_tags_seq: if isinstance(tags, Iterable) and not isinstance(tags, str): for tag in tags: unpack_iname_to_tag.append((iname, tag)) else: unpack_iname_to_tag.append((iname, tags)) - iname_to_tag = unpack_iname_to_tag from loopy.kernel.data import parse_tag as inner_parse_tag - def parse_tag(tag): + def parse_tag(tag: Tag | str) -> Iterable[Tag]: if isinstance(tag, str): if tag.startswith("like."): - tags = kernel.iname_tags(tag[5:]) - if len(tags) == 0: - return None - if len(tags) == 1: - return tags[0] - else: - raise LoopyError("cannot use like for multiple tags (for now)") + return kernel.iname_tags(tag[5:]) elif tag == "unused.g": return find_unused_axis_tag(kernel, "g") elif tag == "unused.l": return find_unused_axis_tag(kernel, "l") - return inner_parse_tag(tag) - - iname_to_tag = [(iname, parse_tag(tag)) for iname, tag in iname_to_tag] + result = inner_parse_tag(tag) + if result is None: + return [] + else: + return [result] - # {{{ globbing + iname_to_parsed_tag = [ + (iname, subtag) + for iname, tag in unpack_iname_to_tag + for subtag in parse_tag(tag) + ] + knl_inames = dict(kernel.inames) all_inames = kernel.all_inames() from loopy.match import re_from_glob - new_iname_to_tag = {} - for iname, new_tag in iname_to_tag: + + for iname, new_tag in iname_to_parsed_tag: if "*" in iname or "?" in iname: match_re = re_from_glob(iname) - for sub_iname in all_inames: - if match_re.match(sub_iname): - new_iname_to_tag[sub_iname] = new_tag - + inames = [sub_iname for sub_iname in all_inames + if match_re.match(sub_iname)] else: if iname not in all_inames: if ignore_nonexistent: @@ -772,22 +778,13 @@ def parse_tag(tag): else: raise LoopyError("iname '%s' does not exist" % iname) - new_iname_to_tag[iname] = new_tag - - iname_to_tag = new_iname_to_tag - del new_iname_to_tag + inames = [iname] - # }}} - - knl_inames = kernel.inames.copy() - for name, new_tag in iname_to_tag.items(): - if not new_tag: + if new_tag is None: continue - if name not in kernel.all_inames(): - raise ValueError("cannot tag '%s'--not known" % name) - - knl_inames[name] = knl_inames[name].tagged(new_tag) + for sub_iname in inames: + knl_inames[sub_iname] = knl_inames[sub_iname].tagged(new_tag) return kernel.copy(inames=knl_inames) From 038793df683442a5d0a36a7969c24ff826fafe51 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 13 Nov 2024 11:07:23 -0600 Subject: [PATCH 3/7] Fix type errors from more precise types in DependencyMapper --- loopy/kernel/data.py | 4 ++-- loopy/target/execution.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 31c06fdb4..5d1de0e5d 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -45,7 +45,7 @@ import numpy as np from immutables import Map -from pymbolic import ArithmeticExpressionT +from pymbolic import ArithmeticExpressionT, Variable from pytools import ImmutableRecord from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase @@ -113,7 +113,7 @@ def _names_from_expr(expr: Union[None, ExpressionT, str]) -> FrozenSet[str]: if isinstance(expr, str): return frozenset({expr}) elif isinstance(expr, Expression): - return frozenset(v.name for v in dep_mapper(expr)) + return frozenset(cast(Variable, v).name for v in dep_mapper(expr)) elif expr is None: return frozenset() elif isinstance(expr, Number): diff --git a/loopy/target/execution.py b/loopy/target/execution.py index 1b62be8c3..2443a1420 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -36,11 +36,12 @@ Set, Tuple, Union, + cast, ) from immutables import Map -from pymbolic import var +from pymbolic import Variable, var from pytools.codegen import CodeGenerator, Indentation from pytools.py_codegen import PythonFunctionGenerator @@ -260,7 +261,7 @@ def generate_integer_arg_finding_from_array_data( unknown_var, = deps order_to_unknown_to_equations \ .setdefault(eqn.order, {}) \ - .setdefault(unknown_var.name, []) \ + .setdefault(cast(Variable, unknown_var).name, []) \ .append((eqn)) else: # Zero deps: nothing to determine, forget about it. @@ -274,7 +275,6 @@ def generate_integer_arg_finding_from_array_data( # {{{ generate arg finding code from pymbolic.algorithm import solve_affine_equations_for - from pymbolic.primitives import Variable from pytools.codegen import CodeGenerator gen("# {{{ find integer arguments from array data") From d52c2909ee7fce7a6088685f155013c0eca550ce Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 13 Nov 2024 11:16:03 -0600 Subject: [PATCH 4/7] Type re_from_glob --- loopy/match.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/match.py b/loopy/match.py index 5e409791b..ae52e6c65 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -24,6 +24,7 @@ THE SOFTWARE. """ +import re from abc import ABC, abstractmethod from dataclasses import dataclass from sys import intern @@ -66,8 +67,7 @@ """ -def re_from_glob(s): - import re +def re_from_glob(s: str) -> re.Pattern: from fnmatch import translate return re.compile("^"+translate(s.strip())+"$") From 920cb49887ef815eac2debf7aa2a4bc128f6e6a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 13 Nov 2024 11:16:08 -0600 Subject: [PATCH 5/7] Type rename_inames --- loopy/transform/iname.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 1f318313c..795154099 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -21,7 +21,7 @@ """ -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Collection, Iterable, Mapping, Sequence from typing import Any, FrozenSet, Optional from typing_extensions import TypeAlias @@ -34,6 +34,7 @@ from loopy.kernel import LoopKernel from loopy.kernel.function_interface import CallableKernel from loopy.kernel.instruction import InstructionBase +from loopy.match import ToStackMatchCovertible from loopy.symbolic import ( RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, @@ -2369,8 +2370,14 @@ def add_inames_for_unused_hw_axes(kernel, within=None): @for_each_kernel @remove_any_newly_unused_inames -def rename_inames(kernel, old_inames, new_iname, existing_ok=False, - within=None, raise_on_domain_mismatch: Optional[bool] = None): +def rename_inames( + kernel: LoopKernel, + old_inames: Collection[str], + new_iname: str, + existing_ok: bool = False, + within: ToStackMatchCovertible = None, + raise_on_domain_mismatch: Optional[bool] = None + ) -> LoopKernel: r""" :arg old_inames: A collection of inames that must be renamed to **new_iname**. :arg within: a stack match as understood by @@ -2380,7 +2387,6 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False, :math:`\exists (i_1,i_2) \in \{\text{old\_inames}\}^2 | \mathcal{D}_{i_1} \neq \mathcal{D}_{i_2}`. """ - from collections.abc import Collection if (isinstance(old_inames, str) or not isinstance(old_inames, Collection)): raise LoopyError("'old_inames' must be a collection of strings, " @@ -2508,9 +2514,15 @@ def does_insn_involve_iname(kernel, insn, *args): @for_each_kernel -def rename_iname(kernel, old_iname, new_iname, existing_ok=False, - within=None, preserve_tags=True, - raise_on_domain_mismatch: Optional[bool] = None): +def rename_iname( + kernel: LoopKernel, + old_iname: str, + new_iname: str, + existing_ok: bool = False, + within: ToStackMatchCovertible = None, + preserve_tags: bool = True, + raise_on_domain_mismatch: Optional[bool] = None + ) -> LoopKernel: r""" Single iname version of :func:`loopy.rename_inames`. :arg existing_ok: execute even if *new_iname* already exists. @@ -2528,7 +2540,7 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False, kernel = rename_inames(kernel, [old_iname], new_iname, existing_ok, within, raise_on_domain_mismatch) if preserve_tags: - kernel = tag_inames(kernel, product([new_iname], tags)) + kernel = tag_inames(kernel, list(product([new_iname], tags))) return kernel # }}} From 57f6654662dcf0188ae4d9976e166d59addb019a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 13 Nov 2024 14:02:54 -0600 Subject: [PATCH 6/7] schedule: update types --- loopy/schedule/__init__.py | 85 ++++++++++++++++------------------- loopy/schedule/tree.py | 23 +++++----- loopy/transform/precompute.py | 5 ++- 3 files changed, 53 insertions(+), 60 deletions(-) diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 1364be850..a9121de8c 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -27,20 +27,11 @@ import logging import sys +from collections.abc import Hashable, Iterator, Mapping, Sequence, Set from dataclasses import dataclass, replace from typing import ( TYPE_CHECKING, - AbstractSet, Any, - Dict, - FrozenSet, - Hashable, - Iterator, - Mapping, - Optional, - Sequence, - Set, - Tuple, TypeVar, ) @@ -155,7 +146,7 @@ class Barrier(ScheduleItem): def gather_schedule_block( schedule: Sequence[ScheduleItem], start_idx: int - ) -> Tuple[Sequence[ScheduleItem], int]: + ) -> tuple[Sequence[ScheduleItem], int]: assert isinstance(schedule[start_idx], BeginBlockItem) level = 0 @@ -176,7 +167,7 @@ def gather_schedule_block( def generate_sub_sched_items( schedule: Sequence[ScheduleItem], start_idx: int - ) -> Iterator[Tuple[int, ScheduleItem]]: + ) -> Iterator[tuple[int, ScheduleItem]]: if not isinstance(schedule[start_idx], BeginBlockItem): yield start_idx, schedule[start_idx] @@ -203,7 +194,7 @@ def generate_sub_sched_items( def get_insn_ids_for_block_at( schedule: Sequence[ScheduleItem], start_idx: int - ) -> FrozenSet[str]: + ) -> frozenset[str]: return frozenset( sub_sched_item.insn_id for i, sub_sched_item in generate_sub_sched_items( @@ -212,7 +203,7 @@ def get_insn_ids_for_block_at( def find_used_inames_within( - kernel: LoopKernel, sched_index: int) -> AbstractSet[str]: + kernel: LoopKernel, sched_index: int) -> set[str]: assert kernel.linearization is not None sched_item = kernel.linearization[sched_index] @@ -234,7 +225,7 @@ def find_used_inames_within( return result -def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]: +def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other inames that are always nested with them. """ @@ -257,11 +248,11 @@ def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str] return result -def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]: +def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other inames that are always nested around them. """ - result: Dict[str, Set[str]] = {} + result: dict[str, set[str]] = {} all_inames = kernel.all_inames() @@ -299,14 +290,14 @@ def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[st def find_loop_insn_dep_map( kernel: LoopKernel, - loop_nest_with_map: Mapping[str, AbstractSet[str]], - loop_nest_around_map: Mapping[str, AbstractSet[str]] - ) -> Mapping[str, AbstractSet[str]]: + loop_nest_with_map: Mapping[str, Set[str]], + loop_nest_around_map: Mapping[str, Set[str]] + ) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other instruction ids that need to be scheduled before the iname should be eligible for scheduling. """ - result: Dict[str, Set[str]] = {} + result: dict[str, set[str]] = {} from loopy.kernel.data import ConcurrentTag, IlpBaseTag for insn in kernel.instructions: @@ -372,7 +363,7 @@ def find_loop_insn_dep_map( def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]: - result: Dict[str, int] = {} + result: dict[str, int] = {} for insn in kernel.instructions: for grp in insn.groups: @@ -382,7 +373,7 @@ def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]: def gen_dependencies_except( - kernel: LoopKernel, insn_id: str, except_insn_ids: AbstractSet[str] + kernel: LoopKernel, insn_id: str, except_insn_ids: Set[str] ) -> Iterator[str]: insn = kernel.id_to_insn[insn_id] for dep_id in insn.depends_on: @@ -396,9 +387,9 @@ def gen_dependencies_except( def get_priority_tiers( - wanted: AbstractSet[int], - priorities: AbstractSet[Sequence[int]] - ) -> Iterator[AbstractSet[int]]: + wanted: Set[int], + priorities: Set[Sequence[int]] + ) -> Iterator[set[int]]: # Get highest priority tier candidates: These are the first inames # of all the given priority constraints candidates = set() @@ -677,24 +668,24 @@ class SchedulerState: order with instruction priorities as tie breaker. """ kernel: LoopKernel - loop_nest_around_map: Mapping[str, AbstractSet[str]] - loop_insn_dep_map: Mapping[str, AbstractSet[str]] + loop_nest_around_map: Mapping[str, set[str]] + loop_insn_dep_map: Mapping[str, set[str]] - breakable_inames: AbstractSet[str] - ilp_inames: AbstractSet[str] - vec_inames: AbstractSet[str] - concurrent_inames: AbstractSet[str] + breakable_inames: set[str] + ilp_inames: set[str] + vec_inames: set[str] + concurrent_inames: set[str] - insn_ids_to_try: Optional[AbstractSet[str]] + insn_ids_to_try: set[str] | None active_inames: Sequence[str] - entered_inames: FrozenSet[str] - enclosing_subkernel_inames: Tuple[str, ...] + entered_inames: frozenset[str] + enclosing_subkernel_inames: tuple[str, ...] schedule: Sequence[ScheduleItem] - scheduled_insn_ids: AbstractSet[str] - unscheduled_insn_ids: AbstractSet[str] + scheduled_insn_ids: frozenset[str] + unscheduled_insn_ids: set[str] preschedule: Sequence[ScheduleItem] - prescheduled_insn_ids: AbstractSet[str] - prescheduled_inames: AbstractSet[str] + prescheduled_insn_ids: set[str] + prescheduled_inames: set[str] may_schedule_global_barriers: bool within_subkernel: bool group_insn_counts: Mapping[str, int] @@ -702,7 +693,7 @@ class SchedulerState: insns_in_topologically_sorted_order: Sequence[InstructionBase] @property - def last_entered_loop(self) -> Optional[str]: + def last_entered_loop(self) -> str | None: if self.active_inames: return self.active_inames[-1] else: @@ -718,7 +709,7 @@ def get_insns_in_topologically_sorted_order( kernel: LoopKernel) -> Sequence[InstructionBase]: from pytools.graph import compute_topological_order - rev_dep_map: Dict[str, Set[str]] = { + rev_dep_map: dict[str, set[str]] = { not_none(insn.id): set() for insn in kernel.instructions} for insn in kernel.instructions: for dep in insn.depends_on: @@ -733,7 +724,7 @@ def get_insns_in_topologically_sorted_order( # Instead of returning these features as a key, we assign an id to # each set of features to avoid comparing them which can be expensive. insn_id_to_feature_id = {} - insn_features: Dict[Hashable, int] = {} + insn_features: dict[Hashable, int] = {} for insn in kernel.instructions: feature = (insn.within_inames, insn.groups, insn.conflicts_with_groups) if feature not in insn_features: @@ -890,7 +881,7 @@ def _get_outermost_diverging_inames( tree: LoopTree, within1: InameStrSet, within2: InameStrSet - ) -> Tuple[InameStr, InameStr]: + ) -> tuple[InameStr, InameStr]: """ For loop nestings *within1* and *within2*, returns the first inames at which the loops nests diverge in the loop nesting tree *tree*. @@ -2180,7 +2171,7 @@ def __init__(self, kernel): def generate_loop_schedules( kernel: LoopKernel, callables_table: CallablesTable, - debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]: + debug_args: Mapping[str, Any] | None = None) -> Iterator[LoopKernel]: """ .. warning:: @@ -2236,7 +2227,7 @@ def _postprocess_schedule(kernel, callables_table, gen_sched): def _generate_loop_schedules_inner( kernel: LoopKernel, callables_table: CallablesTable, - debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]: + debug_args: Mapping[str, Any] | None) -> Iterator[LoopKernel]: if debug_args is None: debug_args = {} @@ -2337,7 +2328,7 @@ def _generate_loop_schedules_inner( get_insns_in_topologically_sorted_order(kernel)), ) - schedule_gen_kwargs: Dict[str, Any] = {} + schedule_gen_kwargs: dict[str, Any] = {} def print_longest_dead_end(): if debug.interactive: @@ -2402,7 +2393,7 @@ def print_longest_dead_end(): schedule_cache: WriteOncePersistentDict[ - Tuple[LoopKernel, CallablesTable], + tuple[LoopKernel, CallablesTable], LoopKernel ] = WriteOncePersistentDict( "loopy-schedule-cache-v4-"+DATA_MODEL_VERSION, diff --git a/loopy/schedule/tree.py b/loopy/schedule/tree.py index 253ff5f84..e98724f83 100644 --- a/loopy/schedule/tree.py +++ b/loopy/schedule/tree.py @@ -34,9 +34,10 @@ THE SOFTWARE. """ +from collections.abc import Hashable, Iterator, Sequence from dataclasses import dataclass from functools import cached_property -from typing import Generic, Hashable, Iterator, List, Optional, Sequence, Tuple, TypeVar +from typing import Generic, TypeVar from immutables import Map @@ -70,11 +71,11 @@ class Tree(Generic[NodeT]): this allocates a new stack frame for each iteration of the operation. """ - _parent_to_children: Map[NodeT, Tuple[NodeT, ...]] - _child_to_parent: Map[NodeT, Optional[NodeT]] + _parent_to_children: Map[NodeT, tuple[NodeT, ...]] + _child_to_parent: Map[NodeT, NodeT | None] @staticmethod - def from_root(root: NodeT) -> "Tree[NodeT]": + def from_root(root: NodeT) -> Tree[NodeT]: return Tree(Map({root: ()}), Map({root: None})) @@ -89,7 +90,7 @@ def root(self) -> NodeT: return guess @memoize_method - def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]: + def ancestors(self, node: NodeT) -> tuple[NodeT, ...]: """ Returns a :class:`tuple` of nodes that are ancestors of *node*. """ @@ -104,7 +105,7 @@ def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]: return (parent,) + self.ancestors(parent) - def parent(self, node: NodeT) -> Optional[NodeT]: + def parent(self, node: NodeT) -> NodeT | None: """ Returns the parent of *node*. """ @@ -112,7 +113,7 @@ def parent(self, node: NodeT) -> Optional[NodeT]: return self._child_to_parent[node] - def children(self, node: NodeT) -> Tuple[NodeT, ...]: + def children(self, node: NodeT) -> tuple[NodeT, ...]: """ Returns the children of *node*. """ @@ -150,7 +151,7 @@ def __contains__(self, node: NodeT) -> bool: """Return *True* if *node* is a node in the tree.""" return node in self._child_to_parent - def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]": + def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]: """ Returns a :class:`Tree` with added node *node* having a parent *parent*. @@ -165,7 +166,7 @@ def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]": .set(node, ())), self._child_to_parent.set(node, parent)) - def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]": + def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]: """ Returns a copy of *self* with *node* replaced with *new_node*. """ @@ -207,7 +208,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]": return Tree(parent_to_children_mut.finish(), child_to_parent_mut.finish()) - def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]": + def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]: """ Returns a copy of *self* with node *node* as a child of *new_parent*. """ @@ -262,7 +263,7 @@ def __str__(self) -> str: ├── D └── E """ - def rec(node: NodeT) -> List[str]: + def rec(node: NodeT) -> list[str]: children_result = [rec(c) for c in self.children(node)] def post_process_non_last_child(children: Sequence[str]) -> list[str]: diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index c2cd0a5ca..b0fbb5468 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -155,7 +155,8 @@ def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]: # {{{ gather rule invocations class RuleInvocationGatherer(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): + def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within) \ + -> None: super().__init__(rule_mapping_context) from loopy.symbolic import SubstitutionRuleExpander @@ -167,7 +168,7 @@ def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): self.subst_tag = subst_tag self.within = within - self.access_descriptors: List[RuleAccessDescriptor] = [] + self.access_descriptors: list[RuleAccessDescriptor] = [] def map_substitution(self, name, tag, arguments, expn_state): process_me = name == self.subst_name From 41b328882172aafddd6c16ee6260df1177a5319b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 14 Nov 2024 10:18:48 -0600 Subject: [PATCH 7/7] LazilyUnpickling{Dict,List}: add __repr__ (#817) * LazilyUnpickling{Dict,List}: better repr * add type to repr of PickledObject --- loopy/tools.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/loopy/tools.py b/loopy/tools.py index bf7785fcf..bb4904bf2 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -348,6 +348,9 @@ def unpickle(self): def __getstate__(self): return {"objstring": self.objstring} + def __repr__(self) -> str: + return type(self).__name__ + "(" + repr(self.unpickle()) + ")" + class _PickledObjectWithEqAndPersistentHashKeys(_PickledObject): """Like :class:`_PickledObject`, with two additional attributes: @@ -406,6 +409,9 @@ def __getstate__(self): key: _PickledObject(val) for key, val in self._map.items()}} + def __repr__(self) -> str: + return type(self).__name__ + "(" + repr(self._map) + ")" + # }}} @@ -444,6 +450,9 @@ def __add__(self, other): def __mul__(self, other): return self._list * other + def __repr__(self) -> str: + return type(self).__name__ + "(" + repr(self._list) + ")" + class LazilyUnpicklingListWithEqAndPersistentHashing(LazilyUnpicklingList): """A list which lazily unpickles its values, and supports equality comparison