Skip to content

Commit

Permalink
Add filtered back projection for 2D projector (#558)
Browse files Browse the repository at this point in the history
* Add filtered back projection for 2D projector

* Update change summary

* Docstring fixes

* Resolve errors in jitting method

* Some improvements

* Clean up

* Improve tests

* Improve mask mechanism

* Improve docs
  • Loading branch information
bwohlberg authored Oct 18, 2024
1 parent d3193ea commit 0dd2f57
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 8 deletions.
7 changes: 5 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• Rename integrated 2D X-ray transform class to
``linop.xray.XRayTransform2D`` and add filtered back projection method
``fbp``.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
Expand All @@ -17,8 +20,8 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31.
• Support ``flax`` versions 0.8.0 to 0.8.3.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33.
• Support ``flax`` versions 0.8.0 to 0.9.0.



Expand Down
8 changes: 8 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}

@Book {kak-1988-principles,
author = {Avinash C. Kak and Malcolm Slaney},
title = {Principles of Computerized Tomographic Imaging},
publisher = {IEEE Press},
year = 1988
}

@TechReport {kamilov-2016-minimizing,
author = {Ulugbek S. Kamilov},
title = {Minimizing Isotropic Total Variation without
Expand Down Expand Up @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}


@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
Expand Down
80 changes: 74 additions & 6 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def __init__(
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
dx: Image pixel side length in x- and y-direction. Must be
set so that the width of a projected pixel is never
larger than 1.0. By default, [:math:`\sqrt{2}/2`,
:math:`\sqrt{2}/2`].
dx: Image pixel side length in x- and y-direction (axis 0 and
1 respectively). Must be set so that the width of a
projected pixel is never larger than 1.0. By default,
[:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
Expand Down Expand Up @@ -114,6 +114,9 @@ def __init__(
self.y0 = y0
self.dy = 1.0

self.fbp_filter: Optional[snp.Array] = None
self.fbp_mask: Optional[snp.Array] = None

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
Expand All @@ -139,6 +142,71 @@ def back_project(self, y: ArrayLike) -> snp.Array:
"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

def fbp(self, y: ArrayLike) -> snp.Array:
r"""Compute filtered back projection (FBP) inverse of projection.
Compute the filtered back projection inverse by filtering each
row of the sinogram with the filter defined in (61) in
:cite:`kak-1988-principles` and then back projecting. The
projection angles are assumed to be evenly spaced in
:math:`[0, \pi)`; reconstruction quality may be poor if
this assumption is violated. Poor quality reconstructions should
also be expected when `dx[0]` and `dx[1]` are not equal.
Args:
y: Input projection, (num_angles, N).
Returns:
FBP inverse of projection.
"""
N = y.shape[1]

if self.fbp_filter is None:
nvec = jnp.arange(N) - (N - 1) // 2
self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)

if self.fbp_mask is None:
unit_sino = jnp.ones(self.output_shape, dtype=np.float32)
# Threshold is multiplied by 0.99... fudge factor to account for numerical errors
# in back projection.
self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore

# Apply ramp filter in the frequency domain, padding to avoid
# boundary effects
h = self.fbp_filter
hf = jnp.fft.fft(h, n=2 * N - 1, axis=1)
yf = jnp.fft.fft(y, n=2 * N - 1, axis=1)
hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[
:, (N - 1) // 2 : -(N - 1) // 2
].real.astype(jnp.float32)

x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore
return x

@staticmethod
def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array:
"""Compute coefficients of ramp filter used in FBP.
Compute coefficients of ramp filter used in FBP, as defined in
(61) in :cite:`kak-1988-principles`.
Args:
x: Sampling locations at which to compute filter coefficients.
tau: Sampling rate.
Returns:
Spatial-domain coefficients of ramp filter.
"""
# The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0)
# is included to avoid division by zero warnings when x == 1
# since np.where evaluates all values for both True and False
# branches.
return jnp.where(
x == 0,
1.0 / (4.0 * tau**2),
jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0),
)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
Expand Down Expand Up @@ -179,7 +247,7 @@ def _project(
@partial(jax.jit, static_argnames=["nx"])
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
) -> snp.Array:
r"""Compute X-ray back projection.
Args:
Expand Down Expand Up @@ -361,7 +429,7 @@ def _project_single(
return proj

@staticmethod
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike:
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array:
r"""
Args:
proj: Input (set of) projection(s).
Expand Down
32 changes: 32 additions & 0 deletions scico/test/linop/xray/test_xray_2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np

import jax
import jax.numpy as jnp

import pytest

import scico
import scico.linop
from scico.linop.xray import XRayTransform2D
from scico.metric import psnr


@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -81,3 +83,33 @@ def test_matched_adjoint():
angles = np.linspace(0, np.pi, n_projection, endpoint=False)
A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)
assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)


@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)])
@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0])
def test_fbp(dx, det_count_factor):
N = 256
x_gt = np.zeros((N, N), dtype=np.float32)
N4 = N // 4
x_gt[N4:-N4, N4:-N4] = 1.0

det_count = int(det_count_factor * N)
n_proj = 360
angles = np.linspace(0, np.pi, n_proj, endpoint=False)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx)
y = A(x_gt)
x_fbp = A.fbp(y)
assert psnr(x_gt, x_fbp) > 28


def test_fbp_jit():
N = 64
x_gt = np.ones((N, N), dtype=np.float32)

det_count = N
n_proj = 90
angles = np.linspace(0, np.pi, n_proj, endpoint=False)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count)
y = A(x_gt)
fbp = jax.jit(A.fbp)
x_fbp = fbp(y)

0 comments on commit 0dd2f57

Please sign in to comment.