Skip to content

Commit

Permalink
jacobian generator raises an error when using batch norm in training …
Browse files Browse the repository at this point in the history
…mode instead of silently failing
  • Loading branch information
tfjgeorge committed Feb 10, 2021
1 parent 555aefb commit 6c3aee0
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
42 changes: 34 additions & 8 deletions nngeometry/generator/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,12 @@ def implicit_Jv(self, v):

return FVector(vector_repr=Jv)

def _check_bn_training(self, mod):
# check that BN layers are in eval mode
if mod.training:
raise NotImplementedError('I don\'t know what to do with BN ' +
'layers in training mode')

def _add_hooks(self, hook_x, hook_gy, mods):
handles = []
for m in mods:
Expand Down Expand Up @@ -653,6 +659,7 @@ def _hook_compute_flat_grad(self, mod, grad_input, grad_output):
start_p:start_p+mod.bias.numel()] \
.add_(gy.sum(dim=(2, 3)))
elif mod_class == 'BatchNorm1d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training,
Expand All @@ -665,6 +672,7 @@ def _hook_compute_flat_grad(self, mod, grad_input, grad_output):
start_p:start_p+mod.bias.numel()] \
.add_(gy)
elif mod_class == 'BatchNorm2d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training,
Expand Down Expand Up @@ -711,6 +719,7 @@ def _hook_compute_diag(self, mod, grad_input, grad_output):
self.diag_m[start_p:start_p+mod.bias.numel()] \
.add_((gy.sum(dim=(2, 3))**2).sum(dim=0))
elif mod_class == 'BatchNorm1d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
Expand All @@ -720,6 +729,7 @@ def _hook_compute_diag(self, mod, grad_input, grad_output):
self.diag_m[start_p: start_p+mod.bias.numel()] \
.add_((gy**2).sum(dim=0))
elif mod_class == 'BatchNorm2d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
Expand Down Expand Up @@ -790,13 +800,15 @@ def _hook_compute_layer_blocks(self, mod, grad_input, grad_output):
gw = torch.cat([gw, gy.sum(dim=(2, 3)).view(bs, -1)], dim=1)
block.add_(torch.mm(gw.t(), gw))
elif mod_class == 'BatchNorm1d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
gw = gy * x_normalized
gw = torch.cat([gw, gy], dim=1)
block.add_(torch.mm(gw.t(), gw))
elif mod_class == 'BatchNorm2d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
Expand Down Expand Up @@ -891,10 +903,15 @@ def _hook_kxy(self, mod, grad_input, grad_output):
torch.mm(gy_inner.sum(dim=(2, 3)),
gy_outer.sum(dim=(2, 3)).t())
elif mod_class == 'BatchNorm1d':
x_norm_inner = F.batch_norm(x_inner, None, None, None,
None, True)
x_norm_outer = F.batch_norm(x_outer, None, None, None,
None, True)
self._check_bn_training(mod)
x_norm_inner = F.batch_norm(x_inner, mod.running_mean,
mod.running_var,
None, None, mod.training,
momentum=0.)
x_norm_outer = F.batch_norm(x_outer, mod.running_mean,
mod.running_var,
None, None, mod.training,
momentum=0.)
indiv_gw_inner = x_norm_inner * gy_inner
indiv_gw_outer = x_norm_outer * gy_outer
self.G[self.i_output_inner,
Expand All @@ -908,10 +925,15 @@ def _hook_kxy(self, mod, grad_input, grad_output):
self.e_outer:self.e_outer+bs_outer] += \
torch.mm(gy_inner, gy_outer.t())
elif mod_class == 'BatchNorm2d':
x_norm_inner = F.batch_norm(x_inner, None, None, None,
None, True)
x_norm_outer = F.batch_norm(x_outer, None, None, None,
None, True)
self._check_bn_training(mod)
x_norm_inner = F.batch_norm(x_inner, mod.running_mean,
mod.running_var,
None, None, mod.training,
momentum=0.)
x_norm_outer = F.batch_norm(x_outer, mod.running_mean,
mod.running_var,
None, None, mod.training,
momentum=0.)
indiv_gw_inner = (x_norm_inner * gy_inner).sum(dim=(2, 3))
indiv_gw_outer = (x_norm_outer * gy_outer).sum(dim=(2, 3))
self.G[self.i_output_inner,
Expand Down Expand Up @@ -1014,6 +1036,7 @@ def _hook_compute_Jv(self, mod, grad_input, grad_output):
self._Jv[self.i_output, self.start:self.start+bs].add_(
torch.mv(gy.sum(dim=(2, 3)), v_bias))
elif mod_class == 'BatchNorm1d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training,
Expand All @@ -1023,6 +1046,7 @@ def _hook_compute_Jv(self, mod, grad_input, grad_output):
self._Jv[self.i_output, self.start:self.start+bs].add_(
torch.mv(gy.contiguous(), v_bias))
elif mod_class == 'BatchNorm2d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training,
Expand Down Expand Up @@ -1057,12 +1081,14 @@ def _hook_compute_trace(self, mod, grad_input, grad_output):
if mod.bias is not None:
self._trace += (gy.sum(dim=(2, 3))**2).sum()
elif mod_class == 'BatchNorm1d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
self._trace += (gy**2 * x_normalized**2).sum()
self._trace += (gy**2).sum()
elif mod_class == 'BatchNorm2d':
self._check_bn_training(mod)
x_normalized = F.batch_norm(x, mod.running_mean,
mod.running_var,
None, None, mod.training)
Expand Down
7 changes: 7 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def get_linear_fc_task():
shuffle=False)
net = LinearFCNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down Expand Up @@ -128,6 +129,7 @@ def get_linear_conv_task():
shuffle=False)
net = LinearConvNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down Expand Up @@ -162,6 +164,7 @@ def get_batchnorm_fc_linear_task():
shuffle=False)
net = BatchNormFCLinearNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down Expand Up @@ -203,6 +206,7 @@ def get_batchnorm_conv_linear_task():
shuffle=False)
net = BatchNormConvLinearNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down Expand Up @@ -255,6 +259,7 @@ def get_batchnorm_nonlinear_task():
shuffle=False)
net = BatchNormNonLinearNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand All @@ -280,6 +285,7 @@ def get_fullyconnect_task(normalization='none'):
shuffle=False)
net = FCNet(out_size=3, normalization=normalization)
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand All @@ -302,6 +308,7 @@ def get_conv_task(normalization='none'):
shuffle=False)
net = ConvNet(normalization=normalization)
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down
38 changes: 22 additions & 16 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_output_vector(loader, function):
def test_jacobian_pushforward_dense_linear():
for get_task in linear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -131,7 +130,6 @@ def test_jacobian_fdense_vs_pullback():
print(get_task)
for centering in [True, False]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -169,7 +167,6 @@ def test_jacobian_pdense_vs_pushforward():
for get_task in linear_tasks + nonlinear_tasks:
for centering in [True, False]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -209,7 +206,6 @@ def test_jacobian_pdense():
for get_task in nonlinear_tasks:
for centering in [True, False]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -254,7 +250,6 @@ def test_jacobian_pdense():

