Skip to content

Commit

Permalink
Merge branch 'main' into withtag
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Nov 14, 2024
2 parents ad08e3e + 41b3288 commit 123a457
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 120 deletions.
4 changes: 2 additions & 2 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpressionT
from pymbolic import ArithmeticExpressionT, Variable
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase

Expand Down Expand Up @@ -113,7 +113,7 @@ def _names_from_expr(expr: Union[None, ExpressionT, str]) -> FrozenSet[str]:
if isinstance(expr, str):
return frozenset({expr})
elif isinstance(expr, Expression):
return frozenset(v.name for v in dep_mapper(expr))
return frozenset(cast(Variable, v).name for v in dep_mapper(expr))
elif expr is None:
return frozenset()
elif isinstance(expr, Number):
Expand Down
4 changes: 2 additions & 2 deletions loopy/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
THE SOFTWARE.
"""

import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from sys import intern
Expand Down Expand Up @@ -66,8 +67,7 @@
"""


def re_from_glob(s):
import re
def re_from_glob(s: str) -> re.Pattern:
from fnmatch import translate
return re.compile("^"+translate(s.strip())+"$")

Expand Down
85 changes: 38 additions & 47 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,11 @@

import logging
import sys
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
from dataclasses import dataclass, replace
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Dict,
FrozenSet,
Hashable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -155,7 +146,7 @@ class Barrier(ScheduleItem):

def gather_schedule_block(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Tuple[Sequence[ScheduleItem], int]:
) -> tuple[Sequence[ScheduleItem], int]:
assert isinstance(schedule[start_idx], BeginBlockItem)
level = 0

