Skip to content

Commit

Permalink
MBIRJAX v0.5.4 (#51)
Browse files Browse the repository at this point in the history
This version of MBIRJAX adds the following features:

    FBP parallel beam reconstruction using the adjoint forward projector
    FDK cone beam reconstruction using the adjoint forward projector
    Demos and tests for FBP and FDK

* fbp_fdk (#50)

* Incorporate fbp and fdk recon from separate branches.

* Incorporate fdk_recon from "fdk" branch

* Clean up commented lines

* Add comments

* switch the order of u and v

* fix the scale factor: recon_filter *= scaling_factor

* Restore fbp and fdk code from separate branches.

* Reorganize the setting of parameters and update the tolerances for fbp and fdk.

* Update the documentation for fbp and fdk (and remove the shepp-logan filter).

---------

Co-authored-by: Yufang Sun <yufang.sun2@gmail.com>

* Minor docstring correction

* Update version number and pin a version of jaxlib to avoid an error in loading cuda libraries on Gilbreth.

* Update test instructions for the GPU.

* Update pip install for jax cuda.

* use jax.scipy.signal.fftconvolve

* Fbp fdk demo (#52)

* Constructed the demo file for FBP and FDK reconstruction

* Changed the scanning angle to full 360 degrees and added projection angles

* write with python notebook format

* Update the faqs to discuss fbp/fdk.

* Update the install scripts to work on Gilbreth.

* Improve and fix a bug in the install scripts.

* Update install instructions.

* Update fbp/fdk description.

* Removed unused parameters: detector_cone_angle, weigth, sharpness and related code; remove VCD reconstruction.

* rename file demo_5_FBP_FDK.py to demo_5_fbp_fdk.py

* add the fbp/fdk demo link

* Update module name for gautschi.

* Correct typo.

---------

Co-authored-by: ZiyunLiiii <a124601@MacBook-Pro-215.local>
Co-authored-by: Yufang Sun <yufang.sun2@gmail.com>

---------

Co-authored-by: gbuzzard <54102356+gbuzzard@users.noreply.github.com>
Co-authored-by: Yufang Sun <yufang.sun2@gmail.com>
Co-authored-by: Greg Buzzard <buzzard@purdue.edu>
Co-authored-by: ZiyunLiiii <a124601@MacBook-Pro-215.local>
  • Loading branch information
5 people authored Jan 9, 2025
1 parent 1159eef commit 3d6eda7
Show file tree
Hide file tree
Showing 18 changed files with 512 additions and 101 deletions.
129 changes: 129 additions & 0 deletions demo/demo_5_fbp_fdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""demo_5_fbp_fdk.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/10ZiCSk1C9D4Fb7Uv6jTtQYxF2lKjmbyh
**MBIRJAX: FBP and FDK Reconstruction Demo**
See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details.
This script demonstrates the MBIRJAX code by creating a 3D phantom inspired by Shepp-Logan, forward projecting it to create a sinogram, and then using MBIRJAX to perform Filtered Back Projection (FBP) for parallel beam reconstruction and Feldkamp-Davis-Kress reconstruction (FDK) for cone beam reconstruntion.
For the demo, we create some synthetic data by first making a phantom, then forward projecting it to obtain a sinogram.
In a real application, you would load your sinogram as a numpy array and use numpy.transpose if needed so that it
has axes in the order (views, rows, channels). For reference, assuming the rotation axis is vertical, then increasing the row index nominally moves down the rotation axis and increasing the channel index moves to the right as seen from the source.
Select a GPU as runtime type for best performance.
"""

# Commented out IPython magic to ensure Python compatibility.
# %pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple mbirjax

import numpy as np
import time
import jax.numpy as jnp
import mbirjax

"""**Set the geometry parameters**"""

# Choose the geometry type
geometry_type = 'cone' # 'cone' or 'parallel'

# Set parameters for the problem size - you can vary these, but if you make num_det_rows very small relative to
# channels, then the generated phantom may not have an interior.
num_views = 128
num_det_rows = 128
num_det_channels = 128

# For cone beam geometry, we need to describe the distances source to detector and source to rotation axis.
# np.Inf is an allowable value, in which case this is essentially parallel beam.
source_detector_dist = 4 * num_det_channels
source_iso_dist = source_detector_dist

# Set parameters for viewing angle.
start_angle = -np.pi
end_angle = np.pi

"""**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram.
Note: the sliders on the viewer won't work in notebook form. For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python.
"""

# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

if geometry_type == 'cone':
ct_model_for_generation = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
elif geometry_type == 'parallel':
ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles)
else:
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))

# Generate 3D Shepp Logan phantom
print('Creating phantom')
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom()

# Generate synthetic sinogram data
print('Creating sinogram')
sinogram = ct_model_for_generation.forward_project(phantom)
sinogram = np.array(sinogram)

# View sinogram
title = 'Original sinogram \nUse the sliders to change the view or adjust the intensity range.'
mbirjax.slice_viewer(sinogram, slice_axis=0, title=title, slice_label='View')

"""**Initialize for the reconstruction**"""

# ####################
# Initialize the model for reconstruction.
if geometry_type == 'cone':
ct_model_for_recon = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
elif geometry_type == 'parallel':
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles)
else:
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))

# Print out model parameters
ct_model_for_recon.print_params()

"""**Do the reconstruction and display the results.**"""

# ##########################
# Perform FBP/FDK reconstruction
if geometry_type == 'cone':
print("Starting FDK recon")
time0 = time.time()
recon = ct_model_for_recon.fdk_recon(sinogram, filter_name="ramp")
else:
print("Starting FBP recon")
time0 = time.time()
recon = ct_model_for_recon.fbp_recon(sinogram, filter_name="ramp")

recon.block_until_ready()
elapsed = time.time() - time0
# ##########################

max_diff = np.amax(np.abs(phantom - recon))
print('Geometry = {}'.format(geometry_type))
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom)
pct_95 = np.percentile(np.abs(recon - phantom), 95)
print('NRMSE between recon and phantom = {}'.format(nrmse))
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff))
print('95% of recon pixels are within {} of phantom'.format(pct_95))

mbirjax.get_memory_stats()
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))

# Display results
title = (f"Phantom (left) vs {'FDK' if geometry_type == 'cone' else 'FBP'} Recon (right). "
f"Filter used: ramp. \nUse the sliders to change the slice or adjust the intensity range.")

mbirjax.slice_viewer(phantom, recon, title=title)

"""**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/demos_and_faqs.html). """
109 changes: 66 additions & 43 deletions dev_scripts/clean_install_all.sh
Original file line number Diff line number Diff line change
@@ -1,52 +1,81 @@
#!/bin/bash
# This script installs everything from scratch
# This script installs mbirjax from scratch, but there may still be cached packaages.
# To do a really clean install, first run
# source deep_clean_conda.sh

# Steps to recover from jax/cuda mismatch
# 1. Delete env from .conda directory. This should be achieved with `conda remove env --name $NAME --all` but check
# in /scratch/gilbreth/$USER/.conda/envs and ~/.conda/envs
# 2. Delete all nvidia and jax related library dirs from /home/$USER/.local
# 3. Delete numpy, scipy, importlib_metadata and other related dirs downloaded with jax installation command from .local
# 4. Delete ~/.cache
#####
# Update the cluster host names, modules, and jax installation as needed, here and in
# get_demo_data_server.sh
#####
GPUCLUSTER="gilbreth"
CPUCLUSTER="negishi"
NAME="mbirjax"
GILBRETH="gilbreth"
NEGISHI="negishi"
GAUTSCHI="gautschi"
PYTHON_VERSION="3.12"

if [[ "$HOSTNAME" == *"$GPUCLUSTER"* ]]; then
module load anaconda
echo "$GPUCLUSTER setting"
conda config --add pkgs_dirs /scratch/$GPUCLUSTER/$USER/.conda/pkgs
CONDA_ENVS_PATH="/scratch/$GPUCLUSTER/$USER/.conda/envs"
conda config --add envs_dirs /scratch/$GPUCLUSTER/$USER/.conda/envs
fi
if [[ "$HOSTNAME" == *"$CPUCLUSTER"* ]]; then
module load anaconda
echo "$CPUCLUSTER setting"
conda config --add pkgs_dirs /scratch/$CPUCLUSTER/$USER/.conda/pkgs
CONDA_ENVS_PATH="/scratch/$CPUCLUSTER/$USER/.conda/envs"
conda config --add envs_dirs /scratch/$CPUCLUSTER/$USER/.conda/envs
fi
# Remove any previous builds
cd ..
/bin/rm -r docs/build &> /dev/null
/bin/rm -r dist &> /dev/null
/bin/rm -r "$NAME.egg-info" &> /dev/null
/bin/rm -r build &> /dev/null
cd dev_scripts

source install_conda_environment.sh
# Create and activate new conda environment
# First check if the target environment is active and deactivate if so

if nvidia-smi | grep -q "CUDA"; then
# pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# To install lower version of jax (say v0.4.13) incase of XLA parallel compilation warnings use the following
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# pip install --upgrade "jax[cuda12]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# pip install jaxlib==0.4.13+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Ref: https://github.com/google/jax/issues/18027
echo " "
echo "To run with jax on Gilbreth, first load the cuda module using "
echo " module load cudnn/cuda-12.1_8.9"
echo " "
# Deactivate all conda environments
while [ ${#CONDA_DEFAULT_ENV} -gt 0 ]; do
echo "Deactivating $CONDA_DEFAULT_ENV"
conda deactivate
done
echo "No conda environment active"

# Remove the environment
output=$(yes | conda remove --name $NAME --all 2>&1)
if echo "$output" | grep -q "DirectoryNotACondaEnvironmentError:"; then
# In some cases the directory may still exist but not really be an environment, so remove the directory itself.
conda activate $NAME
CUR_ENV_PATH=$CONDA_PREFIX
conda deactivate
rm -rf $CUR_ENV_PATH
fi

# Install based on the host
# Gilbreth (gpu)
if [[ "$HOSTNAME" == *"$GILBRETH"* ]]; then
echo "Installing on Gilbreth"
module --force purge
# The following two lines are required in late 2024 to interface jax to an older version of cuda in use on gilbreth.
# After gilbreth/cuda is updated, then the pattern for gautschi could be used here.
module load jax/0.4.31
yes | conda create -n $NAME python=3.11.7
conda activate $NAME
pip install -e ..[cuda12]
# Gautschi (gpu)
elif [[ "$HOSTNAME" == *"$GAUTSCHI"* ]]; then
echo "Installing on Gautschi"
module load conda/2024.09
yes | conda create -n $NAME python="$PYTHON_VERSION"
conda activate $NAME
pip install -e ..[cuda12]
# Negishi (cpu)
elif [[ "$HOSTNAME" == *"$NEGISHI"* ]]; then
echo "Installing on Negishi"
module load anaconda
yes | conda create -n $NAME python="$PYTHON_VERSION"
conda activate $NAME
pip install -e ..
# Other (cpu)
else
pip install --upgrade "jax[cpu]"
echo "Installing on non-RCAC machine"
yes | conda create -n $NAME python="$PYTHON_VERSION"
conda activate $NAME
pip install -e ..
fi

#source install_package.sh
pip install ..[test]
pip install ..[docs]
source build_docs.sh

red=`tput setaf 1`
Expand All @@ -58,9 +87,3 @@ echo "Use"
echo "${red} conda activate mbirjax ${reset}"
echo "to activate the conda environment."
echo " "

if [[ "$HOSTNAME" == *"gilbreth"* ]]; then
echo " "
echo "Verify the versions of anaconda, jax, and cuda as specified in clean_install_all.sh"
echo " "
fi
5 changes: 0 additions & 5 deletions dev_scripts/create_conda_test_environment.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
#!/bin/bash
# Create and activate new conda environment
# First check if the target environment is active and deactivate if so
NAME="mbirjax"
NEW_NAME="test"

if [ "$CONDA_DEFAULT_ENV" = "$NAME" ]; then
conda deactivate
fi

if [ "$CONDA_DEFAULT_ENV" = "$NEW_NAME" ]; then
conda deactivate
fi
Expand Down
7 changes: 7 additions & 0 deletions dev_scripts/deep_clean_conda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

while [ ${#CONDA_DEFAULT_ENV} -gt 0 ]; do
conda deactivate
done

rm -rf ~/.conda/* ~/.cache/conda/* ~/.cache/pip/* ~/.local/lib/python*
25 changes: 0 additions & 25 deletions dev_scripts/install_conda_environment.sh

This file was deleted.

14 changes: 0 additions & 14 deletions dev_scripts/install_package.sh

This file was deleted.

18 changes: 18 additions & 0 deletions docs/source/demos_and_faqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Here are some demos to illustrate the basics of MBIRJAX along with some more adv
2. **Large Object:** `Jupyter notebook <https://colab.research.google.com/drive/1-kk_HeR8Y8f6pZ2zjTza8NTEpAgwgVRB?usp=sharing>`__ or `Python script <https://github.com/cabouman/mbirjax/blob/main/demo/demo_2_large_object.py>`__
3. **Cropped Center:** `Jupyter notebook <https://colab.research.google.com/drive/1WQwIJ_mDcuMMcWseM66aRPvtv6FmMWF-?usp=sharing>`__ or `Python script <https://github.com/cabouman/mbirjax/blob/main/demo/demo_3_cropped_center.py>`__
4. **Wrong Rotation:** `Jupyter notebook <https://colab.research.google.com/drive/1Gd-fMm3XK1WBsuJUklHdZ-4jjsvdpeIT?usp=sharing>`__ or `Python script <https://github.com/cabouman/mbirjax/blob/main/demo/demo_4_wrong_rotation_direction.py>`__
5. **FBP/FDK:** `Jupyter notebook <https://colab.research.google.com/drive/10ZiCSk1C9D4Fb7Uv6jTtQYxF2lKjmbyh?usp=sharing>`__ or `Python script <https://github.com/cabouman/mbirjax/blob/main/demo/demo_5_fbp_fdk.py>`__

First browse the notebooks, then copy and run in your own notebook environment,
or follow the installation instructions at :ref:`InstallationDocs` and run the scripts directly.
Expand Down Expand Up @@ -127,3 +128,20 @@ Positive values of ``offset`` will shift the region down relative to the detecto
This is useful if you would like to reconstruct the top or bottom half of a conebeam reconstruction in order to save memory.


Q: What are the differences between (iterative) recon and fbp_recon/fdk_recon?
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

A: The primary reconstruction method in MBIRJAX is iterative reconstruction (``mbirjax.TomographyModel.recon``)
using a Bayesian formulation that balances a data-fitting loss function with a prior function on the reconstruction that
reduces noise while maintaining sharp edges. This approach updates the reconstruction multiple times in order to
minimize the sum of these two loss functions.

In contrast, FBP (``mbirjax.ParallelBeamModel.fbp_recon``) and FDK (``mbirjax.ConeBeamModel.fdk_recon``) are direct
methods, in which the sinograms are filtered and then backprojected once to form the reconstruction. In this case,
there is no prior information and no attempt to denoise the sinogram or the reconstruction.

In general, FBP and FDK work well when the number of views is large (at least as large as the number of channels in the
detector) and the sinograms have little noise. Iterative reconstruction typically works better when there are
relatively few views and/or the sinograms are noisy. Iterative reconstruction takes more time and memory than
FBP/FDK but can produce significantly better reconstructions when the collected data is less than ideal.

8 changes: 6 additions & 2 deletions docs/source/dev_maintenance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ This is only available for registered maintainers. Typically, you would perform
View the package upload here:
`https://test.pypi.org/project/mbirjax <https://test.pypi.org/project/mbirjax>`__

3. Test the uploaded package::
3. Test the uploaded package (NOTE: to test on the GPU, use 'pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple mbirjax[cuda12]')::

pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple mbirjax
python -c "import mbirjax" # spin the wheel
pip install pytest
pytest tests

4. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.

Expand Down Expand Up @@ -69,10 +71,12 @@ This is only available for registered maintainers.
View the package upload here:
`https://pypi.org/project/mbirjax <https://pypi.org/project/mbirjax>`__

3. Test the uploaded package::
3. Test the uploaded package (NOTE: to test on the GPU, use 'pip install mbirjax[cuda12]')::

pip install mbirjax # OR, "mbirjax==0.1.1" e.g. for a specific version number
python -c "import mbirjax" # spin the wheel
pip install pytest
pytest tests

4. Run one of the demos in `mbirjax/demo <https://github.com/cabouman/mbirjax/tree/main/demo>`__.

Expand Down
9 changes: 1 addition & 8 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,9 @@ In order to download the python code, move to a directory of your choice and run

2. Install the conda environment and package

Option 1: Clean install using mbirjax/dev_scripts - We provide bash scripts that will do a clean install of ``MBIRJAX`` in a new conda environment using the following commands::
Clean install using mbirjax/dev_scripts - We provide bash scripts that will do a clean install of ``MBIRJAX`` in a new conda environment using the following commands::

cd dev_scripts
source clean_install_all.sh

Option 2: Manual install - You can also manually install ``MBIRJAX`` from the main directory of the repository with the following commands::

conda env create --name mbirjax --file environment.yml
conda activate mbirjax
pip install .



Loading

0 comments on commit 3d6eda7

Please sign in to comment.