Skip to content

Commit

Permalink
MBIRJAX v0.5.0 (#44)
Browse files Browse the repository at this point in the history
* Update batch size.

* Remove unused import.

* Improve the memory estimation.

* Reduce default number of iterations.

* Change parallel projectors to match conebeam horizontal projectors.

* Update to match prerelease.

* Update version number.

* Improve preprocssing docs.

* Refactor to loop over subsets in a partition.

* Remove unneeded block_until_ready.

* Minor updates.

* Simplify subset update.

* Simplify granularity sequence.

* Refactor auto rgularization for efficiency and to do the calculations on the CPU to avoid creating large sinogram-shaped arrays on the GPU.

* Convert to do GPU-CPU transfer within the subset update.

* Implement working version of GPU-CPU transfer in the subset updater.

* Implement adaptive choice of GPU-CPU transfer and view batch size.

* Update to use interface with new GPU-CPU transfer recon.

* Convert to do GPU-CPU transfer within the subset update.

* Remove blur model.

* Improve comments and output in Shepp-Logan demo.

* Fix bug in auto-regularization notification.

* Improve demo 1 and add demo 2.

* Improve demos 1 and 2 and add demo 3.

* Improve demos 1 and 2 and add demo 3.

* Improve demos 1 and add demo 4.

* Update version number.

* Simplify the specification of qggmrf neighbor weights.

* Update the import for the slice viewer.

* Update the import for the slice viewer.

* Update the demo to match the notebook.

* Update the demo to match the notebook.

* Update the demos to match the notebook.

* Improve and update docs.

* Improve and update docs.

* Improve and update docs.

* Improve and update docs.

* Improve imports for preprocessing.

* Correct docs.

* Update docs.

---------

Co-authored-by: Greg Buzzard <buzzard@purdue.edu>
Co-authored-by: Diyu Yang <yang1467@gilbreth-fe01.rcac.purdue.edu>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent 7174f1d commit 0235f7d
Show file tree
Hide file tree
Showing 37 changed files with 1,365 additions and 987 deletions.
138 changes: 138 additions & 0 deletions demo/demo_1_shepp_logan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-
"""demo_1_shepp_logan.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC
**MBIRJAX: Basic Demo**
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details.
This script demonstrates the basic MBIRJAX code by creating a 3D phantom inspired by Shepp-Logan, forward projecting it to create a sinogram, and then using MBIRJAX to perform a Model-Based, Multi-Granular Vectorized Coordinate Descent reconstruction.
For the demo, we create some synthetic data by first making a phantom, then forward projecting it to obtain a sinogram.
In a real application, you would load your sinogram as a numpy array and use numpy.transpose if needed so that it
has axes in the order (views, rows, channels). For reference, assuming the rotation axis is vertical, then increasing the row index nominally moves down the rotation axis and increasing the channel index moves to the right as seen from the source.
Select a GPU as runtime type for best performance.
"""

# Commented out IPython magic to ensure Python compatibility.
# %pip install mbirjax

import numpy as np
import time
import pprint
import jax.numpy as jnp
import mbirjax

"""**Set the geometry parameters**"""

# Choose the geometry type
geometry_type = 'parallel' # 'cone' or 'parallel'

# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to
# channels, then the generated phantom may not have an interior.
num_views = 64
num_det_rows = 40
num_det_channels = 128

# For cone beam geometry, we need to describe the distances source to detector and source to rotation axis.
# np.Inf is an allowable value, in which case this is essentially parallel beam
source_detector_dist = 4 * num_det_channels
source_iso_dist = source_detector_dist

# For cone beam reconstruction, we need a little more than 180 degrees for full coverage.
if geometry_type == 'cone':
detector_cone_angle = 2 * np.arctan2(num_det_channels / 2, source_detector_dist)
else:
detector_cone_angle = 0
start_angle = -(np.pi + detector_cone_angle) * (1/2)
end_angle = (np.pi + detector_cone_angle) * (1/2)

"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram.
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python.
"""

# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

if geometry_type == 'cone':
ct_model_for_generation = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
elif geometry_type == 'parallel':
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles)
else:
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))

# Generate 3D Shepp Logan phantom
print('Creating phantom')
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom()

# Generate synthetic sinogram data
print('Creating sinogram')
sinogram = ct_model_for_generation.forward_project(phantom)
sinogram = np.array(sinogram)

# View sinogram
title = 'Original sinogram \nUse the sliders to change the view or adjust the intensity range.'
mbirjax.slice_viewer(sinogram, slice_axis=0, title=title, slice_label='View')

"""**Initialize for the reconstruction**"""

# ####################
# Initialize the model for reconstruction.
if geometry_type == 'cone':
ct_model_for_recon = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
else:
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Generate weights array - for an initial reconstruction, use weights = None, then modify if needed.
weights = None
# weights = ct_model_for_recon.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root')

# Set reconstruction parameter values
# Increase sharpness by 1 or 2 to get clearer edges, possibly with more high-frequency artifacts.
# Decrease by 1 or 2 to get softer edges and smoother interiors.
sharpness = 0.0
ct_model_for_recon.set_params(sharpness=sharpness)

# Print out model parameters
ct_model_for_recon.print_params()

"""**Do the reconstruction and display the results.**"""

# ##########################
# Perform VCD reconstruction
print('Starting recon')
time0 = time.time()
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights)

recon.block_until_ready()
elapsed = time.time() - time0
# ##########################

# Print parameters used in recon
pprint.pprint(recon_params._asdict(), compact=True)

max_diff = np.amax(np.abs(phantom - recon))
print('Geometry = {}'.format(geometry_type))
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom)
pct_95 = np.percentile(np.abs(recon - phantom), 95)
print('NRMSE between recon and phantom = {}'.format(nrmse))
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff))
print('95% of recon pixels are within {} of phantom'.format(pct_95))

mbirjax.get_memory_stats()
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))

# Display results
title = 'Phantom (left) vs VCD Recon (right) \nUse the sliders to change the slice or adjust the intensity range.'
mbirjax.slice_viewer(phantom, recon, title=title)

"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/examples.html). """
169 changes: 169 additions & 0 deletions demo/demo_2_large_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""demo_2_large_object.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1-kk_HeR8Y8f6pZ2zjTza8NTEpAgwgVRB
**MBIRJAX: Large Object Demo**
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details.
This script demonstrates how to improve the reconstruction when the object does not project completely inside the detector. For simplicity, we show this only for parallel beam, but the same steps apply for cone beam.
See [demo_1_shepp_logan.py](https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC) for the basic steps of synthetic sinogram generation and reconstruction.
Select a GPU as runtime type for best performance.
"""

# Commented out IPython magic to ensure Python compatibility.
# %pip install mbirjax

import numpy as np
import time
import pprint
import jax.numpy as jnp
import mbirjax

"""**Set the geometry parameters**"""

# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to
# channels, then the generated phantom may not have an interior.
num_views = 120
num_det_rows = 80
num_det_channels = 100

start_angle = - np.pi / 2
end_angle = np.pi / 2

"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram.
The default recon shape for parallel beam is
(rows, columns, slices) = (num_det_channels, num_det_channels, num_det_rows),
where we assume that the recon voxels are cubes and have the same size as the detector elements.
Here we generate a phantom that is bigger than the detector to show how to deal with this case.
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python.
"""

# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Generate large 3D Shepp Logan phantom
print('Creating phantom that projects partially outside the detector')
phantom_row_scale = 1.0
phantom_col_scale = 1.75
phantom_rows = int(num_det_channels * phantom_row_scale)
phantom_cols = int(num_det_channels * phantom_col_scale)
phantom_slices = num_det_rows
phantom_shape = (phantom_rows, phantom_cols, phantom_slices)
phantom = mbirjax.generate_3d_shepp_logan_low_dynamic_range(phantom_shape)

# Generate synthetic sinogram data
print('Creating sinogram')
ct_model_for_generation.set_params(recon_shape=phantom_shape)
sinogram = ct_model_for_generation.forward_project(phantom)
sinogram = np.array(sinogram)

# View sinogram
mbirjax.slice_viewer(sinogram, title='Original sinogram\nChange view to see projections in and outside detector',
slice_axis=0, slice_label='View')

"""**Do a baseline reconstruction**
First we do a reconstruction with the default settings. This will have significant artifacts because some of the information in the sinogram comes from voxels that are projected to the detector on only some of the views. In the default reconstruction, all of the voxels in the reconstruction project to the detector in all of the views, so there is only way to account for these partially projected voxels leads to a corrupted reconstruciton.
"""

# Initialize model for default reconstruction.
weights = None
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Print model parameters
ct_model_for_recon.print_params()

# Default VCD reconstruction
print('Starting default recon - will have significant artifacts because of the missing projections.\n')
time0 = time.time()
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights)
recon.block_until_ready()
elapsed = time.time() - time0

