diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index deacce792..63a11becb 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -8,8 +8,6 @@ """Circular convolution linear operators.""" import math -import operator -from functools import partial from typing import Optional, Sequence, Tuple, Union import numpy as np @@ -205,7 +203,7 @@ def _adj(self, x: snp.Array) -> snp.Array: # type: ignore H_adj_x = H_adj_x.real return H_adj_x - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") @@ -218,7 +216,7 @@ def __add__(self, other): h_is_dft=True, ) - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index 01f8789de..b8adcdfe4 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -12,9 +12,6 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial - import numpy as np from jax.dtypes import result_type @@ -85,7 +82,7 @@ def __init__( def _eval(self, x: snp.Array) -> snp.Array: return convolve(x, self.h, mode=self.mode) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -102,7 +99,7 @@ def __add__(self, other): raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -216,7 +213,7 @@ def __init__( def _eval(self, h: snp.Array) -> snp.Array: return convolve(self.x, h, mode=self.mode) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -231,7 +228,7 @@ def __add__(self, other): ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index f9ef22a47..393d172da 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -11,8 +11,6 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial from typing import Optional, Union import scico.numpy as snp @@ -101,13 +99,13 @@ def gram_op(self) -> Diagonal: """ return Diagonal(diagonal=self.diagonal.conj() * self.diagonal) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal + other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal - other.diagonal) @@ -205,7 +203,7 @@ def gram_op(self) -> ScaledIdentity: input_dtype=self.input_dtype, ) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( @@ -215,7 +213,7 @@ def __add__(self, other): ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index a0cca0ee1..344d27a4a 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -11,8 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial, wraps +from functools import wraps from typing import Callable, Optional, Union import numpy as np @@ -29,36 +28,55 @@ from scico.typing import BlockShape, DType, Shape -def _wrap_add_sub(func: Callable, op: Callable) -> Callable: - r"""Wrapper function for defining `__add__`, `__sub__`. +def _wrap_add_sub(func: Callable) -> Callable: + r"""Wrapper function for defining `__add__` and `__sub__`. - Wrapper function for defining `__add__`,` __sub__` between + Wrapper function for defining `__add__` and ` __sub__` between :class:`LinearOperator` and derived classes. Operations between :class:`LinearOperator` and :class:`.Operator` types are also supported. - Handles shape checking and dispatching based on operand types: - - - If one of the two operands is an :class:`.Operator`, an - :class:`.Operator` is returned. - - If both operands are :class:`LinearOperator` of different types, - a generic :class:`LinearOperator` is returned. - - If both operands are :class:`LinearOperator` of the same type, a - special constructor can be called + Handles shape checking and function dispatch based on types of + operands `a` and `b` in the call `func(a, b)`. Note that `func` + will always be a method of the type of `a`, and since this wrapper + should only be applied within :class:`LinearOperator` or derived + classes, we can assume that `a` is always an instance of + :class:`LinearOperator`. The general rule for dispatch is that the + `__add__` or `__sub__` operator of the nearest common base class + of `a` and `b` should be called. If `b` is derived from `a`, this + entails using the operator defined in the class of `a`, and + vice-versa. If one of the operands is not a descendant of the other + in the class hierarchy, then it is assumed that their common base + class is either :class:`.Operator` or :class:`LinearOperator`, + depending on the type of `b`. + + - If `b` is not an instance of :class:`.Operator`, a :exc:`TypeError` + is raised. + - If the shapes of `a` and `b` do not match, a :exc:`ValueError` is + raised. + - If `b` is an instance of the type of `a` then `func(a, b)` is + called where `func` is the argument of this wrapper, i.e. + the unwrapped function defined in the class of `a`. + - If `a` is an instance of the type of `b` then `func(a, b)` is + called where `func` is the unwrapped function defined in the class + of `b`. + - If `b` is a :class:`LinearOperator` then `func(a, b)` is called + where `func` is the operator defined in :class:`LinearOperator`. + - Othwerwise, `func(a, b)` is called where `func` is the operator + defined in :class:`.Operator`. Args: func: should be either `.__add__` or `.__sub__`. - op: functional equivalent of func, ex. op.add for func = - `__add__`. + + Returns: + Wrapped version of `func`. Raises: - ValueError: If the shape of both operators does not match. + ValueError: If the shapes of two operators do not match. TypeError: If one of the two operands is not an :class:`.Operator` or :class:`LinearOperator`. """ - # https://stackoverflow.com/a/58290475 - @wraps(func) def wrapper( a: LinearOperator, b: Union[Operator, LinearOperator] @@ -66,21 +84,33 @@ def wrapper( if isinstance(b, Operator): if a.shape == b.shape: if isinstance(b, type(a)): - # same type of linop, e.g. convolution can have special - # behavior (see Conv2d.__add__) + # b is an instance of the class of a: call the unwrapped operator + # defined in the class of a, which is the func argument of this + # wrapper return func(a, b) if isinstance(a, type(b)): - # same type of linop, but with operands reversed from case above + # a is an instance of class b: call the unwrapped operator + # defined in the class of b. A test is required because + # the operators defined in Operator and non-LinearOperator + # derived classes are not wrapped. if hasattr(getattr(type(b), func.__name__), "_unwrapped"): uwfunc = getattr(type(b), func.__name__)._unwrapped else: uwfunc = getattr(type(b), func.__name__) return uwfunc(a, b) - if isinstance(a, LinearOperator) and isinstance(b, LinearOperator): + # The most general approach here would be to automatically determine + # the nearest common ancestor of the classes of a and b (e.g. as + # discussed in https://stackoverflow.com/a/58290475 ), but the + # simpler approach adopted here is to just assume that the common + # base of two classes that do not have an ancestor-descendant + # relationship is either Operator or LinearOperator. + if isinstance(b, LinearOperator): # LinearOperator + LinearOperator -> LinearOperator uwfunc = getattr(LinearOperator, func.__name__)._unwrapped return uwfunc(a, b) - # LinearOperator + Operator -> Operator + # LinearOperator + Operator -> Operator (access to the function + # definition differs from that for LinearOperator because + # Operator __add__ and __sub__ are not wrapped) uwfunc = getattr(Operator, func.__name__) return uwfunc(a, b) raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") @@ -178,7 +208,7 @@ def jit(self): self._adj = jax.jit(self._adj) self._gram = jax.jit(self._gram) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): return LinearOperator( input_shape=self.input_shape, @@ -189,7 +219,7 @@ def __add__(self, other): output_dtype=result_type(self.output_dtype, other.output_dtype), ) - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): return LinearOperator( input_shape=self.input_shape, diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 662df0e49..5ae43b85f 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -18,10 +18,10 @@ import numpy as np import jax.numpy as jnp -from jax.dtypes import result_type from jax.typing import ArrayLike import scico.numpy as snp +from scico.operator._operator import Operator from ._diag import Identity from ._linop import LinearOperator @@ -45,17 +45,17 @@ def wrapper(a, b): raise ValueError(f"Shapes {a.matrix_shape} and {b.shape} do not match.") + if isinstance(b, Operator): + if a.shape != b.shape: + raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") + if isinstance(b, LinearOperator): - if a.shape == b.shape: - return LinearOperator( - input_shape=a.input_shape, - output_shape=a.output_shape, - eval_fn=lambda x: op(a(x), b(x)), - input_dtype=a.input_dtype, - output_dtype=result_type(a.output_dtype, b.output_dtype), - ) + uwfunc = getattr(LinearOperator, func.__name__)._unwrapped + return uwfunc(a, b) - raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") + if isinstance(b, Operator): + uwfunc = getattr(Operator, func.__name__) + return uwfunc(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index ea9e5fb70..9a1fe6cbe 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -40,6 +40,9 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable: func: should be either `.__mul__()`, `.__rmul__()`, or `.__truediv__()`. + Returns: + Wrapped version of `func`. + Raises: TypeError: If a binop with the form `binop(Operator, other)` is called and `other` is not a scalar.