From beaaeda3119da085549d77285cf3bb8d1b614848 Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Tue, 6 Jul 2021 14:29:36 +0200
Subject: [PATCH 01/11] [BUG] Reproduce bug described in #5

---
 tests/test_bugs/test_issue5.py | 55 ++++++++++++++++++++++++++++++++++
 1 file changed, 55 insertions(+)
 create mode 100644 tests/test_bugs/test_issue5.py

diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py
new file mode 100644
index 0000000..9fb6841
--- /dev/null
+++ b/tests/test_bugs/test_issue5.py
@@ -0,0 +1,55 @@
+"""Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5."""
+
+from backpack import extend
+from torch import manual_seed, rand
+from torch.nn import Flatten, Linear, MSELoss, Sequential
+from torch.optim import Adam
+
+from cockpit import Cockpit
+from cockpit.quantities import Alpha, GradHist1d
+from cockpit.utils.schedules import linear
+
+
+def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha():
+    """If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``.
+
+    But if an extension that uses ``BatchGradTransformsHook`` is used at the same time,
+    it will delete the ``grad_batch`` attribute during the backward pass. Consequently,
+    ``Alpha`` cannot access the attribute anymore. This leads to the error.
+    """
+    manual_seed(0)
+
+    N, D_in, D_out = 2, 3, 1
+    model = extend(Sequential(Flatten(), Linear(D_in, D_out)))
+
+    opt_not_sgd = Adam(model.parameters(), lr=1e-3)
+    loss_fn = extend(MSELoss(reduction="mean"))
+    individual_loss_fn = MSELoss(reduction="none")
+
+    on_first = linear(1)
+    alpha = Alpha(on_first)
+    uses_BatchGradTransformsHook = GradHist1d(on_first)
+
+    cockpit = Cockpit(
+        model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook]
+    )
+
+    global_step = 0
+    inputs, labels = rand(N, D_in), rand(N, D_out)
+
+    # forward pass
+    outputs = model(inputs)
+    loss = loss_fn(outputs, labels)
+    losses = individual_loss_fn(outputs, labels)
+
+    # backward pass
+    with cockpit(
+        global_step,
+        info={
+            "batch_size": N,
+            "individual_losses": losses,
+            "loss": loss,
+            "optimizer": opt_not_sgd,
+        },
+    ):
+        loss.backward(create_graph=cockpit.create_graph(global_step))

From 8b7cc0009656ac2f59406dbbf618a8c2bbdd1d8a Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 14:38:46 +0200
Subject: [PATCH 02/11] [REF] Remove dead code

---
 cockpit/cockpit.py | 10 ----------
 1 file changed, 10 deletions(-)

