Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move zero grads logic at the beginning of train step #974

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ class AutoUnit(
this option to True is not needed and often can be worked around
in a much more efficient way.
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.

Note:
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
Expand Down Expand Up @@ -506,6 +507,7 @@ def __init__(
enable_compiled_autograd: bool = False,
loss_backward_retain_graph: Optional[bool] = None,
enable_prefetch: bool = True,
zero_grad_at_train_step_start: bool = False,
) -> None:
super().__init__(
module=module,
Expand Down Expand Up @@ -576,6 +578,10 @@ def __init__(
self.lr_scheduler: Optional[TLRScheduler] = None
self.swa_scheduler: Optional[SWALR] = None

self.zero_grad_at_train_step_start: bool = zero_grad_at_train_step_start
# keep track of when to zero grad at train step start
self._weight_updated_in_prev_step = False

def __setattr__(self, name: str, value: object) -> None:
if isinstance(value, torch.nn.Module):
self._validate_module_attr(name, value)
Expand Down Expand Up @@ -653,6 +659,11 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
self.train_progress.num_steps_completed_in_epoch + 1
) % self.gradient_accumulation_steps == 0 or self._is_last_batch

# zero the gradients if previous step updated weights
if self._weight_updated_in_prev_step and self.zero_grad_at_train_step_start:
self.zero_grad(state)
self._weight_updated_in_prev_step = False

# for pyre, assign to local variable
module = self.module

Expand Down Expand Up @@ -829,21 +840,24 @@ def step_lr_scheduler(self) -> None:
"""
none_throws(self.lr_scheduler).step()

def zero_grad(self) -> None:
def zero_grad(self, state: State) -> None:
"""
Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them.
Zeroes the gradients of the module's parameters. You can override this if you want to log the gradients before zeroing them.

Example of overriding:
class CustomAutoUnit(MyAutoUnit):
...

def zero_grad(self):
def zero_grad(self, state):
# log before zeroing gradients
super().zero_grad()
"""

optimizer = none_throws(self.optimizer)
optimizer.zero_grad(set_to_none=True)
with get_timing_context(
state, f"{self.__class__.__name__}.optimizer_zero_grad"
):
optimizer.zero_grad(set_to_none=True)

def _update_weights(self, state: State) -> Optional[torch.Tensor]:
"""
Expand Down Expand Up @@ -904,11 +918,12 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
else:
optimizer.step()

# sets gradients to zero
with get_timing_context(
state, f"{self.__class__.__name__}.optimizer_zero_grad"
):
self.zero_grad()
if self.zero_grad_at_train_step_start:
# mark that weights were updated in this step
# so in next step we know to zero the gradients
self._weight_updated_in_prev_step = True
else:
self.zero_grad(state)

if self.step_lr_interval == "step":
self._update_lr_and_swa(state, self.train_progress.num_steps_completed)
Expand Down
Loading