From f8fd48ee40f5c3eea57e7e82d5c4a7a73584128d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 23 Oct 2024 20:23:07 -0600 Subject: [PATCH] Clean up CT examples (#563) * Improve compatibility between companion scripts * Clean up * Docstring fix * Cosmetic improvements * Fix angular sampling * Clean up * Fix angular sampling * Typo fix * Update submodule * Resolve warning when 64 bit float enabled * Update submodule --------- Co-authored-by: Brendt Wohlberg --- data | 2 +- examples/scripts/ct_astra_3d_tv_admm.py | 2 +- examples/scripts/ct_astra_3d_tv_padmm.py | 2 +- examples/scripts/ct_astra_noreg_pcg.py | 2 +- examples/scripts/ct_astra_tv_admm.py | 13 ++++---- examples/scripts/ct_astra_weighted_tv_admm.py | 10 ++---- examples/scripts/ct_modl_train_foam2.py | 2 +- examples/scripts/ct_multi_tv_admm.py | 7 ++-- examples/scripts/ct_odp_train_foam2.py | 2 +- examples/scripts/ct_svmbir_tv_multi.py | 10 +++--- examples/scripts/ct_tv_admm.py | 32 ++++++++++++------- scico/test/flax/test_inv.py | 4 +-- 12 files changed, 44 insertions(+), 44 deletions(-) diff --git a/data b/data index 1ceadbbbe..b186bddd1 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 1ceadbbbe6bef9f364fd76dbda44ff6e185a7d10 +Subproject commit b186bddd170ded03be04e7921f5d86d24c92c54f diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py index b3576fda2..9c462cd05 100644 --- a/examples/scripts/ct_astra_3d_tv_admm.py +++ b/examples/scripts/ct_astra_3d_tv_admm.py @@ -44,7 +44,7 @@ tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles C = XRayTransform3D( tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles ) # CT projection operator diff --git a/examples/scripts/ct_astra_3d_tv_padmm.py b/examples/scripts/ct_astra_3d_tv_padmm.py index c6c090075..ed54247ae 100644 --- a/examples/scripts/ct_astra_3d_tv_padmm.py +++ b/examples/scripts/ct_astra_3d_tv_padmm.py @@ -44,7 +44,7 @@ tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_spacing = [1.0, 1.0] det_count = [Nz, max(Nx, Ny)] vectors = angle_to_vector(det_spacing, angles) diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py index 362c8cc3c..a9dab9657 100644 --- a/examples/scripts/ct_astra_noreg_pcg.py +++ b/examples/scripts/ct_astra_noreg_pcg.py @@ -44,7 +44,7 @@ Configure a CT projection operator and generate synthetic measurements. """ n_projection = N # matches the phantom size so this is not few-view CT -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index 5349311c6..fd684e3d7 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -38,21 +38,22 @@ """ N = 512 # phantom size np.random.seed(1234) -x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = snp.array(x_gt) # convert to jax type +x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +det_spacing = np.sqrt(2) +A = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles) # CT projection operator y = A @ x_gt # sinogram """ -Set up ADMM solver object. +Set up problem functional and ADMM solver object. """ λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter @@ -65,9 +66,7 @@ # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() - f = loss.SquaredL2Loss(y=y, A=A) - x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index b3c285cb9..eb7493733 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -35,7 +35,6 @@ Create a ground truth image. """ N = 512 # phantom size - np.random.seed(0) x_gt = discrete_phantom(Soil(porosity=0.80), size=384) x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64))) @@ -49,8 +48,7 @@ n_projection = 360 # number of projections Io = 1e3 # source flux 𝛼 = 1e-2 # attenuation coefficient - -angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y_c = A @ x_gt # sinogram @@ -99,13 +97,10 @@ def postprocess(x): # shown here). ρ = 2.5e3 # ADMM penalty parameter lambda_unweighted = 3e2 # regularization strength - maxiter = 100 # number of ADMM iterations cg_tol = 1e-5 # CG relative tolerance cg_maxiter = 10 # maximum CG iterations per ADMM iteration - f = loss.SquaredL2Loss(y=y, A=A) - admm_unweighted = ADMM( f=f, g_list=[lambda_unweighted * functional.L21Norm()], @@ -137,10 +132,8 @@ def postprocess(x): $I_0$ changes. """ lambda_weighted = 5e1 - weights = snp.array(counts / Io) f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) - admm_weighted = ADMM( f=f, g_list=[lambda_weighted * functional.L21Norm()], @@ -151,6 +144,7 @@ def postprocess(x): subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) +print() admm_weighted.solve() x_weighted = postprocess(admm_weighted.x) diff --git a/examples/scripts/ct_modl_train_foam2.py b/examples/scripts/ct_modl_train_foam2.py index 19a3d810f..101322141 100644 --- a/examples/scripts/ct_modl_train_foam2.py +++ b/examples/scripts/ct_modl_train_foam2.py @@ -92,7 +92,7 @@ Build CT projection operator. Parameters are chosen so that the operator is equivalent to the one used to generate the training data. """ -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), angles=angles, diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index f2f13fd87..df72e3879 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -38,15 +38,14 @@ np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) -det_count = int(N * 1.05 / np.sqrt(2.0)) -det_spacing = np.sqrt(2) - """ Define CT geometry and construct array of (approximately) equivalent projectors. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +det_spacing = np.sqrt(2) projectors = { "astra": astra.XRayTransform2D( x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0 diff --git a/examples/scripts/ct_odp_train_foam2.py b/examples/scripts/ct_odp_train_foam2.py index e5cd58ae9..cad279bbf 100644 --- a/examples/scripts/ct_odp_train_foam2.py +++ b/examples/scripts/ct_odp_train_foam2.py @@ -95,7 +95,7 @@ Build CT projection operator. Parameters are chosen so that the operator is equivalent to the one used to generate the training data. """ -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), angles=angles, diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 8592b44ff..e152c6ff3 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -4,7 +4,7 @@ # and user license can be found in the 'LICENSE.txt' file distributed # with the package. -""" +r""" TV-Regularized CT Reconstruction (Multiple Algorithms) ====================================================== @@ -51,7 +51,7 @@ """ num_angles = int(N / 2) num_channels = N -angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32) +angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt @@ -87,12 +87,9 @@ """ x0 = snp.array(x_mrf) weights = snp.array(weights) - λ = 1e-1 # ℓ1 norm regularization parameter - f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = λ * functional.L21Norm() # regularization functional - # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) @@ -112,6 +109,7 @@ itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") +print("ADMM:") x_admm = solve_admm.solve() hist_admm = solve_admm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n") @@ -130,6 +128,7 @@ maxiter=50, itstat_options={"display": True, "period": 10}, ) +print("Linearized ADMM:") x_ladmm = solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n") @@ -148,6 +147,7 @@ maxiter=50, itstat_options={"display": True, "period": 10}, ) +print("PDHG:") x_pdhg = solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n") diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py index ec48d4eaa..c66a802b4 100644 --- a/examples/scripts/ct_tv_admm.py +++ b/examples/scripts/ct_tv_admm.py @@ -45,15 +45,19 @@ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles -A = XRayTransform2D((N, N), angles) # CT projection operator +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +dx = 1.0 / np.sqrt(2) +A = XRayTransform2D( + (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx +) # CT projection operator y = A @ x_gt # sinogram """ -Set up ADMM solver object. +Set up problem functional and ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance @@ -64,10 +68,8 @@ # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() - f = loss.SquaredL2Loss(y=y, A=A) - -x0 = snp.clip(A.T(y), 0, 1.0) +x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( f=f, @@ -94,18 +96,26 @@ Show the recovered image. """ -fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5)) +fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0]) +plot.imview( + x0, + title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)), + cbar=None, + fig=fig, + ax=ax[1], +) plot.imview( x_reconstruction, title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)), fig=fig, - ax=ax[1], + ax=ax[2], ) -divider = make_axes_locatable(ax[1]) +divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) -fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") +fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py index d49df64bf..03c43736a 100644 --- a/scico/test/flax/test_inv.py +++ b/scico/test/flax/test_inv.py @@ -157,9 +157,7 @@ def setup_method(self, method): self.nproj = 60 # number of projections angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32) self.opCT = XRayTransform2D( - input_shape=(self.N, self.N), - det_count=self.N, - angles=angles, + input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0) ) # Radon transform operator a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze())) y = lax.map(a_f, xt)