Skip to content

Commit

Permalink
Add filtered back projection for 2D projector
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 5, 2024
1 parent 8dc1a2a commit 4d6c632
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}

@Book {kak-1988-principles,
author = {Avinash C. Kak and Malcolm Slaney},
title = {Principles of Computerized Tomographic Imaging},
publisher = {IEEE Press},
year = 1988
}

@TechReport {kamilov-2016-minimizing,
author = {Ulugbek S. Kamilov},
title = {Minimizing Isotropic Total Variation without
Expand Down Expand Up @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}


@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
Expand Down
63 changes: 63 additions & 0 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,69 @@ def back_project(self, y: ArrayLike) -> snp.Array:
"""Compute X-ray back projection"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

def fbp(self, y: ArrayLike) -> snp.Array:
"""Compute Filter Back Projection inverse of projection.
Compute the Filter Back Projection inverse by filtering each row
of the sinogram with the filter defined in (61) in
:cite:`kak-1988-principles` and then back projecting. The
projection angles are assumed to be evenly spaced: poor results
may be obtained if this assumption is violated.
Args:
y: Input projection, (num_angles, N).
Returns:
Filtered Back Projection inverse of projection.
"""

N = y.shape[1]
nvec = np.arange(N) - (N - 1) // 2
dx = np.sqrt(self.dx[0] * self.dx[1]) # type: ignore
h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx)

# Apply ramp filter in the frequency domain, padding to avoid
# boundary effects
hf = snp.fft.fft(h.reshape(1, -1), n=2 * N - 1, axis=1)
yf = snp.fft.fft(y, n=2 * N - 1, axis=1)
hy = snp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[
:, (N - 1) // 2 : -(N - 1) // 2
].real.astype(snp.float32)

x = (snp.pi / y.shape[0]) * self.back_project(hy)
# Mask out the invalid region of the reconstruction
gi, gj = snp.mgrid[: x.shape[0], : x.shape[1]]
x = snp.where(
snp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2,
x,
0.0,
)
return x

@staticmethod
def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array:
"""Compute coefficients of ramp filter used in FBP.
Compute coefficients of ramp filter used in FBP, as defined in
(61) in :cite:`kak-1988-principles`.
Args:
x: Sampling locations at which to compute filter coefficients.
tau: Sampling rate.
Returns:
Spatial-domain coefficients of ramp filter.
"""
# The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0)
# is included to avoid division by zero warnings when x == 1
# since np.where evaluates all values for both True and False
# branches.
return snp.where(
x == 0,
1.0 / (4.0 * tau**2),
snp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0),
)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
Expand Down
17 changes: 17 additions & 0 deletions scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import scico
from scico.linop.xray import XRayTransform2D, XRayTransform3D
from scico.metric import psnr


@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -71,6 +72,22 @@ def test_apply_adjoint():
assert y.shape[1] == det_count


@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)])
@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0])
def test_fbp(dx, det_count_factor):
N = 256
x_gt = np.zeros((256, 256), dtype=np.float32)
x_gt[64:-64, 64:-64] = 1.0

det_count = int(det_count_factor * N)
n_proj = 360
angles = np.linspace(0, np.pi, n_proj)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx)
y = A(x_gt)
x_fbp = A.fbp(y)
assert psnr(x_gt, x_fbp) > 28


def test_3d_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)
Expand Down

0 comments on commit 4d6c632

Please sign in to comment.