Skip to content

Commit

Permalink
move out zero grad logic into separate function (#969)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #969

# Context

Currently it isn't possible to log gradients from AutoUnit as they are zeroed out before `on_train_step_end()` is reached.

# This Diff

Moves out the zeroed grad from the `_update_weights` and into it's own function. Can be overridden, ie
```
class MyAutoUnit(AutoUnit):
    ...

    def zero_grad(self) ->
        self.logger.log(self.module.grad)
        super().zero_grad()
```
to log the gradients prior to zeroing them out

Reviewed By: galrotem, diego-urgell

Differential Revision: D68983117

fbshipit-source-id: 744b72c5634d8b6979ef1145fc3254ddde93d743
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Feb 3, 2025
1 parent 93347d9 commit 5bc1702
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,22 @@ def step_lr_scheduler(self) -> None:
"""
none_throws(self.lr_scheduler).step()

def zero_grad(self) -> None:
"""
Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them.
Example of overriding:
class CustomAutoUnit(MyAutoUnit):
...
def zero_grad(self):
# log before zeroing gradients
super().zero_grad()
"""

optimizer = none_throws(self.optimizer)
optimizer.zero_grad(set_to_none=True)

def _update_weights(self, state: State) -> Optional[torch.Tensor]:
"""
Updates weights of the module, handles clip gradient norm, etc.
Expand Down Expand Up @@ -892,7 +908,7 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
with get_timing_context(
state, f"{self.__class__.__name__}.optimizer_zero_grad"
):
optimizer.zero_grad(set_to_none=True)
self.zero_grad()

if self.step_lr_interval == "step":
self._update_lr_and_swa(state, self.train_progress.num_steps_completed)
Expand Down

0 comments on commit 5bc1702

Please sign in to comment.