From 6c3aee017567f0837a07acb169464c511f257f9b Mon Sep 17 00:00:00 2001 From: Thomas George Date: Tue, 9 Feb 2021 22:25:38 -0500 Subject: [PATCH] jacobian generator raises an error when using batch norm in training mode instead of silently failing --- nngeometry/generator/jacobian.py | 42 ++++++++++++++++++++++++++------ tests/tasks.py | 7 ++++++ tests/test_jacobian.py | 38 +++++++++++++++++------------ tests/test_jacobian_kfac.py | 2 ++ 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/nngeometry/generator/jacobian.py b/nngeometry/generator/jacobian.py index 0017b8a..22b36bb 100644 --- a/nngeometry/generator/jacobian.py +++ b/nngeometry/generator/jacobian.py @@ -610,6 +610,12 @@ def implicit_Jv(self, v): return FVector(vector_repr=Jv) + def _check_bn_training(self, mod): + # check that BN layers are in eval mode + if mod.training: + raise NotImplementedError('I don\'t know what to do with BN ' + + 'layers in training mode') + def _add_hooks(self, hook_x, hook_gy, mods): handles = [] for m in mods: @@ -653,6 +659,7 @@ def _hook_compute_flat_grad(self, mod, grad_input, grad_output): start_p:start_p+mod.bias.numel()] \ .add_(gy.sum(dim=(2, 3))) elif mod_class == 'BatchNorm1d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training, @@ -665,6 +672,7 @@ def _hook_compute_flat_grad(self, mod, grad_input, grad_output): start_p:start_p+mod.bias.numel()] \ .add_(gy) elif mod_class == 'BatchNorm2d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training, @@ -711,6 +719,7 @@ def _hook_compute_diag(self, mod, grad_input, grad_output): self.diag_m[start_p:start_p+mod.bias.numel()] \ .add_((gy.sum(dim=(2, 3))**2).sum(dim=0)) elif mod_class == 'BatchNorm1d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) @@ -720,6 +729,7 @@ def _hook_compute_diag(self, mod, grad_input, grad_output): self.diag_m[start_p: start_p+mod.bias.numel()] \ .add_((gy**2).sum(dim=0)) elif mod_class == 'BatchNorm2d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) @@ -790,6 +800,7 @@ def _hook_compute_layer_blocks(self, mod, grad_input, grad_output): gw = torch.cat([gw, gy.sum(dim=(2, 3)).view(bs, -1)], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'BatchNorm1d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) @@ -797,6 +808,7 @@ def _hook_compute_layer_blocks(self, mod, grad_input, grad_output): gw = torch.cat([gw, gy], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'BatchNorm2d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) @@ -891,10 +903,15 @@ def _hook_kxy(self, mod, grad_input, grad_output): torch.mm(gy_inner.sum(dim=(2, 3)), gy_outer.sum(dim=(2, 3)).t()) elif mod_class == 'BatchNorm1d': - x_norm_inner = F.batch_norm(x_inner, None, None, None, - None, True) - x_norm_outer = F.batch_norm(x_outer, None, None, None, - None, True) + self._check_bn_training(mod) + x_norm_inner = F.batch_norm(x_inner, mod.running_mean, + mod.running_var, + None, None, mod.training, + momentum=0.) + x_norm_outer = F.batch_norm(x_outer, mod.running_mean, + mod.running_var, + None, None, mod.training, + momentum=0.) indiv_gw_inner = x_norm_inner * gy_inner indiv_gw_outer = x_norm_outer * gy_outer self.G[self.i_output_inner, @@ -908,10 +925,15 @@ def _hook_kxy(self, mod, grad_input, grad_output): self.e_outer:self.e_outer+bs_outer] += \ torch.mm(gy_inner, gy_outer.t()) elif mod_class == 'BatchNorm2d': - x_norm_inner = F.batch_norm(x_inner, None, None, None, - None, True) - x_norm_outer = F.batch_norm(x_outer, None, None, None, - None, True) + self._check_bn_training(mod) + x_norm_inner = F.batch_norm(x_inner, mod.running_mean, + mod.running_var, + None, None, mod.training, + momentum=0.) + x_norm_outer = F.batch_norm(x_outer, mod.running_mean, + mod.running_var, + None, None, mod.training, + momentum=0.) indiv_gw_inner = (x_norm_inner * gy_inner).sum(dim=(2, 3)) indiv_gw_outer = (x_norm_outer * gy_outer).sum(dim=(2, 3)) self.G[self.i_output_inner, @@ -1014,6 +1036,7 @@ def _hook_compute_Jv(self, mod, grad_input, grad_output): self._Jv[self.i_output, self.start:self.start+bs].add_( torch.mv(gy.sum(dim=(2, 3)), v_bias)) elif mod_class == 'BatchNorm1d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training, @@ -1023,6 +1046,7 @@ def _hook_compute_Jv(self, mod, grad_input, grad_output): self._Jv[self.i_output, self.start:self.start+bs].add_( torch.mv(gy.contiguous(), v_bias)) elif mod_class == 'BatchNorm2d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training, @@ -1057,12 +1081,14 @@ def _hook_compute_trace(self, mod, grad_input, grad_output): if mod.bias is not None: self._trace += (gy.sum(dim=(2, 3))**2).sum() elif mod_class == 'BatchNorm1d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) self._trace += (gy**2 * x_normalized**2).sum() self._trace += (gy**2).sum() elif mod_class == 'BatchNorm2d': + self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) diff --git a/tests/tasks.py b/tests/tasks.py index 51591cb..8c737bf 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -96,6 +96,7 @@ def get_linear_fc_task(): shuffle=False) net = LinearFCNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -128,6 +129,7 @@ def get_linear_conv_task(): shuffle=False) net = LinearConvNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -162,6 +164,7 @@ def get_batchnorm_fc_linear_task(): shuffle=False) net = BatchNormFCLinearNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -203,6 +206,7 @@ def get_batchnorm_conv_linear_task(): shuffle=False) net = BatchNormConvLinearNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -255,6 +259,7 @@ def get_batchnorm_nonlinear_task(): shuffle=False) net = BatchNormNonLinearNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -280,6 +285,7 @@ def get_fullyconnect_task(normalization='none'): shuffle=False) net = FCNet(out_size=3, normalization=normalization) net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -302,6 +308,7 @@ def get_conv_task(normalization='none'): shuffle=False) net = ConvNet(normalization=normalization) net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 5a51df3..185bea9 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -48,7 +48,6 @@ def get_output_vector(loader, function): def test_jacobian_pushforward_dense_linear(): for get_task in linear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -131,7 +130,6 @@ def test_jacobian_fdense_vs_pullback(): print(get_task) for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -169,7 +167,6 @@ def test_jacobian_pdense_vs_pushforward(): for get_task in linear_tasks + nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -209,7 +206,6 @@ def test_jacobian_pdense(): for get_task in nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -254,7 +250,6 @@ def test_jacobian_pdense(): # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -276,7 +271,6 @@ def test_jacobian_pdense(): def test_jacobian_pdiag_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -341,7 +335,6 @@ def test_jacobian_pdiag_vs_pdense(): # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -362,7 +355,6 @@ def test_jacobian_pdiag_vs_pdense(): def test_jacobian_pblockdiag_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -392,7 +384,6 @@ def test_jacobian_pblockdiag_vs_pdense(): def test_jacobian_pblockdiag(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -445,7 +436,6 @@ def test_jacobian_pblockdiag(): # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -468,7 +458,6 @@ def test_jacobian_pblockdiag(): def test_jacobian_pimplicit_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -508,7 +497,6 @@ def test_jacobian_pimplicit_vs_pdense(): def test_jacobian_plowrank_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -527,7 +515,6 @@ def test_jacobian_plowrank_vs_pdense(): def test_jacobian_plowrank(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -573,7 +560,6 @@ def test_jacobian_plowrank(): def test_jacobian_pquasidiag_vs_pdense(): for get_task in [get_conv_task, get_fullyconnect_task]: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -626,7 +612,6 @@ def test_jacobian_pquasidiag_vs_pdense(): def test_jacobian_pquasidiag(): for get_task in [get_conv_task, get_fullyconnect_task]: loader, lc, parameters, model, function, n_output = get_task() - model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, @@ -655,4 +640,25 @@ def test_jacobian_pquasidiag(): regul = 1e-8 v_back = PMat_qd.solve(mv + regul * v, regul=regul) check_tensors(v.get_flat_representation(), - v_back.get_flat_representation()) \ No newline at end of file + v_back.get_flat_representation()) + +def test_bn_eval_mode(): + for get_task in [get_batchnorm_fc_linear_task, + get_batchnorm_conv_linear_task]: + loader, lc, parameters, model, function, n_output = get_task() + + generator = Jacobian(layer_collection=lc, + model=model, + loader=loader, + function=function, + n_output=n_output) + + model.eval() + FMat_dense = FMatDense(generator) + + model.train() + with pytest.raises(RuntimeError): + FMat_dense = FMatDense(generator) + + + diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index 118cffc..1a333fd 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -87,6 +87,7 @@ def get_fullyconnect_kfac_task(bs=300): net = Net(in_size=18*18) net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input)) @@ -118,6 +119,7 @@ def get_convnet_kfc_task(bs=300): shuffle=False) net = ConvNet() net.to(device) + net.eval() def output_fn(input, target): return net(to_device(input))