Skip to content

Commit

Permalink
Merge pull request #20 from f-dangel/development
Browse files Browse the repository at this point in the history
Fix bug alpha custom optimzer & Add experiment repo
  • Loading branch information
f-dangel authored Oct 26, 2021
2 parents 86f2ef4 + c00852e commit af91391
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 23 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]

## [1.0.2] - 2021-10-26

### Added

- Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper.

### Fixed

- Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)]

## [1.0.1] - 2021-10-13

From this version on, `cockpit` will be available as `cockpit-for-pytorch` on
Expand All @@ -30,6 +40,7 @@ PyPI.

- First public release version of **Cockpit**.

[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.1...HEAD
[Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD
[1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2
[1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1
[1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<p align="center">
<a href="#installation">Installation</a> •
<a href="https://cockpit.readthedocs.io/">Docs</a> •
<a href="https://github.com/fsschneider/cockpit-experiments">Experiments</a> •
<a href="#license">License</a> •
<a href="#citation">Citation</a>
</p>
Expand Down Expand Up @@ -60,7 +61,7 @@ The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on
<!-- Experiments -->
## Experiments

To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604).
To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604).

<!-- LICENSE -->
## License
Expand Down
27 changes: 17 additions & 10 deletions cockpit/cockpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from collections import defaultdict
from typing import Set

import json_tricks
from backpack import disable
Expand Down Expand Up @@ -130,6 +131,22 @@ def _get_extension_hook(self, global_step):

return hook

def _get_protected_savefields(self, global_step: int) -> Set[str]:
"""Return names of protected BackPACK buffers.
Args:
global_step: Current iteration number.
Returns:
List of protected buffers.
"""
protected = set()

for q in self.quantities:
protected.update(q.protected_savefields(global_step))

return protected

def __call__(self, global_step, *exts, info=None, debug=False):
"""Returns the backpack extensions that should be used in this iteration.
Expand Down Expand Up @@ -219,16 +236,6 @@ def _free_backpack_buffers(self, global_step, protected_savefields, verbose=Fals
except AttributeError:
pass

@staticmethod
def _remove_module_io(module):
io_fields = ["input0", "output"]

for field in io_fields:
try:
delattr(module, field)
except AttributeError:
pass

def log(
self,
global_step,
Expand Down
53 changes: 51 additions & 2 deletions cockpit/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Cockpit Context."""

import warnings
from typing import Any, Callable, Union

from backpack import backpack, disable
from backpack.core.derivatives.convnd import weight_jac_t_save_memory
from torch.nn import Module

from cockpit.quantities.hooks.cleanup import CleanupHook
from cockpit.quantities.utils_transforms import BatchGradTransformsHook


class CockpitCTX:
Expand Down Expand Up @@ -99,11 +104,24 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
self.cp = cp
self.global_step = global_step

self.protected_savefields = [e.savefield for e in custom_exts]
self.protected_savefields = {e.savefield for e in custom_exts}
self.protected_savefields.update(cp._get_protected_savefields(global_step))

# choose context
ext = cp._get_extensions(global_step, custom_exts=custom_exts)
ext_hook = cp._get_extension_hook(global_step)
compute_hook = cp._get_extension_hook(global_step)

# Delete 'grad_batch' during backprop if it's not protected
if (
isinstance(compute_hook, BatchGradTransformsHook)
and "grad_batch" not in self.protected_savefields
):
# TODO Implement all quantities with hooks and specify protected savefields
# Remove unprotected buffers in this cleanup hook during backpropagation
cleanup_hook = CleanupHook({"grad_batch"})
ext_hook = self._combine_hooks(compute_hook, cleanup_hook)
else:
ext_hook = compute_hook

save_memory = cp.BACKPACK_CONV_SAVE_MEMORY

Expand All @@ -122,6 +140,7 @@ def __init__(self, cp, global_step, custom_exts, info, debug=False):
print(f" ↪Hooks : {ext_hook}")
print(f" ↪Create graph: {cp.create_graph(global_step)}")
print(f" ↪Save memory : {save_memory}")
print(f" ↪Protect : {self.protected_savefields}")

def __enter__(self):
"""Enter cockpit context(s)."""
Expand All @@ -136,3 +155,33 @@ def __exit__(self, type, value, traceback):
self.cp.track(self.global_step, protected_savefields=self.protected_savefields)

CockpitCTX.erase()

@staticmethod
def _combine_hooks(
*hooks: Union[Callable[[Module], Any], None]
) -> Union[Callable[[Module], None], None]:
"""Combine multiple extension hooks into a single one.
Args:
hooks: List of extension hooks to be combined.
Returns:
Merged hook. ``None`` if all passed hooks were ``None``.
"""
non_empty_hooks = [h for h in hooks if h is not None]

if non_empty_hooks:

def hook(module: Module):
"""Sequentially execute all hooks on the module.
Args:
module: Module to run the extension hook on.
"""
for h in non_empty_hooks:
h(module)

return hook

else:
return None
9 changes: 9 additions & 0 deletions cockpit/quantities/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import warnings
from typing import Set

import numpy as np
import torch
Expand Down Expand Up @@ -89,6 +90,14 @@ def extension_hooks(self, global_step):

return hooks

def protected_savefields(self, global_step: int) -> Set[str]: # noqa: D102
protected = set()
if self.is_start(global_step) or self.is_end(global_step):
if not self.__projection_with_backpack(global_step):
protected.update(["grad_batch"])

return protected

def is_start(self, global_step):
"""Return whether current iteration is start point.
Expand Down
30 changes: 30 additions & 0 deletions cockpit/quantities/hooks/cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Contains hook that deletes BackPACK buffers during backpropagation."""

from typing import Set

from torch import Tensor

from cockpit.quantities.hooks.base import ParameterExtensionHook


class CleanupHook(ParameterExtensionHook):
"""Deletes specified BackPACK buffers during backpropagation."""

def __init__(self, delete_savefields: Set[str]):
"""Store savefields to be deleted in the backward pass.
Args:
delete_savefields: Name of buffers to delete.
"""
super().__init__()
self._delete_savefields = delete_savefields

def param_hook(self, param: Tensor):
"""Delete BackPACK buffers in parameter.
Args:
param: Trainable parameter which hosts BackPACK quantities.
"""
for savefield in self._delete_savefields:
if hasattr(param, savefield):
delattr(param, savefield)
14 changes: 14 additions & 0 deletions cockpit/quantities/quantity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base class for a tracked quantity."""

from collections import defaultdict
from typing import Set

import numpy
import torch
Expand Down Expand Up @@ -74,6 +75,19 @@ def extension_hooks(self, global_step):
"""
return []

def protected_savefields(self, global_step: int) -> Set[str]:
"""Return list of protected BackPACK buffers at the current step.
Protected buffers will not be freed during a backward pass.
Args:
global_step: The current iteration number.
Returns:
Set of protected savefields.
"""
return set()

def track(self, global_step, params, batch_loss):
"""Perform scheduled computations and store result.
Expand Down
12 changes: 3 additions & 9 deletions cockpit/quantities/utils_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import string
import weakref

from torch import einsum
from torch import Tensor, einsum

from cockpit.quantities.hooks.base import ParameterExtensionHook

Expand Down Expand Up @@ -67,20 +67,14 @@ def __init__(self, transforms, savefield=None):
super().__init__(savefield=savefield)
self._transforms = transforms

def param_hook(self, param):
def param_hook(self, param: Tensor):
"""Execute all transformations and store results as dictionary in the parameter.
Delete individual gradients in the parameter.
Args:
param (torch.Tensor): Trainable parameter which hosts BackPACK quantities.
param: Trainable parameter which hosts BackPACK quantities.
"""
param.grad_batch._param_weakref = weakref.ref(param)
# TODO Delete after backward pass with Cockpit
param.grad_batch_transforms = {
key: func(param.grad_batch) for key, func in self._transforms.items()
}
# TODO Delete with a separate hook that also knows which savefield should be
# kept because it's protected by the user. See
# https://github.com/f-dangel/cockpit-paper/issues/197
del param.grad_batch
55 changes: 55 additions & 0 deletions tests/test_bugs/test_issue5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5."""

from backpack import extend
from torch import manual_seed, rand
from torch.nn import Flatten, Linear, MSELoss, Sequential
from torch.optim import Adam

from cockpit import Cockpit
from cockpit.quantities import Alpha, GradHist1d
from cockpit.utils.schedules import linear


def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha():
"""If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``.
But if an extension that uses ``BatchGradTransformsHook`` is used at the same time,
it will delete the ``grad_batch`` attribute during the backward pass. Consequently,
``Alpha`` cannot access the attribute anymore. This leads to the error.
"""
manual_seed(0)

N, D_in, D_out = 2, 3, 1
model = extend(Sequential(Flatten(), Linear(D_in, D_out)))

opt_not_sgd = Adam(model.parameters(), lr=1e-3)
loss_fn = extend(MSELoss(reduction="mean"))
individual_loss_fn = MSELoss(reduction="none")

on_first = linear(1)
alpha = Alpha(on_first)
uses_BatchGradTransformsHook = GradHist1d(on_first)

cockpit = Cockpit(
model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook]
)

global_step = 0
inputs, labels = rand(N, D_in), rand(N, D_out)

# forward pass
outputs = model(inputs)
loss = loss_fn(outputs, labels)
losses = individual_loss_fn(outputs, labels)

# backward pass
with cockpit(
global_step,
info={
"batch_size": N,
"individual_losses": losses,
"loss": loss,
"optimizer": opt_not_sgd,
},
):
loss.backward(create_graph=cockpit.create_graph(global_step))

0 comments on commit af91391

Please sign in to comment.