# Print out parameters used in recon
pprint.pprint(recon_params._asdict(), compact=True)
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))

"""**Display the default reconstruction.**"""

title = 'Default recon: Phantom (left) vs VCD Recon (right)'
title += '\nAdjust intensity range to [0, 1] to see internal artifacts from projection outside detector.'
title += '\nAdjust intensity range to [1.5, 2] to see outer ring from projection outside detector.'
mbirjax.slice_viewer(phantom, recon, title=title, vmin=0.0, vmax=2.0)

"""**Decrease sharpness to reduce artifacts.**
One way to reduce the artifacts seen above is to decrease sharpness, which promotes smoother images. This does reduce artifacts but also blurs the edges.
Below we show that we can pad the recon to reduce artifacts without blurring the edges.
"""

# Increased regularization VCD reconstruction
# We can reduce the artifacts by increasing regularization (decreasing sharpness).
sharpness = -1.5
ct_model_for_recon.set_params(sharpness=sharpness)
print('\nStarting recon with reduced sharpness - will have reduced artifacts but blurred edges.\n')
recon_smooth, recon_params_smooth = ct_model_for_recon.recon(sinogram, weights=weights)

# Print out parameters used in recon
pprint.pprint(recon_params_smooth._asdict(), compact=True)

# Display results
title = 'Recon with sharpness = {:.1f}: Phantom (left) vs VCD Recon (right)'.format(sharpness)
title += '\nAdjust intensity range to [0, 1] to see reduced internal artifacts from projection outside detector.'
title += '\nOuter ring is still evident in intensity range [1, 2], and edges are blurry.'
mbirjax.slice_viewer(phantom, recon_smooth, title=title, vmin=0.0, vmax=2.0)