Expand All @@ -176,7 +167,7 @@ def gather_schedule_block(

def generate_sub_sched_items(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Iterator[Tuple[int, ScheduleItem]]:
) -> Iterator[tuple[int, ScheduleItem]]:
if not isinstance(schedule[start_idx], BeginBlockItem):
yield start_idx, schedule[start_idx]

Expand All @@ -203,7 +194,7 @@ def generate_sub_sched_items(

def get_insn_ids_for_block_at(
schedule: Sequence[ScheduleItem], start_idx: int
) -> FrozenSet[str]:
) -> frozenset[str]:
return frozenset(
sub_sched_item.insn_id
for i, sub_sched_item in generate_sub_sched_items(
Expand All @@ -212,7 +203,7 @@ def get_insn_ids_for_block_at(


def find_used_inames_within(
kernel: LoopKernel, sched_index: int) -> AbstractSet[str]:
kernel: LoopKernel, sched_index: int) -> set[str]:
assert kernel.linearization is not None
sched_item = kernel.linearization[sched_index]

Expand All @@ -234,7 +225,7 @@ def find_used_inames_within(
return result


def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested with them.
"""
Expand All @@ -257,11 +248,11 @@ def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]
return result


def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested around them.
"""
result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

all_inames = kernel.all_inames()

Expand Down Expand Up @@ -299,14 +290,14 @@ def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[st

def find_loop_insn_dep_map(
kernel: LoopKernel,
loop_nest_with_map: Mapping[str, AbstractSet[str]],
loop_nest_around_map: Mapping[str, AbstractSet[str]]
) -> Mapping[str, AbstractSet[str]]:
loop_nest_with_map: Mapping[str, Set[str]],
loop_nest_around_map: Mapping[str, Set[str]]
) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other instruction ids that need to
be scheduled before the iname should be eligible for scheduling.
"""

result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

from loopy.kernel.data import ConcurrentTag, IlpBaseTag
for insn in kernel.instructions:
Expand Down Expand Up @@ -372,7 +363,7 @@ def find_loop_insn_dep_map(


def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:
result: Dict[str, int] = {}
result: dict[str, int] = {}

for insn in kernel.instructions:
for grp in insn.groups:
Expand All @@ -382,7 +373,7 @@ def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:


def gen_dependencies_except(
kernel: LoopKernel, insn_id: str, except_insn_ids: AbstractSet[str]
kernel: LoopKernel, insn_id: str, except_insn_ids: Set[str]
) -> Iterator[str]:
insn = kernel.id_to_insn[insn_id]
for dep_id in insn.depends_on:
Expand All @@ -396,9 +387,9 @@ def gen_dependencies_except(


def get_priority_tiers(
wanted: AbstractSet[int],
priorities: AbstractSet[Sequence[int]]
) -> Iterator[AbstractSet[int]]:
wanted: Set[int],
priorities: Set[Sequence[int]]
) -> Iterator[set[int]]:
# Get highest priority tier candidates: These are the first inames
# of all the given priority constraints
candidates = set()
Expand Down Expand Up @@ -677,32 +668,32 @@ class SchedulerState:
order with instruction priorities as tie breaker.
"""
kernel: LoopKernel
loop_nest_around_map: Mapping[str, AbstractSet[str]]
loop_insn_dep_map: Mapping[str, AbstractSet[str]]
loop_nest_around_map: Mapping[str, set[str]]
loop_insn_dep_map: Mapping[str, set[str]]

breakable_inames: AbstractSet[str]
ilp_inames: AbstractSet[str]
vec_inames: AbstractSet[str]
concurrent_inames: AbstractSet[str]
breakable_inames: set[str]
ilp_inames: set[str]
vec_inames: set[str]
concurrent_inames: set[str]

insn_ids_to_try: Optional[AbstractSet[str]]
insn_ids_to_try: set[str] | None
active_inames: Sequence[str]
entered_inames: FrozenSet[str]
enclosing_subkernel_inames: Tuple[str, ...]
entered_inames: frozenset[str]
enclosing_subkernel_inames: tuple[str, ...]
schedule: Sequence[ScheduleItem]
scheduled_insn_ids: AbstractSet[str]
unscheduled_insn_ids: AbstractSet[str]
scheduled_insn_ids: frozenset[str]
unscheduled_insn_ids: set[str]
preschedule: Sequence[ScheduleItem]
prescheduled_insn_ids: AbstractSet[str]
prescheduled_inames: AbstractSet[str]
prescheduled_insn_ids: set[str]
prescheduled_inames: set[str]
may_schedule_global_barriers: bool
within_subkernel: bool
group_insn_counts: Mapping[str, int]
active_group_counts: Mapping[str, int]
insns_in_topologically_sorted_order: Sequence[InstructionBase]

@property
def last_entered_loop(self) -> Optional[str]:
def last_entered_loop(self) -> str | None:
if self.active_inames:
return self.active_inames[-1]
else:
Expand All @@ -718,7 +709,7 @@ def get_insns_in_topologically_sorted_order(
kernel: LoopKernel) -> Sequence[InstructionBase]:
from pytools.graph import compute_topological_order

rev_dep_map: Dict[str, Set[str]] = {
rev_dep_map: dict[str, set[str]] = {
not_none(insn.id): set() for insn in kernel.instructions}
for insn in kernel.instructions:
for dep in insn.depends_on:
Expand All @@ -733,7 +724,7 @@ def get_insns_in_topologically_sorted_order(
# Instead of returning these features as a key, we assign an id to
# each set of features to avoid comparing them which can be expensive.
insn_id_to_feature_id = {}
insn_features: Dict[Hashable, int] = {}
insn_features: dict[Hashable, int] = {}
for insn in kernel.instructions:
feature = (insn.within_inames, insn.groups, insn.conflicts_with_groups)
if feature not in insn_features:
Expand Down Expand Up @@ -890,7 +881,7 @@ def _get_outermost_diverging_inames(
tree: LoopTree,
within1: InameStrSet,
within2: InameStrSet
) -> Tuple[InameStr, InameStr]:
) -> tuple[InameStr, InameStr]:
"""
For loop nestings *within1* and *within2*, returns the first inames at which
the loops nests diverge in the loop nesting tree *tree*.
Expand Down Expand Up @@ -2180,7 +2171,7 @@ def __init__(self, kernel):
def generate_loop_schedules(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None = None) -> Iterator[LoopKernel]:
"""
.. warning::
Expand Down Expand Up @@ -2236,7 +2227,7 @@ def _postprocess_schedule(kernel, callables_table, gen_sched):
def _generate_loop_schedules_inner(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None) -> Iterator[LoopKernel]:
if debug_args is None:
debug_args = {}

Expand Down Expand Up @@ -2337,7 +2328,7 @@ def _generate_loop_schedules_inner(
get_insns_in_topologically_sorted_order(kernel)),
)

schedule_gen_kwargs: Dict[str, Any] = {}
schedule_gen_kwargs: dict[str, Any] = {}

def print_longest_dead_end():
if debug.interactive:
Expand Down Expand Up @@ -2402,7 +2393,7 @@ def print_longest_dead_end():


schedule_cache: WriteOncePersistentDict[
Tuple[LoopKernel, CallablesTable],
tuple[LoopKernel, CallablesTable],
LoopKernel
] = WriteOncePersistentDict(
"loopy-schedule-cache-v4-"+DATA_MODEL_VERSION,
Expand Down
23 changes: 12 additions & 11 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
THE SOFTWARE.
"""

from collections.abc import Hashable, Iterator, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Hashable, Iterator, List, Optional, Sequence, Tuple, TypeVar
from typing import Generic, TypeVar

from immutables import Map

Expand Down Expand Up @@ -70,11 +71,11 @@ class Tree(Generic[NodeT]):
this allocates a new stack frame for each iteration of the operation.
"""

_parent_to_children: Map[NodeT, Tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, Optional[NodeT]]
_parent_to_children: Map[NodeT, tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, NodeT | None]

@staticmethod
def from_root(root: NodeT) -> "Tree[NodeT]":
def from_root(root: NodeT) -> Tree[NodeT]:
return Tree(Map({root: ()}),
Map({root: None}))

Expand All @@ -89,7 +90,7 @@ def root(self) -> NodeT:
return guess

@memoize_method
def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:
def ancestors(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns a :class:`tuple` of nodes that are ancestors of *node*.
"""
Expand All @@ -104,15 +105,15 @@ def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:

return (parent,) + self.ancestors(parent)

def parent(self, node: NodeT) -> Optional[NodeT]:
def parent(self, node: NodeT) -> NodeT | None:
"""
Returns the parent of *node*.
"""
assert node in self

return self._child_to_parent[node]

def children(self, node: NodeT) -> Tuple[NodeT, ...]:
def children(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns the children of *node*.
"""
Expand Down Expand Up @@ -150,7 +151,7 @@ def __contains__(self, node: NodeT) -> bool:
"""Return *True* if *node* is a node in the tree."""
return node in self._child_to_parent

def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:
"""
Returns a :class:`Tree` with added node *node* having a parent
*parent*.
Expand All @@ -165,7 +166,7 @@ def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
.set(node, ())),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
"""
Returns a copy of *self* with *node* replaced with *new_node*.
"""
Expand Down Expand Up @@ -207,7 +208,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
return Tree(parent_to_children_mut.finish(),
child_to_parent_mut.finish())

def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]":
def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
"""
Returns a copy of *self* with node *node* as a child of *new_parent*.
"""
Expand Down Expand Up @@ -262,7 +263,7 @@ def __str__(self) -> str:
├── D
└── E
"""
def rec(node: NodeT) -> List[str]:
def rec(node: NodeT) -> list[str]:
children_result = [rec(c) for c in self.children(node)]

def post_process_non_last_child(children: Sequence[str]) -> list[str]:
Expand Down
6 changes: 3 additions & 3 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
Set,
Tuple,
Union,
cast,
)

from immutables import Map

from pymbolic import var
from pymbolic import Variable, var
from pytools.codegen import CodeGenerator, Indentation
from pytools.py_codegen import PythonFunctionGenerator

Expand Down Expand Up @@ -260,7 +261,7 @@ def generate_integer_arg_finding_from_array_data(
unknown_var, = deps
order_to_unknown_to_equations \
.setdefault(eqn.order, {}) \
.setdefault(unknown_var.name, []) \
.setdefault(cast(Variable, unknown_var).name, []) \
.append((eqn))
else:
# Zero deps: nothing to determine, forget about it.
Expand All @@ -274,7 +275,6 @@ def generate_integer_arg_finding_from_array_data(
# {{{ generate arg finding code

from pymbolic.algorithm import solve_affine_equations_for
from pymbolic.primitives import Variable
from pytools.codegen import CodeGenerator

gen("# {{{ find integer arguments from array data")
Expand Down
Loading

0 comments on commit 123a457

Please sign in to comment.