From beaaeda3119da085549d77285cf3bb8d1b614848 Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Tue, 6 Jul 2021 14:29:36 +0200 Subject: [PATCH 01/11] [BUG] Reproduce bug described in #5 --- tests/test_bugs/test_issue5.py | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/test_bugs/test_issue5.py diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py new file mode 100644 index 0000000..9fb6841 --- /dev/null +++ b/tests/test_bugs/test_issue5.py @@ -0,0 +1,55 @@ +"""Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" + +from backpack import extend +from torch import manual_seed, rand +from torch.nn import Flatten, Linear, MSELoss, Sequential +from torch.optim import Adam + +from cockpit import Cockpit +from cockpit.quantities import Alpha, GradHist1d +from cockpit.utils.schedules import linear + + +def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): + """If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``. + + But if an extension that uses ``BatchGradTransformsHook`` is used at the same time, + it will delete the ``grad_batch`` attribute during the backward pass. Consequently, + ``Alpha`` cannot access the attribute anymore. This leads to the error. + """ + manual_seed(0) + + N, D_in, D_out = 2, 3, 1 + model = extend(Sequential(Flatten(), Linear(D_in, D_out))) + + opt_not_sgd = Adam(model.parameters(), lr=1e-3) + loss_fn = extend(MSELoss(reduction="mean")) + individual_loss_fn = MSELoss(reduction="none") + + on_first = linear(1) + alpha = Alpha(on_first) + uses_BatchGradTransformsHook = GradHist1d(on_first) + + cockpit = Cockpit( + model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook] + ) + + global_step = 0 + inputs, labels = rand(N, D_in), rand(N, D_out) + + # forward pass + outputs = model(inputs) + loss = loss_fn(outputs, labels) + losses = individual_loss_fn(outputs, labels) + + # backward pass + with cockpit( + global_step, + info={ + "batch_size": N, + "individual_losses": losses, + "loss": loss, + "optimizer": opt_not_sgd, + }, + ): + loss.backward(create_graph=cockpit.create_graph(global_step)) From 8b7cc0009656ac2f59406dbbf618a8c2bbdd1d8a Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 14:38:46 +0200 Subject: [PATCH 02/11] [REF] Remove dead code --- cockpit/cockpit.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/cockpit/cockpit.py b/cockpit/cockpit.py index 9693c0a..6a154e2 100644 --- a/cockpit/cockpit.py +++ b/cockpit/cockpit.py @@ -219,16 +219,6 @@ def _free_backpack_buffers(self, global_step, protected_savefields, verbose=Fals except AttributeError: pass - @staticmethod - def _remove_module_io(module): - io_fields = ["input0", "output"] - - for field in io_fields: - try: - delattr(module, field) - except AttributeError: - pass - def log( self, global_step, From 9cf4b96278c0d815b1eb7a747357fcd57bc3b323 Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 15:28:45 +0200 Subject: [PATCH 03/11] [FIX] Catch AssertionError in bug reproduction --- tests/test_bugs/test_issue5.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py index 9fb6841..10f6470 100644 --- a/tests/test_bugs/test_issue5.py +++ b/tests/test_bugs/test_issue5.py @@ -1,6 +1,7 @@ """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" from backpack import extend +from pytest import raises from torch import manual_seed, rand from torch.nn import Flatten, Linear, MSELoss, Sequential from torch.optim import Adam @@ -43,13 +44,14 @@ def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): losses = individual_loss_fn(outputs, labels) # backward pass - with cockpit( - global_step, - info={ - "batch_size": N, - "individual_losses": losses, - "loss": loss, - "optimizer": opt_not_sgd, - }, - ): - loss.backward(create_graph=cockpit.create_graph(global_step)) + with raises(AttributeError): + with cockpit( + global_step, + info={ + "batch_size": N, + "individual_losses": losses, + "loss": loss, + "optimizer": opt_not_sgd, + }, + ): + loss.backward(create_graph=cockpit.create_graph(global_step)) From 13a05749e3b9edb7a68d2a413dd6aaf7c27048f9 Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 15:32:16 +0200 Subject: [PATCH 04/11] [ADD] Allow quantities to protect buffers --- cockpit/cockpit.py | 17 +++++++++++++++++ cockpit/context.py | 4 +++- cockpit/quantities/quantity.py | 14 ++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/cockpit/cockpit.py b/cockpit/cockpit.py index 6a154e2..d9aa649 100644 --- a/cockpit/cockpit.py +++ b/cockpit/cockpit.py @@ -2,6 +2,7 @@ import os from collections import defaultdict +from typing import Set import json_tricks from backpack import disable @@ -130,6 +131,22 @@ def _get_extension_hook(self, global_step): return hook + def _get_protected_savefields(self, global_step: int) -> Set[str]: + """Return names of protected BackPACK buffers. + + Args: + global_step: Current iteration number. + + Returns: + List of protected buffers. + """ + protected = set() + + for q in self.quantities: + protected.update(q.protected_savefields(global_step)) + + return protected + def __call__(self, global_step, *exts, info=None, debug=False): """Returns the backpack extensions that should be used in this iteration. diff --git a/cockpit/context.py b/cockpit/context.py index b2a9d4d..3d1f43d 100644 --- a/cockpit/context.py +++ b/cockpit/context.py @@ -99,7 +99,9 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): self.cp = cp self.global_step = global_step - self.protected_savefields = [e.savefield for e in custom_exts] + self.protected_savefields = set( + e.savefield for e in custom_exts + ) + cp._get_protected_savefields(global_step) # choose context ext = cp._get_extensions(global_step, custom_exts=custom_exts) diff --git a/cockpit/quantities/quantity.py b/cockpit/quantities/quantity.py index 775134c..6baebdb 100644 --- a/cockpit/quantities/quantity.py +++ b/cockpit/quantities/quantity.py @@ -1,6 +1,7 @@ """Base class for a tracked quantity.""" from collections import defaultdict +from typing import Set import numpy import torch @@ -74,6 +75,19 @@ def extension_hooks(self, global_step): """ return [] + def protected_savefields(self, global_step: int) -> Set[str]: + """Return list of protected BackPACK buffers at the current step. + + Protected buffers will not be freed during a backward pass. + + Args: + global_step: The current iteration number. + + Returns: + Set of protected savefields. + """ + return set() + def track(self, global_step, params, batch_loss): """Perform scheduled computations and store result. From 6b88c3264a2e0eeefbbcbd17d84f294aad528d7f Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 15:46:15 +0200 Subject: [PATCH 05/11] Extract deletion of 'grad_batch_transform' into separate hook --- cockpit/context.py | 55 ++++++++++++++++++++++++-- cockpit/quantities/hooks/cleanup.py | 30 ++++++++++++++ cockpit/quantities/utils_transforms.py | 12 ++---- 3 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 cockpit/quantities/hooks/cleanup.py diff --git a/cockpit/context.py b/cockpit/context.py index 3d1f43d..a63cb8c 100644 --- a/cockpit/context.py +++ b/cockpit/context.py @@ -1,9 +1,14 @@ """Cockpit Context.""" import warnings +from typing import Any, Callable, Union from backpack import backpack, disable from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from torch.nn import Module + +from cockpit.quantities.hooks.cleanup import CleanupHook +from cockpit.quantities.utils_transforms import BatchGradTransformsHook class CockpitCTX: @@ -99,13 +104,24 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): self.cp = cp self.global_step = global_step - self.protected_savefields = set( - e.savefield for e in custom_exts - ) + cp._get_protected_savefields(global_step) + self.protected_savefields = set(e.savefield for e in custom_exts) + self.protected_savefields.update(cp._get_protected_savefields(global_step)) # choose context ext = cp._get_extensions(global_step, custom_exts=custom_exts) - ext_hook = cp._get_extension_hook(global_step) + compute_hook = cp._get_extension_hook(global_step) + + # Delete 'grad_batch' during backprop if it's not protected + if ( + isinstance(compute_hook, BatchGradTransformsHook) + and "grad_batch" not in self.protected_savefields + ): + # TODO Implement all quantities with hooks and specify protected savefields + # Remove unprotected buffers in this cleanup hook during backpropagation + cleanup_hook = CleanupHook(set(["grad_batch"])) + ext_hook = self._combine_hooks(compute_hook, cleanup_hook) + else: + ext_hook = compute_hook save_memory = cp.BACKPACK_CONV_SAVE_MEMORY @@ -124,6 +140,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): print(f" ↪Hooks : {ext_hook}") print(f" ↪Create graph: {cp.create_graph(global_step)}") print(f" ↪Save memory : {save_memory}") + print(f" ↪Protect : {self.protected_savefields}") def __enter__(self): """Enter cockpit context(s).""" @@ -138,3 +155,33 @@ def __exit__(self, type, value, traceback): self.cp.track(self.global_step, protected_savefields=self.protected_savefields) CockpitCTX.erase() + + @staticmethod + def _combine_hooks( + *hooks: Union[Callable[[Module], Any], None] + ) -> Union[Callable[[Module], None], None]: + """Combine multiple extension hooks into a single one. + + Args: + hooks: List of extension hooks to be combined. + + Returns: + Merged hook. ``None`` if all passed hooks were ``None``. + """ + non_empty_hooks = [h for h in hooks if h is not None] + + if non_empty_hooks: + + def hook(module: Module): + """Sequentially execute all hooks on the module. + + Args: + module: Module to run the extension hook on. + """ + for h in non_empty_hooks: + h(module) + + return hook + + else: + return None diff --git a/cockpit/quantities/hooks/cleanup.py b/cockpit/quantities/hooks/cleanup.py new file mode 100644 index 0000000..8c6f99c --- /dev/null +++ b/cockpit/quantities/hooks/cleanup.py @@ -0,0 +1,30 @@ +"""Contains hook that deletes BackPACK buffers during backpropagation.""" + +from typing import Set + +from torch import Tensor + +from cockpit.quantities.hooks.base import ParameterExtensionHook + + +class CleanupHook(ParameterExtensionHook): + """Deletes specified BackPACK buffers during backpropagation.""" + + def __init__(self, delete_savefields: Set[str]): + """Store savefields to be deleted in the backward pass. + + Args: + delete_savefields: Name of buffers to delete. + """ + super().__init__() + self._delete_savefields = delete_savefields + + def param_hook(self, param: Tensor): + """Delete BackPACK buffers in parameter. + + Args: + param: Trainable parameter which hosts BackPACK quantities. + """ + for savefield in self._delete_savefields: + if hasattr(param, savefield): + delattr(param, savefield) diff --git a/cockpit/quantities/utils_transforms.py b/cockpit/quantities/utils_transforms.py index 3c0e67d..e0517ad 100644 --- a/cockpit/quantities/utils_transforms.py +++ b/cockpit/quantities/utils_transforms.py @@ -3,7 +3,7 @@ import string import weakref -from torch import einsum +from torch import Tensor, einsum from cockpit.quantities.hooks.base import ParameterExtensionHook @@ -67,20 +67,14 @@ def __init__(self, transforms, savefield=None): super().__init__(savefield=savefield) self._transforms = transforms - def param_hook(self, param): + def param_hook(self, param: Tensor): """Execute all transformations and store results as dictionary in the parameter. - Delete individual gradients in the parameter. - Args: - param (torch.Tensor): Trainable parameter which hosts BackPACK quantities. + param: Trainable parameter which hosts BackPACK quantities. """ param.grad_batch._param_weakref = weakref.ref(param) # TODO Delete after backward pass with Cockpit param.grad_batch_transforms = { key: func(param.grad_batch) for key, func in self._transforms.items() } - # TODO Delete with a separate hook that also knows which savefield should be - # kept because it's protected by the user. See - # https://github.com/f-dangel/cockpit-paper/issues/197 - del param.grad_batch From 16abca47411a9f5ab76197c77159bf0816540c80 Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 16:02:44 +0200 Subject: [PATCH 06/11] [FIX] flake8 --- cockpit/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cockpit/context.py b/cockpit/context.py index a63cb8c..fcaf796 100644 --- a/cockpit/context.py +++ b/cockpit/context.py @@ -104,7 +104,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): self.cp = cp self.global_step = global_step - self.protected_savefields = set(e.savefield for e in custom_exts) + self.protected_savefields = {e.savefield for e in custom_exts} self.protected_savefields.update(cp._get_protected_savefields(global_step)) # choose context @@ -118,7 +118,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): ): # TODO Implement all quantities with hooks and specify protected savefields # Remove unprotected buffers in this cleanup hook during backpropagation - cleanup_hook = CleanupHook(set(["grad_batch"])) + cleanup_hook = CleanupHook({"grad_batch"}) ext_hook = self._combine_hooks(compute_hook, cleanup_hook) else: ext_hook = compute_hook From 6803a8a4590849f24f62f729d857a689321301fd Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 16:10:21 +0200 Subject: [PATCH 07/11] [FIX] Protect `'batch_grad'` if step is unknown --- cockpit/quantities/alpha.py | 9 +++++++++ tests/test_bugs/test_issue5.py | 21 ++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/cockpit/quantities/alpha.py b/cockpit/quantities/alpha.py index 5590ba2..5e39774 100644 --- a/cockpit/quantities/alpha.py +++ b/cockpit/quantities/alpha.py @@ -2,6 +2,7 @@ import itertools import warnings +from typing import Set import numpy as np import torch @@ -89,6 +90,14 @@ def extension_hooks(self, global_step): return hooks + def protected_savefields(self, global_step: int) -> Set[str]: + protected = set() + if self.is_start(global_step) or self.is_end(global_step): + if not self.__projection_with_backpack(global_step): + protected.update(["grad_batch"]) + + return protected + def is_start(self, global_step): """Return whether current iteration is start point. diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py index 10f6470..c0a61d1 100644 --- a/tests/test_bugs/test_issue5.py +++ b/tests/test_bugs/test_issue5.py @@ -44,14 +44,13 @@ def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): losses = individual_loss_fn(outputs, labels) # backward pass - with raises(AttributeError): - with cockpit( - global_step, - info={ - "batch_size": N, - "individual_losses": losses, - "loss": loss, - "optimizer": opt_not_sgd, - }, - ): - loss.backward(create_graph=cockpit.create_graph(global_step)) + with cockpit( + global_step, + info={ + "batch_size": N, + "individual_losses": losses, + "loss": loss, + "optimizer": opt_not_sgd, + }, + ): + loss.backward(create_graph=cockpit.create_graph(global_step)) From 431f67f6bef19c8601446ab8d24a4fb9a90f717f Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 16:14:20 +0200 Subject: [PATCH 08/11] [DEL] Remove unused import --- tests/test_bugs/test_issue5.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py index c0a61d1..9fb6841 100644 --- a/tests/test_bugs/test_issue5.py +++ b/tests/test_bugs/test_issue5.py @@ -1,7 +1,6 @@ """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" from backpack import extend -from pytest import raises from torch import manual_seed, rand from torch.nn import Flatten, Linear, MSELoss, Sequential from torch.optim import Adam From 732a86a67c403b9331bd921186b8ca6862a3411f Mon Sep 17 00:00:00 2001 From: Felix Dangel <felix.dangel@tuebingen.mpg.de> Date: Fri, 15 Oct 2021 16:19:21 +0200 Subject: [PATCH 09/11] [DOC] Ignore missing docstring (same as parent) --- cockpit/quantities/alpha.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cockpit/quantities/alpha.py b/cockpit/quantities/alpha.py index 5e39774..474a0df 100644 --- a/cockpit/quantities/alpha.py +++ b/cockpit/quantities/alpha.py @@ -90,7 +90,7 @@ def extension_hooks(self, global_step): return hooks - def protected_savefields(self, global_step: int) -> Set[str]: + def protected_savefields(self, global_step: int) -> Set[str]: # noqa: D102 protected = set() if self.is_start(global_step) or self.is_end(global_step): if not self.__projection_with_backpack(global_step): From be7f567542e3e3e4b75fbf5ff4d673d921ab1dcd Mon Sep 17 00:00:00 2001 From: Frank Schneider <frank.stefan.schneider@gmail.com> Date: Wed, 20 Oct 2021 15:11:41 +0200 Subject: [PATCH 10/11] [DOC] Link to cockpit-experiments repo --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8b9ffbd..05a5310 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ <p align="center"> <a href="#installation">Installation</a> • <a href="https://cockpit.readthedocs.io/">Docs</a> • + <a href="https://github.com/fsschneider/cockpit-experiments">Experiments</a> • <a href="#license">License</a> • <a href="#citation">Citation</a> </p> @@ -60,7 +61,7 @@ The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on <!-- Experiments --> ## Experiments -To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604). +To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604). <!-- LICENSE --> ## License From c00852e8bac3561bb7971a751588d52e5040fb78 Mon Sep 17 00:00:00 2001 From: Frank Schneider <frank.stefan.schneider@gmail.com> Date: Tue, 26 Oct 2021 13:39:33 +0200 Subject: [PATCH 11/11] [DOC] Changelog for v1.0.2 --- CHANGELOG.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9b8a6e..428cb40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## [1.0.2] - 2021-10-26 + +### Added + +- Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper. + +### Fixed + +- Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)] + ## [1.0.1] - 2021-10-13 From this version on, `cockpit` will be available as `cockpit-for-pytorch` on @@ -30,6 +40,7 @@ PyPI. - First public release version of **Cockpit**. -[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.1...HEAD +[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD +[1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2 [1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1 [1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0