diff --git a/demo/slice_viewer_demo.gif b/demo/slice_viewer_demo.gif new file mode 100644 index 0000000..279de63 Binary files /dev/null and b/demo/slice_viewer_demo.gif differ diff --git a/docs/source/_static/new_model_template.py b/docs/source/_static/new_model_template.py index 846f645..3e754dc 100644 --- a/docs/source/_static/new_model_template.py +++ b/docs/source/_static/new_model_template.py @@ -57,14 +57,19 @@ def from_file(cls, filename): ConeBeamModel with the specified parameters. """ # Load the parameters and convert view-dependent parameters to use the geometry-specific keywords. - params = ParameterHandler.load_param_dict(filename, values_only=True) - view_params_array = params['view_params_array'] + # TODO: Adjust these to match the signature of __init__ + required_param_names = ['sinogram_shape', 'param1', 'param2'] + required_params, params = ParameterHandler.load_param_dict(filename, required_param_names, values_only=True) # TODO: Adjust these to match the signature of __init__ - view_dependent_vec1 = view_params_array[:, 0] - view_dependent_vec2 = view_params_array[:, 1] + view_params_array = params['view_params_array'] + required_params['view_dependent_vec1'] = view_params_array[:, 0] + required_params['view_dependent_vec2'] = view_params_array[:, 1] del params['view_params_array'] - return cls(view_dependent_vec1=view_dependent_vec1, view_dependent_vec2=view_dependent_vec2, **params) + + new_model = cls(**required_params) + new_model.set_params(**params) + return new_model def get_magnification(self): """ diff --git a/docs/source/dev_api.rst b/docs/source/dev_api.rst index 1f2c1c2..071b411 100644 --- a/docs/source/dev_api.rst +++ b/docs/source/dev_api.rst @@ -1,31 +1,9 @@ Developer API reference ======================= -Documentation for all methods is available at the following links: - -* :ref:`TomographyBeamModelDevDocs` -* :ref:`ParallelBeamModelDevDocs` -* :ref:`ProjectorsDevDocs` - **MBIRJAX** can be extended to include other geometries by following the outline in new_model_template.py: .. include:: _static/new_model_template.py :code: python -.. automodule:: mbirjax - :members: - :undoc-members: - :show-inheritance: - -.. toctree:: - :hidden: - :maxdepth: 4 - :caption: Classes - - dev_tomography_model - dev_projectors - dev_parameter_handler - dev_parallel_beam_model - dev_cone_beam_model - diff --git a/docs/source/dev_cone_beam_model.rst b/docs/source/dev_cone_beam_model.rst deleted file mode 100644 index 2b52c02..0000000 --- a/docs/source/dev_cone_beam_model.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. _ConeBeamModelDevDocs: - -============= -ConeBeamModel -============= - -The ``ConeBeamModel`` extends the functionalities provided by :class:`mbirjax.Tomography`. -This class inherits all behaviors and attributes of the TomographyModel and implements projectors specific -to cone beam CT. - -.. autoclass:: mbirjax.ConeBeamModel - :members: - :member-order: bysource - :show-inheritance: - diff --git a/docs/source/dev_parallel_beam_model.rst b/docs/source/dev_parallel_beam_model.rst deleted file mode 100644 index 72816d6..0000000 --- a/docs/source/dev_parallel_beam_model.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. _ParallelBeamModelDevDocs: - -================= -ParallelBeamModel -================= - -The ``ParallelBeamModel`` extends the functionalities provided by :class:`mbirjax.Tomography`. -This class inherits all behaviors and attributes of the TomographyModel and implements projectors specific -to parallel beam CT. - -.. autoclass:: mbirjax.ParallelBeamModel - :members: - :member-order: bysource - :show-inheritance: - diff --git a/docs/source/dev_parameter_handler.rst b/docs/source/dev_parameter_handler.rst deleted file mode 100644 index 449d7f3..0000000 --- a/docs/source/dev_parameter_handler.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. _ParameterHandlerDevDocs: - -================ -ParameterHandler -================ - -The ``ParameterHandler`` class implements the parameter handling used by :class:`mbirjax.Tomography`. - -.. autoclass:: mbirjax.ParameterHandler - :members: - :member-order: bysource - :show-inheritance: - diff --git a/docs/source/dev_projectors.rst b/docs/source/dev_projectors.rst deleted file mode 100644 index 9e946fd..0000000 --- a/docs/source/dev_projectors.rst +++ /dev/null @@ -1,21 +0,0 @@ -.. _ProjectorsDevDocs: - -========== -Projectors -========== - -The ``Projectors`` class uses the low-level projection functions -implemented in a specific geometry in order to override - -* :meth:`mbirjax.TomographyModel.sparse_forward_project` -* :meth:`mbirjax.TomographyModel.sparse_back_project` -* :meth:`mbirjax.TomographyModel.compute_hessian_diagonal` - -The ``Projectors`` class provides JAX-specific code using vmap and scan along with batching along pixels and -views in order to provide code that balances memory-efficiency with time-efficiency. - -.. autoclass:: mbirjax.Projectors - :members: - :member-order: bysource - :show-inheritance: - diff --git a/docs/source/dev_tomography_model.rst b/docs/source/dev_tomography_model.rst deleted file mode 100644 index fb92582..0000000 --- a/docs/source/dev_tomography_model.rst +++ /dev/null @@ -1,14 +0,0 @@ -.. _TomographyBeamModelDevDocs: - -=============== -TomographyModel -=============== - -The ``TomographyModel`` provides the basic interface for all specific geometries for tomographic projection -and reconstruction. - -.. autoclass:: mbirjax.TomographyModel - :members: - :member-order: bysource - :show-inheritance: - diff --git a/docs/source/usr_api.rst b/docs/source/usr_api.rst index 135b2da..6046f90 100644 --- a/docs/source/usr_api.rst +++ b/docs/source/usr_api.rst @@ -44,7 +44,7 @@ Parameter Handling ------------------ See the :ref:`Primary Parameters ` page for a description of the primary parameters. -Parameter handling is inherited from :ref:`ParameterHandlerDevDocs`, with the following primary methods. +Parameter handling uses the following primary methods. .. autosummary:: @@ -69,3 +69,19 @@ Parameter handling is inherited from :ref:`ParameterHandlerDevDocs`, with the fo usr_parallel_beam_model usr_cone_beam_model usr_plot_utils + usr_preprocess + +Preprocessing +------------------ + +Preprocessing functions are implemented in :ref:`PreprocessDocs`. This includes various methods to compute and correct the sinogram data as needed. + +.. autosummary:: + + mbirjax.preprocess.compute_sino_transmission + mbirjax.preprocess.interpolate_defective_pixels + mbirjax.preprocess.correct_det_rotation + mbirjax.preprocess.estimate_background_offset + mbirjax.preprocess.NSI.compute_sino_and_params + mbirjax.preprocess.NSI.load_scans_and_params + diff --git a/docs/source/usr_cone_beam_model.rst b/docs/source/usr_cone_beam_model.rst index ec38f96..169a1ce 100644 --- a/docs/source/usr_cone_beam_model.rst +++ b/docs/source/usr_cone_beam_model.rst @@ -12,7 +12,6 @@ Constructor ----------- .. autoclass:: mbirjax.ConeBeamModel - :no-index: :show-inheritance: Parent Class diff --git a/docs/source/usr_parallel_beam_model.rst b/docs/source/usr_parallel_beam_model.rst index 816b6e8..41f987d 100644 --- a/docs/source/usr_parallel_beam_model.rst +++ b/docs/source/usr_parallel_beam_model.rst @@ -12,7 +12,6 @@ Constructor ----------- .. autoclass:: mbirjax.ParallelBeamModel - :no-index: :show-inheritance: Parent Class diff --git a/docs/source/usr_plot_utils.rst b/docs/source/usr_plot_utils.rst index bc1cff2..cf0a4c2 100644 --- a/docs/source/usr_plot_utils.rst +++ b/docs/source/usr_plot_utils.rst @@ -11,7 +11,6 @@ Here is an example showing views of a modified Shepp-Logan phantom, with changin :alt: An animated image of the slice viewer. .. automodule:: mbirjax.plot_utils - :no-index: :members: :undoc-members: :show-inheritance: diff --git a/docs/source/usr_preprocess.rst b/docs/source/usr_preprocess.rst new file mode 100644 index 0000000..797c3d0 --- /dev/null +++ b/docs/source/usr_preprocess.rst @@ -0,0 +1,39 @@ +.. _PreprocessDocs: + +==================== +Preprocess utilities +==================== + +The ``preprocess`` module provides scanner-specific preprocessing and more general preprocessing to compute and correct the sinogram data. + +NorthStar Instrument (NSI) functions +------------------------------------ + +.. automodule:: mbirjax.preprocess.NSI + :members: compute_sino_and_params, load_scans_and_params + :undoc-members: + :show-inheritance: + + .. rubric:: **Functions:** + + .. autosummary:: + + compute_sino_and_params + load_scans_and_params + +General preprocess functions +---------------------------- + +.. automodule:: mbirjax.preprocess + :members: compute_sino_transmission, estimate_background_offset, interpolate_defective_pixels, correct_det_rotation + :undoc-members: + :show-inheritance: + + .. rubric:: **Functions:** + + .. autosummary:: + + compute_sino_transmission + estimate_background_offset + interpolate_defective_pixels + correct_det_rotation diff --git a/docs/source/usr_tomography_model.rst b/docs/source/usr_tomography_model.rst index 1705fb5..6a87a2e 100644 --- a/docs/source/usr_tomography_model.rst +++ b/docs/source/usr_tomography_model.rst @@ -12,54 +12,42 @@ Constructor ----------- .. autoclass:: mbirjax.TomographyModel - :no-index: :show-inheritance: Recon and Projection -------------------- .. automethod:: mbirjax.TomographyModel.recon - :no-index: .. automethod:: mbirjax.TomographyModel.prox_map - :no-index: .. automethod:: mbirjax.TomographyModel.forward_project - :no-index: .. automethod:: mbirjax.TomographyModel.back_project - :no-index: Saving and Loading ------------------ .. automethod:: mbirjax.TomographyModel.to_file - :no-index: .. automethod:: mbirjax.TomographyModel.from_file - :no-index: Parameter Handling ------------------ .. automethod:: mbirjax.TomographyModel.set_params - :no-index: .. automethod:: mbirjax.ParameterHandler.get_params - :no-index: .. automethod:: mbirjax.ParameterHandler.print_params - :no-index: Data Generation --------------- .. automethod:: mbirjax.TomographyModel.gen_weights - :no-index: .. automethod:: mbirjax.TomographyModel.gen_modified_3d_sl_phantom - :no-index: .. _detailed-parameter-docs: diff --git a/environment.yml b/environment.yml index 56b0797..e59194c 100644 --- a/environment.yml +++ b/environment.yml @@ -17,6 +17,7 @@ dependencies: - sphinx-copybutton - ruamel.yaml - psutil + - striprtf - pip - pip: - -e . diff --git a/experiments/blur_debug.py b/experiments/blur_debug.py new file mode 100644 index 0000000..dd14379 --- /dev/null +++ b/experiments/blur_debug.py @@ -0,0 +1,136 @@ + +import numpy as np +import jax.numpy as jnp +import jax +import time +import matplotlib.pyplot as plt +import gc +import mbirjax + + +if __name__ == "__main__": + + indices = np.arange(20) + + def mosaic(x): + y = jnp.zeros(x.shape) + y = y.at[1::2].set(x[1::2]) + return y + + + x = jax.device_put(np.r_[:10].astype(np.float32)) + fx = mosaic(x) + trans_fun = jax.linear_transpose(mosaic, x) + b = trans_fun(fx) # Raises exception: + + def mosaic(x): + blurred_image_shape, sigma = ((128, 128, 10), 2.0) # self.get_params(['sinogram_shape', 'sigma_psf']) + blurred_image = jnp.zeros(blurred_image_shape).reshape((-1, blurred_image_shape[2])) + blurred_image = blurred_image.at[indices].set(x) + blurred_image = blurred_image.reshape(blurred_image_shape) + return blurred_image + + x = jnp.ones((len(indices), 10)) + fx = mosaic(x) + + vjp_fun = jax.vjp(mosaic, x)[1] + a = vjp_fun(fx) # works fine! + + # trans_fun = jax.linear_transpose(mosaic, x) + # b = trans_fun(fx) # Raises exception: + + """ + This is a script to develop, debug, and tune the blur model + """ + # Initialize the image + image_shape = (64, 64, 5) + sigma = 0.5 + sharpness = -2 + noise_std = 0.04 + + # Set up blur model + blur_model = mbirjax.blur.Blur(image_shape, sigma) + + # Generate phantom + recon_shape = blur_model.get_params('recon_shape') + num_recon_rows, num_recon_cols, num_recon_slices = recon_shape[:3] + phantom = mbirjax.generate_3d_shepp_logan_low_dynamic_range((image_shape[0], image_shape[1], image_shape[0])) + center = phantom.shape[2] // 2 + start = center - image_shape[2] // 2 + phantom = phantom[:, :, start:start+image_shape[2]] + blurred_phantom = blur_model.forward_project(phantom) + blurred_phantom += noise_std * np.random.randn(*blurred_phantom.shape) + mbirjax.slice_viewer(phantom, blurred_phantom) + + blur_model.set_params(sharpness=sharpness) + # blur_model.set_params(sigma_y=noise_std/2) + blur_model.set_params(partition_sequence=(0, 0, 0, 1, 2, 2, )) + recon, recon_params = blur_model.recon(blurred_phantom, weights=None, compute_prior_loss=True, num_iterations=20) + mbirjax.slice_viewer(phantom, recon) + + # Generate indices of pixels + num_subsets = 1 + full_indices = mbirjax.gen_pixel_partition(recon_shape, num_subsets)[0] + num_subsets = 5 + subset_indices = mbirjax.gen_pixel_partition(recon_shape, num_subsets) + voxel_values = phantom.reshape((-1,) + recon_shape[2:])[full_indices] + + x = np.random.rand(*voxel_values.shape) + Ax = blur_model.sparse_forward_project(x, full_indices) + y = np.random.rand(*Ax.shape) + Aty = blur_model.sparse_back_project(y, full_indices) + yt_Ax = np.sum(y * Ax) + xt_Aty = np.sum(x * Aty) + assert(np.allclose(yt_Ax, xt_Aty)) + print("Adjoint property holds for random x, y = : {}".format(np.allclose(yt_Ax, xt_Aty))) + + # ########################## + # ## Test the hessian against a finite difference approximation ## # + hessian = blur_model.compute_hessian_diagonal() + + x = jnp.zeros(recon_shape) + key = jax.random.key(np.random.randint(100000)) + key, subkey = jax.random.split(key) + i, j = jax.random.randint(subkey, shape=(2,), minval=0, maxval=num_recon_rows) + key, subkey = jax.random.split(key) + k = jax.random.randint(subkey, shape=(), minval=0, maxval=num_recon_slices) + + eps = 0.01 + x = x.at[i, j, k].set(eps) + indices = jnp.arange(num_recon_rows * num_recon_cols) + voxel_values = x.reshape((-1, num_recon_slices))[indices] + Ax = blur_model.sparse_forward_project(voxel_values, indices) + AtAx = blur_model.sparse_back_project(Ax, indices).reshape(x.shape) + finite_diff_hessian = AtAx[i, j, k] / eps + print('Hessian matches finite difference: {}'.format(jnp.allclose(hessian.reshape(x.shape)[i, j, k], finite_diff_hessian))) + + # ########################## + # Show the forward and back projection from a single pixel + i, j = num_recon_rows // 4, num_recon_cols // 3 + x = jnp.zeros(recon_shape) + x = x.at[i, j, :].set(1) + voxel_values = x.reshape((-1, num_recon_slices))[indices] + + Ax = blur_model.sparse_forward_project(voxel_values, indices) + Aty = blur_model.sparse_back_project(Ax, indices) + Aty = blur_model.reshape_recon(Aty) + + y = jnp.zeros_like(Ax) + view_index = 30 + y = y.at[view_index].set(Ax[view_index]) + index = jnp.ravel_multi_index((6, 6), (num_recon_rows, num_recon_cols)) + + slice_index = (num_recon_slices + 1) // 2 + fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5)) + cax = ax[0].imshow(x[:, :, slice_index]) + ax[0].set_title('x = phantom') + fig.colorbar(cax, ax=ax[0]) + cax = ax[1].imshow(Ax[:, :, slice_index]) + ax[1].set_title('y = Ax') + fig.colorbar(cax, ax=ax[1]) + cax = ax[2].imshow(Aty[:, :, slice_index]) + ax[2].set_title('Aty = AtAx') + fig.colorbar(cax, ax=ax[2]) + plt.pause(1) + + a = 0 diff --git a/experiments/cvpr-2024/fourier_response.py b/experiments/cvpr-2024/fourier_response.py new file mode 100644 index 0000000..8b32cb2 --- /dev/null +++ b/experiments/cvpr-2024/fourier_response.py @@ -0,0 +1,237 @@ +import numpy as np +import jax.numpy as jnp +import matplotlib.pyplot as plt +import mbirjax +import mbirjax.parallel_beam +from scipy.sparse.linalg import svds, eigsh, aslinearoperator, LinearOperator +import jax + + +def create_deltas(diam): + # Create an array of images, each with a single nonzero entry, corresponding to increasing Fourier frequencies, + # where low frequencies are tin the center. + # Each image has size diam x diam if diam is even or (diam-1) x (diam-1) otherwise. + # Image 0 has a 1 in the approximate center of the image, corresponding to the constant image after FFT. + # The subsequent images fill in concentric square rings centered at the first point: 3x3, then 5x5, etc. + # Taking the FFT of these images produces a set of complex images in space domain with increasing frequency. + # These delta images will also be used in space to illustrate the AtA psf. + n = diam // 2 + size = 2 * n + total_layers = 4 * (n ** 2) + deltas = np.zeros((size, size, total_layers)) + ordered_indices = np.zeros((total_layers,), dtype=int) + + # Center point + deltas[n, n, 0] = 1 + ordered_indices[0] = np.ravel_multi_index((n, n), (size, size)) + index = 1 + + # Iterate over increasing square rings + for radius in range(1, n+1): + # Get the top right point + i = n - radius + j = n + radius + # Move left to right along the top + for dj in range(0, 2*radius + 1): + if j - dj < size: + deltas[i, j - dj, index] = 1 + ordered_indices[index] = np.ravel_multi_index([i, j - dj], (size, size)) + index += 1 + # Then the left side + j = j - 2*radius + for di in range(1, 2*radius + 1): + if i + di < size: + deltas[i + di, j, index] = 1 + ordered_indices[index] = np.ravel_multi_index([i + di, j], (size, size)) + index += 1 + # Bottom + i = i + 2*radius + for dj in range(1, 2*radius + 1): + if np.maximum(j + dj, i) < size: + deltas[i, j + dj, index] = 1 + ordered_indices[index] = np.ravel_multi_index([i, j + dj], (size, size)) + index += 1 + # Right side + j = j + 2 * radius + for di in range(1, 2 * radius): + if np.maximum(j, i - di) < size: + deltas[i - di, j, index] = 1 + ordered_indices[index] = np.ravel_multi_index([i - di, j], (size, size)) + index += 1 + + return deltas, ordered_indices + + +def neighbor_mean(image_stack): + # Determine the mean over 4 adjacent nearest neighbors in xy directions only + output = np.zeros_like(image_stack) + + # Pad using reflected boundaries + m0, m1 = image_stack.shape[:2] + padded_stack = np.zeros((m0+2, m1+2, image_stack.shape[2])) + padded_stack[1:-1, 1:-1] = image_stack + padded_stack[0, 1:-1] = image_stack[1] + padded_stack[-1, 1:-1] = image_stack[-2] + padded_stack[1:-1, 0] = image_stack[:, 1] + padded_stack[1:-1, -1] = image_stack[:, -2] + for i in [0, 1]: + for j in [0, 1]: + padded_stack[(m0+1)*i, (m1+1)*j] = image_stack[(m0-1)*i, (m1-1)*j] + + # Sum over the xy neighbor differences + for a0 in [0, 1]: + for a1 in [0, 1]: + output[a0:m0-1+a0, a1:m1-1+a1] += image_stack[1-a0:m0-a0, 1-a1:m1-a1] + + output /= 4 + return output + + +if __name__ == "__main__": + """ + This is a script to investigate the Fourier response of the forward and prior models, with and without masking. + """ + # Set the gray level (0-1) and whether the subset is random or a grid + g = 0.5 + grid = False + + view_batch_size = None + pixel_batch_size = None + with jax.experimental.enable_x64(True): # Finite difference requires 64 bit arithmetic + + # Initialize sinogram + num_views = 32 + num_det_rows = 1024 + num_det_channels = 32 + start_angle = 0 + end_angle = jnp.pi + sinogram = jnp.zeros((num_views, num_det_rows, num_det_channels)) + angles = jnp.linspace(start_angle, jnp.pi, num_views, endpoint=False) + + # Set up parallel beam model + parallel_model = mbirjax.ParallelBeamModel(sinogram.shape, angles) + parallel_model.set_params(view_batch_size=view_batch_size, pixel_batch_size=pixel_batch_size) + recon_shape = parallel_model.get_params('recon_shape') + hess = parallel_model.compute_hessian_diagonal()[:, :, 0].reshape((-1, 1)) + + # Generate indices of pixels + flat_recon_shape = (recon_shape[0] * recon_shape[1], recon_shape[2]) + if grid: + linear_subsample = np.ceil(np.sqrt(1 / g)).astype(int) + subset_mask = np.zeros(recon_shape[:2], dtype=int) + subset_mask[::linear_subsample] += 1 + subset_mask[:, ::linear_subsample] += 1 + subset_mask = np.clip(subset_mask - 1, 0, None) + subset_mask = subset_mask.flatten() + subset_indices = np.where(subset_mask > 0)[0] + else: + full_indices = np.arange(np.prod(recon_shape[:2])) + subset_indices = np.where(np.random.rand(*full_indices.shape) <= g)[0] + subset_mask = np.zeros(flat_recon_shape[0]) + subset_mask[subset_indices] = 1 + + clip_min = 0.001 + # Generate single point images, starting from the centering and spiralling out + deltas, ordered_inds = create_deltas(num_det_channels) + + ################## + # Get psf in space + voxel_values = deltas.reshape((-1,) + recon_shape[2:]) + + print('Starting forward projection of spatial deltas') + sinogram = parallel_model.sparse_forward_project(voxel_values[subset_indices], subset_indices) + bp_subset = parallel_model.sparse_back_project(sinogram, subset_indices) + bp = np.zeros(flat_recon_shape) + bp[subset_indices] = bp_subset + bp = bp / hess + bp = bp.reshape(recon_shape) + bp_norm = np.linalg.norm(bp, axis=(0, 1)) + scale = np.amax(bp) + title = 'AtA PSF in space: output scaled by 1 / {:.1f}'.format(scale) + title += '\nLeft: single point in space, Right: AtA of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, bp / scale, title=title) + title = 'AtA PSF in space: output in log10'.format(scale) + title += '\nLeft: single point in space, Right: AtA of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.log10(np.clip(bp / scale, clip_min, 1)), title=title) + + ###################### + # Get psf in frequency + deltas_shift = np.fft.fftshift(deltas, axes=(0, 1)) + fourier_images = np.fft.fft2(deltas_shift, axes=(0, 1)) + title = 'fftshift Fourier frequency and corresponding real(FFT)' + title += '\nLeft: single point in frequency, Right: real(FFT) of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.real(fourier_images), slice_axis=2, title=title) + fft_images_phantom = fourier_images[:, :, :num_det_rows] + + # Generate sinogram data + voxel_values = fft_images_phantom.reshape((-1,) + recon_shape[2:]) + + print('Starting forward projection of frequency deltas') + sinogram_real = parallel_model.sparse_forward_project(np.real(voxel_values[subset_indices]), subset_indices) + sinogram_imag = parallel_model.sparse_forward_project(np.imag(voxel_values[subset_indices]), subset_indices) + + bp_real_subset = parallel_model.sparse_back_project(sinogram_real, subset_indices) + bp_imag_subset = parallel_model.sparse_back_project(sinogram_imag, subset_indices) + bp_complex_subset = bp_real_subset + 1j * bp_imag_subset + bp_complex = np.zeros(flat_recon_shape, dtype=np.complex64) + bp_complex[subset_indices] = bp_complex_subset + bp_complex = bp_complex / hess + bp_complex = bp_complex.reshape(recon_shape) + bp_fft = np.fft.ifft2(bp_complex, axes=(0, 1)) + bp_fft = np.fft.ifftshift(bp_fft, axes=(0, 1)) + scale = np.amax(np.abs(bp_fft)) + title = '|AtA frequency transfer function|: output scaled by 1 / {:.1f}'.format(scale) + title += '\nLeft: single point in frequency, Right: |IFFT(AtA(FFT))| of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.abs(bp_fft) / scale, title=title, + vmin=0, vmax=1, cmap='viridis') + title = '|AtA frequency transfer function|: output in log10' + title += '\nLeft: single point in frequency, Right: |IFFT(AtA(FFT))| of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.log10(np.clip(np.abs(bp_fft), clip_min, None)), + title=title, cmap='viridis') + + ###################################### + # Get psf in frequency for prior model + print('Computing prior update of frequency deltas') + prior_space_real = 2 * (np.real(fft_images_phantom) - neighbor_mean(np.real(fft_images_phantom))) + prior_space_imag = 2 * (np.imag(fft_images_phantom) - neighbor_mean(np.imag(fft_images_phantom))) + prior_space_real *= subset_mask + prior_space_imag *= subset_mask + prior_fft = np.fft.ifft2(prior_space_real + 1j * prior_space_imag, axes=(0, 1)) + prior_fft = np.fft.ifftshift(prior_fft, axes=(0, 1)) + scale = np.amax(np.abs(prior_fft)) + title = '|Prior step frequency transfer function|: output scaled by 1 / {:.1f}'.format(scale) + title += '\nLeft: single point in frequency, Right: |FFT(prior step(FFT))| of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.abs(prior_fft) / scale, title=title, + vmin=0, vmax=1, cmap='viridis') + title = '|Prior step frequency transfer function|: output in log10'.format(scale) + title += '\nLeft: single point in frequency, Right: log10|FFT(prior step(FFT))| of that point, g={}'.format(g) + mbirjax.slice_viewer(deltas, np.log10(np.clip(np.abs(prior_fft), clip_min, None)), + title=title, cmap='viridis') + + bp_fft_flat = bp_fft.reshape((-1, bp_fft.shape[2])) + bp_fft_flat = bp_fft_flat[ordered_inds] + bp_fft_flat_log10 = np.log10(np.clip(np.abs(bp_fft_flat), clip_min, None)) + + prior_fft_flat = prior_fft.reshape((-1, prior_fft.shape[2])) + prior_fft_flat = prior_fft_flat[ordered_inds] + prior_fft_flat_log10 = np.log10(np.clip(np.abs(prior_fft_flat), clip_min, None)) + + plt.plot(np.diag(bp_fft_flat_log10), '.') + plt.plot(np.diag(prior_fft_flat_log10), '.') + plt.title('Diagonal elements of log10 of |freq transfer function|') + plt.legend(['AtA', 'Prior']) + title = 'log10 of |freq transfer function|\nLeft: AtA, Right: Prior' + title += '\nEach row is one input frequency, each column one ouptut frequency' + mbirjax.slice_viewer(bp_fft_flat_log10, prior_fft_flat_log10, cmap='viridis',title=title) + + gammas = np.linspace(0, 2, 20) + weighted_sum = bp_fft_flat[:, :, None] + prior_fft_flat[:, :, None] * gammas[None, None, :] + joint_transfer_log10 = np.log10(np.clip(np.abs(weighted_sum), clip_min, None)) + title = 'log10 of |freq trans func| of AtA + gamma * prior' + title += '\nAdjust the slider to change gamma' + mbirjax.slice_viewer(joint_transfer_log10, title=title, slice_label='10 * gamma =', cmap='viridis') + + mbirjax.slice_viewer(np.log10(np.clip(np.abs(bp_fft), clip_min, None)), + np.log10(np.clip(np.abs(prior_fft), clip_min, None)), + title='Forward (left) and prior (right) PSF in frequency with output in log10') + a = 0 diff --git a/experiments/cvpr-2024/spectral_response.py b/experiments/cvpr-2024/spectral_response.py index 603a448..1145fa0 100644 --- a/experiments/cvpr-2024/spectral_response.py +++ b/experiments/cvpr-2024/spectral_response.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import mbirjax import mbirjax.parallel_beam -from scipy.sparse.linalg import svds, aslinearoperator, LinearOperator +from scipy.sparse.linalg import svds, eigsh, aslinearoperator, LinearOperator import jax if __name__ == "__main__": @@ -15,28 +15,25 @@ with jax.experimental.enable_x64(True): # Finite difference requires 64 bit arithmetic # Initialize sinogram - num_views = 128 + num_views = 32 num_det_rows = 1 - num_det_channels = 128 + num_det_channels = 32 start_angle = 0 end_angle = jnp.pi sinogram = jnp.zeros((num_views, num_det_rows, num_det_channels)) angles = jnp.linspace(start_angle, jnp.pi, num_views, endpoint=False) - # Initialize a random key - seed_value = np.random.randint(1000000) - key = jax.random.PRNGKey(seed_value) - # Set up parallel beam model parallel_model = mbirjax.ParallelBeamModel(sinogram.shape, angles) # Generate phantom recon_shape = parallel_model.get_params('recon_shape') num_recon_rows, num_recon_cols, num_recon_slices = recon_shape[:3] - phantom = mbirjax.gen_cube_phantom(recon_shape) + phantom = mbirjax.generate_3d_shepp_logan_low_dynamic_range((num_det_channels,num_det_channels,num_det_channels)) + phantom = phantom[:, :, num_det_channels // 2] # Generate indices of pixels - num_subsets = 1 + num_subsets = 4 full_indices = mbirjax.gen_pixel_partition(recon_shape, num_subsets)[0] # Generate sinogram data @@ -51,76 +48,128 @@ print('Sinogram shape: {}'.format(sinogram.shape)) # Get the vector of indices - indices = jnp.arange(num_recon_rows * num_recon_cols) + all_indices = jnp.arange(num_recon_rows * num_recon_cols) sinogram = jnp.array(sinogram) - indices = jnp.array(indices) + all_indices = jnp.array(all_indices) + + hess = parallel_model.compute_hessian_diagonal().flatten() # Run once to finish compiling print('Starting back projection') - bp = parallel_model.sparse_back_project(sinogram, indices) + bp = parallel_model.sparse_back_project(sinogram, all_indices) print('Recon shape: ({}, {}, {})'.format(num_recon_rows, num_recon_cols, num_recon_slices)) - # ########################## - # Test the adjoint property - # Get a random 3D phantom to test the adjoint property - key, subkey = jax.random.split(key) - x = jax.random.uniform(subkey, shape=bp.shape) - key, subkey = jax.random.split(key) - y = jax.random.uniform(subkey, shape=sinogram.shape) - - # Do a forward projection, then a backprojection - x = x.reshape((-1, num_recon_slices))[indices] - Ax = parallel_model.sparse_forward_project(x, indices) - Aty = parallel_model.sparse_back_project(y, indices) - - # Calculate and - Aty_x = jnp.sum(Aty * x) - y_Ax = jnp.sum(y * Ax) - - assert(np.allclose(Aty_x, y_Ax)) - print("Adjoint property holds for random x, y = : {}".format(np.allclose(Aty_x, y_Ax))) - - def Ax_flat(local_x): - local_x = np.reshape(local_x, x.shape) - ax_flat = parallel_model.sparse_forward_project(local_x, indices) + input_size = all_indices.size + input_shape = (all_indices.size, 1) + output_size = sinogram.size + output_shape = sinogram.shape + + # Set up for svd + def Ax_full(local_x): + local_x = np.reshape(local_x, input_shape) + ax_flat = parallel_model.sparse_forward_project(local_x, all_indices) ax_flat = np.array(ax_flat.flatten()) return ax_flat - assert(np.allclose(Ax.flatten(), Ax_flat(x.flatten()))) - print('Ax_flat matches forward projection') - - def Aty_flat(y): - local_y = np.reshape(y, Ax.shape) - aty_flat = parallel_model.sparse_back_project(local_y, indices) + def Aty_full(local_y): + local_y = np.reshape(local_y, output_shape) + aty_flat = parallel_model.sparse_back_project(local_y, all_indices) aty_flat = np.array(aty_flat.flatten()) return aty_flat - assert(np.allclose(Aty.flatten(), Aty_flat(y.flatten()))) - print('Aty_flat matches back projection') + def precond_AtAx(local_x): + atax = Aty_full(Ax_full(local_x)) + precond_atax = atax / hess + return precond_atax + + def precond_AtAx_T(local_x): + local_x = local_x / hess + atax = Aty_full(Ax_full(local_x)) + return atax + + AtAx_linear_operator = LinearOperator(matvec=precond_AtAx, rmatvec=precond_AtAx_T, shape=(input_size, input_size)) + + operator_shape = (sinogram.size, input_size) + num_sing_values = np.amin(operator_shape) - 20 + eig_vects = True + print('Computing full AtAx / H eigen-decomposition') + if eig_vects: + u, s, vh = svds(AtAx_linear_operator, k=num_sing_values, tol=1e-6, return_singular_vectors=True, solver='propack') + vh = vh[::-1, :] + u = u[:, ::-1] + # mbirjax.slice_viewer(vh.reshape(num_sing_values, num_det_channels, num_det_channels), slice_axis=0) + # mbirjax.slice_viewer(u.reshape((num_det_channels, num_det_channels, num_sing_values)), slice_axis=2) + else: + s = svds(AtAx_linear_operator, k=num_sing_values, tol=1e-6, return_singular_vectors=False, solver='propack') - Ax_linear_operator = LinearOperator(matvec=Ax_flat, rmatvec=Aty_flat, shape=(Ax.size, x.size)) + s = s[::-1] - Ax_lo = Ax_linear_operator(x.flatten()) - Aty_lo = Ax_linear_operator.rmatvec(np.array(y).flatten()) + # Get a mask + g = 1 / num_subsets + mask = np.array(np.random.rand(*phantom.shape) < g) + mask = mask.reshape((-1, 1)) + mask_indices = np.where(mask)[0] + hess_m = hess[mask_indices] - assert(np.allclose(Ax.flatten(), Ax_lo)) - assert(np.allclose(Aty.flatten(), Aty_lo)) - print('Linear operator matches known projectors') + # Define the masked operators for svd + def Ax_masked(local_x): + local_x = local_x.reshape((-1, 1)) + ax_flat = parallel_model.sparse_forward_project(local_x, mask_indices) + ax_flat = np.array(ax_flat.flatten()) + return ax_flat - num_sing_values = 15 # num_views * num_det_channels - sing_vects = True - if sing_vects: - u, s, vh = svds(Ax_linear_operator, k=num_sing_values, tol=1e-6, return_singular_vectors=True, solver='propack') - vh = vh.reshape(num_sing_values, num_det_channels, num_det_channels) - vh = vh[::-1, :, :] - mbirjax.slice_viewer(vh, slice_axis=0) - u = u.reshape((num_det_channels, num_det_channels, num_sing_values)) - u = u[:, :, ::-1] - mbirjax.slice_viewer(u, slice_axis=2) - else: - s = svds(Ax_linear_operator, k=num_sing_values, tol=1e-6, return_singular_vectors=False, - solver='propack') - plt.plot(np.sort(s)[::-1], '.') + def Aty_masked(local_y): + local_y = local_y.reshape(output_shape) + aty_flat = parallel_model.sparse_back_project(local_y, mask_indices) + aty_flat = np.array(aty_flat.flatten()) + return aty_flat + + def precond_AtAx_masked(local_x): + atax = Aty_masked(Ax_masked(local_x)) + precond_atax = atax / hess_m + return precond_atax + + def precond_AtAx_T_masked(local_x): + local_x = local_x / hess_m + atax = Aty_masked(Ax_masked(local_x)) + return atax + + # Get the svd for the masked operator + masked_operator_shape = (len(mask_indices), len(mask_indices)) + Ax_masked_linear_operator = LinearOperator(matvec=precond_AtAx_masked, rmatvec=precond_AtAx_T_masked, shape=masked_operator_shape) + + print('Computing masked AtA / H eigen-decomposition') + num_masked_sing_values = np.amin(masked_operator_shape) - 1 + num_masked_sing_values = np.minimum(num_masked_sing_values, num_sing_values) + u_m, s_m, vh_m = svds(Ax_masked_linear_operator, k=num_masked_sing_values, tol=1e-6, return_singular_vectors=True, solver='propack') + vh_m = vh_m[::-1, :] + u_m = u_m[:, ::-1] + s_m = s_m[::-1] + + vm = np.zeros((vh_m.shape[0], np.prod(phantom.shape))) + vm[:, mask_indices] = vh_m + vm = vm.reshape((vm.shape[0],) + phantom.shape) + mbirjax.slice_viewer(vm, slice_axis=0, title='Eigenimages for masked A matrix, {} subsets'.format(num_subsets)) + + fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) + plt.suptitle('recon_shape={}; masking with {} subsets'.format(recon_shape, num_subsets)) + + axs[0, 0].semilogy(s, '.') + axs[0, 0].set_title('Singular values of AtA / H') + axs[0, 1].semilogy(s_m, '.') + axs[0, 1].set_title('Singular values of masked AtA / H') + y_limits = list(axs[0, 0].get_ylim()) + y_limits[0] = 1e-2 + axs[0, 0].set_ylim(y_limits) + axs[0, 1].set_ylim(y_limits) + + im0 = axs[1, 0].imshow(u @ (np.diag(s) @ vh)) + fig.colorbar(im0, ax=axs[1, 0]) + axs[1, 0].set_title('Full AtA / H') + + im1 = axs[1, 1].imshow(u_m @ (np.diag(s_m) @ vh_m)) + fig.colorbar(im1, ax=axs[1, 1]) + axs[1, 1].set_title('Masked AtA / H') plt.show() a = 0 \ No newline at end of file diff --git a/experiments/cvpr-2024/subset_investigation.py b/experiments/cvpr-2024/subset_investigation.py index 5cfbfc7..9271bd5 100644 --- a/experiments/cvpr-2024/subset_investigation.py +++ b/experiments/cvpr-2024/subset_investigation.py @@ -6,11 +6,11 @@ """ This is a script to develop subset selection for VCD. """ - image_shape = (1024, 1024) - small_tile_side = 16 - tile_type = 'repeat' # 'repeat', 'permute', 'random', 'select' + image_shape = (128, 128) + small_tile_side = 10 + tile_type = 'grid' # 'repeat', 'permute', 'random', 'select', 'grid' - num_subsets = 1 + num_subsets = 2 ror_mask = mbirjax.get_2d_ror_mask(image_shape) @@ -33,10 +33,10 @@ num_tiles = [np.ceil(image_shape[k] / pattern.shape[k]).astype(int) for k in [0, 1]] single_subset_inds = np.floor(pattern / (2**16 / num_subsets)).astype(int) - full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) if tile_type == 'repeat': # Repeat each bn subset to do the tiling + full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) subset_inds = np.tile(single_subset_inds, num_tiles) subset_inds = subset_inds[:image_shape[0], :image_shape[1]] subset_inds = (subset_inds + 1) * ror_mask - 1 # Get a 0 at each location outside the mask, subset_ind + 1 at other points @@ -68,6 +68,7 @@ elif tile_type == 'permute': # TODO: work with indices rather than masks # Using a permutation of the bn subsets in each tile location + full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) single_subsets = [(pattern >= bin_boundaries[j]) * (pattern < bin_boundaries[j + 1]) for j in range(num_subsets)] perms = [np.random.permutation(num_subsets) for j in np.arange(np.prod(num_tiles))] for k in range(num_subsets): @@ -83,6 +84,7 @@ # For each subset, select one element from each small tile. Use a different permutation of the subsets # for each tile to determine which subset gets which element from each small tile. num_subsets = small_tile_side ** 2 + full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) num_small_tiles = [np.ceil(image_shape[k] / small_tile_side).astype(int) for k in [0, 1]] perms = [np.random.permutation(num_subsets) for j in np.arange(np.prod(num_small_tiles))] perms = np.array(perms).T @@ -105,8 +107,35 @@ full_mask = full_mask.reshape((num_subsets,) + image_shape) full_mask = full_mask * ror_mask.reshape((1,) + image_shape) + elif tile_type == 'grid': + # For each subset, select the same element from each small tile. + num_subsets = small_tile_side ** 2 + full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) + num_small_tiles = [np.ceil(image_shape[k] / small_tile_side).astype(int) for k in [0, 1]] + perms = [np.arange(num_subsets) for j in np.arange(np.prod(num_small_tiles))] + perms = np.array(perms).T + small_tile_corners = np.meshgrid(np.arange(num_small_tiles[0]), np.arange(num_small_tiles[1])) + small_tile_corners[0] *= small_tile_side + small_tile_corners[1] *= small_tile_side + tile_inds = np.unravel_index(perms, (small_tile_side, small_tile_side)) + subset_inds = [small_tile_corners[j].reshape((1, -1)) + tile_inds[j] for j in [0, 1]] + good_inds = (subset_inds[0] < image_shape[0]) * (subset_inds[1] < image_shape[1]) + flat_inds = [] + for k in range(num_subsets): + flat_inds.append( + np.ravel_multi_index((subset_inds[0][k][good_inds[k]], subset_inds[1][k][good_inds[k]]), + image_shape)) + + full_mask = full_mask.reshape((num_subsets, np.prod(image_shape))) + for j in range(num_subsets): + full_mask[j][flat_inds[j]] = 1 + + full_mask = full_mask.reshape((num_subsets,) + image_shape) + full_mask = full_mask * ror_mask.reshape((1,) + image_shape) + else: # 'random' # Random sampling - THIS DOES NOT GIVE A PARTITION! + full_mask = np.zeros((num_subsets,) + image_shape, dtype=np.float32) for k in range(num_subsets): cur_mask = np.random.rand(*image_shape) full_mask[k] = cur_mask < 1 / num_subsets @@ -114,8 +143,8 @@ full_mask_fft = np.fft.fft2(full_mask) full_mask_fft = np.fft.fftshift(full_mask_fft, axes=(1, 2)) + # mbirjax.slice_viewer(np.real(full_mask_fft), np.imag(full_mask_fft), slice_axis=0) full_mask_fft = 20 * np.log10(np.abs(full_mask_fft) + 1e-12) - # print('Number of points = {}'.format(np.sum(subsets, axis=(1, 2)))) mbirjax.slice_viewer(40 * full_mask, full_mask_fft, slice_axis=0, slice_label='Subset', title='Subset mask and FFT in dB', vmin=0, vmax=60) diff --git a/experiments/recon timing.txt b/experiments/recon timing.txt new file mode 100644 index 0000000..fea0831 --- /dev/null +++ b/experiments/recon timing.txt @@ -0,0 +1,550 @@ + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 76.066 seconds +CPU + bytes_in_use: 0.221GB + peak_bytes_in_use: 0.402GB + bytes_limit: 68.091GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 10.688 seconds +CPU + bytes_in_use: 0.204GB + peak_bytes_in_use: 0.378GB + bytes_limit: 68.060GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 8.960 seconds +CPU + bytes_in_use: 0.234GB + peak_bytes_in_use: 0.403GB + bytes_limit: 68.094GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.612 seconds +CPU + bytes_in_use: 0.197GB + peak_bytes_in_use: 0.387GB + bytes_limit: 68.155GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (64, 128, 128) +Recon shape = (128, 128, 128) +Elapsed time for recon is 54.834 seconds +CPU + bytes_in_use: 0.247GB + peak_bytes_in_use: 0.711GB + bytes_limit: 67.827GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (64, 128, 128) +Recon shape = (128, 128, 128) +Elapsed time for recon is 58.490 seconds +CPU + bytes_in_use: 0.279GB + peak_bytes_in_use: 0.722GB + bytes_limit: 67.802GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 63.454 seconds +CPU + bytes_in_use: 0.218GB + peak_bytes_in_use: 0.442GB + bytes_limit: 68.091GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.751 seconds +CPU + bytes_in_use: 0.198GB + peak_bytes_in_use: 0.412GB + bytes_limit: 68.087GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.732 seconds +CPU + bytes_in_use: 0.195GB + peak_bytes_in_use: 0.413GB + bytes_limit: 68.023GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.679 seconds +CPU + bytes_in_use: 0.198GB + peak_bytes_in_use: 0.403GB + bytes_limit: 68.057GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 174.643 seconds +CPU + bytes_in_use: 0.205GB + peak_bytes_in_use: 0.435GB + bytes_limit: 67.577GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.691 seconds +CPU + bytes_in_use: 0.230GB + peak_bytes_in_use: 0.447GB + bytes_limit: 67.495GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.683 seconds +CPU + bytes_in_use: 0.220GB + peak_bytes_in_use: 0.454GB + bytes_limit: 67.463GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.666 seconds +CPU + bytes_in_use: 0.199GB + peak_bytes_in_use: 0.418GB + bytes_limit: 67.551GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.698 seconds +CPU + bytes_in_use: 0.203GB + peak_bytes_in_use: 0.423GB + bytes_limit: 67.561GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.671 seconds +CPU + bytes_in_use: 0.206GB + peak_bytes_in_use: 0.415GB + bytes_limit: 67.572GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.940 seconds +CPU + bytes_in_use: 0.208GB + peak_bytes_in_use: 0.438GB + bytes_limit: 64.475GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.932 seconds +CPU + bytes_in_use: 0.209GB + peak_bytes_in_use: 0.435GB + bytes_limit: 64.444GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.788 seconds +CPU + bytes_in_use: 0.202GB + peak_bytes_in_use: 0.418GB + bytes_limit: 64.186GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.724 seconds +CPU + bytes_in_use: 0.190GB + peak_bytes_in_use: 0.414GB + bytes_limit: 64.163GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.955 seconds +CPU + bytes_in_use: 0.215GB + peak_bytes_in_use: 0.450GB + bytes_limit: 64.180GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 12.252 seconds +CPU + bytes_in_use: 0.208GB + peak_bytes_in_use: 0.729GB + bytes_limit: 63.885GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 10.887 seconds +CPU + bytes_in_use: 0.206GB + peak_bytes_in_use: 0.677GB + bytes_limit: 64.232GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 11.981 seconds +CPU + bytes_in_use: 0.199GB + peak_bytes_in_use: 0.691GB + bytes_limit: 64.223GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.915 seconds +CPU + bytes_in_use: 0.214GB + peak_bytes_in_use: 0.435GB + bytes_limit: 64.566GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.921 seconds +CPU + bytes_in_use: 0.213GB + peak_bytes_in_use: 0.439GB + bytes_limit: 64.585GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.693 seconds +CPU + bytes_in_use: 0.192GB + peak_bytes_in_use: 0.419GB + bytes_limit: 64.607GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.702 seconds +CPU + bytes_in_use: 0.205GB + peak_bytes_in_use: 0.416GB + bytes_limit: 64.558GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 8.912 seconds +CPU + bytes_in_use: 0.233GB + peak_bytes_in_use: 0.397GB + bytes_limit: 64.554GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.707 seconds +CPU + bytes_in_use: 0.196GB + peak_bytes_in_use: 0.419GB + bytes_limit: 64.553GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (64, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 12.200 seconds +CPU + bytes_in_use: 0.202GB + peak_bytes_in_use: 0.453GB + bytes_limit: 63.195GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.662 seconds +CPU + bytes_in_use: 0.199GB + peak_bytes_in_use: 0.421GB + bytes_limit: 63.317GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.638 seconds +CPU + bytes_in_use: 0.204GB + peak_bytes_in_use: 0.420GB + bytes_limit: 63.347GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (32, 32, 32) +Elapsed time for recon is 3.068 seconds +CPU + bytes_in_use: 0.177GB + peak_bytes_in_use: 0.334GB + bytes_limit: 63.394GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 20.528 seconds +CPU + bytes_in_use: 0.233GB + peak_bytes_in_use: 0.594GB + bytes_limit: 63.271GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.686 seconds +CPU + bytes_in_use: 0.208GB + peak_bytes_in_use: 0.440GB + bytes_limit: 63.335GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 20.353 seconds +CPU + bytes_in_use: 0.239GB + peak_bytes_in_use: 0.597GB + bytes_limit: 63.201GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 20.460 seconds +CPU + bytes_in_use: 0.240GB + peak_bytes_in_use: 0.574GB + bytes_limit: 63.243GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.681 seconds +CPU + bytes_in_use: 0.198GB + peak_bytes_in_use: 0.403GB + bytes_limit: 63.316GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.702 seconds +CPU + bytes_in_use: 0.230GB + peak_bytes_in_use: 0.534GB + bytes_limit: 62.959GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 20.603 seconds +CPU + bytes_in_use: 0.270GB + peak_bytes_in_use: 0.677GB + bytes_limit: 62.951GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (64, 64, 64) +Elapsed time for recon is 7.665 seconds +CPU + bytes_in_use: 0.239GB + peak_bytes_in_use: 0.530GB + bytes_limit: 62.738GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 64, 64) +Recon shape = (32, 64, 64) +Elapsed time for recon is 5.773 seconds +CPU + bytes_in_use: 0.234GB + peak_bytes_in_use: 0.460GB + bytes_limit: 62.628GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.593 seconds +CPU + bytes_in_use: 0.216GB + peak_bytes_in_use: 0.479GB + bytes_limit: 60.698GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.526 seconds +CPU + bytes_in_use: 0.215GB + peak_bytes_in_use: 0.481GB + bytes_limit: 60.598GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.511 seconds +CPU + bytes_in_use: 0.208GB + peak_bytes_in_use: 0.490GB + bytes_limit: 60.755GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.645 seconds +CPU + bytes_in_use: 0.228GB + peak_bytes_in_use: 0.509GB + bytes_limit: 60.726GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.564 seconds +CPU + bytes_in_use: 0.242GB + peak_bytes_in_use: 0.489GB + bytes_limit: 60.726GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.575 seconds +CPU + bytes_in_use: 0.249GB + peak_bytes_in_use: 0.488GB + bytes_limit: 60.647GB +------------------------- + +------------------------- +Current stats: +Sinogram shape = (32, 32, 64) +Recon shape = (64, 64, 32) +Elapsed time for recon is 5.536 seconds +CPU + bytes_in_use: 0.237GB + peak_bytes_in_use: 0.492GB + bytes_limit: 60.664GB +------------------------- diff --git a/mbirjax/__init__.py b/mbirjax/__init__.py index 47fe6d3..1822b32 100644 --- a/mbirjax/__init__.py +++ b/mbirjax/__init__.py @@ -4,7 +4,9 @@ from .qggmrf import * from .parallel_beam import * from .cone_beam import * +from .blur import * from .vcd_utils import * from .memory_stats import * from .plot_utils import * - +from .preprocess import * +from mbirjax.preprocess.NSI import * diff --git a/mbirjax/_utils.py b/mbirjax/_utils.py index d0cad72..8f45a50 100644 --- a/mbirjax/_utils.py +++ b/mbirjax/_utils.py @@ -3,6 +3,9 @@ FILE_FORMAT_NUMBER = 1.0 # The format number should be changed if the file format changes. +# Update to include new geometries that should be included in the tests suite +_geometry_types_for_tests = ['parallel', 'cone'] + # The order and content of these dictionaries must match the headings and list of dicts below # The second entry in each case indicates if changing that parameter should trigger a recompile _forward_model_defaults_dict = { @@ -33,8 +36,8 @@ 'positivity_flag': {'val': False, 'recompile_flag': False}, 'snr_db': {'val': 30.0, 'recompile_flag': False}, 'sharpness': {'val': 0.0, 'recompile_flag': False}, - 'granularity': {'val': [1, 2, 64, 512, 2048], 'recompile_flag': False}, - 'partition_sequence': {'val': [0, 1, 2, 3, 3, 2, 2, 2, 3, 3, 3, 4, 3, 4, 4], 'recompile_flag': False}, + 'granularity': {'val': [1, 4, 64, 512, 2048], 'recompile_flag': False}, + 'partition_sequence': {'val': [0, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4], 'recompile_flag': False}, 'verbose': {'val': 1, 'recompile_flag': False}, 'pixel_batch_size': {'val': 2048, 'recompile_flag': True}, # TODO: Determine batch sizes dynamically. 'view_batch_size': {'val': 4, 'recompile_flag': True} diff --git a/mbirjax/blur.py b/mbirjax/blur.py new file mode 100644 index 0000000..f3250e6 --- /dev/null +++ b/mbirjax/blur.py @@ -0,0 +1,393 @@ +import jax +import jax.numpy as jnp +import scipy.ndimage as snd +import mbirjax +from functools import partial +from mbirjax import TomographyModel, ParameterHandler + + +class Blur(TomographyModel): + """ + A class designed for handling forward and backward projections using a blur. + This class is used mainly for development and demo. It shows how to use + jax.vjp to implement a back projector given a forward projector. + + This class inherits all methods and properties from the :ref:`TomographyModelDocs` and may override some + to suit parallel beam geometrical requirements. See the documentation of the parent class for standard methods + like setting parameters and performing projections and reconstructions. + + Parameters not included in the constructor can be set using the set_params method of :ref:`TomographyModelDocs`. + Refer to :ref:`TomographyModelDocs` documentation for a detailed list of possible parameters. + + Args: + recon_shape (tuple): + Shape of the recon as a tuple in the form `(rows, columns, slices)`. + sigma_psf (float): + The 2D spatial standard deviation of the Gaussian blur to be applied.. + + See Also + -------- + TomographyModel : The base class from which this class inherits. + """ + + def __init__(self, recon_shape, sigma_psf): + view_params_array = jnp.zeros((1, 1)) + sinogram_shape = recon_shape + super().__init__(sinogram_shape, view_params_array=view_params_array, sigma_psf=sigma_psf) + self.set_params(recon_shape=recon_shape) + + @classmethod + def from_file(cls, filename): + """ + Construct a model from parameters saved using save_params() + + Args: + filename (str): Name of the file containing parameters to load. + + Returns: + Model with the specified parameters. + """ + # Load the parameters and convert to use the ParallelBeamModel keywords. + required_param_names = ['recon_shape', 'sigma_psf'] + required_params, params = ParameterHandler.load_param_dict(filename, required_param_names, values_only=True) + + new_model = cls(**required_params) + new_model.set_params(**params) + return new_model + + def get_magnification(self): + """ + Compute the scale factor from a voxel at iso (at the origin on the center of rotation) to + its projection on the detector. For parallel beam, this is 1, but it may be parameter-dependent + for other geometries. + + Returns: + (float): magnification + """ + magnification = 1.0 + return magnification + + def verify_valid_params(self): + """ + Check that all parameters are compatible for a reconstruction. + + Note: + Raises ValueError for invalid parameters. + """ + pass + + def auto_set_recon_size(self, sinogram_shape, no_compile=True, no_warning=False): + + recon_shape = sinogram_shape + self.set_params(no_compile=no_compile, no_warning=no_warning, recon_shape=recon_shape) + + def create_projectors(self): + """ + Creates an instance of the Projectors class and set the local instance variables needed for forward + and back projection and compute_hessian_diagonal. This method requires that the current geometry has + implementations of :meth:`forward_project_pixel_batch_to_one_view` and :meth:`back_project_one_view_to_pixel_batch` + + Returns: + Nothing, but creates jit-compiled functions. + """ + self.sparse_forward_project = self.sparse_forward + self.sparse_back_project = self.sparse_back + self.compute_hessian_diagonal = self.hessian_diagonal + + def sparse_forward(self, voxel_values, indices, view_indices=()): + """ + Forward project the given voxel cylinders. + The indices are into a flattened 2D array of shape (recon_rows, recon_cols), and the projection is done using + all voxels with those indices across all the slices. + + Args: + voxel_values (ndarray or jax array): 2D array of voxel values to project, size (len(pixel_indices), num_recon_slices). + indices (ndarray or jax array): Array of indices specifying which voxels to project. + view_indices (ndarray or jax array, optional): Unused + + Returns: + jnp array: The resulting 3D image after applying the blur. + """ + blurred_image_shape, sigma = self.get_params(['sinogram_shape', 'sigma_psf']) + blurred_image = jnp.zeros(blurred_image_shape).reshape((-1, blurred_image_shape[2])) + blurred_image = blurred_image.at[indices].set(voxel_values) + blurred_image = blurred_image.reshape(blurred_image_shape) + if sigma > 0: + blurred_image = gaussian_filter(blurred_image, sigma, axes=(0, 1)) + + return blurred_image + + def sparse_back(self, blurred_image, indices, coeff_power=1, view_indices=()): + """ + Back project the image to the voxels given by the indices. + The indices are into a flattened 2D array of shape (recon_rows, recon_cols), and the projection is done using + all voxels with those indices across all the slices. + + Args: + blurred_image (jnp array): 3D jax array containing the image. + indices (jnp array): Array of indices specifying which voxels to back project. + coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). + Normally 1, but should be 2 for compute_hessian_diagonal. + view_indices (ndarray or jax array, optional): Unused + + Returns: + A jax array of shape (len(indices), num_slices) + """ + sigma = self.get_params('sigma_psf') + # TODO: This assumes a symmetric kernel. Should change to allow a more general kernel. + if sigma > 0: + bp_image = gaussian_filter(blurred_image, sigma, axes=(0, 1), coeff_power=coeff_power) + else: + bp_image = blurred_image + bp_image = bp_image.reshape((-1, blurred_image.shape[2])) + recon_at_indices = bp_image[indices] + return recon_at_indices + + def sparse_back_vjp(self, blurred_image, indices, coeff_power=1, view_indices=()): + """ + Back project the image to the voxels given by the indices by using autodifferentiation on sparse_forward. + The indices are into a flattened 2D array of shape (recon_rows, recon_cols), and the projection is done using + all voxels with those indices across all the slices. + + Args: + blurred_image (jnp array): 3D jax array containing the image. + indices (jnp array): Array of indices specifying which voxels to back project. + coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). + Normally 1, but should be 2 for compute_hessian_diagonal. + view_indices (ndarray or jax array, optional): Unused + + Returns: + A jax array of shape (len(indices), num_slices) + """ + sigma = self.get_params('sigma_psf') + input_shape = blurred_image.shape + x = jnp.ones((len(indices), blurred_image.shape[2])) + # TODO: Implement a direct Hessian diagonal. The problem is how to differentiate efficiently on just one input and output at a time. + if coeff_power == 2: + return self.sparse_back(blurred_image, indices, coeff_power=2) + + # Set up the forward model to be able to use vjp = vector Jacobian product = v^T A = (A^T v)^T + def local_forward(voxel_values): + return self.sparse_forward(voxel_values, indices) + if sigma > 0: + vjp_fun = jax.vjp(local_forward, x)[1] + bp_image = vjp_fun(blurred_image)[0] + if coeff_power == 2: + bp_image = bp_image.reshape(input_shape) + bp_image = vjp_fun(bp_image)[0] + else: + bp_image = blurred_image + return bp_image + + def hessian_diagonal(self, weights=None): + """ + Computes the diagonal elements of the Hessian matrix for given weights. + + Args: + weights (jax array, optional): 3D positive weights with same shape as sinogram. Defaults to all 1s. + view_indices (ndarray or jax array, optional): 1D array of indices into the view parameters array. + If None, then all views are used. + + Returns: + jnp array: Diagonal of the Hessian matrix with same shape as recon. + """ + projected_shape, recon_shape = self.get_params(['sinogram_shape', 'recon_shape']) + if weights is None: + weights = jnp.ones(projected_shape) + elif weights.shape != projected_shape: + error_message = 'Weights must be constant or an array compatible with sinogram' + error_message += '\nGot weights.shape = {}, but sinogram.shape = {}'.format(weights.shape, projected_shape) + raise ValueError(error_message) + + num_recon_rows, num_recon_cols, num_recon_slices = recon_shape[:3] + max_index = num_recon_rows * num_recon_cols + indices = jnp.arange(max_index) + hessian = self.sparse_back(weights, indices, coeff_power=2) + return hessian + + +# From scipy: +# Copyright (C) 2003-2005 Peter J. Verveer +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# 3. The name of the author may not be used to endorse or promote +# products derived from this software without specific prior +# written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS +# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +# GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from scipy.ndimage import _ni_support + + +@partial(jax.jit, static_argnames=['sigma', 'order', 'axes', 'truncate', 'radius', 'coeff_power']) +def gaussian_filter(input_image, sigma, order=0, axes=None, + truncate=4.0, *, radius=None, coeff_power=1): + """Multidimensional Gaussian filter. + + Parameters + ---------- + %(input)s + sigma : scalar or sequence of scalars + Standard deviation for Gaussian kernel. The standard + deviations of the Gaussian filter are given for each axis as a + sequence, or as a single number, in which case it is equal for + all axes. + order : int or sequence of ints, optional + The order of the filter along each axis is given as a sequence + of integers, or as a single number. An order of 0 corresponds + to convolution with a Gaussian kernel. A positive order + corresponds to convolution with that derivative of a Gaussian. + truncate : float, optional + Truncate the filter at this many standard deviations. + Default is 4.0. + radius : None or int or sequence of ints, optional + Radius of the Gaussian kernel. The radius are given for each axis + as a sequence, or as a single number, in which case it is equal + for all axes. If specified, the size of the kernel along each axis + will be ``2*radius + 1``, and `truncate` is ignored. + Default is None. + + Returns + ------- + gaussian_filter : ndarray + Returned array of same shape as `input`. + + Notes + ----- + The multidimensional filter is implemented as a sequence of + 1-D convolution filters. The intermediate arrays are + stored in the same data type as the output. Therefore, for output + types with a limited precision, the results may be imprecise + because intermediate results may be stored with insufficient + precision. + + The Gaussian kernel will have size ``2*radius + 1`` along each axis + where ``radius = round(truncate * sigma)``. + + Examples + -------- + >>> from scipy.ndimage import gaussian_filter + >>> import numpy as np + >>> a = np.arange(50, step=2).reshape((5,5)) + >>> a + array([[ 0, 2, 4, 6, 8], + [10, 12, 14, 16, 18], + [20, 22, 24, 26, 28], + [30, 32, 34, 36, 38], + [40, 42, 44, 46, 48]]) + >>> gaussian_filter(a, sigma=1) + array([[ 4, 6, 8, 9, 11], + [10, 12, 14, 15, 17], + [20, 22, 24, 25, 27], + [29, 31, 33, 34, 36], + [35, 37, 39, 40, 42]]) + + >>> from scipy import datasets + >>> import matplotlib.pyplot as plt + >>> fig = plt.figure() + >>> plt.gray() # show the filtered result in grayscale + >>> ax1 = fig.add_subplot(121) # left side + >>> ax2 = fig.add_subplot(122) # right side + >>> ascent = datasets.ascent() + >>> result = gaussian_filter(ascent, sigma=5) + >>> ax1.imshow(ascent) + >>> ax2.imshow(result) + >>> plt.show() + """ + if axes is None: + axes = list(range(input_image.ndim)) + orders = _ni_support._normalize_sequence(order, len(axes)) + sigmas = _ni_support._normalize_sequence(sigma, len(axes)) + radiuses = _ni_support._normalize_sequence(radius, len(axes)) + + axes = [(axes[ii], sigmas[ii], orders[ii], radiuses[ii]) + for ii in range(len(axes)) if sigmas[ii] > 1e-15] + + if len(axes) > 0: + for axis, sigma, order, radius in axes: + sd = float(sigma) + # make the radius of the filter equal to truncate standard deviations + lw = int(truncate * sd + 0.5) + if radius is not None: + lw = radius + if not isinstance(lw, int) or lw < 0: + raise ValueError('Radius must be a nonnegative integer.') + + weights = _gaussian_kernel1d(sigma, order, lw) ** coeff_power + input_image = jnp.moveaxis(input_image, axis, 0) + cur_shape = input_image.shape + input_image = input_image.reshape((cur_shape[0], -1)) + input_image = jax.vmap(convolve_reflect_1d, in_axes=(1, None), out_axes=1)(input_image, weights) + input_image = input_image.reshape(cur_shape) + input_image = jnp.moveaxis(input_image, 0, axis) + else: + input_image = jnp.copy(input_image) + return input_image + + +def convolve_reflect_1d(input1, weights): + """ + Perform 1D convolution with reflected boundary conditions to match scipy.ndimage.convolve1d + """ + lw = (len(weights) - 1) // 2 + # Convolve with full overlap between inputs and weights + c1 = jnp.convolve(input1, weights) + # To get reflected convolution, flip the two ends outside the original region and add them to the full convolution + # so that each end point is added twice, then subtract the contributions that were doubled + c1 = c1.at[lw:2 * lw].set(c1[lw:2 * lw] + jnp.flip(c1[0:lw])) + c1 = c1.at[-2 * lw:-lw].set(c1[-2 * lw:-lw] + jnp.flip(c1[-lw:])) + c1 = c1.at[lw:-lw].get() + return c1 + + +def _gaussian_kernel1d(sigma, order, radius): + """ + Computes a 1-D Gaussian convolution kernel. + """ + if order < 0: + raise ValueError('order must be non-negative') + exponent_range = jnp.arange(order + 1) + sigma2 = sigma * sigma + x = jnp.arange(-radius, radius+1) + phi_x = jnp.exp(-0.5 / sigma2 * x ** 2) + phi_x = phi_x / phi_x.sum() + + if order == 0: + return phi_x + else: + # f(x) = q(x) * phi(x) = q(x) * exp(p(x)) + # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x) + # p'(x) = -1 / sigma ** 2 + # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the + # coefficients of q(x) + q = jnp.zeros(order + 1) + q[0] = 1 + D = jnp.diag(exponent_range[1:], 1) # D @ q(x) = q'(x) + P = jnp.diag(jnp.ones(order)/-sigma2, -1) # P @ q(x) = q(x) * p'(x) + Q_deriv = D + P + for _ in range(order): + q = Q_deriv.dot(q) + q = (x[:, None] ** exponent_range).dot(q) + return q * phi_x diff --git a/mbirjax/cone_beam.py b/mbirjax/cone_beam.py index 79e7497..08594ad 100644 --- a/mbirjax/cone_beam.py +++ b/mbirjax/cone_beam.py @@ -22,18 +22,19 @@ class ConeBeamModel(TomographyModel): Shape of the sinogram as a tuple in the form `(views, rows, channels)`, where 'views' is the number of different projection angles, 'rows' correspond to the number of detector rows, and 'channels' index columns of the detector that are assumed to be aligned with the rotation axis. - angles (jnp.ndarray): + angles (ndarray or jax array): A 1D array of projection angles, in radians, specifying the angle of each projection relative to the origin. source_detector_dist (float): Distance between the X-ray source and the detector in units of ALU. source_iso_dist (float): Distance between the X-ray source and the center of rotation in units of ALU. - recon_slice_offset (float, optional, default=0): Vertical offset of the image in ALU. - If recon_slice_offset is positive, we reconstruct the region below iso. - det_rotation (float, optional, default=0): Angle in radians between the projection of the object rotation axis - and the detector vertical axis, where positive describes a clockwise rotation of the detector as seen from the source. + + Note: + One additional parameter for ConeBeamModel that can be set using set_params() is + + **recon_slice_offset** (float, default=0) - + Vertical offset of the image in ALU. If recon_slice_offset is positive, we reconstruct the region below iso. """ - def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist, - recon_slice_offset=0.0, det_rotation=0.0): + def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist): # Convert the view-dependent vectors to an array # This is more complicated than needed with only a single view-dependent vector but is included to # illustrate the process as shown in TemplateModel @@ -46,7 +47,7 @@ def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist super().__init__(sinogram_shape, view_params_array=view_params_array, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist, - recon_slice_offset=recon_slice_offset, det_rotation=det_rotation) + recon_slice_offset=0.0) @classmethod def from_file(cls, filename): @@ -60,10 +61,18 @@ def from_file(cls, filename): ConeBeamModel with the specified parameters. """ # Load the parameters and convert to use the ConeBeamModel keywords. - params = ParameterHandler.load_param_dict(filename, values_only=True) + required_param_names = ['sinogram_shape', 'source_detector_dist', 'source_iso_dist'] + required_params, params = ParameterHandler.load_param_dict(filename, required_param_names, values_only=True) + + # Collect the required parameters into a separate dictionary and remove them from the loaded dict. angles = params['view_params_array'] del params['view_params_array'] - return cls(angles=angles, **params) + required_params['angles'] = angles + + # Get an instance with the required parameters, then set any optional parameters + new_model = cls(**required_params) + new_model.set_params(**params) + return new_model def get_magnification(self): """ @@ -109,7 +118,7 @@ def get_geometry_parameters(self): """ # First get the parameters managed by ParameterHandler geometry_param_names = \ - ['delta_det_row', 'delta_det_channel', 'det_row_offset', 'det_channel_offset', 'det_rotation', + ['delta_det_row', 'delta_det_channel', 'det_row_offset', 'det_channel_offset', 'source_detector_dist', 'delta_voxel', 'recon_slice_offset'] geometry_param_values = self.get_params(geometry_param_names) @@ -495,7 +504,7 @@ def compute_vertical_data_single_pixel(pixel_index, angle, projector_params): u_p, v_p, pixel_mag = ConeBeamModel.geometry_xyz_to_uv_mag(x_p, y_p, z_p, gp.source_detector_dist, gp.magnification) # Convert from uv to index coordinates in detector and get the vector of center detector rows for this cylinder m_p, _ = ConeBeamModel.detector_uv_to_mn(u_p, v_p, gp.delta_det_channel, gp.delta_det_row, gp.det_channel_offset, - gp.det_row_offset, num_det_rows, num_det_channels, gp.det_rotation) + gp.det_row_offset, num_det_rows, num_det_channels) m_p_center = jnp.round(m_p).astype(int) # Compute vertical cone angle of pixel @@ -605,18 +614,15 @@ def geometry_xyz_to_uv_mag(x, y, z, source_detector_dist, magnification): return u, v, pixel_mag @staticmethod - @partial(jax.jit, static_argnames='det_rotation') + @jax.jit def detector_uv_to_mn(u, v, delta_det_channel, delta_det_row, det_channel_offset, det_row_offset, num_det_rows, - num_det_channels, det_rotation=0): + num_det_channels): """ Convert (u, v) detector coordinates to fractional indices (m, n) into the detector. Note: This version does not account for nonzero detector rotation. """ - if det_rotation != 0: - raise ValueError('Nonzero det_rotation is not implemented.') - # Account for small rotation of the detector # TODO: In addition to including the rotation, we'd need to adjust the calculation of the channel as a # function of slice. diff --git a/mbirjax/parallel_beam.py b/mbirjax/parallel_beam.py index f0d9ba5..84a0a34 100644 --- a/mbirjax/parallel_beam.py +++ b/mbirjax/parallel_beam.py @@ -61,11 +61,19 @@ def from_file(cls, filename): Returns: ConeBeamModel with the specified parameters. """ - # Load the parameters and convert to use the ParallelBeamModel keywords. - params = ParameterHandler.load_param_dict(filename, values_only=True) + # Load the parameters and convert to use the ConeBeamModel keywords. + required_param_names = ['sinogram_shape'] + required_params, params = ParameterHandler.load_param_dict(filename, required_param_names, values_only=True) + + # Collect the required parameters into a separate dictionary and remove them from the loaded dict. angles = params['view_params_array'] del params['view_params_array'] - return cls(angles=angles, **params) + required_params['angles'] = angles + + # Get an instance with the required parameters, then set any optional parameters + new_model = cls(**required_params) + new_model.set_params(**params) + return new_model def get_magnification(self): """ diff --git a/mbirjax/parameter_handler.py b/mbirjax/parameter_handler.py index ca93559..db55813 100644 --- a/mbirjax/parameter_handler.py +++ b/mbirjax/parameter_handler.py @@ -3,6 +3,7 @@ from ruamel.yaml import YAML import mbirjax._utils as utils import warnings +import copy class ParameterHandler(): @@ -86,7 +87,12 @@ def save_params(self, filename): Returns: Nothing but creates or overwrites the specified file. """ - output_params = ParameterHandler.convert_arrays_to_strings(self.params.copy()) + output_params = ParameterHandler.convert_arrays_to_strings(copy.deepcopy(self.params)) + # Convert any lists to tuples for consistency with load + keys = output_params.keys() + for key in keys: + if isinstance(output_params[key]['val'], list): + output_params[key]['val'] = tuple(output_params[key]['val']) # Determine file type if filename[-4:] == '.yml' or filename[-5:] == '.yaml': @@ -99,16 +105,18 @@ def save_params(self, filename): raise ValueError('Filename must end in .yaml or .yml: ' + filename) @staticmethod - def load_param_dict(filename, values_only=True): + def load_param_dict(filename, required_param_names=None, values_only=True): """ Load parameter dictionary from yaml file. Args: filename (str): Path to load to store the parameter dictionary. Must end in .yml or .yaml + required_param_names (list of strings): List of parameter names that are required for a class. values_only (bool): If True, then extract and return the values of each entry only. Returns: - dict: The dictionary of paramters. + required_params (dict): Dictionary of required parameter entries. + params (dict): Dictionary of all other parameters. """ # Determine file type if filename[-4:] == '.yml' or filename[-5:] == '.yaml': @@ -118,25 +126,23 @@ def load_param_dict(filename, values_only=True): params = yaml.load(file) params = ParameterHandler.convert_strings_to_arrays(params) - keys = params.keys() - if 'recon_shape' in keys: - params['recon_shape']['val'] = tuple(params['recon_shape']['val']) - if values_only: - for key in keys: - params[key] = params[key]['val'] - return params + # Convert any lists to tuples for consistency with save + for key in params.keys(): + if isinstance(params[key]['val'], list): + params[key]['val'] = tuple(params[key]['val']) - def load_params(self, filename): - """ - Load parameter dictionary from yaml file. - Args: - filename (str): Path to load to store the parameter dictionary. Must end in .yml or .yaml + # Separate the required parameters into a new dict and delete those entries from the original + required_params = dict() + for name in required_param_names: + required_params[name] = params[name] + del params[name] - Returns: - Nothing, but the parameters are set from the file. - """ - # Determine file type - self.params = ParameterHandler.load_param_dict(filename) + if values_only: + for key in required_params.keys(): + required_params[key] = required_params[key]['val'] + for key in params.keys(): + params[key] = params[key]['val'] + return required_params, params def set_params(self, no_warning=False, no_compile=False, **kwargs): """ @@ -221,14 +227,14 @@ def get_params_from_dict(param_dict, parameter_names): if parameter_names in param_dict.keys(): value = param_dict[parameter_names]['val'] else: - raise NameError('"{}" not a recognized argument'.format(parameter_names)) + raise NameError('"{}" is not a recognized argument'.format(parameter_names)) return value values = [] for name in parameter_names: if name in param_dict.keys(): values.append(param_dict[name]['val']) else: - raise NameError('"{}" not a recognized argument'.format(name)) + raise NameError('"{}" is not a recognized argument'.format(name)) return values def get_params(self, parameter_names): diff --git a/mbirjax/plot_utils.py b/mbirjax/plot_utils.py index 92ad1c7..45c5aaa 100644 --- a/mbirjax/plot_utils.py +++ b/mbirjax/plot_utils.py @@ -1,3 +1,4 @@ +import warnings import matplotlib.pyplot as plt import numpy import numpy as np @@ -104,8 +105,10 @@ def slice_viewer(data, data2=None, title='', vmin=None, vmax=None, slice_label=' # Then add the slice slider ax_slice_slider = fig.add_subplot(gs[1, :]) + warnings.filterwarnings('ignore', category=UserWarning) slice_slider = Slider(ax=ax_slice_slider, label=slice_label, valmin=0, valmax=data.shape[2] - 1, valinit=slice_index, valfmt='%0.0f') + warnings.filterwarnings('default', category=UserWarning) # Then the intensity slider ax_intensity_slider = fig.add_subplot(gs[2, :]) diff --git a/mbirjax/preprocess/NSI.py b/mbirjax/preprocess/NSI.py new file mode 100644 index 0000000..175c99d --- /dev/null +++ b/mbirjax/preprocess/NSI.py @@ -0,0 +1,630 @@ +import os, sys +import re +import numpy as np +import warnings +import striprtf.striprtf as striprtf +import mbirjax.preprocess as preprocess +import glob +import pprint +pp = pprint.PrettyPrinter(indent=4) + + +def compute_sino_and_params(dataset_dir, + downsample_factor=(1, 1), crop_region=((0, 1), (0, 1)), + subsample_view_factor=1): + """ + Load NSI sinogram data and prepare all needed arrays and parameters for a ConeBeamModel reconstruction. + + This function computes the sinogram and geometry parameters from an NSI scan directory containing scan data and parameters. + More specifically, the function performs the following operations in a single easy-to-use manner: + + 1. Loads the object, blank, and dark scans, as well as the geometry parameters from an NSI dataset directory. + + 2. Computes the sinogram from object, blank, and dark scans. + + 3. Replaces defective pixels with interpolated values. + + 4. Performs background offset correction to the sinogram from the edge pixels. + + 5. Corrects sinogram data to account for detector rotation. + + Args: + dataset_dir (string): Path to an NSI scan directory. The directory is assumed to have the following structure: + + - ``*.nsipro`` (NSI config file) + - ``Geometry*.rtf`` (geometry report) + - ``Radiographs*/`` (directory containing all radiograph images) + - ``**/gain0.tif`` (blank scan image) + - ``**/offset.tif`` (dark scan image) + - ``**/*.defect`` (defective pixel information) + + downsample_factor ((int, int), optional) - Down-sample factors along the detector rows and channels respectively. + If scan size is not divisible by `downsample_factor`, the scans will be first truncated to a size that is divisible by `downsample_factor`. + + crop_region (((float, float),(float, float)), optional) - Values of ((row_start, row_end), (col_start, col_end)) define a bounding box that crops the scan. + The default of ((0, 1), (0, 1)) retains the entire scan. + + subsample_view_factor (int, optional): View subsample factor. By default no view subsampling will be performed. + + Returns: + tuple: [sinogram, cone_beam_params, optional_params] + + sino (jax array): 3D sinogram data with shape (num_views, num_det_rows, num_det_channels). + + cone_beam_params (dict): Required parameters for the ConeBeamModel constructor. + + optional_params (dict): Additional ConeBeamModel parameters to be set using set_params(). + + Example: + .. code-block:: python + + # Get data and recon parameters + sino, cone_beam_params, optional_params = mbirjax.preprocess.NSI.compute_sino_and_params(dataset_dir, downsample_factor=downsample_factor, subsample_view_factor=subsample_view_factor) + + # Create the model and set the parameters + ct_model = mbirjax.ConeBeamModel(**cone_beam_params) + ct_model.set_params(**optional_params) + ct_model.set_params(sharpness=sharpness, verbose=1) + + # Compute sinogram weights and do the reconstruction + weights = ct_model.gen_weights(sino, weight_type='transmission_root') + recon, recon_params = ct_model.recon(sino, weights=weights) + + """ + + print("\n\n########## Loading object, blank, dark scans, as well as geometry parameters from NSI dataset directory ...") + obj_scan, blank_scan, dark_scan, cone_beam_params, optional_params, defective_pixel_list = \ + load_scans_and_params(dataset_dir, + downsample_factor=downsample_factor, crop_region=crop_region, + subsample_view_factor=subsample_view_factor) + + print("MBIRJAX geometry parameters:") + pp.pprint(cone_beam_params) + pp.pprint(optional_params) + print('obj_scan shape = ', obj_scan.shape) + print('blank_scan shape = ', blank_scan.shape) + print('dark_scan shape = ', dark_scan.shape) + + print("\n\n########## Computing sinogram from object, blank, and dark scans ...") + sino, defective_pixel_list = \ + preprocess.compute_sino_transmission(obj_scan, blank_scan, dark_scan, defective_pixel_list) + del obj_scan, blank_scan, dark_scan # delete scan images to save memory + + print("\n\n########## Correcting background offset to the sinogram from edge pixels ...") + background_offset = preprocess.estimate_background_offset(sino) + print("background_offset = ", background_offset) + sino = sino - background_offset + + print("\n\n########## Correcting sinogram data to account for detector rotation ...") + sino = preprocess.correct_det_rotation(sino, det_rotation=optional_params["det_rotation"]) + del optional_params["det_rotation"] + + return sino, cone_beam_params, optional_params + + +def load_scans_and_params(dataset_dir, downsample_factor=(1, 1), crop_region=((0, 1), (0, 1)), + view_id_start=0, view_id_end=None, subsample_view_factor=1): + """ + Load the object scan, blank scan, dark scan, view angles, defective pixel information, and geometry parameters from an NSI scan directory. + + This function loads the sinogram data and parameters from an NSI scan directory for users who would prefer to implement custom preprocessing of the data. + + Args: + dataset_dir (string): Path to an NSI scan directory. The directory is assumed to have the following structure: + + - ``*.nsipro`` (NSI config file) + - ``Geometry*.rtf`` (geometry report) + - ``Radiographs*/`` (directory containing all radiograph images) + - ``**/gain0.tif`` (blank scan image) + - ``**/offset.tif`` (dark scan image) + - ``**/*.defect`` (defective pixel information) + + downsample_factor ((int, int), optional) - Down-sample factors along the detector rows and channels respectively. + If scan size is not divisible by `downsample_factor`, the scans will be first truncated to a size that is divisible by `downsample_factor`. + + crop_region (((float, float),(float, float)), optional) - Values of ((row_start, row_end), (col_start, col_end)) define a bounding box that crops the scan. + The default of ((0, 1), (0, 1)) retains the entire scan. + + view_id_start (int, optional): view index corresponding to the first view. + + view_id_end (int, optional): view index corresponding to the last view. If None, this will be equal to the total number of object scan images in ``obj_scan_dir``. + + subsample_view_factor (int, optional): view subsample factor. + + Returns: + tuple: [obj_scan, blank_scan, dark_scan, cone_beam_params, optional_params, defective_pixel_list] + + obj_scan (jax array): 3D object scan with shape (num_views, num_det_rows, num_det_channels). + + blank_scan (jax array): 3D blank scan with shape (1, num_det_rows, num_det_channels). + + dark_scan (jax array): 3D dark scan with shape (1, num_det_rows, num_det_channels). + + cone_beam_params (dict): Required parameters for the ConeBeamModel constructor. + + optional_params (dict): Additional ConeBeamModel parameters to be set using set_params(). + + defective_pixel_list (list(tuple)): A list of tuples containing indices of invalid sinogram pixels, with the format (detector_row_idx, detector_channel_idx). + """ + ### automatically parse the paths to NSI metadata and scans from dataset_dir + config_file_path, geom_report_path, obj_scan_dir, blank_scan_path, dark_scan_path, defective_pixel_path = \ + _parse_filenames_from_dataset_dir(dataset_dir) + + print("The following files will be used to compute the NSI reconstruction:\n", + f" - NSI config file: {config_file_path}\n", + f" - Geometry report: {geom_report_path}\n", + f" - Radiograph directory: {obj_scan_dir}\n", + f" - Blank scan image: {blank_scan_path}\n", + f" - Dark scan image: {dark_scan_path}\n", + f" - Defective pixel information: {defective_pixel_path}\n") + ### NSI param tags in nsipro file + tag_section_list = [['source', 'Result'], # vector from origin to source + ['reference', 'Result'], # vector from origin to first row and column of the detector + ['pitch', 'Object Radiograph'], # detector pixel pitch + ['width pixels', 'Detector'], # number of detector rows + ['height pixels', 'Detector'], # number of detector channels + ['number', 'Object Radiograph'], # number of views + ['Rotation range', 'CT Project Configuration'], # Range of rotation angle (usually 360) + ['rotate', 'Correction'], # rotation of radiographs + ['flipH', 'Correction'], # Horizontal flip (boolean) + ['flipV', 'Correction'], # Vertical flip (boolean) + ['angleStep', 'Object Radiograph'], # step size of adjacent view angles + ['clockwise', 'Processed'], # rotation direction (boolean) + ['axis', 'Result'], # unit vector in direction ofrotation axis + ['normal', 'Result'], # unit vector in direction of source-detector line + ['horizontal', 'Result'] # unit vector in direction of detector rows + ] + assert(os.path.isfile(config_file_path)), f'Error! NSI config file does not exist. Please check whether {config_file_path} is a valid file.' + NSI_params = _read_str_from_config(config_file_path, tag_section_list) + + # vector from origin to source + r_s = NSI_params[0].split(' ') + r_s = np.array([np.single(elem) for elem in r_s]) + + # vector from origin to reference, where reference is the center of first row and column of the detector + r_r = NSI_params[1].split(' ') + r_r = np.array([np.single(elem) for elem in r_r]) + + # correct the coordinate of (0,0) detector pixel based on "Geometry Report.rtf" + x_r, y_r = _read_detector_location_from_geom_report(geom_report_path) + r_r[0] = x_r + r_r[1] = y_r + print("Corrected coordinate of (0,0) detector pixel (from Geometry Report) = ", r_r) + + # detector pixel pitch + pixel_pitch_det = NSI_params[2].split(' ') + delta_det_channel = np.single(pixel_pitch_det[0]) + delta_det_row = np.single(pixel_pitch_det[1]) + + # dimension of radiograph + num_det_channels = int(NSI_params[3]) + num_det_rows = int(NSI_params[4]) + + # total number of radiograph scans + num_acquired_scans = int(NSI_params[5]) + + # total angles (usually 360 for 3D data, and (360*number_of_full_rotations) for 4D data + total_angles = int(NSI_params[6]) + + # Radiograph rotation (degree) + scan_rotate = int(NSI_params[7]) + if (scan_rotate == 180) or (scan_rotate == 0): + print('scans are in portrait mode!') + elif (scan_rotate == 270) or (scan_rotate == 90): + print('scans are in landscape mode!') + num_det_channels, num_det_rows = num_det_rows, num_det_channels + else: + warnings.warn("Picture mode unknown! Should be either portrait (0 or 180 deg rotation) or landscape (90 or 270 deg rotation). Automatically setting picture mode to portrait.") + scan_rotate = 180 + + # Radiograph horizontal & vertical flip + if NSI_params[8] == "True": + flipH = True + else: + flipH = False + if NSI_params[9] == "True": + flipV = True + else: + flipV = False + + # Detector rotation angle step (degree) + angle_step = np.single(NSI_params[10]) + + # Detector rotation direction + if NSI_params[11] == "True": + print("clockwise rotation.") + else: + print("counter-clockwise rotation.") + # counter-clockwise rotation + angle_step = -angle_step + + # Rotation axis + r_a = NSI_params[12].split(' ') + r_a = np.array([np.single(elem) for elem in r_a]) + # make sure rotation axis points down + if r_a[1] > 0: + r_a = -r_a + + # Detector normal vector + r_n = NSI_params[13].split(' ') + r_n = np.array([np.single(elem) for elem in r_n]) + + # Detector horizontal vector + r_h = NSI_params[14].split(' ') + r_h = np.array([np.single(elem) for elem in r_h]) + + print("############ NSI geometry parameters ############") + print("vector from origin to source = ", r_s, " [mm]") + print("vector from origin to (0,0) detector pixel = ", r_r, " [mm]") + print("Unit vector of rotation axis = ", r_a) + print("Unit vector of normal = ", r_n) + print("Unit vector of horizontal = ", r_h) + print(f"Detector pixel pitch: (delta_det_row, delta_det_channel) = ({delta_det_row:.3f},{delta_det_channel:.3f}) [mm]") + print(f"Detector size: (num_det_rows, num_det_channels) = ({num_det_rows},{num_det_channels})") + print("############ End NSI geometry parameters ############") + ### END load NSI parameters from an nsipro file + + + ### Convert NSI geometry parameters to MBIR parameters + source_detector_dist, source_iso_dist, magnification, det_rotation = calc_source_detector_params(r_a, r_n, r_h, r_s, r_r) + + det_channel_offset, det_row_offset = calc_row_channel_params(r_a, r_n, r_h, r_s, r_r, delta_det_channel, delta_det_row, num_det_channels, num_det_rows, magnification) + + ### END Convert NSI geometry parameters to MBIR parameters + + ### Adjust geometry NSI_params according to crop_region and downsample_factor + if isinstance(crop_region[0], (list, tuple)): + (row0, row1), (col0, col1) = crop_region + else: + row0, row1, col0, col1 = crop_region + + ### Adjust detector size and pixel pitch params w.r.t. downsampling arguments + num_det_rows = num_det_rows // downsample_factor[0] + num_det_channels = num_det_channels // downsample_factor[1] + + delta_det_row *= downsample_factor[0] + delta_det_channel *= downsample_factor[1] + + ### Adjust detector size params w.r.t. cropping arguments + num_det_rows_shift0 = np.round(num_det_rows * row0) + num_det_rows_shift1 = np.round(num_det_rows * (1 - row1)) + num_det_rows = num_det_rows - (num_det_rows_shift0 + num_det_rows_shift1) + + num_det_channels_shift0 = np.round(num_det_channels * col0) + num_det_channels_shift1 = np.round(num_det_channels * (1 - col1)) + num_det_channels = num_det_channels - (num_det_channels_shift0 + num_det_channels_shift1) + + ### read blank scans and dark scans + blank_scan = np.expand_dims(preprocess.read_scan_img(blank_scan_path), axis=0) + if dark_scan_path is not None: + dark_scan = np.expand_dims(preprocess.read_scan_img(dark_scan_path), axis=0) + else: + dark_scan = np.zeros(blank_scan.shape) + + if view_id_end is None: + view_id_end = num_acquired_scans + view_ids = list(range(view_id_start, view_id_end, subsample_view_factor)) + obj_scan = preprocess.read_scan_dir(obj_scan_dir, view_ids) + + ### Load defective pixel information + if defective_pixel_path is not None: + tag_section_list = [['Defect', 'Defective Pixels']] + defective_loc = _read_str_from_config(defective_pixel_path, tag_section_list) + defective_pixel_list = np.array([defective_pixel_ind.split()[1::-1] for defective_pixel_ind in defective_loc ]).astype(int) + defective_pixel_list = list(map(tuple, defective_pixel_list)) + else: + defective_pixel_list = None + + + ### flip the scans according to flipH and flipV information from nsipro file + if flipV: + print("Flip scans vertically!") + obj_scan = np.flip(obj_scan, axis=1) + blank_scan = np.flip(blank_scan, axis=1) + dark_scan = np.flip(dark_scan, axis=1) + # adjust the defective pixel information: vertical flip + if defective_pixel_list is not None: + for i in range(len(defective_pixel_list)): + (r,c) = defective_pixel_list[i] + defective_pixel_list[i] = (blank_scan.shape[1]-r-1, c) + if flipH: + print("Flip scans horizontally!") + obj_scan = np.flip(obj_scan, axis=2) + blank_scan = np.flip(blank_scan, axis=2) + dark_scan = np.flip(dark_scan, axis=2) + # adjust the defective pixel information: horizontal flip + if defective_pixel_list is not None: + for i in range(len(defective_pixel_list)): + (r,c) = defective_pixel_list[i] + defective_pixel_list[i] = (r, blank_scan.shape[2]-c-1) + + ### rotate the scans according to scan_rotate param + rot_count = scan_rotate // 90 + for n in range(rot_count): + obj_scan = np.rot90(obj_scan, 1, axes=(2,1)) + blank_scan = np.rot90(blank_scan, 1, axes=(2,1)) + dark_scan = np.rot90(dark_scan, 1, axes=(2,1)) + # adjust the defective pixel information: rotation (clockwise) + if defective_pixel_list is not None: + for i in range(len(defective_pixel_list)): + (r,c) = defective_pixel_list[i] + defective_pixel_list[i] = (c, blank_scan.shape[2]-r-1) + + ### crop the scans based on input params + obj_scan, blank_scan, dark_scan, defective_pixel_list = preprocess.crop_scans(obj_scan, blank_scan, dark_scan, + crop_region=crop_region, + defective_pixel_list=defective_pixel_list) + + ### downsample the scans with block-averaging + if downsample_factor[0]*downsample_factor[1] > 1: + obj_scan, blank_scan, dark_scan, defective_pixel_list = preprocess.downsample_scans(obj_scan, blank_scan, dark_scan, + downsample_factor=downsample_factor, + defective_pixel_list=defective_pixel_list) + + ### compute projection angles based on angle_step and view_ids + angles = np.deg2rad(np.array([(view_idx*angle_step) % 360.0 for view_idx in view_ids])) + + ### Set 1 ALU = delta_det_channel + source_detector_dist /= delta_det_channel # mm to ALU + source_iso_dist /= delta_det_channel # mm to ALU + det_channel_offset /= delta_det_channel # mm to ALU + det_row_offset /= delta_det_row # mm to ALU + delta_det_channel = 1.0 + delta_det_row = 1.0 + + # Create a dictionary to store MBIR parameters + num_views = len(angles) + cone_beam_params = dict() + cone_beam_params["sinogram_shape"] = (num_views, num_det_rows, num_det_channels) + cone_beam_params["angles"] = angles + cone_beam_params["source_detector_dist"] = source_detector_dist + cone_beam_params["source_iso_dist"] = source_iso_dist + + optional_params = dict() + optional_params["delta_det_channel"] = delta_det_channel + optional_params["delta_det_row"] = delta_det_row + optional_params['delta_voxel'] = delta_det_channel * (source_iso_dist/source_detector_dist) + optional_params["det_channel_offset"] = det_channel_offset + optional_params["det_row_offset"] = det_row_offset + optional_params["det_rotation"] = det_rotation # tilt angle of rotation axis + + return obj_scan, blank_scan, dark_scan, cone_beam_params, optional_params, defective_pixel_list + + +######## subroutines for parsing NSI metadata +def _parse_filenames_from_dataset_dir(dataset_dir): + """ Given the path to an NSI dataset directory, automatically parse the paths to the following files and directories: + - NSI config file (nsipro file), + - geometry report (Geometry Report.rtf), + - object scan directory (Radiographs/), + - blank scan (Corrections/gain0.tif), + - dark scan (Corrections/offset.tif), + - defective pixel information (Corrections/defective_pixels.defect), + If multiple files with the same patterns are found, then the user will be prompted to select the correct file. + + Args: + dataset_dir (string): Path to the directory containing the NSI scans and metadata. + Returns: + 6-element tuple containing: + - config_file_path (string): Path to the NSI config file (nsipro file). + - geom_report_path (string): Path to the geometry report file (Geometry Report.rtf) + - obj_scan_dir (string): Path to the directory containing the object scan images (radiographs). + - blank_scan_path (string): Path to the blank scan image. + - dark_scan_path (string): Path to the dark scan image. + - defective_pixel_path (string): Path to the file containing defective pixel information. + """ + # NSI config file + config_file_path_list = glob.glob(os.path.join(dataset_dir, "*.nsipro")) + config_file_path = _prompt_user_choice("NSI config files", config_file_path_list) + + # geometry report + geom_report_path_list = glob.glob(os.path.join(dataset_dir, "Geometry*.rtf")) + geom_report_path = _prompt_user_choice("geometry report files", geom_report_path_list) + + # Radiograph directory + obj_scan_dir_list = glob.glob(os.path.join(dataset_dir, "Radiographs*")) + obj_scan_dir = _prompt_user_choice("radiograph directories", obj_scan_dir_list) + + # blank scan + blank_scan_path_list = glob.glob(os.path.join(dataset_dir, "**/gain0.tif")) + blank_scan_path = _prompt_user_choice("blank scans", blank_scan_path_list) + + # dark scan + dark_scan_path_list = glob.glob(os.path.join(dataset_dir, "**/offset.tif")) + dark_scan_path = _prompt_user_choice("dark scans", dark_scan_path_list) + + # defective pixel file + defective_pixel_path_list = glob.glob(os.path.join(dataset_dir, "**/*.defect")) + defective_pixel_path = _prompt_user_choice("defective pixel files", defective_pixel_path_list) + + return config_file_path, geom_report_path, obj_scan_dir, blank_scan_path, dark_scan_path, defective_pixel_path + +def _prompt_user_choice(file_description, file_path_list): + """ Given a list of candidate files, prompt the user to select the desired one. + If only one candidate exists, the function will return the name of that file without any user prompts. + """ + # file_path_list should contain at least one element + assert(len(file_path_list) > 0), f"No {file_description} found!! Please make sure you provided a valid NSI scan path." + + # if only file_path_list contains only one file, then return it without user prompt. + if len(file_path_list) == 1: + return file_path_list[0] + + # file_path_list contains multiple files. Prompt the user to select the desired one. + choice_min = 0 + choice_max = len(file_path_list)-1 + question = f"Multiple {file_description} detected. Please select the desired one from the following candidates " + prompt = ":\n" + for i in range(len(file_path_list)): + prompt += f"\n {i}: {file_path_list[i]}" + prompt += f"\n[{choice_min}-{choice_max}]" + while True: + sys.stdout.write(question + prompt) + try: + choice = int(input()) + if choice in range(len(file_path_list)): + return file_path_list[choice] + else: + sys.stdout.write(f"Please respond with a number between {choice_min} and {choice_max}.\n") + except: + sys.stdout.write(f"Please respond with a number between {choice_min} and {choice_max}.\n") + return + +def _read_detector_location_from_geom_report(geom_report_path): + """ Give the path to "Geometry Report.rtf", returns the X and Y coordinates of the first row and first column of the detector. + It is observed that the coordinates given in "Geometry Report.rtf" is more accurate than the coordinates given in the field in nsipro file. + Specifically, this function parses the information of "Image center" from "Geometry Report.rtf". + Example: + - content in "Geometry Report.rtf": Image center (95.707, 123.072) [mm] / (3.768, 4.845) [in] + - Returns: (95.707, 123.072) + Args: + geom_report_path (string): Path to "Geometry Report.rtf" file. This file contains more accurate information regarding the coordinates of the first detector row and column. + Returns: + (x_r, y_r): A tuple containing the X and Y coordinates of center of the first detector row and column. + """ + rtf_file = open(geom_report_path, 'r') + rtf_raw = rtf_file.read() + rtf_file.close() + # convert rft file content to plain text. + rtf_converted = striprtf.rtf_to_text(rtf_raw).split("\n") + for line in rtf_converted: + if "Image center" in line: + # read the two floating numbers immediately following the keyword "Image center". + # This is the X and Y coordinates of (0,0) detector pixel in units of mm. + data = re.findall(r"(\d+\.*\d*, \d+\.*\d*)", line) + break + data = data[0].split(",") + x_r = float(data[0]) + y_r = float(data[1]) + return x_r, y_r + +def _read_str_from_config(filepath, tags_sections): + """Returns strings about dataset information read from NSI configuration file. + + Args: + filepath (string): Path to NSI configuration file. The filename extension is '.nsipro'. + tags_sections (list[string,string]): Given tags and sections to locate the information we want to read. + Returns: + list[string], a list of strings have needed dataset information for reconstruction. + + """ + tag_strs = ['<' + tag + '>' for tag, section in tags_sections] + section_starts = ['<' + section + '>' for tag, section in tags_sections] + section_ends = ['' for tag, section in tags_sections] + NSI_params = [] + + try: + with open(filepath, 'r') as f: + lines = f.readlines() + except IOError: + print("Could not read file:", filepath) + + for tag_str, section_start, section_end in zip(tag_strs, section_starts, section_ends): + section_start_inds = [ind for ind, match in enumerate(lines) if section_start in match] + section_end_inds = [ind for ind, match in enumerate(lines) if section_end in match] + section_start_ind = section_start_inds[0] + section_end_ind = section_end_inds[0] + + for line_ind in range(section_start_ind + 1, section_end_ind): + line = lines[line_ind] + if tag_str in line: + tag_ind = line.find(tag_str, 1) + len(tag_str) + if tag_ind == -1: + NSI_params.append("") + else: + NSI_params.append(line[tag_ind:].strip('\n')) + + return NSI_params +######## END subroutines for parsing NSI metadata + +######## subroutines for NSI-MBIR parameter conversion + +def calc_det_rotation(r_a, r_n, r_h, r_v): + """ Calculate the tilt angle between the rotation axis and the detector columns in unit of radians. User should call `preprocess.correct_det_rotation()` to rotate the sinogram images w.r.t. to the tilt angle. + + Args: + r_a: 3D real-valued unit vector in direction of rotation axis pointing down. + r_n: 3D real-valued unit vector perpendicular to the detector plan pointing from source to detector. + r_h: 3D real-valued unit vector in direction parallel to detector rows pointing from left to right. + r_v: 3D real-valued unit vector in direction parallel to detector columns pointing down. + Returns: + float number specifying the angle between the rotation axis and the detector columns in units of radians. + """ + # project the rotation axis onto the detector plane + r_a_p = preprocess.unit_vector(r_a - preprocess.project_vector_to_vector(r_a, r_n)) + # calculate angle between the projected rotation axis and the horizontal detector vector + det_rotation = -np.arctan(np.dot(r_a_p, r_h)/np.dot(r_a_p, r_v)) + return det_rotation + +def calc_source_detector_params(r_a, r_n, r_h, r_s, r_r): + """ Calculate the MBIRJAX geometry parameters: source_detector_dist, magnification, and rotation axis tilt angle. + Args: + r_a (tuple): 3D real-valued unit vector in direction of rotation axis pointing down. + r_n (tuple): 3D real-valued unit vector perpendicular to the detector plan pointing from source to detector. + r_h (tuple): 3D real-valued unit vector in direction parallel to detector rows pointing from left to right. + r_s (tuple): 3D real-valued vector from origin to the source location. + r_r (tuple): 3D real-valued vector from origin to the center of pixel on first row and colum of detector. + Returns: + 4-element tuple containing: + - **source_detector_dist** (float): Distance between the X-ray source and the detector. + - **source_iso_dist** (float): Distance between the X-ray source and the center of rotation. + - **det_rotation (float)**: Angle between the rotation axis and the detector columns in units of radians. + - **magnification** (float): Magnification of the cone-beam geometry defined as + (source to detector distance)/(source to center-of-rotation distance). + """ + r_n = preprocess.unit_vector(r_n) # make sure r_n is normalized + r_v = np.cross(r_n, r_h) # r_v = r_n x r_h + + #### vector pointing from source to center of rotation along the source-detector line. + r_s_r = preprocess.project_vector_to_vector(-r_s, r_n) # project -r_s to r_n + + #### vector pointing from source to detector along the source-detector line. + r_s_d = preprocess.project_vector_to_vector(r_r-r_s, r_n) + + source_detector_dist = np.linalg.norm(r_s_d) # ||r_s_d|| + source_iso_dist = np.linalg.norm(r_s_r) # ||r_s_r|| + magnification = source_detector_dist/source_iso_dist + det_rotation = calc_det_rotation(r_a, r_n, r_h, r_v) # rotation axis tilt angle + return source_detector_dist, source_iso_dist, magnification, det_rotation + +def calc_row_channel_params(r_a, r_n, r_h, r_s, r_r, delta_det_channel, delta_det_row, num_det_channels, num_det_rows, magnification): + """ Calculate the MBIRJAX geometry parameters: det_channel_offset, det_row_offset. + Args: + r_a (tuple): 3D real-valued unit vector in direction of rotation axis pointing down. + r_n (tuple): 3D real-valued unit vector perpendicular to the detector plan pointing from source to detector. + r_h (tuple): 3D real-valued unit vector in direction parallel to detector rows pointing from left to right. + r_s (tuple): 3D real-valued vector from origin to the source location. + r_r (tuple): 3D real-valued vector from origin to the center of pixel on first row and colum of detector. + delta_det_channel (float): spacing between detector columns. + delta_det_row (float): spacing between detector rows. + num_det_channels (int): Number of detector channels. + num_det_rows (int): Number of detector rows. + magnification (float): Magnification of the cone-beam geometry. + Returns: + 2-element tuple containing: + - **det_channel_offset** (float): Distance from center of detector to the source-detector line along a row. + - **det_row_offset** (float): Distance from center of detector to the source-detector line along a column. + """ + r_n = preprocess.unit_vector(r_n) # make sure r_n is normalized + r_h = preprocess.unit_vector(r_h) # make sure r_h is normalized + r_v = np.cross(r_n, r_h) # r_v = r_n x r_h + + # vector pointing from center of detector to the first row and column of detector along detector columns. + c_v = -(num_det_rows-1)/2*delta_det_row*r_v + # vector pointing from center of detector to the first row and column of detector along detector rows. + c_h = -(num_det_channels-1)/2*delta_det_channel*r_h + # vector pointing from source to first row and column of detector. + r_s_r = r_r - r_s + # vector pointing from source-detector line to center of detector. + r_delta = r_s_r - preprocess.project_vector_to_vector(r_s_r, r_n) - c_v - c_h + # detector row and channel offsets + det_channel_offset = -np.dot(r_delta, r_h) + det_row_offset = -np.dot(r_delta, r_v) + # rotation offset + delta_source = r_s - preprocess.project_vector_to_vector(r_s, r_n) + delta_rot = delta_source - preprocess.project_vector_to_vector(delta_source, r_a)# rotation offset vector (perpendicular to rotation axis) + rotation_offset = np.dot(delta_rot, np.cross(r_n, r_a)) + det_channel_offset += rotation_offset*magnification + return det_channel_offset, det_row_offset + +######## END subroutines for NSI-MBIR parameter conversion diff --git a/mbirjax/preprocess/__init__.py b/mbirjax/preprocess/__init__.py new file mode 100644 index 0000000..01ec46c --- /dev/null +++ b/mbirjax/preprocess/__init__.py @@ -0,0 +1 @@ +from .utilities import * \ No newline at end of file diff --git a/mbirjax/preprocess/utilities.py b/mbirjax/preprocess/utilities.py new file mode 100644 index 0000000..9a3b92e --- /dev/null +++ b/mbirjax/preprocess/utilities.py @@ -0,0 +1,375 @@ +import numpy as np +import warnings +import math +import scipy +from PIL import Image +import glob +import os + +def compute_sino_transmission(obj_scan, blank_scan, dark_scan, defective_pixel_list=None, correct_defective_pixels=True): + """ + Compute sinogram from object, blank, and dark scans. + + This function computes sinogram by taking the negative log of the attenuation estimate. + It can also take in a list of defective pixels and correct those pixel values. + The invalid sinogram entries are the union of defective pixel entries and sinogram entries with values of inf or Nan. + + Args: + obj_scan (ndarray, float): 3D object scan with shape (num_views, num_det_rows, num_det_channels). + blank_scan (ndarray, float): [Default=None] 3D blank scan with shape (num_blank_scans, num_det_rows, num_det_channels). When num_blank_scans>1, the pixel-wise mean will be used as the blank scan. + dark_scan (ndarray, float): [Default=None] 3D dark scan with shape (num_dark_scans, num_det_rows, num_det_channels). When num_dark_scans>1, the pixel-wise mean will be used as the dark scan. + defective_pixel_list (optional, list(tuple)): A list of tuples containing indices of invalid sinogram pixels, with the format (view_idx, row_idx, channel_idx) or (detector_row_idx, detector_channel_idx). + If None, then the invalid pixels will be identified as sino entries with inf or Nan values. + correct_defective_pixels (optioonal, boolean): [Default=True] If true, the defective sinogram entries will be automatically corrected with `mbirjax.preprocess.interpolate_defective_pixels()`. + + Returns: + 2-element tuple containing: + - **sino** (*ndarray, float*): Sinogram data with shape (num_views, num_det_rows, num_det_channels). + - **defective_pixel_list** (list(tuple)): A list of tuples containing indices of invalid sinogram pixels, with the format (view_idx, row_idx, channel_idx) or (detector_row_idx, detector_channel_idx). + + """ + # take average of multiple blank/dark scans, and expand the dimension to be the same as obj_scan. + blank_scan = 0 * obj_scan + np.mean(blank_scan, axis=0, keepdims=True) + dark_scan = 0 * obj_scan + np.mean(dark_scan, axis=0, keepdims=True) + + obj_scan = obj_scan - dark_scan + blank_scan = blank_scan - dark_scan + + #### compute the sinogram. + # suppress warnings in np.log(), since the defective sino entries will be corrected. + with np.errstate(divide='ignore', invalid='ignore'): + sino = -np.log(obj_scan / blank_scan) + + # set the sino pixels corresponding to the provided defective list to 0.0 + if defective_pixel_list is None: + defective_pixel_list = [] + else: # if provided list is not None + for defective_pixel_idx in defective_pixel_list: + if len(defective_pixel_idx) == 2: + (r,c) = defective_pixel_idx + sino[:,r,c] = 0.0 + elif len(defective_pixel_idx) == 3: + (v,r,c) = defective_pixel_idx + sino[v,r,c] = 0.0 + else: + raise Exception("compute_sino_transmission: index information in defective_pixel_list cannot be parsed.") + + # set NaN sino pixels to 0.0 + nan_pixel_list = list(map(tuple, np.argwhere(np.isnan(sino)) )) + for (v,r,c) in nan_pixel_list: + sino[v,r,c] = 0.0 + + # set Inf sino pixels to 0.0 + inf_pixel_list = list(map(tuple, np.argwhere(np.isinf(sino)) )) + for (v,r,c) in inf_pixel_list: + sino[v,r,c] = 0.0 + + # defective_pixel_list = union{input_defective_pixel_list, nan_pixel_list, inf_pixel_list} + defective_pixel_list = list(set().union(defective_pixel_list,nan_pixel_list,inf_pixel_list)) + + if correct_defective_pixels: + print("Interpolate invalid sinogram entries.") + sino, defective_pixel_list = interpolate_defective_pixels(sino, defective_pixel_list) + else: + if defective_pixel_list: + print("Invalid sino entries detected! Please correct then manually or with function `mbirjax.preprocess.interpolate_defective_pixels()`.") + return sino, defective_pixel_list + +def interpolate_defective_pixels(sino, defective_pixel_list): + """ + Interpolates defective sinogram entries with the mean of neighboring pixels. + + Args: + sino (ndarray, float): Sinogram data with 3D shape (num_views, num_det_rows, num_det_channels). + defective_pixel_list (list(tuple)): A list of tuples containing indices of invalid sinogram pixels, with the format (detector_row_idx, detector_channel_idx) or (view_idx, detector_row_idx, detector_channel_idx). + Returns: + 2-element tuple containing: + - **sino** (*ndarray, float*): Corrected sinogram data with shape (num_views, num_det_rows, num_det_channels). + - **defective_pixel_list** (*list(tuple)*): Updated defective_pixel_list with the format (detector_row_idx, detector_channel_idx) or (view_idx, detector_row_idx, detector_channel_idx). + """ + defective_pixel_list_new = [] + num_views, num_det_rows, num_det_channels = sino.shape + weights = np.ones((num_views, num_det_rows, num_det_channels)) + + for defective_pixel_idx in defective_pixel_list: + if len(defective_pixel_idx) == 2: + (r,c) = defective_pixel_idx + weights[:,r,c] = 0.0 + elif len(defective_pixel_idx) == 3: + (v,r,c) = defective_pixel_idx + weights[v,r,c] = 0.0 + else: + raise Exception("replace_defective_with_mean: index information in defective_pixel_list cannot be parsed.") + + for defective_pixel_idx in defective_pixel_list: + if len(defective_pixel_idx) == 2: + v_list = list(range(num_views)) + (r,c) = defective_pixel_idx + elif len(defective_pixel_idx) == 3: + (v,r,c) = defective_pixel_idx + v_list = [v,] + + r_min, r_max = max(r-1, 0), min(r+2, num_det_rows) + c_min, c_max = max(c-1, 0), min(c+2, num_det_channels) + for v in v_list: + # Perform interpolation when there are non-defective pixels in the neighborhood + if np.sum(weights[v,r_min:r_max,c_min:c_max]) > 0: + sino[v,r,c] = np.average(sino[v,r_min:r_max,c_min:c_max], + weights=weights[v,r_min:r_max,c_min:c_max]) + # Corner case: all the neighboring pixels are defective + else: + print(f"Unable to correct sino entry ({v},{r},{c})! All neighborhood values are defective!") + defective_pixel_list_new.append((v,r,c)) + return sino, defective_pixel_list_new + +def correct_det_rotation(sino, weights=None, det_rotation=0.0): + """ + Correct sinogram data and weights to account for detector rotation. + + This function can be used to rotate sinogram views when the axis of rotation is not exactly aligned with the detector columns. + + Args: + sino (float, ndarray): Sinogram data with 3D shape (num_views, num_det_rows, num_det_channels). + weights (float, ndarray): Sinogram weights, with the same array shape as ``sino``. + det_rotation (optional, float): tilt angle between the rotation axis and the detector columns in unit of radians. + + Returns: + - A numpy array containing the corrected sinogram data if weights is None. + - A tuple (sino, weights) if weights is not None + """ + sino = scipy.ndimage.rotate(sino, np.rad2deg(det_rotation), axes=(1,2), reshape=False, order=3) + # weights not provided + if weights is None: + return sino + # weights provided + print("correct_det_rotation: weights provided by the user. Please note that zero weight entries might become non-zero after tilt angle correction.") + weights = scipy.ndimage.rotate(weights, np.rad2deg(det_rotation), axes=(1,2), reshape=False, order=3) + return sino, weights + +def estimate_background_offset(sino, option=0, edge_width=9): + """ + Estimate background offset of a sinogram from the edge pixels. + + This function estimates the background offset when no object is present by computing a robust centroid estimate using `edge_width` pixels along the edge of the sinogram across views. + Typically, this estimate is subtracted from the sinogram so that air is reconstructed as approximately 0. + + Args: + sino (float, ndarray): Sinogram data with 3D shape (num_views, num_det_rows, num_det_channels). + option (int, optional): [Default=0] Option of algorithm used to calculate the background offset. + edge_width(int, optional): [Default=9] Width of the edge regions in pixels. It must be an odd integer >= 3. + Returns: + offset (float): Background offset value. + """ + + # Check validity of edge_width value + assert(isinstance(edge_width, int)), "edge_width must be an integer!" + if (edge_width % 2 == 0): + edge_width = edge_width+1 + warnings.warn(f"edge_width of background regions should be an odd number! Setting edge_width to {edge_width}.") + + if (edge_width < 3): + warnings.warn("edge_width of background regions should be >= 3! Setting edge_width to 3.") + edge_width = 3 + + _, _, num_det_channels = sino.shape + + # calculate mean sinogram + sino_median=np.median(sino, axis=0) + + # offset value of the top edge region. + # Calculated as median([median value of each horizontal line in top edge region]) + median_top = np.median(np.median(sino_median[:edge_width], axis=1)) + + # offset value of the left edge region. + # Calculated as median([median value of each vertical line in left edge region]) + median_left = np.median(np.median(sino_median[:, :edge_width], axis=0)) + + # offset value of the right edge region. + # Calculated as median([median value of each vertical line in right edge region]) + median_right = np.median(np.median(sino_median[:, num_det_channels-edge_width:], axis=0)) + + # offset = median of three offset values from top, left, right edge regions. + offset = np.median([median_top, median_left, median_right]) + return offset + +######## subroutines for image cropping and down-sampling +def downsample_scans(obj_scan, blank_scan, dark_scan, + downsample_factor, + defective_pixel_list=None): + """Performs Down-sampling to the scan images in the detector plane. + + Args: + obj_scan (float): A stack of sinograms. 3D numpy array, (num_views, num_det_rows, num_det_channels). + blank_scan (float): A blank scan. 2D numpy array, (num_det_rows, num_det_channels). + dark_scan (float): A dark scan. 3D numpy array, (num_det_rows, num_det_channels). + downsample_factor ([int, int]): Default=[1,1]] Two numbers to define down-sample factor. + Returns: + Downsampled scans + - **obj_scan** (*ndarray, float*): A stack of sinograms. 3D numpy array, (num_views, num_det_rows, num_det_channels). + - **blank_scan** (*ndarray, float*): A blank scan. 3D numpy array, (num_det_rows, num_det_channels). + - **dark_scan** (*ndarray, float*): A dark scan. 3D numpy array, (num_det_rows, num_det_channels). + """ + + assert len(downsample_factor) == 2, 'factor({}) needs to be of len 2'.format(downsample_factor) + assert (downsample_factor[0]>=1 and downsample_factor[1]>=1), 'factor({}) along each dimension should be greater or equal to 1'.format(downsample_factor) + + good_pixel_mask = np.ones((blank_scan.shape[1], blank_scan.shape[2]), dtype=int) + if defective_pixel_list is not None: + for (r,c) in defective_pixel_list: + good_pixel_mask[r,c] = 0 + + # crop the scan if the size is not divisible by downsample_factor. + new_size1 = downsample_factor[0] * (obj_scan.shape[1] // downsample_factor[0]) + new_size2 = downsample_factor[1] * (obj_scan.shape[2] // downsample_factor[1]) + + obj_scan = obj_scan[:, 0:new_size1, 0:new_size2] + blank_scan = blank_scan[:, 0:new_size1, 0:new_size2] + dark_scan = dark_scan[:, 0:new_size1, 0:new_size2] + good_pixel_mask = good_pixel_mask[0:new_size1, 0:new_size2] + + ### Compute block sum of the high res scan images. Defective pixels are excluded. + # filter out defective pixels + good_pixel_mask = good_pixel_mask.reshape(good_pixel_mask.shape[0] // downsample_factor[0], downsample_factor[0], + good_pixel_mask.shape[1] // downsample_factor[1], downsample_factor[1]) + obj_scan = obj_scan.reshape(obj_scan.shape[0], + obj_scan.shape[1] // downsample_factor[0], downsample_factor[0], + obj_scan.shape[2] // downsample_factor[1], downsample_factor[1]) * good_pixel_mask + + blank_scan = blank_scan.reshape(blank_scan.shape[0], + blank_scan.shape[1] // downsample_factor[0], downsample_factor[0], + blank_scan.shape[2] // downsample_factor[1], downsample_factor[1]) * good_pixel_mask + dark_scan = dark_scan.reshape(dark_scan.shape[0], + dark_scan.shape[1] // downsample_factor[0], downsample_factor[0], + dark_scan.shape[2] // downsample_factor[1], downsample_factor[1]) * good_pixel_mask + + # compute block sum + obj_scan = obj_scan.sum((2,4)) + blank_scan = blank_scan.sum((2, 4)) + dark_scan = dark_scan.sum((2, 4)) + # number of good pixels in each down-sampling block + good_pixel_count = good_pixel_mask.sum((1,3)) + + # new defective pixel list = {indices of pixels where the downsampling block contains all bad pixels} + defective_pixel_list = np.argwhere(good_pixel_count < 1) + + # compute block averaging by dividing block sum with number of good pixels in the block + obj_scan = obj_scan / good_pixel_count + blank_scan = blank_scan / good_pixel_count + dark_scan = dark_scan / good_pixel_count + + return obj_scan, blank_scan, dark_scan, defective_pixel_list + + +def crop_scans(obj_scan, blank_scan, dark_scan, + crop_region=[(0, 1), (0, 1)], + defective_pixel_list=None): + """Crop obj_scan, blank_scan, and dark_scan images by decimal factors, and update defective_pixel_list accordingly. + Args: + obj_scan (float): A stack of sinograms. 3D numpy array, (num_views, num_det_rows, num_det_channels). + blank_scan (float) : A blank scan. 3D numpy array, (1, num_det_rows, num_det_channels). + dark_scan (float): A dark scan. 3D numpy array, (1, num_det_rows, num_det_channels). + crop_region ([(float, float),(float, float)] or [float, float, float, float]): + [Default=[(0, 1), (0, 1)]] Two points to define the bounding box. Sequence of [(row0, row1), (col0, col1)] or + [row0, row1, col0, col1], where 0<=row0 <= row1<=1 and 0<=col0 <= col1<=1. + + The scan images will be cropped using the following algorithm: + obj_scan <- obj_scan[:,Nr_lo:Nr_hi, Nc_lo:Nc_hi], where + - Nr_lo = round(row0 * obj_scan.shape[1]) + - Nr_hi = round(row1 * obj_scan.shape[1]) + - Nc_lo = round(col0 * obj_scan.shape[2]) + - Nc_hi = round(col1 * obj_scan.shape[2]) + + Returns: + Cropped scans + - **obj_scan** (*ndarray, float*): A stack of sinograms. 3D numpy array, (num_views, num_det_rows, num_det_channels). + - **blank_scan** (*ndarray, float*): A blank scan. 3D numpy array, (1, num_det_rows, num_det_channels). + - **dark_scan** (*ndarray, float*): A dark scan. 3D numpy array, (1, num_det_rows, num_det_channels). + """ + if isinstance(crop_region[0], (list, tuple)): + (row0, row1), (col0, col1) = crop_region + else: + row0, row1, col0, col1 = crop_region + + assert 0 <= row0 <= row1 <= 1 and 0 <= col0 <= col1 <= 1, 'crop_region should be sequence of [(row0, row1), (col0, col1)] ' \ + 'or [row0, row1, col0, col1], where 1>=row1 >= row0>=0 and 1>=col1 >= col0>=0.' + assert math.isclose(col0, 1 - col1), 'horizontal crop limits must be symmetric' + + Nr_lo = round(row0 * obj_scan.shape[1]) + Nc_lo = round(col0 * obj_scan.shape[2]) + + Nr_hi = round(row1 * obj_scan.shape[1]) + Nc_hi = round(col1 * obj_scan.shape[2]) + + obj_scan = obj_scan[:, Nr_lo:Nr_hi, Nc_lo:Nc_hi] + blank_scan = blank_scan[:, Nr_lo:Nr_hi, Nc_lo:Nc_hi] + dark_scan = dark_scan[:, Nr_lo:Nr_hi, Nc_lo:Nc_hi] + + # adjust the defective pixel information: any down-sampling block containing a defective pixel is also defective + i = 0 + while i < len(defective_pixel_list): + (r,c) = defective_pixel_list[i] + (r_new, c_new) = (r-Nr_lo, c-Nc_lo) + # delete the index tuple if it falls outside the cropped region + if (r_new<0 or r_new>=obj_scan.shape[1] or c_new<0 or c_new>=obj_scan.shape[2]): + del defective_pixel_list[i] + else: + i+=1 + return obj_scan, blank_scan, dark_scan, defective_pixel_list +######## END subroutines for image cropping and down-sampling + + +######## subroutines for loading scan images +def read_scan_img(img_path): + """Reads a single scan image from an image path. This function is a subroutine to the function `read_scan_dir`. + + Args: + img_path (string): Path object or file object pointing to an image. + The image type must be compatible with `PIL.Image.open()`. See `https://pillow.readthedocs.io/en/stable/reference/Image.html` for more details. + Returns: + ndarray (float): 2D numpy array. A single scan image. + """ + + img = np.asarray(Image.open(img_path)) + + if np.issubdtype(img.dtype, np.integer): + # make float and normalize integer types + maxval = np.iinfo(img.dtype).max + img = img.astype(np.float32) / maxval + + return img.astype(np.float32) + + +def read_scan_dir(scan_dir, view_ids=[]): + """Reads a stack of scan images from a directory. This function is a subroutine to `load_scans_and_params`. + + Args: + scan_dir (string): Path to a ConeBeam Scan directory. + Example: "/Radiographs" + view_ids (list[int]): List of view indices to specify which scans to read. + Returns: + ndarray (float): 3D numpy array, (num_views, num_det_rows, num_det_channels). A stack of scan images. + """ + + if view_ids == []: + warnings.warn("view_ids should not be empty.") + + img_path_list = sorted(glob.glob(os.path.join(scan_dir, '*'))) + img_path_list = [img_path_list[idx] for idx in view_ids] + img_list = [read_scan_img(img_path) for img_path in img_path_list] + + # return shape = num_views x num_det_rows x num_det_channels + return np.stack(img_list, axis=0) +######## END subroutines for loading scan images + + +def unit_vector(v): + """ Normalize v. Returns v/||v|| """ + return v / np.linalg.norm(v) + + +def project_vector_to_vector(u1, u2): + """ Projects the vector u1 onto the vector u2. Returns the vector . + """ + u2 = unit_vector(u2) + u1_proj = np.dot(u1, u2)*u2 + return u1_proj diff --git a/mbirjax/tomography_model.py b/mbirjax/tomography_model.py index 94b2263..b18e7f0 100644 --- a/mbirjax/tomography_model.py +++ b/mbirjax/tomography_model.py @@ -57,8 +57,7 @@ def from_file(cls, filename): ConeBeamModel with the specified parameters. """ # Load the parameters and convert to use the ConeBeamModel keywords. - params = ParameterHandler.load_param_dict(filename, values_only=True) - return cls(**params) + raise ValueError('from_file is not implemented for base TomographyModel') def to_file(self, filename): """ @@ -458,7 +457,7 @@ def recon(self, sinogram, weights=None, num_iterations=15, first_iteration=0, in # Return num_iterations, granularity, partition_sequence, fm_rmse values, regularization_params recon_param_names = ['num_iterations', 'granularity', 'partition_sequence', 'fm_rmse', 'prior_loss', - 'regularization_params', 'nrms_recon_change'] + 'regularization_params', 'nrms_recon_change', 'alpha_values'] ReconParams = namedtuple('ReconParams', recon_param_names) partition_sequence = [int(val) for val in partition_sequence] fm_rmse = [float(val) for val in loss_vectors[0]] @@ -467,8 +466,9 @@ def recon(self, sinogram, weights=None, num_iterations=15, first_iteration=0, in else: prior_loss = [0] nrms_recon_change = [float(val) for val in loss_vectors[2]] + alpha_values = [float(val) for val in loss_vectors[3]] recon_param_values = [num_iterations, granularity, partition_sequence, fm_rmse, prior_loss, - regularization_params._asdict(), nrms_recon_change] + regularization_params._asdict(), nrms_recon_change, alpha_values] recon_params = ReconParams(*tuple(recon_param_values)) return recon, recon_params @@ -553,19 +553,21 @@ def vcd_recon(self, sinogram, partitions, partition_sequence, weights=None, init fm_rmse = np.zeros(num_iters) pm_loss = np.zeros(num_iters) nrms_update = np.zeros(num_iters) + alpha_values = np.zeros(num_iters) for i in range(num_iters): # Get the current partition (set of subsets) and shuffle the subsets partition = partitions[partition_sequence[i]] subset_indices = np.random.permutation(partition.shape[0]) # Do an iteration - vcd_data = [error_sinogram, flat_recon, partition, nrms_update[i]] - error_sinogram, flat_recon, norm_square_update = vcd_partition_iterator(vcd_data, subset_indices) + vcd_data = [error_sinogram, flat_recon, partition, nrms_update[i], alpha_values[i]] + error_sinogram, flat_recon, norm_square_update, alpha = vcd_partition_iterator(vcd_data, subset_indices) # Compute the stats and display as desired fm_rmse[i] = self.get_forward_model_loss(error_sinogram, sigma_y, weights) nrms_update[i] = norm_square_update / jnp.sum(flat_recon * flat_recon) es_rmse = jnp.linalg.norm(error_sinogram) / jnp.sqrt(error_sinogram.size) + alpha_values[i] = alpha if verbose >= 1: iter_output = 'After iteration {}: Pct change={:.3f}, Forward loss={:.3f}'.format(i + first_iteration, 100*nrms_update[i], fm_rmse[i]) @@ -583,12 +585,12 @@ def vcd_recon(self, sinogram, partitions, partition_sequence, weights=None, init iter_output += ', Prior loss={:.3f}, Weighted total loss={:.3f}'.format(pm_loss[i], total_loss) print(iter_output) - print(f'Error sino RMSE={es_rmse:.3f}') + print(f'Relative step size (alpha)={alpha:.2f}, Error sino RMSE={es_rmse:.3f}') if verbose >= 2: mbirjax.get_memory_stats() print('--------') - return self.reshape_recon(flat_recon), (fm_rmse, pm_loss, nrms_update) + return self.reshape_recon(flat_recon), (fm_rmse, pm_loss, nrms_update, alpha_values) @staticmethod def create_vcd_partition_iterator(vcd_subset_iterator): @@ -617,19 +619,21 @@ def vcd_partition_iterator(sinogram_recon_partition, subset_indices): * flat_recon (jax array): 2D array reconstruction with shape (num_recon_rows x num_recon_cols, num_recon_slices). * partition (jax array): 2D array where partition[subset_index] gives a 1D array of pixel indices. * Squared sum of changes to the recon during this iteration but before this subset. + * alpha_sum (float): Sum of the alpha step sizes over previous subsets in this partition. subset_indices (jax array): An array of indices into the partition - this gives the order in which the subsets are updated. Returns: - (error_sinogram, flat_recon, norm_square_update): The first two have the same shape as above, but + (error_sinogram, flat_recon, norm_square_update, alpha): The first two have the same shape as above, but are updated to reduce overall loss function. The norm_square_update includes the changes from this subset. + alpha is the relative step size in the gradient descent step, averaged over the subsets in the partition.. """ # Scan over the subsets of the partition, using the subset_indices to order them. sinogram_recon_partition, _ = jax.lax.scan(vcd_subset_iterator, sinogram_recon_partition, subset_indices) - error_sinogram, flat_recon, _, norm_square_update = sinogram_recon_partition - return error_sinogram, flat_recon, norm_square_update + error_sinogram, flat_recon, partition, norm_square_update, alpha_sum = sinogram_recon_partition + return error_sinogram, flat_recon, norm_square_update, alpha_sum / partition.shape[0] return jax.jit(vcd_partition_iterator) @@ -682,7 +686,7 @@ def vcd_subset_iterator(sinogram_recon_partition, subset_index): [error_sinogram, flat_recon, partition, norm_square_update]: The first two have the same shape as above, but are updated to reduce overall loss function. The norm_square_update includes the changes from this subset. """ - error_sinogram, flat_recon, partition, norm_square_update = sinogram_recon_partition + error_sinogram, flat_recon, partition, norm_square_update, alpha_prev_sum = sinogram_recon_partition pixel_indices = partition[subset_index] def delta_recon_batch(index_batch): @@ -776,7 +780,7 @@ def delta_recon_batch(index_batch): error_sinogram = error_sinogram - alpha * delta_sinogram norm_square_update += jnp.sum(update * update) - return [error_sinogram, flat_recon, partition, norm_square_update], None + return [error_sinogram, flat_recon, partition, norm_square_update, alpha + alpha_prev_sum], None return jax.jit(vcd_subset_iterator) @@ -898,12 +902,14 @@ def get_transpose(linear_map, input_shape): Returns: transpose: A function to evaluate the transpose of the given map. The input to transpose must be a jax or ndarray with the same shape as the output of the original linear_map. - transpose(input) returns a 1 element tuple containing an array holding the result, so the final output - must be obtained using transpose(input)[0] + transpose(input) returns an array of shape input_shape. """ # print('Defining transpose map') # t0 = time.time() input_info = types.SimpleNamespace(shape=input_shape, dtype=jnp.dtype(jnp.float32)) - transpose = jax.linear_transpose(linear_map, input_info) - # print('Done: ' + str(time.time() - t0)) + transpose_list = jax.linear_transpose(linear_map, input_info) + + def transpose(input): + return transpose_list(input)[0] + return transpose diff --git a/mbirjax/vcd_utils.py b/mbirjax/vcd_utils.py index d7e713f..8e0ed78 100644 --- a/mbirjax/vcd_utils.py +++ b/mbirjax/vcd_utils.py @@ -51,6 +51,55 @@ def gen_set_of_pixel_partitions(recon_shape, granularity): return partitions +def gen_pixel_partition_grid(recon_shape, num_subsets): + + small_tile_side = np.ceil(np.sqrt(num_subsets)).astype(int) + num_subsets = small_tile_side ** 2 + num_small_tiles = [np.ceil(recon_shape[k] / small_tile_side).astype(int) for k in [0, 1]] + + single_subset_inds = np.random.permutation(num_subsets).reshape((small_tile_side, small_tile_side)) + subset_inds = np.tile(single_subset_inds, num_small_tiles) + subset_inds = subset_inds[:recon_shape[0], :recon_shape[1]] + + ror_mask = get_2d_ror_mask(recon_shape[:2]) + subset_inds = (subset_inds + 1) * ror_mask - 1 # Get a - at each location outside the mask, subset_ind at other points + subset_inds = subset_inds.flatten() + num_inds = len(np.where(subset_inds > -1)[0]) + + if num_subsets > num_inds: + # num_subsets = len(indices) + warning = '\nThe number of partition subsets is greater than the number of pixels in the region of ' + warning += 'reconstruction. \nReducing the number of subsets to equal the number of indices.' + warnings.warn(warning) + subset_inds = subset_inds[subset_inds > -1] + return jnp.array(subset_inds).reshape((-1, 1)) + + flat_inds = [] + max_points = 0 + min_points = subset_inds.size + nonempty_subsets = np.unique(subset_inds[subset_inds>=0]) + for k in nonempty_subsets: + cur_inds = np.where(subset_inds == k)[0] + flat_inds.append(cur_inds) # Get all the indices for each subset + max_points = max(max_points, cur_inds.size) + min_points = min(min_points, cur_inds.size) + + extra_point_inds = np.random.randint(min_points, size=(max_points - min_points + 1,)) + for k in range(len(nonempty_subsets)): + cur_inds = flat_inds[k] + num_extra_points = max_points - cur_inds.size + if num_extra_points > 0: + extra_subset_inds = (k + 1 + np.arange(num_extra_points, dtype=int)) % len(nonempty_subsets) + new_point_inds = [flat_inds[extra_subset_inds[j]][extra_point_inds[j]] for j in range(num_extra_points)] + flat_inds[k] = np.concatenate((cur_inds, new_point_inds)) + flat_inds = np.array(flat_inds) + + # Reorganize into subsets, then sort each subset + indices = jnp.array(flat_inds) + + return jnp.array(indices) + + def gen_pixel_partition(recon_shape, num_subsets): """ Generates a partition of pixel indices into specified number of subsets for use in tomographic reconstruction algorithms. @@ -94,7 +143,7 @@ def gen_pixel_partition(recon_shape, num_subsets): # Reorganize into subsets, then sort each subset indices = indices.reshape(num_subsets, indices.size // num_subsets) - indices = np.sort(indices, axis=1) + indices = jnp.sort(indices, axis=1) return jnp.array(indices) diff --git a/pyproject.toml b/pyproject.toml index e887a4b..4df2610 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mbirjax" -version = "0.3.3" +version = "0.4.0" description = "High-performance tomographic reconstruction" keywords = ["tomography", "tomographic reconstruction", "computed tomography"] readme = "README.rst" diff --git a/tests/test_projectors.py b/tests/test_projectors.py index 29256a4..6ba5f85 100644 --- a/tests/test_projectors.py +++ b/tests/test_projectors.py @@ -1,4 +1,5 @@ import numpy as np +import os import jax import jax.numpy as jnp import mbirjax @@ -18,7 +19,7 @@ class TestProjectors(unittest.TestCase): def setUp(self): """Set up before each test method.""" # Choose the geometry type - self.geometry_types = ['parallel', 'cone'] + self.geometry_types = mbirjax._utils._geometry_types_for_tests # Set parameters self.num_views = 64 @@ -55,6 +56,9 @@ def get_model(self, geometry_type): source_iso_dist=self.source_iso_dist) elif geometry_type == 'parallel': ct_model = mbirjax.ParallelBeamModel(self.sinogram_shape, self.angles) + elif geometry_type == 'blur': + sigma = 2.0 + ct_model = mbirjax.Blur(self.sinogram_shape, sigma) else: raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type)) @@ -72,6 +76,69 @@ def test_all_hessians(self): print("Testing Hessian with", geometry_type) self.verify_hessian(geometry_type) + def test_save_load(self): + for geometry_type in self.geometry_types: + with self.subTest(geometry_type=geometry_type): + print("Testing save/load with", geometry_type) + self.verify_save_load(geometry_type) + + def verify_save_load(self, geometry_type): + """ + Verify the adjoint property of the projectors: + Choose a random phantom, x, and a random sinogram, y, and verify that = . + """ + self.set_angles(geometry_type) + ct_model = self.get_model(geometry_type) + + # Generate phantom + recon_shape = ct_model.get_params('recon_shape') + num_recon_rows, num_recon_cols, num_recon_slices = recon_shape[:3] + + # Get the vector of indices + indices = jnp.arange(num_recon_rows * num_recon_cols) + + # ########################## + # Do a forward and back projection from a single pixel + i, j = num_recon_rows // 4, num_recon_cols // 3 + x = jnp.zeros(recon_shape) + x = x.at[i, j, :].set(1) + voxel_values = x.reshape((-1, num_recon_slices))[indices] + + Ax = ct_model.sparse_forward_project(voxel_values, indices) + Aty = ct_model.sparse_back_project(Ax, indices) + Aty = ct_model.reshape_recon(Aty) + + # Save the model + filename = 'saved_model_test.yaml' + ct_model.to_file(filename) + + # Load the model + new_model = self.get_model(geometry_type) + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_model = new_model.from_file(filename) + if os.path.exists(filename): + os.remove(filename) + + # Compare parameters + for key, entry in ct_model.params.items(): + if isinstance(entry['val'], list): + entry['val'] = tuple(entry['val']) + loaded_entry = new_model.params[key] + if isinstance(entry['val'], type(jnp.zeros(1))): + assert(jnp.allclose(entry['val'], loaded_entry['val'])) + else: + assert(entry == loaded_entry) + + # Do a forward and back projection with loaded model + Ax_new = new_model.sparse_forward_project(voxel_values, indices) + Aty_new = new_model.sparse_back_project(Ax_new, indices) + Aty_new = new_model.reshape_recon(Aty_new) + + # Compare to original + assert(np.allclose(Aty, Aty_new, atol=1e-4)) + def verify_adjoint(self, geometry_type): """ Verify the adjoint property of the projectors: diff --git a/tests/test_vcd.py b/tests/test_vcd.py index be47623..63df43c 100644 --- a/tests/test_vcd.py +++ b/tests/test_vcd.py @@ -5,7 +5,7 @@ import unittest -class TestProjectors(unittest.TestCase): +class TestVCD(unittest.TestCase): """ Test the adjoint property of the forward and back projectors, both the full versions and the sparse voxel version. This means if x is an image, and y is a sinogram, then = . @@ -19,10 +19,11 @@ def setUp(self): """Set up before each test method.""" np.random.seed(0) # Set a seed to avoid variations due to partition creation. # Choose the geometry type - self.geometry_types = ['parallel', 'cone'] + self.geometry_types = mbirjax._utils._geometry_types_for_tests parallel_tolerances = {'nrmse': 0.15, 'max_diff': 0.38, 'pct_95': 0.04} cone_tolerances = {'nrmse': 0.19, 'max_diff': 0.56, 'pct_95': 0.05} - self.all_tolerances = [parallel_tolerances, cone_tolerances] + blur_tolerances = {'nrmse': 0.19, 'max_diff': 0.56, 'pct_95': 0.05} + self.all_tolerances = [parallel_tolerances, cone_tolerances, blur_tolerances] # Set parameters self.num_views = 64 @@ -59,6 +60,11 @@ def get_model(self, geometry_type): source_iso_dist=self.source_iso_dist) elif geometry_type == 'parallel': ct_model = mbirjax.ParallelBeamModel(self.sinogram_shape, self.angles) + elif geometry_type == 'blur': + sigma = 0.5 + sinogram_shape = (64, 64, 32) + ct_model = mbirjax.Blur(sinogram_shape, sigma) + ct_model.set_params(sharpness=-2) else: raise ValueError('Invalid geometry type. Expected cone or parallel, got {}'.format(geometry_type))