diff --git a/cockpit/cockpit.py b/cockpit/cockpit.py
index 9693c0a..6a154e2 100644
--- a/cockpit/cockpit.py
+++ b/cockpit/cockpit.py
@@ -219,16 +219,6 @@ def _free_backpack_buffers(self, global_step, protected_savefields, verbose=Fals
                 except AttributeError:
                     pass
 
-    @staticmethod
-    def _remove_module_io(module):
-        io_fields = ["input0", "output"]
-
-        for field in io_fields:
-            try:
-                delattr(module, field)
-            except AttributeError:
-                pass
-
     def log(
         self,
         global_step,

From 9cf4b96278c0d815b1eb7a747357fcd57bc3b323 Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 15:28:45 +0200
Subject: [PATCH 03/11] [FIX] Catch AssertionError in bug reproduction

---
 tests/test_bugs/test_issue5.py | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py
index 9fb6841..10f6470 100644
--- a/tests/test_bugs/test_issue5.py
+++ b/tests/test_bugs/test_issue5.py
@@ -1,6 +1,7 @@
 """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5."""
 
 from backpack import extend
+from pytest import raises
 from torch import manual_seed, rand
 from torch.nn import Flatten, Linear, MSELoss, Sequential
 from torch.optim import Adam
@@ -43,13 +44,14 @@ def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha():
     losses = individual_loss_fn(outputs, labels)
 
     # backward pass
-    with cockpit(
-        global_step,
-        info={
-            "batch_size": N,
-            "individual_losses": losses,
-            "loss": loss,
-            "optimizer": opt_not_sgd,
-        },
-    ):
-        loss.backward(create_graph=cockpit.create_graph(global_step))
+    with raises(AttributeError):
+        with cockpit(
+            global_step,
+            info={
+                "batch_size": N,
+                "individual_losses": losses,
+                "loss": loss,
+                "optimizer": opt_not_sgd,
+            },
+        ):
+            loss.backward(create_graph=cockpit.create_graph(global_step))

From 13a05749e3b9edb7a68d2a413dd6aaf7c27048f9 Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 15:32:16 +0200
Subject: [PATCH 04/11] [ADD] Allow quantities to protect buffers

---
 cockpit/cockpit.py             | 17 +++++++++++++++++
 cockpit/context.py             |  4 +++-
 cockpit/quantities/quantity.py | 14 ++++++++++++++
 3 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/cockpit/cockpit.py b/cockpit/cockpit.py
index 6a154e2..d9aa649 100644
--- a/cockpit/cockpit.py
+++ b/cockpit/cockpit.py
@@ -2,6 +2,7 @@
 
 import os
 from collections import defaultdict
+from typing import Set
 
 import json_tricks
 from backpack import disable
@@ -130,6 +131,22 @@ def _get_extension_hook(self, global_step):
 
         return hook
 
+    def _get_protected_savefields(self, global_step: int) -> Set[str]:
+        """Return names of protected BackPACK buffers.
+
+        Args:
+            global_step: Current iteration number.
+
+        Returns:
+            List of protected buffers.
+        """
+        protected = set()
+
+        for q in self.quantities:
+            protected.update(q.protected_savefields(global_step))
+
+        return protected
+
     def __call__(self, global_step, *exts, info=None, debug=False):
         """Returns the backpack extensions that should be used in this iteration.
 
diff --git a/cockpit/context.py b/cockpit/context.py
index b2a9d4d..3d1f43d 100644
--- a/cockpit/context.py
+++ b/cockpit/context.py
@@ -99,7 +99,9 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
         self.cp = cp
         self.global_step = global_step
 
-        self.protected_savefields = [e.savefield for e in custom_exts]
+        self.protected_savefields = set(
+            e.savefield for e in custom_exts
+        ) + cp._get_protected_savefields(global_step)
 
         # choose context
         ext = cp._get_extensions(global_step, custom_exts=custom_exts)
diff --git a/cockpit/quantities/quantity.py b/cockpit/quantities/quantity.py
index 775134c..6baebdb 100644
--- a/cockpit/quantities/quantity.py
+++ b/cockpit/quantities/quantity.py
@@ -1,6 +1,7 @@
 """Base class for a tracked quantity."""
 
 from collections import defaultdict
+from typing import Set
 
 import numpy
 import torch
@@ -74,6 +75,19 @@ def extension_hooks(self, global_step):
         """
         return []
 
+    def protected_savefields(self, global_step: int) -> Set[str]:
+        """Return list of protected BackPACK buffers at the current step.
+
+        Protected buffers will not be freed during a backward pass.
+
+        Args:
+            global_step: The current iteration number.
+
+        Returns:
+            Set of protected savefields.
+        """
+        return set()
+
     def track(self, global_step, params, batch_loss):
         """Perform scheduled computations and store result.
 

From 6b88c3264a2e0eeefbbcbd17d84f294aad528d7f Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 15:46:15 +0200
Subject: [PATCH 05/11] Extract deletion of 'grad_batch_transform' into
 separate hook

---
 cockpit/context.py                     | 55 ++++++++++++++++++++++++--
 cockpit/quantities/hooks/cleanup.py    | 30 ++++++++++++++
 cockpit/quantities/utils_transforms.py | 12 ++----
 3 files changed, 84 insertions(+), 13 deletions(-)
 create mode 100644 cockpit/quantities/hooks/cleanup.py

diff --git a/cockpit/context.py b/cockpit/context.py
index 3d1f43d..a63cb8c 100644
--- a/cockpit/context.py
+++ b/cockpit/context.py
@@ -1,9 +1,14 @@
 """Cockpit Context."""
 
 import warnings
+from typing import Any, Callable, Union
 
 from backpack import backpack, disable
 from backpack.core.derivatives.convnd import weight_jac_t_save_memory
+from torch.nn import Module
+
+from cockpit.quantities.hooks.cleanup import CleanupHook
+from cockpit.quantities.utils_transforms import BatchGradTransformsHook
 
 
 class CockpitCTX:
@@ -99,13 +104,24 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
         self.cp = cp
         self.global_step = global_step
 
-        self.protected_savefields = set(
-            e.savefield for e in custom_exts
-        ) + cp._get_protected_savefields(global_step)
+        self.protected_savefields = set(e.savefield for e in custom_exts)
+        self.protected_savefields.update(cp._get_protected_savefields(global_step))
 
         # choose context
         ext = cp._get_extensions(global_step, custom_exts=custom_exts)
-        ext_hook = cp._get_extension_hook(global_step)
+        compute_hook = cp._get_extension_hook(global_step)
+
+        # Delete 'grad_batch' during backprop if it's not protected
+        if (
+            isinstance(compute_hook, BatchGradTransformsHook)
+            and "grad_batch" not in self.protected_savefields
+        ):
+            # TODO Implement all quantities with hooks and specify protected savefields
+            # Remove unprotected buffers in this cleanup hook during backpropagation
+            cleanup_hook = CleanupHook(set(["grad_batch"]))
+            ext_hook = self._combine_hooks(compute_hook, cleanup_hook)
+        else:
+            ext_hook = compute_hook
 
         save_memory = cp.BACKPACK_CONV_SAVE_MEMORY
 
@@ -124,6 +140,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
             print(f" ↪Hooks       : {ext_hook}")
             print(f" ↪Create graph: {cp.create_graph(global_step)}")
             print(f" ↪Save memory : {save_memory}")
+            print(f" ↪Protect     : {self.protected_savefields}")
 
     def __enter__(self):
         """Enter cockpit context(s)."""
@@ -138,3 +155,33 @@ def __exit__(self, type, value, traceback):
         self.cp.track(self.global_step, protected_savefields=self.protected_savefields)
 
         CockpitCTX.erase()
+
+    @staticmethod
+    def _combine_hooks(
+        *hooks: Union[Callable[[Module], Any], None]
+    ) -> Union[Callable[[Module], None], None]:
+        """Combine multiple extension hooks into a single one.
+
+        Args:
+            hooks: List of extension hooks to be combined.
+
+        Returns:
+            Merged hook. ``None`` if all passed hooks were ``None``.
+        """
+        non_empty_hooks = [h for h in hooks if h is not None]
+
+        if non_empty_hooks:
+
+            def hook(module: Module):
+                """Sequentially execute all hooks on the module.
+
+                Args:
+                    module: Module to run the extension hook on.
+                """
+                for h in non_empty_hooks:
+                    h(module)
+
+            return hook
+
+        else:
+            return None
diff --git a/cockpit/quantities/hooks/cleanup.py b/cockpit/quantities/hooks/cleanup.py
new file mode 100644
index 0000000..8c6f99c
--- /dev/null
+++ b/cockpit/quantities/hooks/cleanup.py
@@ -0,0 +1,30 @@
+"""Contains hook that deletes BackPACK buffers during backpropagation."""
+
+from typing import Set
+
+from torch import Tensor
+
+from cockpit.quantities.hooks.base import ParameterExtensionHook
+
+
+class CleanupHook(ParameterExtensionHook):
+    """Deletes specified BackPACK buffers during backpropagation."""
+
+    def __init__(self, delete_savefields: Set[str]):
+        """Store savefields to be deleted in the backward pass.
+
+        Args:
+            delete_savefields: Name of buffers to delete.
+        """
+        super().__init__()
+        self._delete_savefields = delete_savefields
+
+    def param_hook(self, param: Tensor):
+        """Delete BackPACK buffers in parameter.
+
+        Args:
+            param: Trainable parameter which hosts BackPACK quantities.
+        """
+        for savefield in self._delete_savefields:
+            if hasattr(param, savefield):
+                delattr(param, savefield)
diff --git a/cockpit/quantities/utils_transforms.py b/cockpit/quantities/utils_transforms.py
index 3c0e67d..e0517ad 100644
--- a/cockpit/quantities/utils_transforms.py
+++ b/cockpit/quantities/utils_transforms.py
@@ -3,7 +3,7 @@
 import string
 import weakref
 
-from torch import einsum
+from torch import Tensor, einsum
 
 from cockpit.quantities.hooks.base import ParameterExtensionHook
 
@@ -67,20 +67,14 @@ def __init__(self, transforms, savefield=None):
         super().__init__(savefield=savefield)
         self._transforms = transforms
 
-    def param_hook(self, param):
+    def param_hook(self, param: Tensor):
         """Execute all transformations and store results as dictionary in the parameter.
 
-        Delete individual gradients in the parameter.
-
         Args:
-            param (torch.Tensor): Trainable parameter which hosts BackPACK quantities.
+            param: Trainable parameter which hosts BackPACK quantities.
         """
         param.grad_batch._param_weakref = weakref.ref(param)
         # TODO Delete after backward pass with Cockpit
         param.grad_batch_transforms = {
             key: func(param.grad_batch) for key, func in self._transforms.items()
         }
-        # TODO Delete with a separate hook that also knows which savefield should be
-        # kept because it's protected by the user. See
-        # https://github.com/f-dangel/cockpit-paper/issues/197
-        del param.grad_batch

From 16abca47411a9f5ab76197c77159bf0816540c80 Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 16:02:44 +0200
Subject: [PATCH 06/11] [FIX] flake8

---
 cockpit/context.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/cockpit/context.py b/cockpit/context.py
index a63cb8c..fcaf796 100644
--- a/cockpit/context.py
+++ b/cockpit/context.py
@@ -104,7 +104,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
         self.cp = cp
         self.global_step = global_step
 
-        self.protected_savefields = set(e.savefield for e in custom_exts)
+        self.protected_savefields = {e.savefield for e in custom_exts}
         self.protected_savefields.update(cp._get_protected_savefields(global_step))
 
         # choose context
@@ -118,7 +118,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
         ):
             # TODO Implement all quantities with hooks and specify protected savefields
             # Remove unprotected buffers in this cleanup hook during backpropagation
-            cleanup_hook = CleanupHook(set(["grad_batch"]))
+            cleanup_hook = CleanupHook({"grad_batch"})
             ext_hook = self._combine_hooks(compute_hook, cleanup_hook)
         else:
             ext_hook = compute_hook

From 6803a8a4590849f24f62f729d857a689321301fd Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 16:10:21 +0200
Subject: [PATCH 07/11] [FIX] Protect `'batch_grad'` if step is unknown

---
 cockpit/quantities/alpha.py    |  9 +++++++++
 tests/test_bugs/test_issue5.py | 21 ++++++++++-----------
 2 files changed, 19 insertions(+), 11 deletions(-)

diff --git a/cockpit/quantities/alpha.py b/cockpit/quantities/alpha.py
index 5590ba2..5e39774 100644
--- a/cockpit/quantities/alpha.py
+++ b/cockpit/quantities/alpha.py
@@ -2,6 +2,7 @@
 
 import itertools
 import warnings
+from typing import Set
 
 import numpy as np
 import torch
@@ -89,6 +90,14 @@ def extension_hooks(self, global_step):
 
         return hooks
 
+    def protected_savefields(self, global_step: int) -> Set[str]:
+        protected = set()
+        if self.is_start(global_step) or self.is_end(global_step):
+            if not self.__projection_with_backpack(global_step):
+                protected.update(["grad_batch"])
+
+        return protected
+
     def is_start(self, global_step):
         """Return whether current iteration is start point.
 
diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py
index 10f6470..c0a61d1 100644
--- a/tests/test_bugs/test_issue5.py
+++ b/tests/test_bugs/test_issue5.py
@@ -44,14 +44,13 @@ def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha():
     losses = individual_loss_fn(outputs, labels)
 
     # backward pass
-    with raises(AttributeError):
-        with cockpit(
-            global_step,
-            info={
-                "batch_size": N,
-                "individual_losses": losses,
-                "loss": loss,
-                "optimizer": opt_not_sgd,
-            },
-        ):
-            loss.backward(create_graph=cockpit.create_graph(global_step))
+    with cockpit(
+        global_step,
+        info={
+            "batch_size": N,
+            "individual_losses": losses,
+            "loss": loss,
+            "optimizer": opt_not_sgd,
+        },
+    ):
+        loss.backward(create_graph=cockpit.create_graph(global_step))

From 431f67f6bef19c8601446ab8d24a4fb9a90f717f Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 16:14:20 +0200
Subject: [PATCH 08/11] [DEL] Remove unused import

---
 tests/test_bugs/test_issue5.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py
index c0a61d1..9fb6841 100644
--- a/tests/test_bugs/test_issue5.py
+++ b/tests/test_bugs/test_issue5.py
@@ -1,7 +1,6 @@
 """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5."""
 
 from backpack import extend
-from pytest import raises
 from torch import manual_seed, rand
 from torch.nn import Flatten, Linear, MSELoss, Sequential
 from torch.optim import Adam

From 732a86a67c403b9331bd921186b8ca6862a3411f Mon Sep 17 00:00:00 2001
From: Felix Dangel <felix.dangel@tuebingen.mpg.de>
Date: Fri, 15 Oct 2021 16:19:21 +0200
Subject: [PATCH 09/11] [DOC] Ignore missing docstring (same as parent)

---
 cockpit/quantities/alpha.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/cockpit/quantities/alpha.py b/cockpit/quantities/alpha.py
index 5e39774..474a0df 100644
--- a/cockpit/quantities/alpha.py
+++ b/cockpit/quantities/alpha.py
@@ -90,7 +90,7 @@ def extension_hooks(self, global_step):
 
         return hooks
 
-    def protected_savefields(self, global_step: int) -> Set[str]:
+    def protected_savefields(self, global_step: int) -> Set[str]:  # noqa: D102
         protected = set()
         if self.is_start(global_step) or self.is_end(global_step):
             if not self.__projection_with_backpack(global_step):

From be7f567542e3e3e4b75fbf5ff4d673d921ab1dcd Mon Sep 17 00:00:00 2001
From: Frank Schneider <frank.stefan.schneider@gmail.com>
Date: Wed, 20 Oct 2021 15:11:41 +0200
Subject: [PATCH 10/11] [DOC] Link to cockpit-experiments repo

---
 README.md | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 8b9ffbd..05a5310 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,7 @@
 <p align="center">
   <a href="#installation">Installation</a> •
   <a href="https://cockpit.readthedocs.io/">Docs</a> •
+  <a href="https://github.com/fsschneider/cockpit-experiments">Experiments</a> •
   <a href="#license">License</a> •
   <a href="#citation">Citation</a>
 </p>
@@ -60,7 +61,7 @@ The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on
 <!-- Experiments -->
 ## Experiments
 
-To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604).
+To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604).
 
 <!-- LICENSE -->
 ## License

From c00852e8bac3561bb7971a751588d52e5040fb78 Mon Sep 17 00:00:00 2001
From: Frank Schneider <frank.stefan.schneider@gmail.com>
Date: Tue, 26 Oct 2021 13:39:33 +0200
Subject: [PATCH 11/11] [DOC] Changelog for v1.0.2

---
 CHANGELOG.md | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index b9b8a6e..428cb40 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
 
 ## [Unreleased]
 
+## [1.0.2] - 2021-10-26
+
+### Added
+
+- Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper.
+
+### Fixed
+
+- Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)]
+
 ## [1.0.1] - 2021-10-13
 
 From this version on, `cockpit` will be available as `cockpit-for-pytorch` on
@@ -30,6 +40,7 @@ PyPI.
 
 - First public release version of **Cockpit**.
 
-[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.1...HEAD
+[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD
+[1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2
 [1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1
 [1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0