-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Change default settings to investigate convergence. * Update demo_1 * Improve plotting and change sharpness and snr_db. * Change default granularity. * Change default granularity. * Restore demo_1 and move new convergence investigation to experiments. * Get changes from main. * Change default granularity. * Use extreme values for sharpness and snr_db. * Include ability to select a circular region and display mean and std dev. * Fix bug in region selection and incorporate changes from investigate_convergence. * Fix bug in circle selection. * Fix bug in circle selection. * Fix bug in region selection and incorporate changes from investigate_convergence. * Update version number and maintenance docs --------- Co-authored-by: Charles Bouman <Charles.Bouman@gmail.com>
- Loading branch information
Showing
7 changed files
with
367 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# -*- coding: utf-8 -*- | ||
"""demo_1_shepp_logan.ipynb | ||
Automatically generated by Colab. | ||
Original file is located at | ||
https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC | ||
**MBIRJAX: Basic Demo** | ||
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details. | ||
This script demonstrates the basic MBIRJAX code by creating a 3D phantom inspired by Shepp-Logan, forward projecting it to create a sinogram, and then using MBIRJAX to perform a Model-Based, Multi-Granular Vectorized Coordinate Descent reconstruction. | ||
For the demo, we create some synthetic data by first making a phantom, then forward projecting it to obtain a sinogram. | ||
In a real application, you would load your sinogram as a numpy array and use numpy.transpose if needed so that it | ||
has axes in the order (views, rows, channels). For reference, assuming the rotation axis is vertical, then increasing the row index nominally moves down the rotation axis and increasing the channel index moves to the right as seen from the source. | ||
Select a GPU as runtime type for best performance. | ||
""" | ||
|
||
# Commented out IPython magic to ensure Python compatibility. | ||
# %pip install mbirjax | ||
|
||
import numpy as np | ||
import time | ||
import pprint | ||
import jax.numpy as jnp | ||
import mbirjax | ||
|
||
"""**Set the geometry parameters**""" | ||
|
||
# Choose the geometry type | ||
geometry_type = 'parallel' # 'cone' or 'parallel' | ||
|
||
# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to | ||
# channels, then the generated phantom may not have an interior. | ||
num_views = 32 | ||
num_det_rows = 60 | ||
num_det_channels = 256 | ||
sharpness = 2.0 | ||
snr_db = 40 | ||
|
||
# For cone beam geometry, we need to describe the distances source to detector and source to rotation axis. | ||
# np.Inf is an allowable value, in which case this is essentially parallel beam | ||
source_detector_dist = 4 * num_det_channels | ||
source_iso_dist = source_detector_dist | ||
|
||
# For cone beam reconstruction, we need a little more than 180 degrees for full coverage. | ||
if geometry_type == 'cone': | ||
detector_cone_angle = 2 * np.arctan2(num_det_channels / 2, source_detector_dist) | ||
else: | ||
detector_cone_angle = 0 | ||
start_angle = -(np.pi + detector_cone_angle) * (1/2) | ||
end_angle = (np.pi + detector_cone_angle) * (1/2) | ||
|
||
"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram. | ||
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python. | ||
""" | ||
|
||
# Initialize sinogram | ||
sinogram_shape = (num_views, num_det_rows, num_det_channels) | ||
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False) | ||
|
||
if geometry_type == 'cone': | ||
ct_model_for_generation = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist) | ||
elif geometry_type == 'parallel': | ||
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
else: | ||
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type)) | ||
|
||
# Generate 3D Shepp Logan phantom | ||
print('Creating phantom') | ||
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom() | ||
|
||
# Generate synthetic sinogram data | ||
print('Creating sinogram') | ||
sinogram = ct_model_for_generation.forward_project(phantom) | ||
sinogram = np.array(sinogram) | ||
|
||
# View sinogram | ||
title = 'Original sinogram \nUse the sliders to change the view or adjust the intensity range.' | ||
mbirjax.slice_viewer(sinogram, slice_axis=0, title=title, slice_label='View') | ||
|
||
"""**Initialize for the reconstruction**""" | ||
|
||
# #################### | ||
# Initialize the model for reconstruction. | ||
if geometry_type == 'cone': | ||
ct_model_for_recon = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist) | ||
else: | ||
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles) | ||
|
||
# Generate weights array - for an initial reconstruction, use weights = None, then modify if needed. | ||
weights = None | ||
# weights = ct_model_for_recon.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root') | ||
|
||
# Set reconstruction parameter values | ||
# Increase sharpness by 1 or 2 to get clearer edges, possibly with more high-frequency artifacts. | ||
# Decrease by 1 or 2 to get softer edges and smoother interiors. | ||
|
||
|
||
# Set parameters | ||
ct_model_for_recon.set_params(sharpness=sharpness, snr_db=snr_db) | ||
|
||
# Print out model parameters | ||
ct_model_for_recon.print_params() | ||
|
||
"""**Do the reconstruction and display the results.**""" | ||
|
||
# ########################## | ||
# Perform VCD reconstruction | ||
print('Starting recon') | ||
time0 = time.time() | ||
recon = None | ||
iterations_per_step = 5 | ||
num_iterations = 20 | ||
for iteration in range(0, num_iterations, iterations_per_step): | ||
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights, num_iterations=iteration + iterations_per_step, | ||
first_iteration=iteration, init_recon=recon, | ||
compute_prior_loss=True) | ||
mbirjax.slice_viewer(recon) | ||
mbirjax.slice_viewer(recon, slice_axis=1) | ||
# if iteration > 8: | ||
# sharpness = 2.0 | ||
# snr_db = 40 | ||
# ct_model_for_recon.set_params(sharpness=sharpness, snr_db=snr_db) | ||
|
||
recon.block_until_ready() | ||
elapsed = time.time() - time0 | ||
# ########################## | ||
|
||
# Print parameters used in recon | ||
pprint.pprint(recon_params._asdict(), compact=True) | ||
|
||
max_diff = np.amax(np.abs(phantom - recon)) | ||
print('Geometry = {}'.format(geometry_type)) | ||
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom) | ||
pct_95 = np.percentile(np.abs(recon - phantom), 95) | ||
print('NRMSE between recon and phantom = {}'.format(nrmse)) | ||
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff)) | ||
print('95% of recon pixels are within {} of phantom'.format(pct_95)) | ||
|
||
mbirjax.get_memory_stats() | ||
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed)) | ||
|
||
# Display results | ||
title = 'Phantom (left) vs VCD Recon (right) \nUse the sliders to change the slice or adjust the intensity range.' | ||
mbirjax.slice_viewer(phantom, recon, title=title) | ||
|
||
"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/demos_and_faqs.html). """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.