Skip to content

Commit

Permalink
Prerelease (#47)
Browse files Browse the repository at this point in the history
* Change default settings to investigate convergence.

* Update demo_1

* Improve plotting and change sharpness and snr_db.

* Change default granularity.

* Change default granularity.

* Restore demo_1 and move new convergence investigation to experiments.

* Get changes from main.

* Change default granularity.

* Use extreme values for sharpness and snr_db.

* Include ability to select a circular region and display mean and std dev.

* Fix bug in region selection and incorporate changes from investigate_convergence.

* Fix bug in circle selection.

* Fix bug in circle selection.

* Fix bug in region selection and incorporate changes from investigate_convergence.

* Update version number and maintenance docs

---------

Co-authored-by: Charles Bouman <Charles.Bouman@gmail.com>
  • Loading branch information
gbuzzard and cabouman authored Oct 29, 2024
1 parent 2249e95 commit 1159eef
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 43 deletions.
24 changes: 14 additions & 10 deletions docs/source/dev_maintenance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,67 @@ Uploading to Test PyPI

This is only available for registered maintainers. Typically, you would perform these steps on the prerelease branch before the final commit to main.

0. Install the newest versions of `setuptools`, `wheel`, `build`, and `twine`. Then from the main mbirjax directory, delete any previous build and then build the project::
0. Update the version number in prerelease and accept the PR to main.

1. Install the newest versions of `setuptools`, `wheel`, `build`, and `twine`. Then from the main mbirjax directory, delete any previous build and then build the project::

pip install setuptools build wheel twine
rm -r dist
python -m build


1. Upload to Test PyPI. You will need to get an API token from TestPyPI. You will be prompted for this token from the command line. NOTE: You cannot upload the same version more than once::
2. Upload to Test PyPI. You will need to get an API token from TestPyPI. You will be prompted for this token from the command line. NOTE: You cannot upload the same version more than once::

python -m twine upload --repository testpypi dist/*

View the package upload here:
`https://test.pypi.org/project/mbirjax <https://test.pypi.org/project/mbirjax>`__

2. Test the uploaded package::
3. Test the uploaded package::

pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple mbirjax
python -c "import mbirjax" # spin the wheel

3. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.
4. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.

NOTE: If the install fails and you need to re-test, *temporarily* set
the version number in `pyproject.toml` from X.X.X to X.X.X.1 (then 2, 3, etc.),
for further testing. After the test is successful, reset the version number in
`pyproject.toml`, then merge any required changes into the master branch,
then delete and re-create the git tag, and proceed to PyPI upload.

4. Verify that the `corresponding build <https://readthedocs.org/projects/mbirjax/builds/>`__ of the MBIRJAX documentation has built correctly.
5. Verify that the `corresponding build <https://readthedocs.org/projects/mbirjax/builds/>`__ of the MBIRJAX documentation has built correctly.

Uploading to PyPI
-----------------

This is only available for registered maintainers.

0. First, make sure you have installed the newest versions of `setuptools`, `wheel`, `build`, and `twine`. Then from the main mbirjax directory, delete any previous build and then build the project::
0. Update the version number in prerelease and accept the PR to main.

1. First, make sure you have installed the newest versions of `setuptools`, `wheel`, `build`, and `twine`. Then from the main mbirjax directory, delete any previous build and then build the project::

pip install setuptools build wheel twine
rm -r dist
python -m build


1. Upload to PyPI. As above, you will need an API token, this time from PyPI. NOTE: You cannot upload the same version more than once::
2. Upload to PyPI. As above, you will need an API token, this time from PyPI. NOTE: You cannot upload the same version more than once::

python -m twine upload dist/*

View the package upload here:
`https://pypi.org/project/mbirjax <https://pypi.org/project/mbirjax>`__

2. Test the uploaded package::
3. Test the uploaded package::

pip install mbirjax # OR, "mbirjax==0.1.1" e.g. for a specific version number
python -c "import mbirjax" # spin the wheel

3. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.
4. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.


4. Verify that the `corresponding build <https://readthedocs.org/projects/mbirjax/builds/>`__ of the MBIRJAX documentation has built correctly.
5. Verify that the `corresponding build <https://readthedocs.org/projects/mbirjax/builds/>`__ of the MBIRJAX documentation has built correctly.

Reference
---------
Expand Down
41 changes: 24 additions & 17 deletions experiments/cvpr-2024/vcd_figs_for_abst_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import mbirjax.parallel_beam


def display_slices_for_abstract( recon1, recon2, recon3, labels) :
def display_slices_for_abstract( recon1, recon2, recon3, labels, fig_title=None):
# Set global font size
plt.rcParams.update({'font.size': 15}) # Adjust font size here

vmin = 0.0
vmax = phantom.max()
slice_index = recon1.shape[2] // 2

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 6))
fig.suptitle(fig_title)

a0 = ax[0].imshow(recon1[:, :, slice_index], vmin=vmin, vmax=vmax, cmap='gray')
#plt.colorbar(a0, ax=ax[0])
Expand All @@ -29,6 +30,7 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :
ax[2].set_title(labels[2])

plt.show()
fig.savefig('../figs/' + fig_title + '.png')


if __name__ == "__main__":
Expand All @@ -41,12 +43,13 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :
print('Using {} geometry.'.format(geometry_type))

# Set parameters
num_views = 256
num_det_rows = 10
num_views = 32
num_det_rows = 40
num_det_channels = 256
start_angle = 0
end_angle = np.pi
sharpness = 0.0
sharpness = 1.5
snr_db = 35

# These can be adjusted to describe the geometry in the cone beam case.
# np.Inf is an allowable value, in which case this is essentially parallel beam
Expand Down Expand Up @@ -83,32 +86,34 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :
weights = ct_model.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root')

# Set reconstruction parameter values
ct_model.set_params(sharpness=sharpness, verbose=1)
ct_model.set_params(sharpness=sharpness, snr_db=snr_db, verbose=1)

# Print out model parameters
ct_model.print_params()

# 'granularity': {'val': [1, 4, 64, 128], 'recompile_flag': False},
# 'partition_sequence': {'val': [0, 1, 2, 2, 2], 'recompile_flag': False},
# 'granularity': {'val': [2, 48, 96, 128], 'recompile_flag': False},
# 'partition_sequence': {'val': [0, 1, 2, 2, 2, 2, 3], 'recompile_flag': False},

granularity_alt_1 = [4, 48, 96, 128]
partition_sequence_alt_1 = [0, 1, 2, 2, 2, 2, 3]
granularity_alt_1 = [1, 2, 4, 8, 16, 32, 64, 128, 256]
partition_sequence_alt_1 = [0, 1, 2, 3, 4, 5, 6, 7, 8]
# granularity_alt_1 = [1, 4, 16, 64, 256]
# partition_sequence_alt_1 = [0, 1, 2, 3, 4]

granularity_alt_2 = [2, 48, 96, 128]
partition_sequence_alt_2 = [0, 1, 2, 2, 2, 2, 3]
granularity_alt_2 = [1, 3, 9, 27, 81, 243]
partition_sequence_alt_2 = [0, 1, 2, 3, 4, 5]

# ##########################
# Perform default VCD reconstruction
print('Starting default sequence')
num_iterations = 8
num_iterations = 20
recon_default, recon_params_default = ct_model.recon(sinogram, weights=weights, num_iterations=num_iterations,
compute_prior_loss=True)
compute_prior_loss=True)
fm_rmse_default = recon_params_default.fm_rmse
prior_loss_default = recon_params_default.prior_loss
partition_sequence = recon_params_default.partition_sequence
granularity = np.array(recon_params_default.granularity)
granularity_sequence_default = granularity[partition_sequence]
label_default = 'Default: ' + str(granularity_sequence_default)
label_default = 'Base: ' + str(granularity_sequence_default)

# Perform alt_1 default reconstruction
print('Starting alt_1 sequence')
Expand Down Expand Up @@ -140,15 +145,17 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :
# ##########################

# Display reconstructions
fig_title = '(v, r, c) = ({}, {}, {}), sharpness={}, snr_db={}'.format(num_views, num_det_rows, num_det_channels, sharpness, snr_db)
labels = [label_alt_1, label_default, label_alt_2]
# display_slices_for_abstract(recon_alt_1, recon_default, recon_alt_2, labels)
display_slices_for_abstract(recon_alt_1, recon_default, recon_alt_2, labels, fig_title=fig_title)

# Display granularity plots:
granularity_sequences = [granularity_sequence_alt_1, granularity_sequence_default, granularity_sequence_alt_2]
fm_losses = [fm_rmse_alt_1, fm_rmse_default, fm_rmse_alt_2]
prior_losses = [prior_loss_alt_1, prior_loss_default, prior_loss_alt_2]
# labels = ['Gradient Descent', 'Vectorized Coordinate Descent', 'Coordinate Descent']
mbirjax.plot_granularity_and_loss(granularity_sequences, fm_losses, prior_losses, labels, granularity_ylim=(0, 256), loss_ylim=(0.1, 15))
mbirjax.plot_granularity_and_loss(granularity_sequences, fm_losses, prior_losses, labels, granularity_ylim=(0, 256),
loss_ylim=(0.1, 15), fig_title=fig_title)

# Generate sequence of partition images for Figure 1
recon_shape = (32, 32, 1)
Expand Down
154 changes: 154 additions & 0 deletions experiments/shepp_logan_convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
"""demo_1_shepp_logan.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC
**MBIRJAX: Basic Demo**
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details.
This script demonstrates the basic MBIRJAX code by creating a 3D phantom inspired by Shepp-Logan, forward projecting it to create a sinogram, and then using MBIRJAX to perform a Model-Based, Multi-Granular Vectorized Coordinate Descent reconstruction.
For the demo, we create some synthetic data by first making a phantom, then forward projecting it to obtain a sinogram.
In a real application, you would load your sinogram as a numpy array and use numpy.transpose if needed so that it
has axes in the order (views, rows, channels). For reference, assuming the rotation axis is vertical, then increasing the row index nominally moves down the rotation axis and increasing the channel index moves to the right as seen from the source.
Select a GPU as runtime type for best performance.
"""

# Commented out IPython magic to ensure Python compatibility.
# %pip install mbirjax

import numpy as np
import time
import pprint
import jax.numpy as jnp
import mbirjax

"""**Set the geometry parameters**"""

# Choose the geometry type
geometry_type = 'parallel' # 'cone' or 'parallel'

# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to
# channels, then the generated phantom may not have an interior.
num_views = 32
num_det_rows = 60
num_det_channels = 256
sharpness = 2.0
snr_db = 40

# For cone beam geometry, we need to describe the distances source to detector and source to rotation axis.
# np.Inf is an allowable value, in which case this is essentially parallel beam
source_detector_dist = 4 * num_det_channels
source_iso_dist = source_detector_dist

# For cone beam reconstruction, we need a little more than 180 degrees for full coverage.
if geometry_type == 'cone':
detector_cone_angle = 2 * np.arctan2(num_det_channels / 2, source_detector_dist)
else:
detector_cone_angle = 0
start_angle = -(np.pi + detector_cone_angle) * (1/2)
end_angle = (np.pi + detector_cone_angle) * (1/2)

"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram.
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python.
"""

# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

if geometry_type == 'cone':
ct_model_for_generation = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
elif geometry_type == 'parallel':
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles)
else:
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))

# Generate 3D Shepp Logan phantom
print('Creating phantom')
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom()

# Generate synthetic sinogram data
print('Creating sinogram')
sinogram = ct_model_for_generation.forward_project(phantom)
sinogram = np.array(sinogram)

# View sinogram
title = 'Original sinogram \nUse the sliders to change the view or adjust the intensity range.'
mbirjax.slice_viewer(sinogram, slice_axis=0, title=title, slice_label='View')

"""**Initialize for the reconstruction**"""

# ####################
# Initialize the model for reconstruction.
if geometry_type == 'cone':
ct_model_for_recon = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
else:
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Generate weights array - for an initial reconstruction, use weights = None, then modify if needed.
weights = None
# weights = ct_model_for_recon.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root')

# Set reconstruction parameter values
# Increase sharpness by 1 or 2 to get clearer edges, possibly with more high-frequency artifacts.
# Decrease by 1 or 2 to get softer edges and smoother interiors.


# Set parameters
ct_model_for_recon.set_params(sharpness=sharpness, snr_db=snr_db)

# Print out model parameters
ct_model_for_recon.print_params()

"""**Do the reconstruction and display the results.**"""

# ##########################
# Perform VCD reconstruction
print('Starting recon')
time0 = time.time()
recon = None
iterations_per_step = 5
num_iterations = 20
for iteration in range(0, num_iterations, iterations_per_step):
recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights, num_iterations=iteration + iterations_per_step,
first_iteration=iteration, init_recon=recon,
compute_prior_loss=True)
mbirjax.slice_viewer(recon)
mbirjax.slice_viewer(recon, slice_axis=1)
# if iteration > 8:
# sharpness = 2.0
# snr_db = 40
# ct_model_for_recon.set_params(sharpness=sharpness, snr_db=snr_db)

recon.block_until_ready()
elapsed = time.time() - time0
# ##########################

# Print parameters used in recon
pprint.pprint(recon_params._asdict(), compact=True)

max_diff = np.amax(np.abs(phantom - recon))
print('Geometry = {}'.format(geometry_type))
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom)
pct_95 = np.percentile(np.abs(recon - phantom), 95)
print('NRMSE between recon and phantom = {}'.format(nrmse))
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff))
print('95% of recon pixels are within {} of phantom'.format(pct_95))

mbirjax.get_memory_stats()
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))

# Display results
title = 'Phantom (left) vs VCD Recon (right) \nUse the sliders to change the slice or adjust the intensity range.'
mbirjax.slice_viewer(phantom, recon, title=title)

"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/demos_and_faqs.html). """
4 changes: 2 additions & 2 deletions mbirjax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
'positivity_flag': {'val': False, 'recompile_flag': False},
'snr_db': {'val': 30.0, 'recompile_flag': False},
'sharpness': {'val': 0.0, 'recompile_flag': False},
'granularity': {'val': [2, 48, 96, 128], 'recompile_flag': False},
'partition_sequence': {'val': [0, 1, 2, 2, 2, 2, 3], 'recompile_flag': False},
'granularity': {'val': [1, 2, 4, 8, 16, 32, 64, 128], 'recompile_flag': False},
'partition_sequence': {'val': [0, 1, 2, 3, 4, 5, 6, 7], 'recompile_flag': False},
'verbose': {'val': 1, 'recompile_flag': False},
'use_gpu': {'val': 'automatic', 'recompile_flag': True} # Possible values are 'automatic', 'full', 'worker', 'none'
}
Expand Down
Loading

0 comments on commit 1159eef

Please sign in to comment.