Skip to content

Commit

Permalink
Clean up CT examples (#563)
Browse files Browse the repository at this point in the history
* Improve compatibility between companion scripts

* Clean up

* Docstring fix

* Cosmetic improvements

* Fix angular sampling

* Clean up

* Fix angular sampling

* Typo fix

* Update submodule

* Resolve warning when 64 bit float enabled

* Update submodule

---------

Co-authored-by: Brendt Wohlberg <brendt@lanl.gov>
  • Loading branch information
bwohlberg and Brendt Wohlberg authored Oct 24, 2024
1 parent 31c20a3 commit f8fd48e
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_3d_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
C = XRayTransform3D(
tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles
) # CT projection operator
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_3d_tv_padmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_spacing = [1.0, 1.0]
det_count = [Nz, max(Nx, Ny)]
vectors = angle_to_vector(det_spacing, angles)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_noreg_pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
Configure a CT projection operator and generate synthetic measurements.
"""
n_projection = N # matches the phantom size so this is not few-view CT
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y = A @ x_gt # sinogram

Expand Down
13 changes: 6 additions & 7 deletions examples/scripts/ct_astra_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,22 @@
"""
N = 512 # phantom size
np.random.seed(1234)
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt) # convert to jax type
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))


"""
Configure CT projection operator and generate synthetic measurements.
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)
A = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles) # CT projection operator
y = A @ x_gt # sinogram


"""
Set up ADMM solver object.
Set up problem functional and ADMM solver object.
"""
λ = 2e0 # ℓ1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
Expand All @@ -65,9 +66,7 @@
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()

f = loss.SquaredL2Loss(y=y, A=A)

x0 = snp.clip(A.fbp(y), 0, 1.0)

solver = ADMM(
Expand Down
10 changes: 2 additions & 8 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Create a ground truth image.
"""
N = 512 # phantom size

np.random.seed(0)
x_gt = discrete_phantom(Soil(porosity=0.80), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
Expand All @@ -49,8 +48,7 @@
n_projection = 360 # number of projections
Io = 1e3 # source flux
𝛼 = 1e-2 # attenuation coefficient

angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y_c = A @ x_gt # sinogram

Expand Down Expand Up @@ -99,13 +97,10 @@ def postprocess(x):
# shown here).
ρ = 2.5e3 # ADMM penalty parameter
lambda_unweighted = 3e2 # regularization strength

maxiter = 100 # number of ADMM iterations
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration

f = loss.SquaredL2Loss(y=y, A=A)

admm_unweighted = ADMM(
f=f,
g_list=[lambda_unweighted * functional.L21Norm()],
Expand Down Expand Up @@ -137,10 +132,8 @@ def postprocess(x):
$I_0$ changes.
"""
lambda_weighted = 5e1

weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))

admm_weighted = ADMM(
f=f,
g_list=[lambda_weighted * functional.L21Norm()],
Expand All @@ -151,6 +144,7 @@ def postprocess(x):
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print()
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
angles=angles,
Expand Down
7 changes: 3 additions & 4 deletions examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)


"""
Define CT geometry and construct array of (approximately) equivalent projectors.
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)
projectors = {
"astra": astra.XRayTransform2D(
x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
angles=angles,
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/ct_svmbir_tv_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""
r"""
TV-Regularized CT Reconstruction (Multiple Algorithms)
======================================================
Expand Down Expand Up @@ -51,7 +51,7 @@
"""
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt

Expand Down Expand Up @@ -87,12 +87,9 @@
"""
x0 = snp.array(x_mrf)
weights = snp.array(weights)

λ = 1e-1 # ℓ1 norm regularization parameter

f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g = λ * functional.L21Norm() # regularization functional

# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
Expand All @@ -112,6 +109,7 @@
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
print("ADMM:")
x_admm = solve_admm.solve()
hist_admm = solve_admm.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n")
Expand All @@ -130,6 +128,7 @@
maxiter=50,
itstat_options={"display": True, "period": 10},
)
print("Linearized ADMM:")
x_ladmm = solver_ladmm.solve()
hist_ladmm = solver_ladmm.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n")
Expand All @@ -148,6 +147,7 @@
maxiter=50,
itstat_options={"display": True, "period": 10},
)
print("PDHG:")
x_pdhg = solver_pdhg.solve()
hist_pdhg = solver_pdhg.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n")
Expand Down
32 changes: 21 additions & 11 deletions examples/scripts/ct_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@
Configure CT projection operator and generate synthetic measurements.
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
A = XRayTransform2D((N, N), angles) # CT projection operator
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_count = int(N * 1.05 / np.sqrt(2.0))
dx = 1.0 / np.sqrt(2)
A = XRayTransform2D(
(N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx
) # CT projection operator
y = A @ x_gt # sinogram


"""
Set up ADMM solver object.
Set up problem functional and ADMM solver object.
"""
λ = 2e0 # L1 norm regularization parameter
λ = 2e0 # ℓ1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
Expand All @@ -64,10 +68,8 @@
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()

f = loss.SquaredL2Loss(y=y, A=A)

x0 = snp.clip(A.T(y), 0, 1.0)
x0 = snp.clip(A.fbp(y), 0, 1.0)

solver = ADMM(
f=f,
Expand All @@ -94,18 +96,26 @@
Show the recovered image.
"""

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5))
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
x0,
title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(x_gt, x0), metric.mae(x_gt, x0)),
cbar=None,
fig=fig,
ax=ax[1],
)
plot.imview(
x_reconstruction,
title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),
fig=fig,
ax=ax[1],
ax=ax[2],
)
divider = make_axes_locatable(ax[1])
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()


Expand Down
4 changes: 1 addition & 3 deletions scico/test/flax/test_inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def setup_method(self, method):
self.nproj = 60 # number of projections
angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32)
self.opCT = XRayTransform2D(
input_shape=(self.N, self.N),
det_count=self.N,
angles=angles,
input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0)
) # Radon transform operator
a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze()))
y = lax.map(a_f, xt)
Expand Down

0 comments on commit f8fd48e

Please sign in to comment.