diff --git a/run_tests.sh b/run_tests.sh index e41c0ec..aa4e611 100644 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +set -e pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest +# We test ndft_test.py separately as it causes some issues with tracing resulting in hangs. python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py python -m pytest tfkbnufft/tests/ndft_test.py diff --git a/tfkbnufft/__init__.py b/tfkbnufft/__init__.py index ae3e4fe..f360877 100644 --- a/tfkbnufft/__init__.py +++ b/tfkbnufft/__init__.py @@ -1,6 +1,6 @@ """Package info""" -__version__ = '0.2.2' +__version__ = '0.2.3' __author__ = 'Zaccharie Ramzi' __author_email__ = 'zaccharie.ramzi@inria.fr' __license__ = 'MIT' diff --git a/tfkbnufft/mri/dcomp_calc.py b/tfkbnufft/mri/dcomp_calc.py index 627201e..60018e0 100644 --- a/tfkbnufft/mri/dcomp_calc.py +++ b/tfkbnufft/mri/dcomp_calc.py @@ -85,7 +85,7 @@ def calculate_radial_dcomp_tf(interpob, nufftob_forw, nufftob_back, ktraj, stack return dcomp -def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, num_iterations=10): +def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, num_iterations=10, zero_grad=True): """Numerical density compensation estimation for a any trajectory. Estimates the density compensation function numerically using a NUFFT @@ -105,28 +105,47 @@ def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, n trajectory. num_iterations (int): default 10 number of iterations + zero_grad (bool): default True + when true, assumes that the density compensator is a constant and + returns zero gradients Returns: tensor: The density compensation coefficients for ktraj of size (m). """ - test_sig = tf.ones([1, 1, ktraj.shape[1]], dtype=tf.complex64) - for i in range(num_iterations): - test_sig = test_sig / tf.cast(tf.math.abs(kbinterp( - adjkbinterp(test_sig, ktraj[None, :], interpob), - ktraj[None, :], - interpob - )), 'complex64') - im_size = interpob['im_size'] - test_size = tf.concat([(1, 1,), im_size], axis=0) - test_im = tf.ones(test_size, dtype=tf.complex64) - test_im_recon = nufftob_back( - test_sig * nufftob_forw( - test_im, + def _calculate_density_compensator(ktraj): + test_sig = tf.ones([1, 1, ktraj.shape[1]], dtype=tf.float32) + for i in range(num_iterations): + test_sig = test_sig / tf.math.abs(kbinterp( + adjkbinterp(tf.cast(test_sig, tf.complex64), ktraj[None, :], interpob), + ktraj[None, :], + interpob + )) + im_size = interpob['im_size'] + test_sig = tf.cast(test_sig, tf.complex64) + test_size = tf.concat([(1, 1,), im_size], axis=0) + test_im = tf.ones(test_size, dtype=tf.complex64) + test_im_recon = nufftob_back( + test_sig * nufftob_forw( + test_im, + ktraj[None, :] + ), ktraj[None, :] - ), - ktraj[None, :] - ) - ratio = tf.reduce_mean(test_im_recon) - test_sig = test_sig / tf.cast(ratio, test_sig.dtype) - test_sig = test_sig[0, 0] - return test_sig + ) + ratio = tf.reduce_mean(tf.math.abs(test_im_recon)) + test_sig = test_sig / tf.cast(ratio, test_sig.dtype) + test_sig = test_sig[0, 0] + return test_sig + + @tf.custom_gradient + def _calculate_density_compensator_no_grad(ktraj): + """Internal function that returns density compensators, but also returns + no gradients""" + dc_weights = _calculate_density_compensator(ktraj) + def grad(dy): + return None + return dc_weights, grad + + if zero_grad: + return _calculate_density_compensator_no_grad(ktraj) + else: + return _calculate_density_compensator(ktraj) diff --git a/tfkbnufft/tests/mri/dcomp_calc_test.py b/tfkbnufft/tests/mri/dcomp_calc_test.py index 34efaf5..07b90c6 100644 --- a/tfkbnufft/tests/mri/dcomp_calc_test.py +++ b/tfkbnufft/tests/mri/dcomp_calc_test.py @@ -57,4 +57,5 @@ def test_density_compensators_tf(): tf_ktraj = tf.convert_to_tensor(ktraj) nufftob_back = kbnufft_adjoint(interpob) nufftob_forw = kbnufft_forward(interpob) - tf_dcomp = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj) + tf_dcomp = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=False) + tf_dcomp_no_grad = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=True)