"""**Padded recon VCD reconstruction**
Alternatively, we can pad the recon to allow for a partial reconstruction of the pixels that project outside the detector in some views. This reduces the artifacts without blurring edges and greatly reduces the outer ring seen in the non-padded recon.
Note that the enlarged recon doesn't have to match the phantom size. Increasing the recon size won't allow us to fully reconstruct the pixels that sometimes project outside the detector. However, it will provide
room for those partial projections to be absorbed into partial projected pixels, which allows for better reconstruction of the pixels with full projections.
"""

# Increase the recon size. In this case, we increase just the columns to
# approximate the phantom shape. Note that it doesn't have to be an exact match.
recon_row_scale = 1.0
recon_col_scale = 1.5
# Version 0.4.X:
(num_rows, num_cols, num_slices) = ct_model_for_recon.get_params('recon_shape')
new_shape = (int(num_rows * recon_row_scale), int(num_cols * recon_col_scale), num_slices)
ct_model_for_recon.set_params(recon_shape=new_shape)
# Version 0.5.0:
# ct_model_for_recon.scale_recon_shape(row_scale=recon_row_scale, col_scale=recon_col_scale)

# Reset the default sharpness
sharpness = 0
ct_model_for_recon.set_params(sharpness=sharpness)

print('\nStarting enlarged recon - will have reduced artifacts, sharper edges, some extra pixel estimation.\n')
recon_enlarged, recon_params_enlarged = ct_model_for_recon.recon(sinogram, weights=weights)

# Print out parameters used in recon
pprint.pprint(recon_params_enlarged._asdict(), compact=True)

"""**Display the result using the enlarged reconstruction.**"""

title = 'Padded recon with sharpness = {:.1f}: Phantom (left) vs VCD Recon (right)'.format(sharpness)
title += '\nPadding the recon reduces the internal artifacts even with default sharpness.'
title += '\nEdges are sharp, outer ring is mostly gone, and the partially projected pixels are partially recovered.'
mbirjax.slice_viewer(phantom, recon_enlarged, title=title, vmin=0.0, vmax=2.0)

"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/examples.html). """
Loading

0 comments on commit 0235f7d

Please sign in to comment.