From ac282860cc29386c890cf8dfed50fe4848c50e9d Mon Sep 17 00:00:00 2001 From: Eivind Fonn Date: Tue, 13 Feb 2024 21:09:08 +0100 Subject: [PATCH] Mypy type safety: round 5 --- splipy/splinemodel.py | 339 +++++++++++++++++++++++++-------------- splipy/splineobject.py | 6 +- splipy/types.py | 11 +- splipy/utils/__init__.py | 22 ++- 4 files changed, 238 insertions(+), 140 deletions(-) diff --git a/splipy/splinemodel.py b/splipy/splinemodel.py index 3d0e9eb..f1ffd79 100644 --- a/splipy/splinemodel.py +++ b/splipy/splinemodel.py @@ -1,24 +1,30 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from collections import Counter, OrderedDict, namedtuple +from dataclasses import dataclass from itertools import chain, product, permutations, islice from operator import itemgetter -from typing import Callable, Dict, List, Tuple, Any, Optional +from typing import Callable, Dict, List, Tuple, Any, Optional, Literal, Sequence, Union, TypeVar, cast, Iterator +from typing_extensions import Self, Unpack import numpy as np +from numpy.typing import NDArray from .splineobject import SplineObject from .utils import check_section, sections, section_from_index, section_to_index, uniquify, is_right_hand from .utils import bisect +from .types import Scalar, FArray, Section, SectionLike, SectionElt, SectionKwargs from . import state -try: - from collections.abc import MutableMapping -except ImportError: - from collections import MutableMapping +from collections.abc import MutableMapping + +IArray = NDArray[np.int_] -def _section_to_index(section): + +def _section_to_index(section: Section) -> tuple[Union[Literal[-1, 0], slice], ...]: """Replace all `None` in `section` with `slice(None)`, so that it works as a numpy array indexing tuple. """ @@ -28,7 +34,11 @@ def _section_to_index(section): face_t = np.dtype([('nodes', int, (4,)), ('owner', int, ()), ('neighbor', int, ()), ('name', object, ())]) -class VertexDict(MutableMapping): +T = TypeVar("T") +G = TypeVar("G", bound=np.generic) + + +class VertexDict(MutableMapping[FArray, T]): """A dictionary where the keys are numpy arrays, and where equality is computed in an approximate sense for floating point numbers. @@ -38,20 +48,19 @@ class VertexDict(MutableMapping): rtol: float atol: float - _keys: List[Optional[np.ndarray]] - _values: List[Any] + _keys: list[Optional[FArray]] + _values: list[Optional[T]] - lut: Dict[Tuple[int, ...], List[Tuple[int, float]]] + lut: dict[tuple[int, ...], list[tuple[int, float]]] - def __init__(self, rtol=1e-5, atol=1e-8): - # List of (key, value) pairs + def __init__(self, rtol: float = 1e-5, atol: float = 1e-8) -> None: self.rtol = rtol self.atol = atol self._keys = [] self._values = [] self.lut = dict() - def _bounds(self, key): + def _bounds(self, key: Scalar) -> tuple[Scalar, Scalar]: if key >= self.atol: return ( (key - self.atol) / (1 + self.rtol), @@ -69,14 +78,14 @@ def _bounds(self, key): (key + self.atol) / (1 - self.rtol), ) - def _candidate(self, key): + def _candidate(self, key: FArray) -> int: """Return the internal index for the first stored mapping that matches the given key. :param numpy.array key: The key to look for :raises KeyError: If the key is not found """ - candidates = None + candidates: Optional[set[int]] = None for coord, k in np.ndenumerate(key): lut = self.lut.setdefault(coord, []) minval, maxval = self._bounds(k) @@ -86,12 +95,14 @@ def _candidate(self, key): candidates = {i for i, _ in lut[lo:hi]} else: candidates &= {i for i, _ in lut[lo:hi]} + + assert candidates is not None for c in candidates: if self._keys[c] is not None: return c raise KeyError(key) - def _insert(self, key, value): + def _insert(self, key: FArray, value: T) -> None: newindex = len(self._values) for coord, v in np.ndenumerate(key): lut = self.lut.setdefault(coord, []) @@ -99,7 +110,7 @@ def _insert(self, key, value): self._keys.append(key) self._values.append(value) - def __setitem__(self, key, value): + def __setitem__(self, key: FArray, value: T) -> None: """Assign a key to a value.""" try: c = self._candidate(key) @@ -107,15 +118,15 @@ def __setitem__(self, key, value): except KeyError: self._insert(key, value) - def __getitem__(self, key): + def __getitem__(self, key: FArray) -> T: """Gets the value assigned to a key. :raises KeyError: If the key is not found """ c = self._candidate(key) - return self._values[c] + return cast(T, self._values[c]) - def __delitem__(self, key): + def __delitem__(self, key: FArray) -> None: """Deletes an assignment.""" try: i = self._candidate(key) @@ -124,18 +135,16 @@ def __delitem__(self, key): self._keys[i] = None self._values[i] = None - def __iter__(self): + def __iter__(self) -> Iterator[FArray]: """Iterate over all keys. .. note:: This generates all the stored keys, not all matching keys. """ - yield from self._keys - - def items(self): - """Return a list of key, value pairs.""" - yield from self._values + for key in self._keys: + if key is not None: + yield key - def __len__(self): + def __len__(self) -> int: """Returns the number of stored assignments.""" return len(self._values) @@ -147,13 +156,15 @@ class OrientationError(RuntimeError): """ pass + class TwinError(RuntimeError): """A `TwinError` is raised when two objects with identical interfaces are added, but different interiors. """ pass -class Orientation(object): + +class Orientation: """An `Orientation` represents a mapping between two coordinate systems: the *reference* system and the *actual* or *mapped* system. @@ -165,7 +176,11 @@ class Orientation(object): direction `d` *in the reference system* should be reversed. """ - def __init__(self, perm, flip): + perm: tuple[int, ...] + perm_inv: tuple[int, ...] + flip: tuple[bool, ...] + + def __init__(self, perm: tuple[int, ...], flip: tuple[bool, ...]): """Initialize an Orientation object. .. warning:: This constructor is for internal use. Use @@ -177,7 +192,7 @@ def __init__(self, perm, flip): self.perm_inv = tuple(perm.index(d) for d in range(len(perm))) @classmethod - def compute(cls, cpa, cpb=None): + def compute(cls, cpa: SplineObject, cpb: Optional[SplineObject] = None) -> Self: """Compute and return a new orientation object representing the mapping between `cpa` (the reference system) and `cpb` (the mapped system). @@ -194,8 +209,10 @@ def compute(cls, cpa, cpb=None): # Return the identity orientation if no cpb if cpb is None: - return cls(tuple(range(pardim)), - tuple(False for _ in range(pardim))) + return cls( + tuple(range(pardim)), + (False,) * pardim, + ) # Deal with the easy cases: dimension mismatch, and # comparing the shapes as multisets @@ -243,10 +260,10 @@ def compute(cls, cpa, cpb=None): raise OrientationError("Non-matching objects") @property - def pardim(self): + def pardim(self) -> int: return len(self.perm) - def __mul__(self, other): + def __mul__(self, other: Orientation) -> Orientation: """Compose two mappings. If `ort_left` maps system `A` (reference) to system `B`, and @@ -261,13 +278,13 @@ def __mul__(self, other): return Orientation(perm, flip) - def map_array(self, array): + def map_array(self, array: NDArray[G]) -> NDArray[G]: """Map an array in the mapped system to the reference system.""" array = array.transpose(*self.perm) flips = tuple(slice(None, None, -1) if f else slice(None) for f in self.flip) return array[flips] - def map_section(self, section): + def map_section(self, section: SectionLike) -> Section: """Map a section in the mapped system to the reference system. The input is a section tuple as described in @@ -277,7 +294,7 @@ def map_section(self, section): """ permuted = tuple(section[d] for d in self.perm) - flipped = () + flipped: Section = () for s, f, in zip(permuted, self.flip): # Flipping only applies to indexed directions, not variable ones if f and s is not None: @@ -287,7 +304,7 @@ def map_section(self, section): return flipped - def view_section(self, section): + def view_section(self, section: Section) -> Self: """Reduce a mapping to a lower dimension. The input is a section tuple as described in @@ -315,7 +332,7 @@ def view_section(self, section): return self.__class__(new_perm, new_flip) @property - def ifem_format(self): + def ifem_format(self) -> int: """Compute the orientation in IFEM format. For one-dimensional objects, this is a single binary digit indicating @@ -349,7 +366,7 @@ def ifem_format(self): ) -class TopologicalNode(object): +class TopologicalNode: """A `TopologicalNode` object refers to a single, persistent point in the topological graph. It represents some object of dimension `d` (that is, a point, an edge, etc.) and it has references to all the other objects it @@ -375,7 +392,17 @@ class TopologicalNode(object): of any kind. """ - def __init__(self, obj, lower_nodes, index): + obj: SplineObject + lower_nodes: list[tuple[TopologicalNode, ...]] + higher_nodes: dict[int, list[TopologicalNode]] + index: int + owner: Optional[TopologicalNode] + + name: Optional[str] + cell_numbers: Optional[IArray] + cp_numbers: Optional[IArray] + + def __init__(self, obj: SplineObject, lower_nodes: list[tuple[TopologicalNode, ...]], index: int) -> None: """Initialize a `TopologicalNode` object associated with the given `SplineObject` and lower order nodes. @@ -404,26 +431,26 @@ def __init__(self, obj, lower_nodes, index): node._transfer_ownership(self) @property - def pardim(self): + def pardim(self) -> int: return self.obj.pardim @property - def nhigher(self): + def nhigher(self) -> int: return len(self.higher_nodes[self.pardim + 1]) @property - def super_owner(self): + def super_owner(self) -> TopologicalNode: """Return the highest owning node.""" owner = self while owner.owner is not None: owner = owner.owner return owner - def assign_higher(self, node): + def assign_higher(self, node: TopologicalNode) -> None: """Add a link to a node of higher dimension.""" self.higher_nodes.setdefault(node.pardim, list()).append(node) - def view(self, other_obj=None): + def view(self, other_obj: Optional[SplineObject] = None) -> NodeView: """Return a `NodeView` object of this node. The returned view has an orientation that matches that of the input @@ -439,7 +466,7 @@ def view(self, other_obj=None): orientation = Orientation.compute(self.obj) return NodeView(self, orientation) - def _transfer_ownership(self, new_owner): + def _transfer_ownership(self, new_owner: TopologicalNode) -> None: """Transfers ownership of this node to a new owner. This operation is transitive, so all child nodes owned by this node, or who are owner-less will also be transferred. @@ -453,7 +480,7 @@ def _transfer_ownership(self, new_owner): if child.owner is self or child.owner is None: child._transfer_ownership(new_owner) - def generate_cp_numbers(self, start=0): + def generate_cp_numbers(self, start: int = 0) -> int: """Generate a control point numbering starting at `start`. Return the next unused index.""" assert self.owner is None @@ -476,7 +503,7 @@ def generate_cp_numbers(self, start=0): self.assign_cp_numbers(numbers) return start + nowned - def assign_cp_numbers(self, numbers): + def assign_cp_numbers(self, numbers: IArray) -> None: """Directly assign control point numbers.""" self.cp_numbers = numbers @@ -488,17 +515,19 @@ def assign_cp_numbers(self, numbers): # orientations not matching up. node.assign_cp_numbers(numbers[_section_to_index(section)]) - def read_cp_numbers(self): + def read_cp_numbers(self) -> None: """Read control point numbers for unowned control points from child nodes.""" + assert self.cp_numbers is not None for node, section in zip(self.lower_nodes[-1], sections(self.pardim, self.pardim-1)): + assert node.cp_numbers is not None if node.owner is not self: # The two sections may not agree on orientation, so we fix this here. - ori = Orientation.compute(self.obj.section(*section), node.obj) + ori = Orientation.compute(self.obj.section(*section, unwrap_points=False), node.obj) self.cp_numbers[_section_to_index(section)] = ori.map_array(node.cp_numbers) assert (self.cp_numbers != -1).all() - def generate_cell_numbers(self, start=0): + def generate_cell_numbers(self, start: int = 0) -> int: """Generate a cell numbering starting at `start`. Return the next unused index.""" assert self.owner is None @@ -506,17 +535,25 @@ def generate_cell_numbers(self, start=0): shape = [len(kvec) - 1 for kvec in self.obj.knots()] nelems = np.prod(shape) self.cell_numbers = np.reshape(np.arange(start, start + nelems, dtype=int), shape) - return start + nelems + return start + int(nelems) - def faces(self): + def faces(self) -> list[NDArray]: """Return all faces owned by this node, as a list of numpy arrays with dtype `face_t`.""" assert self.pardim == 3 assert self.obj.order() == (2,2,2) + assert self.cp_numbers is not None + assert self.cell_numbers is not None + shape = [len(kvec) - 1 for kvec in self.obj.knots()] ncells = np.prod(shape) retval = [] - def mkindex(dim, z, a, b): + def mkindex( + dim: int, + z: Union[slice, int], + a: Union[slice, int], + b: Union[slice, int], + ) -> tuple[Union[slice, int], ...]: rval = [a, b] if dim != 1 else [b, a] rval.insert(dim, z) return tuple(rval) @@ -568,13 +605,14 @@ def mkindex(dim, z, a, b): faces['neighbor'] = -1 else: neighbor = next(c for c in bdnode.higher_nodes[3] if c is not self) + assert neighbor.cell_numbers is not None # Find out which face the interface is as numbered from the neighbor's perspective nb_index = neighbor.lower_nodes[2].index(bdnode) # Get the spline object on that interface as oriented from the neighbor's perspective nb_sec = section_from_index(3, 2, nb_index) - nb_obj = neighbor.obj.section(*nb_sec) + nb_obj = neighbor.obj.section(*nb_sec, unwrap_points=False) # Compute the relative orientation ori = Orientation.compute(bdnode.obj, nb_obj) @@ -590,7 +628,7 @@ def mkindex(dim, z, a, b): return retval -class NodeView(object): +class NodeView: """A `NodeView` object refers to a *view* to a point in the topological graph. It is composed of a node (:class:`splipy.SplineModel.TopologicalNode`) and an orientation (:class:`splipy.SplineModel.Orienation`). @@ -599,7 +637,10 @@ class NodeView(object): persistent. """ - def __init__(self, node, orientation=None): + node: TopologicalNode + orientation: Optional[Orientation] + + def __init__(self, node: TopologicalNode, orientation: Optional[Orientation] = None) -> None: """Initialize a `NodeView` object with the given node and orientation. .. warning:: This constructor is for internal use. @@ -608,18 +649,18 @@ def __init__(self, node, orientation=None): self.orientation = orientation @property - def pardim(self): + def pardim(self) -> int: return self.node.pardim @property - def name(self): + def name(self) -> Optional[str]: return self.node.name @name.setter - def name(self, value): + def name(self, value: str) -> None: self.node.name = value - def section(self, *args, **kwargs): + def section(self, *args: SectionElt, **kwargs: Unpack[SectionKwargs]) -> NodeView: """Return a section. See :func:`splipy.SplineObject.section` for more details on the input arguments. @@ -631,6 +672,7 @@ def section(self, *args, **kwargs): tgt_dim = sum(1 for s in section if s is None) # The index of the section in the reference system + assert self.orientation is not None ref_idx = section_to_index(self.orientation.map_section(section)) # The underlying node @@ -643,40 +685,48 @@ def section(self, *args, **kwargs): return NodeView(node, ref_ori * my_ori) - def corner(self, i): + def corner(self, i: int) -> NodeView: """Return the i'th corner.""" return self.section(*section_from_index(self.pardim, 0, i)) @property - def corners(self): + def corners(self) -> tuple[NodeView, ...]: """A tuple of all corners.""" - return tuple(self.section(s) for s in sections(self.pardim, 0)) + return tuple(self.section(*s) for s in sections(self.pardim, 0)) - def edge(self, i): + def edge(self, i: int) -> NodeView: """Return the i'th edge.""" return self.section(*section_from_index(self.pardim, 1, i)) @property - def edges(self): + def edges(self) -> tuple[NodeView, ...]: """A tuple of all edges.""" - return tuple(self.section(s) for s in sections(self.pardim, 1)) + return tuple(self.section(*s) for s in sections(self.pardim, 1)) - def face(self, i): + def face(self, i: int) -> NodeView: """Return the i'th face.""" return self.section(*section_from_index(self.pardim, 2, i)) @property - def faces(self): + def faces(self) -> tuple[NodeView, ...]: """A tuple of all faces.""" - return tuple(self.section(s) for s in sections(self.pardim, 2)) + return tuple(self.section(*s) for s in sections(self.pardim, 2)) -class ObjectCatalogue(object): +class ObjectCatalogue: """An `ObjectCatalogue` maintains a complete topological graph of objects with at most `pardim` parametric directions. """ - def __init__(self, pardim): + pardim: int + count: int + + internal: OrderedDict[tuple[TopologicalNode, ...], list[TopologicalNode]] + + lower: Union[ObjectCatalogue, VertexDict[TopologicalNode]] + callbacks: dict[str, list[Callable[[TopologicalNode], None]]] + + def __init__(self, pardim: int) -> None: """Initialize a catalogue for objects of parametric dimension `pardim`. """ @@ -696,11 +746,11 @@ def __init__(self, pardim): # Callbacks for events self.callbacks = dict() - def add_callback(self, event: str, callback: Callable[[TopologicalNode], None]): + def add_callback(self, event: str, callback: Callable[[TopologicalNode], None]) -> None: """Add a callback function to be called on a given event.""" self.callbacks.setdefault(event, []).append(callback) - def lookup(self, obj, add=False, raise_on_twins=()): + def lookup(self, obj: SplineObject, add: bool = False, raise_on_twins: Sequence[int] = ()) -> NodeView: """Obtain the `NodeView` object corresponding to a given object. If the keyword argument `add` is true, this function may generate one @@ -724,10 +774,12 @@ def lookup(self, obj, add=False, raise_on_twins=()): """ # Pass lower-dimensional objects through to the lower levels if self.pardim > obj.pardim: + assert isinstance(self.lower, ObjectCatalogue) return self.lower.lookup(obj, add=add, raise_on_twins=raise_on_twins) # Special case for points: self.lower is a mapping from array to node if self.pardim == 0: + assert isinstance(self.lower, VertexDict) cps = obj.controlpoints if obj.rational: cps = cps[..., :-1] @@ -740,13 +792,20 @@ def lookup(self, obj, add=False, raise_on_twins=()): return rval return self.lower[cps].view() + assert isinstance(self.lower, ObjectCatalogue) + # Get all nodes of lower dimension (points, vertices, etc.) # This involves a recursive call to self.lower.__call__ lower_nodes = [] for i in range(0, self.pardim): - nodes = tuple(self.lower.lookup(obj.section(*args, unwrap_points=False), add=add, - raise_on_twins=raise_on_twins).node - for args in sections(self.pardim, i)) + nodes = tuple( + self.lower.lookup( + obj.section(*args, unwrap_points=False), + add=add, + raise_on_twins=raise_on_twins + ).node + for args in sections(self.pardim, i) + ) lower_nodes.append(nodes) # Try looking up the lower-order nodes in the internal dictionary, @@ -795,7 +854,7 @@ def lookup(self, obj, add=False, raise_on_twins=()): raise KeyError("No such object found") return self._add(obj, lower_nodes) - def add(self, obj, raise_on_twins=()): + def add(self, obj: SplineObject, raise_on_twins: Sequence[int] = ()) -> NodeView: """Add new nodes to the graph to accommodate the given object, then return the corresponding `NodeView` object. @@ -820,7 +879,7 @@ def add(self, obj, raise_on_twins=()): """ return self.lookup(obj, add=True, raise_on_twins=raise_on_twins) - def _add(self, obj, lower_nodes): + def _add(self, obj: SplineObject, lower_nodes: list[tuple[TopologicalNode, ...]]) -> NodeView: node = TopologicalNode(obj, lower_nodes, index=self.count) self.count += 1 # Assign the new node to each possible permutation of lower-order @@ -836,25 +895,37 @@ def _add(self, obj, lower_nodes): __call__ = add __getitem__ = lookup - def top_nodes(self): + def top_nodes(self) -> list[TopologicalNode]: """Return all nodes of the highest parametric dimension.""" return self.nodes(self.pardim) - def nodes(self, pardim): + def nodes(self, pardim: int) -> list[TopologicalNode]: """Return all nodes of a given parametric dimension.""" if self.pardim == pardim: if self.pardim > 0: return list(uniquify(chain.from_iterable(self.internal.values()))) + assert isinstance(self.lower, VertexDict) return list(uniquify(self.lower.values())) + assert isinstance(self.lower, ObjectCatalogue) return self.lower.nodes(pardim) -# FIXME: This class is unfinished, and right now it doesn't do much other than -# wrap ObjectCatalogue +# TODO: This class is unfinished, and right now it doesn't do much other than wrap ObjectCatalogue +class SplineModel: -class SplineModel(object): + pardim: int + dimension: int + force_right_hand: bool + catalogue: ObjectCatalogue + names: dict[str, SplineObject] - def __init__(self, pardim=3, dimension=3, objs=[], force_right_hand=False): + def __init__( + self, + pardim: int = 3, + dimension: int = 3, + objs: Sequence[SplineObject] = (), + force_right_hand: bool = False, + ) -> None: self.pardim = pardim self.dimension = dimension @@ -866,39 +937,48 @@ def __init__(self, pardim=3, dimension=3, objs=[], force_right_hand=False): self.names = {} self.add(objs) - def add_callback(self, event: str, callback: Callable[[TopologicalNode], None]): - catalogue = self.catalogue + def add_callback(self, event: str, callback: Callable[[TopologicalNode], None]) -> None: + catalogue: Union[ObjectCatalogue, VertexDict] = self.catalogue while isinstance(catalogue, ObjectCatalogue): catalogue.add_callback(event, callback) catalogue = catalogue.lower - def add(self, obj, name=None, raise_on_twins=True): + def add( + self, + obj: Union[SplineObject, Sequence[SplineObject]], + name: Optional[str] = None, + raise_on_twins: Union[bool, Sequence[int]] = True, + ) -> None: + rot: tuple[int, ...] if raise_on_twins is True: - raise_on_twins = tuple(range(self.pardim + 1)) + rot = tuple(range(self.pardim + 1)) elif raise_on_twins is False: - raise_on_twins = () - if isinstance(obj, SplineObject): - obj = [obj] - self._validate(obj) - self._generate(obj, raise_on_twins=raise_on_twins) - if name and isinstance(obj, SplineObject): - self.names[name] = obj - - def __getitem__(self, obj): + rot = () + else: + rot = tuple(raise_on_twins) + objs = [obj] if isinstance(obj, SplineObject) else obj + + self._validate(objs) + self._generate(objs, raise_on_twins=rot) + if name: + for obj in objs: + self.names[name] = obj + + def __getitem__(self, obj: SplineObject) -> NodeView: return self.catalogue[obj] - def boundary(self, name=None): + def boundary(self, name: Optional[str] = None) -> Iterator[TopologicalNode]: for node in self.catalogue.nodes(self.pardim-1): if node.nhigher == 1 and (name is None or name == node.name): yield node - def assign_boundary(self, name): + def assign_boundary(self, name: str) -> None: """Give a name to all unnamed boundary nodes.""" for node in self.boundary(): if node.name is None: node.name = name - def _validate(self, objs): + def _validate(self, objs: Sequence[SplineObject]) -> None: if any(p.dimension != self.dimension for p in objs): raise ValueError("Patches with different dimension added") if any(p.pardim > self.pardim for p in objs): @@ -909,10 +989,10 @@ def _validate(self, objs): indices = ', '.join(map(str, left_inds)) raise ValueError(f"Possibly left-handed patches detected, indexes {indices}") - def _generate(self, objs, **kwargs): + def _generate(self, objs: Sequence[SplineObject], raise_on_twins: Sequence[int]) -> None: for i, p in enumerate(objs): try: - self.catalogue.add(p, **kwargs) + self.catalogue.add(p, raise_on_twins=raise_on_twins) except OrientationError as err: # TODO: Mutating exceptions is fishy. if len(err.args) > 1: @@ -923,7 +1003,7 @@ def _generate(self, objs, **kwargs): ) raise err - def generate_cp_numbers(self): + def generate_cp_numbers(self) -> None: index = 0 for node in self.catalogue.top_nodes(): index = node.generate_cp_numbers(index) @@ -931,49 +1011,61 @@ def generate_cp_numbers(self): for node in self.catalogue.top_nodes(): node.read_cp_numbers() - def generate_cell_numbers(self): + def generate_cell_numbers(self) -> None: index = 0 for node in self.catalogue.top_nodes(): index = node.generate_cell_numbers(index) self.ncells = index - def cps(self): - cps = np.zeros((self.ncps, self.dimension)) + def cps(self) -> FArray: + cps = np.zeros((self.ncps, self.dimension), dtype=float) for node in self.catalogue.top_nodes(): + assert node.cp_numbers is not None indices = node.cp_numbers.reshape(-1) values = node.obj.controlpoints.reshape(-1, self.dimension) cps[indices] = values return cps - def faces(self): + def faces(self) -> NDArray: assert self.pardim == 3 faces = list(chain.from_iterable(node.faces() for node in self.catalogue.top_nodes())) return np.hstack(faces) - def summary(self): - c = self.catalogue + def summary(self) -> None: + c: Union[ObjectCatalogue, VertexDict] = self.catalogue while isinstance(c, ObjectCatalogue): print('Dim {}: {}'.format(c.pardim, len(c.top_nodes()))) c = c.lower - def write_ifem(self, filename): + def write_ifem(self, filename: str) -> None: IFEMWriter(self).write(filename) - -IFEMConnection = namedtuple('IFEMConnection', ['master', 'slave', 'midx', 'sidx', 'orient']) +# TODO: Py310 add slots=True +@dataclass(frozen=True) +class IFEMConnection: + master: int + slave: int + midx: int + sidx: int + orient: int class IFEMWriter: - def __init__(self, model): + model: SplineModel + + nodes: list[TopologicalNode] + node_ids: dict[TopologicalNode, int] + + def __init__(self, model: SplineModel) -> None: self.model = model # List the nodes so that the order is deterministic self.nodes = list(model.catalogue.top_nodes()) self.node_ids = {node: i for i, node in enumerate(self.nodes)} - def connections(self): + def connections(self) -> Iterator[IFEMConnection]: p = self.model.pardim # For every object in the model... @@ -1014,14 +1106,14 @@ def connections(self): orientation = Orientation.compute(node_sub, neigh_sub) yield IFEMConnection( - master = self.node_ids[node] + 1, - slave = self.node_ids[neigh] + 1, - midx = node_sub_idx + 1, - sidx = neigh_sub_idx + 1, - orient = orientation.ifem_format, + master=self.node_ids[node] + 1, + slave=self.node_ids[neigh] + 1, + midx=node_sub_idx + 1, + sidx=neigh_sub_idx + 1, + orient=orientation.ifem_format, ) - def write(self, filename): + def write(self, filename: str) -> None: lines = [ "", "", @@ -1052,11 +1144,12 @@ def write(self, filename): }) for name in names: - entries = {} + entries: dict[int, set[int]] = {} for node in self.model.catalogue.nodes(self.model.pardim - 1): if node.name != name: continue parent = node.owner + assert parent is not None sub_idx = next(idx for idx, sub in enumerate(parent.lower_nodes[self.model.pardim - 1]) if sub is node) entries.setdefault(self.node_ids[parent], set()).add(sub_idx) if entries: diff --git a/splipy/splineobject.py b/splipy/splineobject.py index 0e8c10e..8c97235 100644 --- a/splipy/splineobject.py +++ b/splipy/splineobject.py @@ -12,7 +12,7 @@ from numpy.typing import ArrayLike from .basis import BSplineBasis -from .types import Direction, Scalars, Scalar, FArray, IArray, ScalarOrScalars +from .types import Direction, Scalars, Scalar, FArray, IArray, ScalarOrScalars, SectionElt from .utils import ( reshape, rotation_matrix, is_singleton, ensure_listlike, check_direction, ensure_flatlist, check_section, sections, @@ -452,7 +452,7 @@ def tangent(self, *params, direction = None, above = True, tensor = True): # ty @overload def section( self, - *args: Literal[-1, 0, None], + *args: SectionElt, unwrap_points: Literal[True] = True, **kwargs: Unpack[SectionKwargs] ) -> Union[SplineObject, FArray]: @@ -461,7 +461,7 @@ def section( @overload def section( self, - *args: Literal[-1, 0, None], + *args: SectionElt, unwrap_points: Literal[False], **kwargs: Unpack[SectionKwargs] ) -> SplineObject: diff --git a/splipy/types.py b/splipy/types.py index 1c407f2..db42d1b 100644 --- a/splipy/types.py +++ b/splipy/types.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Union, Literal +from typing import Union, Literal, TypedDict from numpy.typing import NDArray from numpy import float_, int_ @@ -21,3 +21,12 @@ ] ScalarOrScalars = Union[Scalar, Scalars] + +SectionElt = Literal[-1, 0, None] +SectionLike = Sequence[SectionElt] +Section = tuple[SectionElt, ...] + +class SectionKwargs(TypedDict, total=False): + u: SectionElt + v: SectionElt + w: SectionElt diff --git a/splipy/utils/__init__.py b/splipy/utils/__init__.py index 432a805..b8d4065 100644 --- a/splipy/utils/__init__.py +++ b/splipy/utils/__init__.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from itertools import combinations, product from math import atan2, sqrt import numpy as np -from typing import TYPE_CHECKING, SupportsFloat, Literal, TypedDict, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, SupportsFloat, Literal, TypedDict, Sequence, TypeVar, Union, Iterator from typing_extensions import Unpack try: @@ -11,7 +13,7 @@ except ImportError: from collections import Sized -from ..types import Direction, ScalarOrScalars, Scalar +from ..types import Direction, ScalarOrScalars, Scalar, Section, SectionElt, SectionKwargs, SectionLike if TYPE_CHECKING: from ..basis import BSplineBasis @@ -53,7 +55,7 @@ def rotation_matrix(theta, axis): [2*(b*c+a*d), a*a+c*c-b*b-d*d, 2*(c*d-a*b)], [2*(b*d-a*c), 2*(c*d+a*b), a*a+d*d-b*b-c*c]]) -def sections(src_dim, tgt_dim): +def sections(src_dim, tgt_dim) -> Iterator[Section]: """Generate all boundary sections from a source dimension to a target dimension. For example, `sections(3,1)` generates all edges on a volume. @@ -69,9 +71,9 @@ def sections(src_dim, tgt_dim): args = [None] * src_dim for f, i in zip(fixed, indices[::-1]): args[f] = i - yield args + yield tuple(args) -def section_from_index(src_dim, tgt_dim, i): +def section_from_index(src_dim, tgt_dim, i) -> Section: """Return the i'th section from a source dimension to a target dimension. See :func:`splipy.Utils.sections` for more information. @@ -80,7 +82,7 @@ def section_from_index(src_dim, tgt_dim, i): if i == j: return s -def section_to_index(section): +def section_to_index(section: SectionLike) -> int: """Return the index corresponding to a section.""" src_dim = len(section) tgt_dim = sum(1 for s in section if s is None) @@ -88,13 +90,7 @@ def section_to_index(section): if tuple(section) == tuple(t): return i - -class SectionKwargs(TypedDict, total=False): - u: Literal[-1, 0, None] - v: Literal[-1, 0, None] - w: Literal[-1, 0, None] - -def check_section(*args: Literal[-1, 0, None], pardim: int = 0, **kwargs: Unpack[SectionKwargs]) -> tuple[Literal[-1, 0, None]]: +def check_section(*args: SectionElt, pardim: int = 0, **kwargs: Unpack[SectionKwargs]) -> Section: """check_section(u, v, ...) Parse arguments and return a section spec.