Skip to content

Commit

Permalink
Improve qggmrf efficiency and accuracy. (#29)
Browse files Browse the repository at this point in the history
* Convert from cos_sin_angles to angles.

* Working version of new install.

* Update install to work with readthedocs (I hope).

* Update docs and doc installation.

* Improve developer docs.

* Add gpu memory and timing scripts.

* Improve and revise docs

* Minor correction

* buzzard_dev (#10)

* Refactor num_iterations as a method input rather than a class parameter.

* Implement voxel_batch_size for forward projection.

* Implement voxel_batch_size for back projection.

* Add initialization feature to recon

* Prepare for merge of add_prox.

* Improve error message.

---------

Co-authored-by: Charles Bouman <Charles.Bouman@gmail.com>

* Add proximal map (#12)

* Possibly buggy but functioning commit of refactored VCD functions

* Partially tested initial implementation of prox map

* Prox bug causes nan when initial condition is perfect

* Update gpu evaluatation for new interface.

* Add epsilon to avoid alpha=nan.

* Pull auto_set_regularization_params() out of prox()

* Update demo script

---------

Co-authored-by: Greg Buzzard <buzzard@purdue.edu>
Co-authored-by: gbuzzard <54102356+gbuzzard@users.noreply.github.com>

* Remove unused attributes: proximal_map_flag, max_resolutions, initialization

* Add batch size management to recon_test.py

* Add advanced features to documentation

* Fix typo.

* Add prox_map to docs

* Convert to use public instance variables in Projectors.

* Refactor to use recon_shape.

* Refactor to use recon_shape and projector_params.

* Remove prox_recon parameter

* Improve documentation of Projectors.

* Improve documentation of Projectors.

* Basic Conebeam (#16)

* Refactor parallel_beam.py for clarity

* Remove references to svmbir.

* More renaming in parallel_beam.py

* Add partial implementation of cone_beam.py

* Update incomplete version of cone_beam.py

* Updates to incomplete version cone beam projector

* Updates to incomplete version cone beam projector

* Progress towards working version.

* More progress towards working version.

* More progress towards working version.

* More progress towards working version.

* More progress towards working version.

* More progress towards working version.

* Working version - still needs to be tested in multiple cases.

* Add cone beam demo and experiment scripts

* Refactor per-pixel magnification to allow for infinite source-detector distance.

* Include more options to exercise conebeam model.

* Remove reshape from test and demo scripts

* Bug fix to get the correct system matrix row entries.

* Add link to powerpoint slides.

* Rename variables and functions for clarity

* Rename various variables and methods.

* Rename voxel to pixel as needed.

* Fix naming problems and restore vcd_figs_for_abst.py to working condition.

* Minor improvement to vcd_figs_for_abst.py script

* Add variable intensity window to slice_viewer and add 'modified' to the function that generates the modified Shepp-Logan phantom.

* Partial progress towards improving efficiency in conebeam.

* Update cvpr scripts

* Partial progress towards improving efficiency in conebeam.

* Rename delta_voxel_xy, delta_voxel_z, and delta_pixel_recon to delta_voxel.

* First implementation of forward vertical fan beam projection.

* Fix bugs in sphinx docs

* Add TomographyModel parameter definitions

* Refactor sphinx docs names

* Add documentation for cone beam class

* Fix bug in conebeam backprojection and simplify parameter handling.

---------

Co-authored-by: gbuzzard <54102356+gbuzzard@users.noreply.github.com>
Co-authored-by: Greg Buzzard <buzzard@purdue.edu>

* Efficient Cone Beam and Parameter Handling (#18)

* Partial implementation of horizontal fan beam projection.

* Add cone beam developer sphinx doc

* Working refactored forward projection for cone beam.

* Remove unused code.

* Change from magnification to source_iso_dist.

* Refactor magnification from a parameter to a function.

* Additional comment.

* Working version of vertical fan beam back projection.

* Working version of full refactored back projection.

* Working version of full refactored back projection.

* Fix Hessian calculation for cone beam.

* Add cone beam figure

* Remove logo

* Add logo

* Add pdf, jpg, pptx to gitignore

* Fix bug in docs.

* Change interface for basic geometry class to operate on a batch of pixels and one view for forward and back projections.
Improve docstrings.

* Refactor to vmap over pixels in vertical back projection.

* Move p to the list of geometry parameters.

* Include default batch sizes.  Are you happy?

* Refactor top level data flow in forward projector.

* Add projectors. Yes :-)

* Add cone beam tests and pytest support

* Refactor top level data flow in back projector.

* Add vmap over slices in horizontal fan beam.

* Add warnings when tests fail.

* Update pixel and voxel batch sizes and handling.
Make source_detector_distance a multiple of num_det_channels.

* Change auto_set_recon_size to be geometry specific.

* Refactored auto_set_sigma functions

* Refactor TomographModel to put parameter handling in a superclass.
Perform some code cleanup.

* Fix docs to reflect ParameterHandler.

* Change p to psf_radius and include sinogram viewer.

* Add psf_radius support

* Add psf_radius for both parallel and cone geometry

* Add ability to save and load to file.

* Check that loaded parameters match the type of the loading class.

* Update documentation.

* Edit documentation

---------

Co-authored-by: Greg Buzzard <buzzard@purdue.edu>

* Add bug-fixes and PyPI support (#20)

* Refactor to separate out the calculation of quantities needed for projection.
Improve documentation.

* Provide a display default when vmin=vmax.

* Provide a display default when vmin=vmax.

* Allow the slice label to be set.

* Fix bug to prevent zooming in sliders.

* Add cpu memory stats.

* Refactor for joint pan and zoom with two images.

* Minor modification to pyproject.toml

* Add PyPI dependencies

---------

Co-authored-by: Greg Buzzard <buzzard@purdue.edu>

* Update step size in vcd and other various improvements.   (#26)

* Update PyPI instructions.

* Correct docstring, tweak auto_regularization, improve demos, and update version number

* Update gitignore for jupyter notebooks and add LLNL experiments directory

* Minor bug fixes.

* Add files to reconstruct and view nersc data.

* Fix bug in backproject and improve parameter return for recon.

* Fix bug in default weights.

* Fix step size in vcd updates.
Clean up docs.
Change default partitions.
Fix bug in row and channel offets in cone and parallel.
Combine 2 demos into one.
Include more tests.
Fix bug in default weights.
Use reflected boundary conditions for qggmrf top and bottom slices.

* Fix minor bug.

* Update code for cvpr abstract to use the original partition sequences.

* Remove lbnl files for PR.

* Update version number

---------

Co-authored-by: Charles Bouman <Charles.Bouman@gmail.com>

* Improve package maintenance instructions.

* Refactor get_delta as a separate function and include a test for it.

* Refactor to use apply_map_in_batches.

* Refactor to use sum_function_in_batches and rename apply_map_in batches to concatenate_function_in_batches.

* Label regularization parameters as static for more efficient compiling and execution.

* Reorganize qggmrf names and functions in preparation for refactoring to use vmap.

* Fix error in doc related to recon_shape

* Add parameter checking to raise an exception if the user tries to set an invalid parameter.

* Add parameter checking to raise an exception if the user tries to set an invalid parameter.

* Relax the tolerances a little to account for randomness in vcd partition selection.

* Add cube phantom.

* Refactor qggmrf gradient and hessian to use vmap.

* Update qggmrf test.

* Remove randomness to ensure reproducibility.

* Partially working version - not fully debugged:
Refactor sum and concatenate in batches to work on tuples of inputs and outputs and include tests of these functions.
Implement infrastructure to return prior cost and to choose optimal alpha using prior gradient and hessian.

* Add QR code generation script

* Update tests.

* Include cost function for full recon and test for qggmrf gradient correctness.

* Start to remove cost function.

* Finish refactoring prior cost.

* Change default number of iterations.

* Update cvpr figures.

* Include test for surrogate hessian.

* Update displayed total loss.

* Rename cost to loss, sigma_p to sigma_prox.

* Move qggmrf functions to separate file.

* Move qggmrf functions to separate file.

* Working version of alpha using approximate prior quadratic.

* Remove deprecated keywords.

* Change version number.

---------

Co-authored-by: Charles Bouman <Charles.Bouman@gmail.com>
  • Loading branch information
gbuzzard and cabouman authored Jun 13, 2024
1 parent 4be13de commit dc6ea60
Show file tree
Hide file tree
Showing 22 changed files with 1,279 additions and 462 deletions.
2 changes: 1 addition & 1 deletion demo/demo_shepp_logan.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
# ##########################
# Perform VCD reconstruction
time0 = time.time()
recon, recon_params = ct_model.recon(sinogram, weights=weights)
recon, recon_params = ct_model.recon(sinogram, weights=weights, compute_prior_loss=True)

recon.block_until_ready()
elapsed = time.time() - time0
Expand Down
61 changes: 61 additions & 0 deletions dev_scripts/QR-code/create-QR-code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import qrcode
from PIL import Image, ImageDraw

# pip install qrcode[pil]

# URL of the GitHub documentation
url = 'https://mbirjax.readthedocs.io'

# Path to the logo image
logo_path = '../../docs/source/_static/logo.png'

# Generate QR code
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=10,
border=4,
)
qr.add_data(url)
qr.make(fit=True)

# Create an image from the QR code instance
img_qr = qr.make_image(fill='black', back_color='white').convert('RGB')

# Open the logo image
logo = Image.open(logo_path)

# Resize the logo
logo_width = 400
logo_ratio = logo_width / float(logo.size[0])
logo_height = int(float(logo.size[1]) * logo_ratio)
logo = logo.resize((logo_width, logo_height), Image.LANCZOS)

# Calculate the dimensions of the final image
qr_width, qr_height = img_qr.size
padding = 0
total_height = qr_height + logo_height + padding # Add some space between QR and logo

# Create a new image with an off-white background
bg_color = (235, 235, 230) # Off-white
final_image = Image.new('RGB', (qr_width, total_height), bg_color)

# Paste the QR code onto the final image
final_image.paste(img_qr, (0, 0))

# Paste the logo onto the final image
logo_pos = ((qr_width - logo_width) // 2, qr_height + padding) # Reduced padding
final_image.paste(logo, logo_pos, mask=logo)

# Create rounded corners mask
radius = 30 # Adjust the radius as needed
mask = Image.new('L', final_image.size, 0)
draw = ImageDraw.Draw(mask)
draw.rounded_rectangle([(0, 0), final_image.size], radius, fill=255)

# Apply rounded corners mask to the final image
rounded_final_image = Image.new('RGB', final_image.size)
rounded_final_image.paste(final_image, (0, 0), mask=mask)

# Save the final image
rounded_final_image.save('mbirjax-qr-code.png')
8 changes: 6 additions & 2 deletions docs/source/_static/new_model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ class TemplateModel(TomographyModel):
"""

# TODO: Adjust the signature as needed for a particular geometry and update the docstring to match.
def __init__(self, sinogram_shape, param1, param2, view_dependent_vec1, view_dependent_vec2, **kwargs):
# Don't include any additional unspecified keyword arguments in the form of **kwargs - use only the parameters that
# are required for the geometry. Any changes to existing parameters should be done by the user with set_params,
# which checks for invalid parameter names. There is no check for invalid parameter names here because the
# geometry may need to define new parameters.
def __init__(self, sinogram_shape, param1, param2, view_dependent_vec1, view_dependent_vec2):
# Convert the view-dependent vectors to an array
view_dependent_vecs = [vec.flatten() for vec in [view_dependent_vec1, view_dependent_vec2]]
try:
Expand All @@ -38,7 +42,7 @@ def __init__(self, sinogram_shape, param1, param2, view_dependent_vec1, view_dep
raise ValueError("Incompatible view dependent vector lengths: all view-dependent vectors must have the "
"same length.")

super().__init__(sinogram_shape, param1=param1, param2=param2, view_params_array=view_params_array, **kwargs)
super().__init__(sinogram_shape, param1=param1, param2=param2, view_params_array=view_params_array)

@classmethod
def from_file(cls, filename):
Expand Down
2 changes: 1 addition & 1 deletion experiments/cone_beam_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
# ##########################
# Do all the setup
view_batch_size = 2 # Reduce this for large detector row/channel count, increase for smaller
view_batch_size = 4 # Reduce this for large detector row/channel count, increase for smaller
pixel_batch_size = 2048

# Initialize sinogram parameters
Expand Down
93 changes: 93 additions & 0 deletions experiments/cube_phantom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np
import time
import pprint
import jax.numpy as jnp
import mbirjax.plot_utils as pu
import mbirjax.parallel_beam

if __name__ == "__main__":
"""
This is a script to develop, debug, and tune the parallel beam mbirjax code.
"""
# Choose the geometry type
geometry_type = 'cone' # 'cone' or 'parallel'

print('Using {} geometry.'.format(geometry_type))

# Set parameters
num_views = 64
num_det_rows = 64
num_det_channels = 64
sharpness = 0.0

# These can be adjusted to describe the geometry in the cone beam case.
# np.Inf is an allowable value, in which case this is essentially parallel beam
source_detector_dist = 4 * num_det_channels
source_iso_dist = source_detector_dist

if geometry_type == 'cone':
detector_cone_angle = 2 * np.arctan2(num_det_channels / 2, source_detector_dist)
else:
detector_cone_angle = 0
start_angle = -(np.pi + detector_cone_angle) * (1/2)
end_angle = (np.pi + detector_cone_angle) * (1/2)

# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

# Set up the model
if geometry_type == 'cone':
ct_model = mbirjax.ConeBeamModel(sinogram_shape, angles, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist)
elif geometry_type == 'parallel':
ct_model = mbirjax.ParallelBeamModel(sinogram_shape, angles)
else:
raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))

# Generate 3D Shepp Logan phantom
print('Creating phantom')
recon_shape = ct_model.get_params('recon_shape')
phantom = np.zeros(recon_shape) # ct_model.gen_modified_3d_sl_phantom()
phantom[16:48, 16:48, 16:48] = 1

# Generate synthetic sinogram data
print('Creating sinogram')
sinogram = ct_model.forward_project(phantom)

# View sinogram
pu.slice_viewer(sinogram, title='Original sinogram', slice_axis=0, slice_label='View')

# Generate weights array - for an initial reconstruction, use weights = None, then modify as desired.
weights = None
# weights = ct_model.gen_weights(sinogram / sinogram.max(), weight_type='transmission_root')

# Set reconstruction parameter values
ct_model.set_params(sharpness=sharpness, verbose=1)

# Print out model parameters
ct_model.print_params()

# ##########################
# Perform VCD reconstruction
time0 = time.time()
recon, recon_params = ct_model.recon(sinogram, weights=weights, num_iterations=15)

recon.block_until_ready()
elapsed = time.time() - time0
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))
# ##########################

# Print out parameters used in recon
pprint.pprint(recon_params._asdict())

max_diff = np.amax(np.abs(phantom - recon))
print('Geometry = {}'.format(geometry_type))
nrmse = np.linalg.norm(recon - phantom) / np.linalg.norm(phantom)
pct_95 = np.percentile(np.abs(recon - phantom), 95)
print('NRMSE between recon and phantom = {}'.format(nrmse))
print('Maximum pixel difference between phantom and recon = {}'.format(max_diff))
print('95% of recon pixels are within {} of phantom'.format(pct_95))

# Display results
pu.slice_viewer(phantom, recon, title='Phantom (left) vs VCD Recon (right)')

21 changes: 14 additions & 7 deletions experiments/cvpr-2024/vcd_figs_for_abst.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,32 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :

# ##########################
# Perform VCD reconstruction
num_iterations = 13
recon_vcd, recon_params_vcd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations)
num_iterations = 10
recon_vcd, recon_params_vcd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations,
compute_prior_loss=True)
fm_rmse_vcd = recon_params_vcd.fm_rmse
prior_loss_vcd = recon_params_vcd.prior_loss
default_partition_sequence = parallel_model.get_params('partition_sequence')
partition_sequence = mbirjax.gen_partition_sequence(default_partition_sequence, num_iterations=num_iterations)
granularity_sequence_vcd = granularity[partition_sequence]

# Perform GD reconstruction
partition_sequence = [0, ]
parallel_model.set_params(partition_sequence=partition_sequence)
recon_gd, recon_params_gd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations)
recon_gd, recon_params_gd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations,
compute_prior_loss=True)
fm_rmse_gd = recon_params_gd.fm_rmse
prior_loss_gd = recon_params_gd.prior_loss
partition_sequence = mbirjax.gen_partition_sequence(partition_sequence=partition_sequence, num_iterations=num_iterations)
granularity_sequence_gd = granularity[partition_sequence]

# Perform CD reconstruction
partition_sequence = [3, ]
parallel_model.set_params(partition_sequence=partition_sequence)
recon_cd, recon_params_cd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations)
fm_rmse_cd =recon_params_cd.fm_rmse
recon_cd, recon_params_cd = parallel_model.recon(sinogram, weights=weights, num_iterations=num_iterations,
compute_prior_loss=True)
fm_rmse_cd = recon_params_cd.fm_rmse
prior_loss_cd = recon_params_cd.prior_loss
partition_sequence = mbirjax.gen_partition_sequence(partition_sequence=partition_sequence, num_iterations=num_iterations)
granularity_sequence_cd = granularity[partition_sequence]
# ##########################
Expand All @@ -100,9 +106,10 @@ def display_slices_for_abstract( recon1, recon2, recon3, labels) :

# Display granularity plots:
granularity_sequences = [granularity_sequence_gd, granularity_sequence_vcd, granularity_sequence_cd]
losses = [fm_rmse_gd, fm_rmse_vcd, fm_rmse_cd]
fm_losses = [fm_rmse_gd, fm_rmse_vcd, fm_rmse_cd]
prior_losses = [prior_loss_gd, prior_loss_vcd, prior_loss_cd]
labels = ['Gradient Descent', 'Vectorized Coordinate Descent', 'Coordinate Descent']
pu.plot_granularity_and_loss(granularity_sequences, losses, labels, granularity_ylim=(0, 256), loss_ylim=(0.1, 15))
pu.plot_granularity_and_loss(granularity_sequences, fm_losses, prior_losses, labels, granularity_ylim=(0, 256), loss_ylim=(0.1, 15))

# Generate sequence of partition images for Figure 1
recon_shape = (32, 32, 1)
Expand Down
2 changes: 1 addition & 1 deletion experiments/prox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# Run auto regularization. If auto_regularize_flag is False, then this will have no effect
parallel_model.auto_set_regularization_params(sinogram, weights=weights)
init_recon = phantom + 1.0
recon, fm_rmse = parallel_model.prox_map(phantom, sinogram, weights=weights, init_recon=init_recon, num_iterations=13)
recon, loss_vectors = parallel_model.prox_map(phantom, sinogram, weights=weights, init_recon=init_recon, num_iterations=13)

# Reshape recon into 3D form
recon_3d = parallel_model.reshape_recon(recon)
Expand Down
1 change: 1 addition & 0 deletions experiments/recon_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def print_summary(file=None):
print('Granularity = {}'.format(recon_params.granularity))
print('Partition sequence = {}'.format(recon_params.partition_sequence))
print('Final RMSE = {:.3f}'.format(recon_params.fm_rmse[-1]))
print('Final prior loss = {:.3f}'.format(recon_params.prior_loss[-1]))
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed), file=file)
mbirjax.get_memory_stats(print_results=True, file=file)
print('-------------------------', file=file)
Expand Down
1 change: 1 addition & 0 deletions mbirjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .projectors import *
from .parameter_handler import *
from .tomography_model import *
from .qggmrf import *
from .parallel_beam import *
from .cone_beam import *
from .vcd_utils import *
Expand Down
4 changes: 2 additions & 2 deletions mbirjax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
'recon_shape': {'val': None, 'recompile_flag': True},
'delta_voxel': {'val': None, 'recompile_flag': True},
'sigma_x': {'val': 1.0, 'recompile_flag': False},
'sigma_p': {'val': 1.0, 'recompile_flag': False},
'sigma_prox': {'val': 1.0, 'recompile_flag': False},
'p': {'val': 2.0, 'recompile_flag': False},
'q': {'val': 1.2, 'recompile_flag': False},
'T': {'val': 1.0, 'recompile_flag': False},
'b': {'val': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'recompile_flag': False},
'b': {'val': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'recompile_flag': False}, # Order is [row+1, row-1, col+1, col-1, slice+1, slice-1]
}

_reconstruction_defaults_dict = {
Expand Down
5 changes: 2 additions & 3 deletions mbirjax/cone_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ConeBeamModel(TomographyModel):
"""

def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist,
recon_slice_offset=0.0, det_rotation=0.0, **kwargs):
recon_slice_offset=0.0, det_rotation=0.0):
# Convert the view-dependent vectors to an array
# This is more complicated than needed with only a single view-dependent vector but is included to
# illustrate the process as shown in TemplateModel
Expand All @@ -49,8 +49,7 @@ def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist

super().__init__(sinogram_shape, view_params_array=view_params_array,
source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist,
recon_slice_offset=recon_slice_offset, det_rotation=det_rotation,
**kwargs)
recon_slice_offset=recon_slice_offset, det_rotation=det_rotation)

@classmethod
def from_file(cls, filename):
Expand Down
4 changes: 2 additions & 2 deletions mbirjax/parallel_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ParallelBeamModel(TomographyModel):
TomographyModel : The base class from which this class inherits.
"""

def __init__(self, sinogram_shape, angles, **kwargs):
def __init__(self, sinogram_shape, angles):
# Convert the view-dependent vectors to an array
# This is more complicated than needed with only a single view-dependent vector but is included to
# illustrate the process as shown in TemplateModel
Expand All @@ -49,7 +49,7 @@ def __init__(self, sinogram_shape, angles, **kwargs):
except ValueError as e:
raise ValueError("Incompatible view dependent vector lengths: all view-dependent vectors must have the "
"same length.")
super().__init__(sinogram_shape, view_params_array=view_params_array, **kwargs)
super().__init__(sinogram_shape, view_params_array=view_params_array)

@classmethod
def from_file(cls, filename):
Expand Down
11 changes: 9 additions & 2 deletions mbirjax/parameter_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,20 @@ def set_params(self, no_warning=False, no_compile=False, **kwargs):

if key in self.params.keys():
recompile_flag = self.params[key]['recompile_flag']
elif not no_warning:
error_message = '{} is not a recognized parameter'.format(key)
error_message += '\nValid parameters are: \n'
for valid_key in self.params.keys():
error_message += ' {}\n'.format(valid_key)
raise ValueError(error_message)

new_entry = {'val': val, 'recompile_flag': recompile_flag}
self.params[key] = new_entry

# Handle special cases
if recompile_flag:
recompile = True
elif key in ["sigma_y", "sigma_x", "sigma_p"]:
elif key in ["sigma_y", "sigma_x", "sigma_prox"]:
regularization_parameter_change = True
elif key in ["sharpness", "snr_db"]:
meta_parameter_change = True
Expand All @@ -179,7 +186,7 @@ def set_params(self, no_warning=False, no_compile=False, **kwargs):
if regularization_parameter_change:
self.set_params(auto_regularize_flag=False)
if not no_warning:
warnings.warn('You are directly setting regularization parameters, sigma_x, sigma_y or sigma_p. '
warnings.warn('You are directly setting regularization parameters, sigma_x, sigma_y or sigma_prox. '
'This is an advanced feature that will disable auto-regularization.')

# Handle case if any meta regularization parameter changed
Expand Down
9 changes: 5 additions & 4 deletions mbirjax/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def debug_plot_indices(num_recon_rows, num_recon_cols, indices, recon_at_indices
plt.show()


def plot_granularity_and_loss(granularity_sequences, losses, labels, granularity_ylim=None, loss_ylim=None,
def plot_granularity_and_loss(granularity_sequences, fm_losses, prior_losses, labels, granularity_ylim=None, loss_ylim=None,
fig_title=None):
"""
Plots multiple granularity and loss data sets on a single figure.
Expand All @@ -261,8 +261,8 @@ def plot_granularity_and_loss(granularity_sequences, losses, labels, granularity
if num_plots == 1:
axes = [axes] # Make it iterable for a single subplot scenario

for ax, granularity_sequence, loss, label in zip(axes, granularity_sequences, losses, labels):
index = list(range(len(granularity_sequence)))
for ax, granularity_sequence, fm_loss, prior_loss, label in zip(axes, granularity_sequences, fm_losses, prior_losses, labels):
index = list(1 + np.arange(len(granularity_sequence)))

# Plot granularity sequence on the first y-axis
ax1 = ax
Expand All @@ -274,7 +274,8 @@ def plot_granularity_and_loss(granularity_sequences, losses, labels, granularity

# Create a second y-axis for the loss
ax2 = ax1.twinx()
ax2.plot(index, loss, label='Loss', color='r')
ax2.plot(index, fm_loss, label='Data loss', color='r')
ax2.plot(index, prior_loss, label='Prior loss', color='g')
ax2.set_ylabel('Loss', color='r')
ax2.tick_params(axis='y', labelcolor='r')
ax2.set_yscale('log')
Expand Down
Loading

0 comments on commit dc6ea60

Please sign in to comment.