Skip to content

Commit

Permalink
Fix typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jun 27, 2024
1 parent 7095558 commit 08f2a63
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions scico/linop/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import numpy as np

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.typing import DType, Shape
from scico.typing import BlockShape, DType, Shape

from ._linop import LinearOperator


def diffstack(x, axis=None):
def diffstack(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
"""Compute the discrete difference along multiple axes.
Apply :func:`snp.diff` along multiple axes, stacking the results on
Expand Down Expand Up @@ -81,8 +81,8 @@ class ProjectedGradient(LinearOperator):
def __init__(
self,
input_shape: Shape,
axes: Optional[Tuple[int]] = None,
coord: Optional[Tuple[Union[Array, BlockArray]]] = None,
axes: Optional[Tuple[int, ...]] = None,
coord: Optional[Sequence[Union[Array, BlockArray]]] = None,
cdiff: bool = False,
input_dtype: DType = np.float32,
jit: bool = True,
Expand Down Expand Up @@ -133,6 +133,7 @@ def __init__(
f"len(input_shape)={len(input_shape)}."
)
self.axes = axes
output_shape: Union[Shape, BlockShape]
if coord is None:
# If coord is None, output shape is determined by number of axes.
if len(self.axes) == 1:
Expand Down Expand Up @@ -199,8 +200,8 @@ class PolarGradient(ProjectedGradient):
def __init__(
self,
input_shape: Shape,
axes: Optional[Tuple[int]] = None,
center: Optional[Union[Tuple[int], Array]] = None,
axes: Optional[Tuple[int, ...]] = None,
center: Optional[Union[Tuple[int, ...], Array]] = None,
angular: bool = True,
radial: bool = True,
cdiff: bool = False,
Expand Down Expand Up @@ -246,6 +247,8 @@ def __init__(
if center is None:
center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2
else:
if isinstance(center, (tuple, list)):
center = snp.array(center)
center = center.astype(real_input_dtype)
end = snp.array(axes_shape, dtype=real_input_dtype) - center
g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]]
Expand Down Expand Up @@ -295,8 +298,8 @@ class CylindricalGradient(ProjectedGradient):
def __init__(
self,
input_shape: Shape,
axes: Optional[Tuple[int]] = None,
center: Optional[Union[Tuple[int], Array]] = None,
axes: Optional[Tuple[int, ...]] = None,
center: Optional[Union[Tuple[int, ...], Array]] = None,
angular: bool = True,
radial: bool = True,
axial: bool = True,
Expand Down Expand Up @@ -349,8 +352,10 @@ def __init__(
axes_shape = [input_shape[ax] for ax in axes]
if center is None:
center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2
center = center.at[-1].set(0)
center = center.at[-1].set(0) # type: ignore
else:
if isinstance(center, (tuple, list)):
center = snp.array(center)
center = center.astype(real_input_dtype)
end = snp.array(axes_shape, dtype=real_input_dtype) - center
g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]]
Expand Down Expand Up @@ -420,8 +425,8 @@ class SphericalGradient(ProjectedGradient):
def __init__(
self,
input_shape: Shape,
axes: Optional[Tuple[int]] = None,
center: Optional[Union[Tuple[int], Array]] = None,
axes: Optional[Tuple[int, ...]] = None,
center: Optional[Union[Tuple[int, ...], Array]] = None,
azimuthal: bool = True,
polar: bool = True,
radial: bool = True,
Expand Down Expand Up @@ -474,6 +479,8 @@ def __init__(
if center is None:
center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2
else:
if isinstance(center, (tuple, list)):
center = snp.array(center)
center = center.astype(real_input_dtype)
end = snp.array(axes_shape, dtype=real_input_dtype) - center
g0, g1, g2 = snp.ogrid[-center[0] : end[0], -center[1] : end[1], -center[2] : end[2]]
Expand Down

0 comments on commit 08f2a63

Please sign in to comment.