-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from f-dangel/development
Fix bug alpha custom optimzer & Add experiment repo
- Loading branch information
Showing
9 changed files
with
193 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Contains hook that deletes BackPACK buffers during backpropagation.""" | ||
|
||
from typing import Set | ||
|
||
from torch import Tensor | ||
|
||
from cockpit.quantities.hooks.base import ParameterExtensionHook | ||
|
||
|
||
class CleanupHook(ParameterExtensionHook): | ||
"""Deletes specified BackPACK buffers during backpropagation.""" | ||
|
||
def __init__(self, delete_savefields: Set[str]): | ||
"""Store savefields to be deleted in the backward pass. | ||
Args: | ||
delete_savefields: Name of buffers to delete. | ||
""" | ||
super().__init__() | ||
self._delete_savefields = delete_savefields | ||
|
||
def param_hook(self, param: Tensor): | ||
"""Delete BackPACK buffers in parameter. | ||
Args: | ||
param: Trainable parameter which hosts BackPACK quantities. | ||
""" | ||
for savefield in self._delete_savefields: | ||
if hasattr(param, savefield): | ||
delattr(param, savefield) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" | ||
|
||
from backpack import extend | ||
from torch import manual_seed, rand | ||
from torch.nn import Flatten, Linear, MSELoss, Sequential | ||
from torch.optim import Adam | ||
|
||
from cockpit import Cockpit | ||
from cockpit.quantities import Alpha, GradHist1d | ||
from cockpit.utils.schedules import linear | ||
|
||
|
||
def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): | ||
"""If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``. | ||
But if an extension that uses ``BatchGradTransformsHook`` is used at the same time, | ||
it will delete the ``grad_batch`` attribute during the backward pass. Consequently, | ||
``Alpha`` cannot access the attribute anymore. This leads to the error. | ||
""" | ||
manual_seed(0) | ||
|
||
N, D_in, D_out = 2, 3, 1 | ||
model = extend(Sequential(Flatten(), Linear(D_in, D_out))) | ||
|
||
opt_not_sgd = Adam(model.parameters(), lr=1e-3) | ||
loss_fn = extend(MSELoss(reduction="mean")) | ||
individual_loss_fn = MSELoss(reduction="none") | ||
|
||
on_first = linear(1) | ||
alpha = Alpha(on_first) | ||
uses_BatchGradTransformsHook = GradHist1d(on_first) | ||
|
||
cockpit = Cockpit( | ||
model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook] | ||
) | ||
|
||
global_step = 0 | ||
inputs, labels = rand(N, D_in), rand(N, D_out) | ||
|
||
# forward pass | ||
outputs = model(inputs) | ||
loss = loss_fn(outputs, labels) | ||
losses = individual_loss_fn(outputs, labels) | ||
|
||
# backward pass | ||
with cockpit( | ||
global_step, | ||
info={ | ||
"batch_size": N, | ||
"individual_losses": losses, | ||
"loss": loss, | ||
"optimizer": opt_not_sgd, | ||
}, | ||
): | ||
loss.backward(create_graph=cockpit.create_graph(global_step)) |