diff --git a/README.md b/README.md index d41c1e5..a7a4527 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,16 @@ python profile_tfkbnufft.py These numbers were obtained with a Quadro P5000. + +## Gradients + +### w.r.t trajectory + +This is experimental currently and is WIP. Please be cautious. +Currently this is tested in CI against results from NDFT, but clear mathematical backing to some +aspects are still being understood for applying the chain rule. + + ## References 1. Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation. *IEEE transactions on signal processing*, 51(2), 560-574. diff --git a/tfkbnufft/__init__.py b/tfkbnufft/__init__.py index e231808..ae3e4fe 100644 --- a/tfkbnufft/__init__.py +++ b/tfkbnufft/__init__.py @@ -1,6 +1,6 @@ """Package info""" -__version__ = '0.2.1' +__version__ = '0.2.2' __author__ = 'Zaccharie Ramzi' __author_email__ = 'zaccharie.ramzi@inria.fr' __license__ = 'MIT' diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index 8e267ce..951f707 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -43,6 +43,9 @@ def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None, self.im_size = im_size self.im_rank = len(im_size) self.grad_traj = grad_traj + if self.grad_traj: + warnings.warn('The gradient w.r.t trajectory is Experimental and WIP. ' + 'Please use with caution') if grid_size is None: self.grid_size = tuple(np.array(self.im_size) * 2) else: @@ -189,7 +192,7 @@ def grad(dy): grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] fft_dx_dom = scale_and_fft_on_image_volume( x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32) + dy_dom = tf.cast(-1j * tf.math.conj(dy) * kbinterp(fft_dx_dom, om, interpob), tf.float32) else: dy_dom = None return ifft_dy, dy_dom @@ -233,7 +236,7 @@ def grad(dx): r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] ifft_dxr = scale_and_fft_on_image_volume( - dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) + tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) else: dx_dom = None diff --git a/tfkbnufft/tests/ndft_test.py b/tfkbnufft/tests/ndft_test.py index 2cdcd7b..05fc7e6 100644 --- a/tfkbnufft/tests/ndft_test.py +++ b/tfkbnufft/tests/ndft_test.py @@ -25,29 +25,40 @@ def test_adjoint_and_gradients(im_size): M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) - kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)) - + kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)) + Idata = tf.Variable(kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori)) + ktraj_noise = np.copy(ktraj_ori) + ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj = tf.Variable(ktraj_noise) with tf.GradientTape(persistent=True) as g: I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True) I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size) + loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2) + loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-2) + tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-3) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3) + + # Test gradients in chain rule with respect to ktraj + gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] + gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0] + tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4) + # This is gradient of NDFT from matrix, will help in debug # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] @@ -60,27 +71,37 @@ def test_forward_and_gradients(im_size): M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) - + kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)[0] + ktraj_noise = np.copy(ktraj_ori) + ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj = tf.Variable(ktraj_noise) with tf.GradientTape(persistent=True) as g: kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False) kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1)))) + loss_nufft = tf.math.reduce_mean(tf.abs(kdata - kdata_nufft)**2) + loss_ndft = tf.math.reduce_mean(tf.abs(kdata - kdata_ndft)**2) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-2) + tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-3) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0] gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3) + + # Test gradients in chain rule with respect to ktraj + gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] + gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0] + tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4) # This is gradient of NDFT from matrix, will help in debug - # gradient_ndft_matrix = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) + # gradient_ndft_matrix = -1j * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(grid_r, tf.complex64) * signal[0][0])))