diff --git a/docs/source/references.bib b/docs/source/references.bib index 257f2428..e612e36e 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -396,6 +396,13 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Book {kak-1988-principles, + author = {Avinash C. Kak and Malcolm Slaney}, + title = {Principles of Computerized Tomographic Imaging}, + publisher = {IEEE Press}, + year = 1988 +} + @TechReport {kamilov-2016-minimizing, author = {Ulugbek S. Kamilov}, title = {Minimizing Isotropic Total Variation without @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn pages = {3142--3155} } + @Article {zhang-2021-plug, author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 770bf627..127b4e7e 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -124,6 +124,69 @@ def back_project(self, y: ArrayLike) -> snp.Array: """Compute X-ray back projection""" return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) + def fbp(self, y: ArrayLike) -> snp.Array: + """Compute Filter Back Projection inverse of projection. + + Compute the Filter Back Projection inverse by filtering each row + of the sinogram with the filter defined in (61) in + :cite:`kak-1988-principles` and then back projecting. The + projection angles are assumed to be evenly spaced: poor results + may be obtained if this assumption is violated. + + Args: + y: Input projection, (num_angles, N). + + Returns: + Filtered Back Projection inverse of projection. + """ + + N = y.shape[1] + nvec = np.arange(N) - (N - 1) // 2 + dx = np.sqrt(self.dx[0] * self.dx[1]) # type: ignore + h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx) + + # Apply ramp filter in the frequency domain, padding to avoid + # boundary effects + hf = snp.fft.fft(h.reshape(1, -1), n=2 * N - 1, axis=1) + yf = snp.fft.fft(y, n=2 * N - 1, axis=1) + hy = snp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + :, (N - 1) // 2 : -(N - 1) // 2 + ].real.astype(snp.float32) + + x = (snp.pi / y.shape[0]) * self.back_project(hy) + # Mask out the invalid region of the reconstruction + gi, gj = snp.mgrid[: x.shape[0], : x.shape[1]] + x = snp.where( + snp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, + x, + 0.0, + ) + return x + + @staticmethod + def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: + """Compute coefficients of ramp filter used in FBP. + + Compute coefficients of ramp filter used in FBP, as defined in + (61) in :cite:`kak-1988-principles`. + + Args: + x: Sampling locations at which to compute filter coefficients. + tau: Sampling rate. + + Returns: + Spatial-domain coefficients of ramp filter. + """ + # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0) + # is included to avoid division by zero warnings when x == 1 + # since np.where evaluates all values for both True and False + # branches. + return snp.where( + x == 0, + 1.0 / (4.0 * tau**2), + snp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + ) + @staticmethod @partial(jax.jit, static_argnames=["ny"]) def _project( diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index cd7c0dcd..4aab2e92 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -6,6 +6,7 @@ import scico from scico.linop.xray import XRayTransform2D, XRayTransform3D +from scico.metric import psnr @pytest.mark.filterwarnings("error") @@ -71,6 +72,22 @@ def test_apply_adjoint(): assert y.shape[1] == det_count +@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)]) +@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) +def test_fbp(dx, det_count_factor): + N = 256 + x_gt = np.zeros((256, 256), dtype=np.float32) + x_gt[64:-64, 64:-64] = 1.0 + + det_count = int(det_count_factor * N) + n_proj = 360 + angles = np.linspace(0, np.pi, n_proj) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) + y = A(x_gt) + x_fbp = A.fbp(y) + assert psnr(x_gt, x_fbp) > 28 + + def test_3d_scaling(): x = jnp.zeros((4, 4, 1)) x = x.at[1:3, 1:3, 0].set(1.0)