Skip to content

Commit

Permalink
Make gradient of DC estimation to zero based on an argument (#38)
Browse files Browse the repository at this point in the history
This also solves a problem with CI regarding test failures (#39 )
Co-authored-by: chaithyagr <chaithyagr@gitlab.com>
  • Loading branch information
chaithyagr authored Apr 22, 2021
1 parent 0d29bf4 commit 6739678
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 23 deletions.
2 changes: 2 additions & 0 deletions run_tests.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env bash
set -e
pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest
# We test ndft_test.py separately as it causes some issues with tracing resulting in hangs.
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
2 changes: 1 addition & 1 deletion tfkbnufft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Package info"""

__version__ = '0.2.2'
__version__ = '0.2.3'
__author__ = 'Zaccharie Ramzi'
__author_email__ = 'zaccharie.ramzi@inria.fr'
__license__ = 'MIT'
Expand Down
61 changes: 40 additions & 21 deletions tfkbnufft/mri/dcomp_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def calculate_radial_dcomp_tf(interpob, nufftob_forw, nufftob_back, ktraj, stack
return dcomp


def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, num_iterations=10):
def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, num_iterations=10, zero_grad=True):
"""Numerical density compensation estimation for a any trajectory.
Estimates the density compensation function numerically using a NUFFT
Expand All @@ -105,28 +105,47 @@ def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, n
trajectory.
num_iterations (int): default 10
number of iterations
zero_grad (bool): default True
when true, assumes that the density compensator is a constant and
returns zero gradients
Returns:
tensor: The density compensation coefficients for ktraj of size (m).
"""
test_sig = tf.ones([1, 1, ktraj.shape[1]], dtype=tf.complex64)
for i in range(num_iterations):
test_sig = test_sig / tf.cast(tf.math.abs(kbinterp(
adjkbinterp(test_sig, ktraj[None, :], interpob),
ktraj[None, :],
interpob
)), 'complex64')
im_size = interpob['im_size']
test_size = tf.concat([(1, 1,), im_size], axis=0)
test_im = tf.ones(test_size, dtype=tf.complex64)
test_im_recon = nufftob_back(
test_sig * nufftob_forw(
test_im,
def _calculate_density_compensator(ktraj):
test_sig = tf.ones([1, 1, ktraj.shape[1]], dtype=tf.float32)
for i in range(num_iterations):
test_sig = test_sig / tf.math.abs(kbinterp(
adjkbinterp(tf.cast(test_sig, tf.complex64), ktraj[None, :], interpob),
ktraj[None, :],
interpob
))
im_size = interpob['im_size']
test_sig = tf.cast(test_sig, tf.complex64)
test_size = tf.concat([(1, 1,), im_size], axis=0)
test_im = tf.ones(test_size, dtype=tf.complex64)
test_im_recon = nufftob_back(
test_sig * nufftob_forw(
test_im,
ktraj[None, :]
),
ktraj[None, :]
),
ktraj[None, :]
)
ratio = tf.reduce_mean(test_im_recon)
test_sig = test_sig / tf.cast(ratio, test_sig.dtype)
test_sig = test_sig[0, 0]
return test_sig
)
ratio = tf.reduce_mean(tf.math.abs(test_im_recon))
test_sig = test_sig / tf.cast(ratio, test_sig.dtype)
test_sig = test_sig[0, 0]
return test_sig

@tf.custom_gradient
def _calculate_density_compensator_no_grad(ktraj):
"""Internal function that returns density compensators, but also returns
no gradients"""
dc_weights = _calculate_density_compensator(ktraj)
def grad(dy):
return None
return dc_weights, grad

if zero_grad:
return _calculate_density_compensator_no_grad(ktraj)
else:
return _calculate_density_compensator(ktraj)
3 changes: 2 additions & 1 deletion tfkbnufft/tests/mri/dcomp_calc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ def test_density_compensators_tf():
tf_ktraj = tf.convert_to_tensor(ktraj)
nufftob_back = kbnufft_adjoint(interpob)
nufftob_forw = kbnufft_forward(interpob)
tf_dcomp = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj)
tf_dcomp = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=False)
tf_dcomp_no_grad = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=True)

0 comments on commit 6739678

Please sign in to comment.