From 9afb3396aa843abf65f5ffe83947b7e400a98c52 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Wed, 21 Jul 2021 12:45:45 +0200 Subject: [PATCH] Add batching codes in grads (#40) * Add batching codes in grads * NDFT Test * make batching right * Remove parallel iterations Co-authored-by: chaithyagr --- tfkbnufft/kbnufft.py | 17 +++++++++++++-- tfkbnufft/tests/ndft_test.py | 42 +++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index 951f707..761b163 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -192,7 +192,13 @@ def grad(dy): grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] fft_dx_dom = scale_and_fft_on_image_volume( x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - dy_dom = tf.cast(-1j * tf.math.conj(dy) * kbinterp(fft_dx_dom, om, interpob), tf.float32) + # Do this when handling batches + fft_dx_dom = tf.reshape(fft_dx_dom, shape=(-1, 1, *fft_dx_dom.shape[2:])) + nufft_dx_dom = kbinterp(fft_dx_dom, tf.repeat(om, im_rank, axis=0), interpob) + # Unbatch back the data + nufft_dx_dom = tf.reshape(nufft_dx_dom, shape=(-1, im_rank, *nufft_dx_dom.shape[2:])) + dy_dom = tf.cast(-1j * tf.math.conj(dy) * nufft_dx_dom, om.dtype) + # dy_dom = tf.math.reduce_sum(dy_dom, axis=1)[None, :] else: dy_dom = None return ifft_dy, dy_dom @@ -234,10 +240,17 @@ def grad(dx): if grad_traj: # Gradients with respect to trajectory locations r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + # This wont work for multicoil case as the dimension for dx is `batch_size x coil x Nx x Ny` grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] ifft_dxr = scale_and_fft_on_image_volume( tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) - dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) + # Do this when handling batches + ifft_dxr = tf.reshape(ifft_dxr, shape=(-1, 1, *ifft_dxr.shape[2:])) + inufft_dxr = kbinterp(ifft_dxr, tf.repeat(om, im_rank, axis=0), interpob, conj=True) + # Unbatch back the data + inufft_dxr = tf.reshape(inufft_dxr, shape=(-1, im_rank, *inufft_dxr.shape[2:])) + dx_dom = tf.cast(1j * y * inufft_dxr, om.dtype) + # dx_dom = tf.math.reduce_sum(dx_dom, axis=1)[None, :] else: dx_dom = None return dx_dy, dx_dom diff --git a/tfkbnufft/tests/ndft_test.py b/tfkbnufft/tests/ndft_test.py index 05fc7e6..1943118 100644 --- a/tfkbnufft/tests/ndft_test.py +++ b/tfkbnufft/tests/ndft_test.py @@ -8,7 +8,7 @@ def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False): r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) - traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64) + traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj, [0, 2, 1]), tf.repeat(grid_r[None], ktraj.shape[0], axis=0)), tf.complex64) if do_ifft: A = tf.exp(1j * traj_grid) else: @@ -18,41 +18,42 @@ def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False): @pytest.mark.parametrize('im_size', [(10, 10)]) -def test_adjoint_and_gradients(im_size): +@pytest.mark.parametrize('batch_size', [1, 2]) +def test_adjoint_and_gradients(im_size, batch_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal - signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) + signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64)) kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)) Idata = tf.Variable(kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori)) ktraj_noise = np.copy(ktraj_ori) - ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) ktraj = tf.Variable(ktraj_noise) with tf.GradientTape(persistent=True) as g: - I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] + I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj) A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True) - I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size) + I_ndft = tf.reshape(tf.transpose(tf.matmul(kdata, A), [0, 1, 2]), (batch_size, 1, *im_size)) loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2) loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-3) + tf_test.assertAllClose(I_nufft, I_ndft, atol=2e-3) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3) # Test gradients in chain rule with respect to ktraj gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] @@ -64,40 +65,41 @@ def test_adjoint_and_gradients(im_size): @pytest.mark.parametrize('im_size', [(10, 10)]) -def test_forward_and_gradients(im_size): +@pytest.mark.parametrize('batch_size', [1, 2]) +def test_forward_and_gradients(im_size, batch_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal - signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) - kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)[0] + signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64)) + kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori) ktraj_noise = np.copy(ktraj_ori) - ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) ktraj = tf.Variable(ktraj_noise) with tf.GradientTape(persistent=True) as g: - kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] + kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj) A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False) - kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1)))) + kdata_ndft = tf.matmul(tf.reshape(signal, (batch_size, 1, tf.reduce_prod(im_size))), tf.transpose(A, [0, 2, 1])) loss_nufft = tf.math.reduce_mean(tf.abs(kdata - kdata_nufft)**2) loss_ndft = tf.math.reduce_mean(tf.abs(kdata - kdata_ndft)**2) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-3) + tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=2e-3) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0] gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3) # Test gradients in chain rule with respect to ktraj gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0]