Skip to content

Commit

Permalink
Fix Gradient for trajectory (#37)
Browse files Browse the repository at this point in the history
* Merged and fix grad codes

* Remove unwanted

* Remove unwanted debug lines

* Loosen constraints

* Add warnings, update the constraints, update version

* Update tfkbnufft/tests/ndft_test.py

Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>

* Update tfkbnufft/tests/ndft_test.py

Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>

* Add readme

* Update README.md

Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>

Co-authored-by: chaithyagr <chaithyagr@gitlab.com>
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
  • Loading branch information
3 people authored Apr 19, 2021
1 parent 3002c90 commit 0d29bf4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ python profile_tfkbnufft.py

These numbers were obtained with a Quadro P5000.


## Gradients

### w.r.t trajectory

This is experimental currently and is WIP. Please be cautious.
Currently this is tested in CI against results from NDFT, but clear mathematical backing to some
aspects are still being understood for applying the chain rule.


## References

1. Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation. *IEEE transactions on signal processing*, 51(2), 560-574.
Expand Down
2 changes: 1 addition & 1 deletion tfkbnufft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Package info"""

__version__ = '0.2.1'
__version__ = '0.2.2'
__author__ = 'Zaccharie Ramzi'
__author_email__ = 'zaccharie.ramzi@inria.fr'
__license__ = 'MIT'
Expand Down
7 changes: 5 additions & 2 deletions tfkbnufft/kbnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None,
self.im_size = im_size
self.im_rank = len(im_size)
self.grad_traj = grad_traj
if self.grad_traj:
warnings.warn('The gradient w.r.t trajectory is Experimental and WIP. '
'Please use with caution')
if grid_size is None:
self.grid_size = tuple(np.array(self.im_size) * 2)
else:
Expand Down Expand Up @@ -189,7 +192,7 @@ def grad(dy):
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...]
fft_dx_dom = scale_and_fft_on_image_volume(
x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32)
dy_dom = tf.cast(-1j * tf.math.conj(dy) * kbinterp(fft_dx_dom, om, interpob), tf.float32)
else:
dy_dom = None
return ifft_dy, dy_dom
Expand Down Expand Up @@ -233,7 +236,7 @@ def grad(dx):
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...]
ifft_dxr = scale_and_fft_on_image_volume(
dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True)
tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True)
dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype)
else:
dx_dom = None
Expand Down
45 changes: 33 additions & 12 deletions tfkbnufft/tests/ndft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,40 @@ def test_adjoint_and_gradients(im_size):
M = im_size[0] * 2**im_rank
nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True)
# Generate Trajectory
ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
# Have a random signal
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64))
kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj))

kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori))
Idata = tf.Variable(kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori))
ktraj_noise = np.copy(ktraj_ori)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj = tf.Variable(ktraj_noise)
with tf.GradientTape(persistent=True) as g:
I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0]
A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True)
I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size)
loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2)
loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2)

tf_test = tf.test.TestCase()
# Test if the NUFFT and NDFT operation is same
tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-2)
tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-3)

# Test gradients with respect to kdata
gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0]
gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0]
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2)
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3)

# Test gradients with respect to trajectory location
gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0]
gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2)
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3)

# Test gradients in chain rule with respect to ktraj
gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0]
gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4)

# This is gradient of NDFT from matrix, will help in debug
# gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0]

Expand All @@ -60,27 +71,37 @@ def test_forward_and_gradients(im_size):
M = im_size[0] * 2**im_rank
nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True)
# Generate Trajectory
ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
# Have a random signal
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64))

kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)[0]
ktraj_noise = np.copy(ktraj_ori)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj = tf.Variable(ktraj_noise)
with tf.GradientTape(persistent=True) as g:
kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0]
A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False)
kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1))))
loss_nufft = tf.math.reduce_mean(tf.abs(kdata - kdata_nufft)**2)
loss_ndft = tf.math.reduce_mean(tf.abs(kdata - kdata_ndft)**2)

tf_test = tf.test.TestCase()
# Test if the NUFFT and NDFT operation is same
tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-2)
tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-3)

# Test gradients with respect to kdata
gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0]
gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0]
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2)
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3)

# Test gradients with respect to trajectory location
gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0]
gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2)
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3)

# Test gradients in chain rule with respect to ktraj
gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0]
gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4)
# This is gradient of NDFT from matrix, will help in debug
# gradient_ndft_matrix = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,)))))
# gradient_ndft_matrix = -1j * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(grid_r, tf.complex64) * signal[0][0])))

0 comments on commit 0d29bf4

Please sign in to comment.