# Test add, sub, rmul
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand All @@ -276,7 +271,6 @@ def test_jacobian_pdense():
def test_jacobian_pdiag_vs_pdense():
for get_task in nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -341,7 +335,6 @@ def test_jacobian_pdiag_vs_pdense():

# Test add, sub, rmul
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand All @@ -362,7 +355,6 @@ def test_jacobian_pdiag_vs_pdense():
def test_jacobian_pblockdiag_vs_pdense():
for get_task in linear_tasks + nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -392,7 +384,6 @@ def test_jacobian_pblockdiag_vs_pdense():
def test_jacobian_pblockdiag():
for get_task in nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -445,7 +436,6 @@ def test_jacobian_pblockdiag():

# Test add, sub, rmul
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand All @@ -468,7 +458,6 @@ def test_jacobian_pblockdiag():
def test_jacobian_pimplicit_vs_pdense():
for get_task in linear_tasks + nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -508,7 +497,6 @@ def test_jacobian_pimplicit_vs_pdense():
def test_jacobian_plowrank_vs_pdense():
for get_task in nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand All @@ -527,7 +515,6 @@ def test_jacobian_plowrank_vs_pdense():
def test_jacobian_plowrank():
for get_task in nonlinear_tasks:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -573,7 +560,6 @@ def test_jacobian_plowrank():
def test_jacobian_pquasidiag_vs_pdense():
for get_task in [get_conv_task, get_fullyconnect_task]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -626,7 +612,6 @@ def test_jacobian_pquasidiag_vs_pdense():
def test_jacobian_pquasidiag():
for get_task in [get_conv_task, get_fullyconnect_task]:
loader, lc, parameters, model, function, n_output = get_task()
model.train()
generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
Expand Down Expand Up @@ -655,4 +640,25 @@ def test_jacobian_pquasidiag():
regul = 1e-8
v_back = PMat_qd.solve(mv + regul * v, regul=regul)
check_tensors(v.get_flat_representation(),
v_back.get_flat_representation())
v_back.get_flat_representation())

def test_bn_eval_mode():
for get_task in [get_batchnorm_fc_linear_task,
get_batchnorm_conv_linear_task]:
loader, lc, parameters, model, function, n_output = get_task()

generator = Jacobian(layer_collection=lc,
model=model,
loader=loader,
function=function,
n_output=n_output)

model.eval()
FMat_dense = FMatDense(generator)

model.train()
with pytest.raises(RuntimeError):
FMat_dense = FMatDense(generator)



2 changes: 2 additions & 0 deletions tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_fullyconnect_kfac_task(bs=300):

net = Net(in_size=18*18)
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down Expand Up @@ -118,6 +119,7 @@ def get_convnet_kfc_task(bs=300):
shuffle=False)
net = ConvNet()
net.to(device)
net.eval()

def output_fn(input, target):
return net(to_device(input))
Expand Down

0 comments on commit 6c3aee0

Please sign in to comment.