Skip to content

Commit

Permalink
Add batching codes in grads (#40)
Browse files Browse the repository at this point in the history
* Add batching codes in grads

* NDFT Test

* make batching right

* Remove parallel iterations

Co-authored-by: chaithyagr <chaithyagr@gitlab.com>
  • Loading branch information
chaithyagr and chaithyagr authored Jul 21, 2021
1 parent 6739678 commit 9afb339
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
17 changes: 15 additions & 2 deletions tfkbnufft/kbnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,13 @@ def grad(dy):
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...]
fft_dx_dom = scale_and_fft_on_image_volume(
x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
dy_dom = tf.cast(-1j * tf.math.conj(dy) * kbinterp(fft_dx_dom, om, interpob), tf.float32)
# Do this when handling batches
fft_dx_dom = tf.reshape(fft_dx_dom, shape=(-1, 1, *fft_dx_dom.shape[2:]))
nufft_dx_dom = kbinterp(fft_dx_dom, tf.repeat(om, im_rank, axis=0), interpob)
# Unbatch back the data
nufft_dx_dom = tf.reshape(nufft_dx_dom, shape=(-1, im_rank, *nufft_dx_dom.shape[2:]))
dy_dom = tf.cast(-1j * tf.math.conj(dy) * nufft_dx_dom, om.dtype)
# dy_dom = tf.math.reduce_sum(dy_dom, axis=1)[None, :]
else:
dy_dom = None
return ifft_dy, dy_dom
Expand Down Expand Up @@ -234,10 +240,17 @@ def grad(dx):
if grad_traj:
# Gradients with respect to trajectory locations
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
# This wont work for multicoil case as the dimension for dx is `batch_size x coil x Nx x Ny`
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...]
ifft_dxr = scale_and_fft_on_image_volume(
tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True)
dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype)
# Do this when handling batches
ifft_dxr = tf.reshape(ifft_dxr, shape=(-1, 1, *ifft_dxr.shape[2:]))
inufft_dxr = kbinterp(ifft_dxr, tf.repeat(om, im_rank, axis=0), interpob, conj=True)
# Unbatch back the data
inufft_dxr = tf.reshape(inufft_dxr, shape=(-1, im_rank, *inufft_dxr.shape[2:]))
dx_dom = tf.cast(1j * y * inufft_dxr, om.dtype)
# dx_dom = tf.math.reduce_sum(dx_dom, axis=1)[None, :]
else:
dx_dom = None
return dx_dy, dx_dom
Expand Down
42 changes: 22 additions & 20 deletions tfkbnufft/tests/ndft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False):
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32)
traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)
traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj, [0, 2, 1]), tf.repeat(grid_r[None], ktraj.shape[0], axis=0)), tf.complex64)
if do_ifft:
A = tf.exp(1j * traj_grid)
else:
Expand All @@ -18,41 +18,42 @@ def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False):


@pytest.mark.parametrize('im_size', [(10, 10)])
def test_adjoint_and_gradients(im_size):
@pytest.mark.parametrize('batch_size', [1, 2])
def test_adjoint_and_gradients(im_size, batch_size):
tf.random.set_seed(0)
grid_size = tuple(np.array(im_size)*2)
im_rank = len(im_size)
M = im_size[0] * 2**im_rank
nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True)
# Generate Trajectory
ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
# Have a random signal
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64))
signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64))
kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori))
Idata = tf.Variable(kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori))
ktraj_noise = np.copy(ktraj_ori)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj = tf.Variable(ktraj_noise)
with tf.GradientTape(persistent=True) as g:
I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0]
I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)
A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True)
I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size)
I_ndft = tf.reshape(tf.transpose(tf.matmul(kdata, A), [0, 1, 2]), (batch_size, 1, *im_size))
loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2)
loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2)

tf_test = tf.test.TestCase()
# Test if the NUFFT and NDFT operation is same
tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-3)
tf_test.assertAllClose(I_nufft, I_ndft, atol=2e-3)

# Test gradients with respect to kdata
gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0]
gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0]
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3)
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3)

# Test gradients with respect to trajectory location
gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0]
gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3)
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3)

# Test gradients in chain rule with respect to ktraj
gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0]
Expand All @@ -64,40 +65,41 @@ def test_adjoint_and_gradients(im_size):


@pytest.mark.parametrize('im_size', [(10, 10)])
def test_forward_and_gradients(im_size):
@pytest.mark.parametrize('batch_size', [1, 2])
def test_forward_and_gradients(im_size, batch_size):
tf.random.set_seed(0)
grid_size = tuple(np.array(im_size)*2)
im_rank = len(im_size)
M = im_size[0] * 2**im_rank
nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True)
# Generate Trajectory
ktraj_ori = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
# Have a random signal
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64))
kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)[0]
signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64))
kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)
ktraj_noise = np.copy(ktraj_ori)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi)
ktraj = tf.Variable(ktraj_noise)
with tf.GradientTape(persistent=True) as g:
kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0]
kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)
A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False)
kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1))))
kdata_ndft = tf.matmul(tf.reshape(signal, (batch_size, 1, tf.reduce_prod(im_size))), tf.transpose(A, [0, 2, 1]))
loss_nufft = tf.math.reduce_mean(tf.abs(kdata - kdata_nufft)**2)
loss_ndft = tf.math.reduce_mean(tf.abs(kdata - kdata_ndft)**2)

tf_test = tf.test.TestCase()
# Test if the NUFFT and NDFT operation is same
tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-3)
tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=2e-3)

# Test gradients with respect to kdata
gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0]
gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0]
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=5e-3)
tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3)

# Test gradients with respect to trajectory location
gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0]
gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0]
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=5e-3)
tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3)

# Test gradients in chain rule with respect to ktraj
gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0]
Expand Down

0 comments on commit 9afb339

Please sign in to comment.