-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update batch size. * Remove unused import. * Improve the memory estimation. * Reduce default number of iterations. * Change parallel projectors to match conebeam horizontal projectors. * Update to match prerelease. * Update version number. * Improve preprocssing docs. * Refactor to loop over subsets in a partition. * Remove unneeded block_until_ready. * Minor updates. * Simplify subset update. * Simplify granularity sequence. * Refactor auto rgularization for efficiency and to do the calculations on the CPU to avoid creating large sinogram-shaped arrays on the GPU. * Convert to do GPU-CPU transfer within the subset update. * Implement working version of GPU-CPU transfer in the subset updater. * Implement adaptive choice of GPU-CPU transfer and view batch size. * Update to use interface with new GPU-CPU transfer recon. * Convert to do GPU-CPU transfer within the subset update. * Remove blur model. * Improve comments and output in Shepp-Logan demo. * Fix bug in auto-regularization notification. * Improve demo 1 and add demo 2. * Improve demos 1 and 2 and add demo 3. * Improve demos 1 and 2 and add demo 3. * Improve demos 1 and add demo 4. * Update version number. * Simplify the specification of qggmrf neighbor weights. * Update the import for the slice viewer. * Update the import for the slice viewer. * Update the demo to match the notebook. * Update the demo to match the notebook. * Update the demos to match the notebook. * Improve and update docs. * Improve and update docs. * Improve and update docs. * Improve and update docs. * Improve imports for preprocessing. * Correct docs. * Update docs. --------- Co-authored-by: Greg Buzzard <buzzard@purdue.edu> Co-authored-by: Diyu Yang <yang1467@gilbreth-fe01.rcac.purdue.edu>
- Loading branch information
1 parent
7174f1d
commit 0235f7d
Showing
37 changed files
with
1,365 additions
and
987 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# -*- coding: utf-8 -*- | ||
"""demo_1_shepp_logan.ipynb | ||
Automatically generated by Colab. | ||
Original file is located at | ||
https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC | ||
**MBIRJAX: Basic Demo** | ||
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details. | ||
This script demonstrates the basic MBIRJAX code by creating a 3D phantom inspired by Shepp-Logan, forward projecting it to create a sinogram, and then using MBIRJAX to perform a Model-Based, Multi-Granular Vectorized Coordinate Descent reconstruction. | ||
For the demo, we create some synthetic data by first making a phantom, then forward projecting it to obtain a sinogram. | ||
In a real application, you would load your sinogram as a numpy array and use numpy.transpose if needed so that it | ||
has axes in the order (views, rows, channels). For reference, assuming the rotation axis is vertical, then increasing the row index nominally moves down the rotation axis and increasing the channel index moves to the right as seen from the source. | ||
Select a GPU as runtime type for best performance. | ||
""" | ||
|
||
# Commented out IPython magic to ensure Python compatibility. | ||
# %pip install mbirjax | ||
|
||
import numpy as np | ||
import time | ||
import pprint | ||
import jax.numpy as jnp | ||
import mbirjax | ||
|
||
"""**Set the geometry parameters**""" | ||
|
||
# Choose the geometry type | ||
geometry_type = 'parallel' # 'cone' or 'parallel' | ||
|
||
# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to | ||
# channels, then the generated phantom may not have an interior. | ||
num_views = 64 | ||
num_det_rows = 40 | ||
num_det_channels = 128 | ||
|
||
# For cone beam geometry, we need to describe the distances source to detector and source to rotation axis. | ||
# np.Inf is an allowable value, in which case this is essentially parallel beam | ||
source_detector_dist = 4 * num_det_channels | ||
source_iso_dist = source_detector_dist | ||
|
||
# For cone beam reconstruction, we need a little more than 180 degrees for full coverage. | ||
if geometry_type == 'cone': | ||
detector_cone_angle = 2 * np.arctan2(num_det_channels / 2, source_detector_dist) | ||
else: | ||
detector_cone_angle = 0 | ||
start_angle = -(np.pi + detector_cone_angle) * (1/2) | ||
end_angle = (np.pi + detector_cone_angle) * (1/2) | ||
|
||
"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram. | ||
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python. | ||
""" | ||
|
||
# Initialize sinogram | ||
sinogram_shape = (num_views, num_det_rows, num_det_channels) | ||
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False) | ||
|
||
if geometry_type == 'cone': | ||
ct_model_for_generation = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist) | ||
elif geometry_type == 'parallel': | ||
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
else: | ||
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type)) | ||
|
||
# Generate 3D Shepp Logan phantom | ||
print('Creating phantom') | ||
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom() | ||
|
||
# Generate synthetic sinogram data | ||
print('Creating sinogram') | ||
sinogram = ct_model_for_generation.forward_project(phantom) | ||
sinogram = np.array(sinogram) | ||
|
||
# View sinogram | ||
title = 'Original sinogram \nUse the sliders to change the view or adjust the intensity range.' | ||
mbirjax.slice_viewer(sinogram, slice_axis=0, title=title, slice_label='View') | ||
|
||
"""**Initialize for the reconstruction**""" | ||
|
||
# #################### | ||
# Initialize the model for reconstruction. | ||
if geometry_type == 'cone': | ||
ct_model_for_recon = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist) | ||
else: | ||
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
|
||
# Generate weights array - for an initial reconstruction, use weights = None, then modify if needed. | ||
weights = None | ||
# weights = ct_model_for_recon.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root') | ||
|
||
# Set reconstruction parameter values | ||
# Increase sharpness by 1 or 2 to get clearer edges, possibly with more high-frequency artifacts. | ||
# Decrease by 1 or 2 to get softer edges and smoother interiors. | ||
sharpness = 0.0 | ||
ct_model_for_recon.set_params(sharpness=sharpness) | ||
|
||
# Print out model parameters | ||
ct_model_for_recon.print_params() | ||
|
||
"""**Do the reconstruction and display the results.**""" | ||
|
||
# ########################## | ||
# Perform VCD reconstruction | ||
print('Starting recon') | ||
time0 = time.time() | ||
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights) | ||
|
||
recon.block_until_ready() | ||
elapsed = time.time() - time0 | ||
# ########################## | ||
|
||
# Print parameters used in recon | ||
pprint.pprint(recon_params._asdict(), compact=True) | ||
|
||
max_diff = np.amax(np.abs(phantom - recon)) | ||
print('Geometry = {}'.format(geometry_type)) | ||
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom) | ||
pct_95 = np.percentile(np.abs(recon - phantom), 95) | ||
print('NRMSE between recon and phantom = {}'.format(nrmse)) | ||
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff)) | ||
print('95% of recon pixels are within {} of phantom'.format(pct_95)) | ||
|
||
mbirjax.get_memory_stats() | ||
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed)) | ||
|
||
# Display results | ||
title = 'Phantom (left) vs VCD Recon (right) \nUse the sliders to change the slice or adjust the intensity range.' | ||
mbirjax.slice_viewer(phantom, recon, title=title) | ||
|
||
"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/examples.html). """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# -*- coding: utf-8 -*- | ||
"""demo_2_large_object.ipynb | ||
Automatically generated by Colab. | ||
Original file is located at | ||
https://colab.research.google.com/drive/1-kk_HeR8Y8f6pZ2zjTza8NTEpAgwgVRB | ||
**MBIRJAX: Large Object Demo** | ||
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details. | ||
This script demonstrates how to improve the reconstruction when the object does not project completely inside the detector. For simplicity, we show this only for parallel beam, but the same steps apply for cone beam. | ||
See [demo_1_shepp_logan.py](https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC) for the basic steps of synthetic sinogram generation and reconstruction. | ||
Select a GPU as runtime type for best performance. | ||
""" | ||
|
||
# Commented out IPython magic to ensure Python compatibility. | ||
# %pip install mbirjax | ||
|
||
import numpy as np | ||
import time | ||
import pprint | ||
import jax.numpy as jnp | ||
import mbirjax | ||
|
||
"""**Set the geometry parameters**""" | ||
|
||
# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to | ||
# channels, then the generated phantom may not have an interior. | ||
num_views = 120 | ||
num_det_rows = 80 | ||
num_det_channels = 100 | ||
|
||
start_angle = - np.pi / 2 | ||
end_angle = np.pi / 2 | ||
|
||
"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram. | ||
The default recon shape for parallel beam is | ||
(rows, columns, slices) = (num_det_channels, num_det_channels, num_det_rows), | ||
where we assume that the recon voxels are cubes and have the same size as the detector elements. | ||
Here we generate a phantom that is bigger than the detector to show how to deal with this case. | ||
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python. | ||
""" | ||
|
||
# Initialize sinogram | ||
sinogram_shape = (num_views, num_det_rows, num_det_channels) | ||
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False) | ||
|
||
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
|
||
# Generate large 3D Shepp Logan phantom | ||
print('Creating phantom that projects partially outside the detector') | ||
phantom_row_scale = 1.0 | ||
phantom_col_scale = 1.75 | ||
phantom_rows = int(num_det_channels * phantom_row_scale) | ||
phantom_cols = int(num_det_channels * phantom_col_scale) | ||
phantom_slices = num_det_rows | ||
phantom_shape = (phantom_rows, phantom_cols, phantom_slices) | ||
phantom = mbirjax.generate_3d_shepp_logan_low_dynamic_range(phantom_shape) | ||
|
||
# Generate synthetic sinogram data | ||
print('Creating sinogram') | ||
ct_model_for_generation.set_params(recon_shape=phantom_shape) | ||
sinogram = ct_model_for_generation.forward_project(phantom) | ||
sinogram = np.array(sinogram) | ||
|
||
# View sinogram | ||
mbirjax.slice_viewer(sinogram, title='Original sinogram\nChange view to see projections in and outside detector', | ||
slice_axis=0, slice_label='View') | ||
|
||
"""**Do a baseline reconstruction** | ||
First we do a reconstruction with the default settings. This will have significant artifacts because some of the information in the sinogram comes from voxels that are projected to the detector on only some of the views. In the default reconstruction, all of the voxels in the reconstruction project to the detector in all of the views, so there is only way to account for these partially projected voxels leads to a corrupted reconstruciton. | ||
""" | ||
|
||
# Initialize model for default reconstruction. | ||
weights = None | ||
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
|
||
# Print model parameters | ||
ct_model_for_recon.print_params() | ||
|
||
# Default VCD reconstruction | ||
print('Starting default recon - will have significant artifacts because of the missing projections.\n') | ||
time0 = time.time() | ||
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights) | ||
recon.block_until_ready() | ||
elapsed = time.time() - time0 | ||
|
||
# Print out parameters used in recon | ||
pprint.pprint(recon_params._asdict(), compact=True) | ||
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed)) | ||
|
||
"""**Display the default reconstruction.**""" | ||
|
||
title = 'Default recon: Phantom (left) vs VCD Recon (right)' | ||
title += '\nAdjust intensity range to [0, 1] to see internal artifacts from projection outside detector.' | ||
title += '\nAdjust intensity range to [1.5, 2] to see outer ring from projection outside detector.' | ||
mbirjax.slice_viewer(phantom, recon, title=title, vmin=0.0, vmax=2.0) | ||
|
||
"""**Decrease sharpness to reduce artifacts.** | ||
One way to reduce the artifacts seen above is to decrease sharpness, which promotes smoother images. This does reduce artifacts but also blurs the edges. | ||
Below we show that we can pad the recon to reduce artifacts without blurring the edges. | ||
""" | ||
|
||
# Increased regularization VCD reconstruction | ||
# We can reduce the artifacts by increasing regularization (decreasing sharpness). | ||
sharpness = -1.5 | ||
ct_model_for_recon.set_params(sharpness=sharpness) | ||
print('\nStarting recon with reduced sharpness - will have reduced artifacts but blurred edges.\n') | ||
recon_smooth, recon_params_smooth = ct_model_for_recon.recon(sinogram, weights=weights) | ||
|
||
# Print out parameters used in recon | ||
pprint.pprint(recon_params_smooth._asdict(), compact=True) | ||
|
||
# Display results | ||
title = 'Recon with sharpness = {:.1f}: Phantom (left) vs VCD Recon (right)'.format(sharpness) | ||
title += '\nAdjust intensity range to [0, 1] to see reduced internal artifacts from projection outside detector.' | ||
title += '\nOuter ring is still evident in intensity range [1, 2], and edges are blurry.' | ||
mbirjax.slice_viewer(phantom, recon_smooth, title=title, vmin=0.0, vmax=2.0) | ||
|
||
"""**Padded recon VCD reconstruction** | ||
Alternatively, we can pad the recon to allow for a partial reconstruction of the pixels that project outside the detector in some views. This reduces the artifacts without blurring edges and greatly reduces the outer ring seen in the non-padded recon. | ||
Note that the enlarged recon doesn't have to match the phantom size. Increasing the recon size won't allow us to fully reconstruct the pixels that sometimes project outside the detector. However, it will provide | ||
room for those partial projections to be absorbed into partial projected pixels, which allows for better reconstruction of the pixels with full projections. | ||
""" | ||
|
||
# Increase the recon size. In this case, we increase just the columns to | ||
# approximate the phantom shape. Note that it doesn't have to be an exact match. | ||
recon_row_scale = 1.0 | ||
recon_col_scale = 1.5 | ||
# Version 0.4.X: | ||
(num_rows, num_cols, num_slices) = ct_model_for_recon.get_params('recon_shape') | ||
new_shape = (int(num_rows * recon_row_scale), int(num_cols * recon_col_scale), num_slices) | ||
ct_model_for_recon.set_params(recon_shape=new_shape) | ||
# Version 0.5.0: | ||
# ct_model_for_recon.scale_recon_shape(row_scale=recon_row_scale, col_scale=recon_col_scale) | ||
|
||
# Reset the default sharpness | ||
sharpness = 0 | ||
ct_model_for_recon.set_params(sharpness=sharpness) | ||
|
||
print('\nStarting enlarged recon - will have reduced artifacts, sharper edges, some extra pixel estimation.\n') | ||
recon_enlarged, recon_params_enlarged = ct_model_for_recon.recon(sinogram, weights=weights) | ||
|
||
# Print out parameters used in recon | ||
pprint.pprint(recon_params_enlarged._asdict(), compact=True) | ||
|
||
"""**Display the result using the enlarged reconstruction.**""" | ||
|
||
title = 'Padded recon with sharpness = {:.1f}: Phantom (left) vs VCD Recon (right)'.format(sharpness) | ||
title += '\nPadding the recon reduces the internal artifacts even with default sharpness.' | ||
title += '\nEdges are sharp, outer ring is mostly gone, and the partially projected pixels are partially recovered.' | ||
mbirjax.slice_viewer(phantom, recon_enlarged, title=title, vmin=0.0, vmax=2.0) | ||
|
||
"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/examples.html). """ |
Oops, something went wrong.