diff --git a/CHANGELOG.md b/CHANGELOG.md index b9b8a6e..428cb40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## [1.0.2] - 2021-10-26 + +### Added + +- Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper. + +### Fixed + +- Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)] + ## [1.0.1] - 2021-10-13 From this version on, `cockpit` will be available as `cockpit-for-pytorch` on @@ -30,6 +40,7 @@ PyPI. - First public release version of **Cockpit**. -[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.1...HEAD +[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD +[1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2 [1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1 [1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0 diff --git a/README.md b/README.md index 8b9ffbd..05a5310 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@

InstallationDocs • + ExperimentsLicenseCitation

@@ -60,7 +61,7 @@ The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on ## Experiments -To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604). +To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604). ## License diff --git a/cockpit/cockpit.py b/cockpit/cockpit.py index 9693c0a..d9aa649 100644 --- a/cockpit/cockpit.py +++ b/cockpit/cockpit.py @@ -2,6 +2,7 @@ import os from collections import defaultdict +from typing import Set import json_tricks from backpack import disable @@ -130,6 +131,22 @@ def _get_extension_hook(self, global_step): return hook + def _get_protected_savefields(self, global_step: int) -> Set[str]: + """Return names of protected BackPACK buffers. + + Args: + global_step: Current iteration number. + + Returns: + List of protected buffers. + """ + protected = set() + + for q in self.quantities: + protected.update(q.protected_savefields(global_step)) + + return protected + def __call__(self, global_step, *exts, info=None, debug=False): """Returns the backpack extensions that should be used in this iteration. @@ -219,16 +236,6 @@ def _free_backpack_buffers(self, global_step, protected_savefields, verbose=Fals except AttributeError: pass - @staticmethod - def _remove_module_io(module): - io_fields = ["input0", "output"] - - for field in io_fields: - try: - delattr(module, field) - except AttributeError: - pass - def log( self, global_step, diff --git a/cockpit/context.py b/cockpit/context.py index b2a9d4d..fcaf796 100644 --- a/cockpit/context.py +++ b/cockpit/context.py @@ -1,9 +1,14 @@ """Cockpit Context.""" import warnings +from typing import Any, Callable, Union from backpack import backpack, disable from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from torch.nn import Module + +from cockpit.quantities.hooks.cleanup import CleanupHook +from cockpit.quantities.utils_transforms import BatchGradTransformsHook class CockpitCTX: @@ -99,11 +104,24 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): self.cp = cp self.global_step = global_step - self.protected_savefields = [e.savefield for e in custom_exts] + self.protected_savefields = {e.savefield for e in custom_exts} + self.protected_savefields.update(cp._get_protected_savefields(global_step)) # choose context ext = cp._get_extensions(global_step, custom_exts=custom_exts) - ext_hook = cp._get_extension_hook(global_step) + compute_hook = cp._get_extension_hook(global_step) + + # Delete 'grad_batch' during backprop if it's not protected + if ( + isinstance(compute_hook, BatchGradTransformsHook) + and "grad_batch" not in self.protected_savefields + ): + # TODO Implement all quantities with hooks and specify protected savefields + # Remove unprotected buffers in this cleanup hook during backpropagation + cleanup_hook = CleanupHook({"grad_batch"}) + ext_hook = self._combine_hooks(compute_hook, cleanup_hook) + else: + ext_hook = compute_hook save_memory = cp.BACKPACK_CONV_SAVE_MEMORY @@ -122,6 +140,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False): print(f" ↪Hooks : {ext_hook}") print(f" ↪Create graph: {cp.create_graph(global_step)}") print(f" ↪Save memory : {save_memory}") + print(f" ↪Protect : {self.protected_savefields}") def __enter__(self): """Enter cockpit context(s).""" @@ -136,3 +155,33 @@ def __exit__(self, type, value, traceback): self.cp.track(self.global_step, protected_savefields=self.protected_savefields) CockpitCTX.erase() + + @staticmethod + def _combine_hooks( + *hooks: Union[Callable[[Module], Any], None] + ) -> Union[Callable[[Module], None], None]: + """Combine multiple extension hooks into a single one. + + Args: + hooks: List of extension hooks to be combined. + + Returns: + Merged hook. ``None`` if all passed hooks were ``None``. + """ + non_empty_hooks = [h for h in hooks if h is not None] + + if non_empty_hooks: + + def hook(module: Module): + """Sequentially execute all hooks on the module. + + Args: + module: Module to run the extension hook on. + """ + for h in non_empty_hooks: + h(module) + + return hook + + else: + return None diff --git a/cockpit/quantities/alpha.py b/cockpit/quantities/alpha.py index 5590ba2..474a0df 100644 --- a/cockpit/quantities/alpha.py +++ b/cockpit/quantities/alpha.py @@ -2,6 +2,7 @@ import itertools import warnings +from typing import Set import numpy as np import torch @@ -89,6 +90,14 @@ def extension_hooks(self, global_step): return hooks + def protected_savefields(self, global_step: int) -> Set[str]: # noqa: D102 + protected = set() + if self.is_start(global_step) or self.is_end(global_step): + if not self.__projection_with_backpack(global_step): + protected.update(["grad_batch"]) + + return protected + def is_start(self, global_step): """Return whether current iteration is start point. diff --git a/cockpit/quantities/hooks/cleanup.py b/cockpit/quantities/hooks/cleanup.py new file mode 100644 index 0000000..8c6f99c --- /dev/null +++ b/cockpit/quantities/hooks/cleanup.py @@ -0,0 +1,30 @@ +"""Contains hook that deletes BackPACK buffers during backpropagation.""" + +from typing import Set + +from torch import Tensor + +from cockpit.quantities.hooks.base import ParameterExtensionHook + + +class CleanupHook(ParameterExtensionHook): + """Deletes specified BackPACK buffers during backpropagation.""" + + def __init__(self, delete_savefields: Set[str]): + """Store savefields to be deleted in the backward pass. + + Args: + delete_savefields: Name of buffers to delete. + """ + super().__init__() + self._delete_savefields = delete_savefields + + def param_hook(self, param: Tensor): + """Delete BackPACK buffers in parameter. + + Args: + param: Trainable parameter which hosts BackPACK quantities. + """ + for savefield in self._delete_savefields: + if hasattr(param, savefield): + delattr(param, savefield) diff --git a/cockpit/quantities/quantity.py b/cockpit/quantities/quantity.py index 775134c..6baebdb 100644 --- a/cockpit/quantities/quantity.py +++ b/cockpit/quantities/quantity.py @@ -1,6 +1,7 @@ """Base class for a tracked quantity.""" from collections import defaultdict +from typing import Set import numpy import torch @@ -74,6 +75,19 @@ def extension_hooks(self, global_step): """ return [] + def protected_savefields(self, global_step: int) -> Set[str]: + """Return list of protected BackPACK buffers at the current step. + + Protected buffers will not be freed during a backward pass. + + Args: + global_step: The current iteration number. + + Returns: + Set of protected savefields. + """ + return set() + def track(self, global_step, params, batch_loss): """Perform scheduled computations and store result. diff --git a/cockpit/quantities/utils_transforms.py b/cockpit/quantities/utils_transforms.py index 3c0e67d..e0517ad 100644 --- a/cockpit/quantities/utils_transforms.py +++ b/cockpit/quantities/utils_transforms.py @@ -3,7 +3,7 @@ import string import weakref -from torch import einsum +from torch import Tensor, einsum from cockpit.quantities.hooks.base import ParameterExtensionHook @@ -67,20 +67,14 @@ def __init__(self, transforms, savefield=None): super().__init__(savefield=savefield) self._transforms = transforms - def param_hook(self, param): + def param_hook(self, param: Tensor): """Execute all transformations and store results as dictionary in the parameter. - Delete individual gradients in the parameter. - Args: - param (torch.Tensor): Trainable parameter which hosts BackPACK quantities. + param: Trainable parameter which hosts BackPACK quantities. """ param.grad_batch._param_weakref = weakref.ref(param) # TODO Delete after backward pass with Cockpit param.grad_batch_transforms = { key: func(param.grad_batch) for key, func in self._transforms.items() } - # TODO Delete with a separate hook that also knows which savefield should be - # kept because it's protected by the user. See - # https://github.com/f-dangel/cockpit-paper/issues/197 - del param.grad_batch diff --git a/tests/test_bugs/test_issue5.py b/tests/test_bugs/test_issue5.py new file mode 100644 index 0000000..9fb6841 --- /dev/null +++ b/tests/test_bugs/test_issue5.py @@ -0,0 +1,55 @@ +"""Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" + +from backpack import extend +from torch import manual_seed, rand +from torch.nn import Flatten, Linear, MSELoss, Sequential +from torch.optim import Adam + +from cockpit import Cockpit +from cockpit.quantities import Alpha, GradHist1d +from cockpit.utils.schedules import linear + + +def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): + """If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``. + + But if an extension that uses ``BatchGradTransformsHook`` is used at the same time, + it will delete the ``grad_batch`` attribute during the backward pass. Consequently, + ``Alpha`` cannot access the attribute anymore. This leads to the error. + """ + manual_seed(0) + + N, D_in, D_out = 2, 3, 1 + model = extend(Sequential(Flatten(), Linear(D_in, D_out))) + + opt_not_sgd = Adam(model.parameters(), lr=1e-3) + loss_fn = extend(MSELoss(reduction="mean")) + individual_loss_fn = MSELoss(reduction="none") + + on_first = linear(1) + alpha = Alpha(on_first) + uses_BatchGradTransformsHook = GradHist1d(on_first) + + cockpit = Cockpit( + model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook] + ) + + global_step = 0 + inputs, labels = rand(N, D_in), rand(N, D_out) + + # forward pass + outputs = model(inputs) + loss = loss_fn(outputs, labels) + losses = individual_loss_fn(outputs, labels) + + # backward pass + with cockpit( + global_step, + info={ + "batch_size": N, + "individual_losses": losses, + "loss": loss, + "optimizer": opt_not_sgd, + }, + ): + loss.backward(create_graph=cockpit.create_graph(global_step))