diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 6e361564e0..b15b620ae1 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -471,6 +471,7 @@ class AutoUnit( this option to True is not needed and often can be worked around in a much more efficient way. enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded + zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback. Note: Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well. @@ -506,6 +507,7 @@ def __init__( enable_compiled_autograd: bool = False, loss_backward_retain_graph: Optional[bool] = None, enable_prefetch: bool = True, + zero_grad_at_train_step_start: bool = False, ) -> None: super().__init__( module=module, @@ -576,6 +578,10 @@ def __init__( self.lr_scheduler: Optional[TLRScheduler] = None self.swa_scheduler: Optional[SWALR] = None + self.zero_grad_at_train_step_start: bool = zero_grad_at_train_step_start + # keep track of when to zero grad at train step start + self._weight_updated_in_prev_step = False + def __setattr__(self, name: str, value: object) -> None: if isinstance(value, torch.nn.Module): self._validate_module_attr(name, value) @@ -653,6 +659,11 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: self.train_progress.num_steps_completed_in_epoch + 1 ) % self.gradient_accumulation_steps == 0 or self._is_last_batch + # zero the gradients if previous step updated weights + if self._weight_updated_in_prev_step and self.zero_grad_at_train_step_start: + self.zero_grad(state) + self._weight_updated_in_prev_step = False + # for pyre, assign to local variable module = self.module @@ -829,21 +840,24 @@ def step_lr_scheduler(self) -> None: """ none_throws(self.lr_scheduler).step() - def zero_grad(self) -> None: + def zero_grad(self, state: State) -> None: """ - Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them. + Zeroes the gradients of the module's parameters. You can override this if you want to log the gradients before zeroing them. Example of overriding: class CustomAutoUnit(MyAutoUnit): ... - def zero_grad(self): + def zero_grad(self, state): # log before zeroing gradients super().zero_grad() """ optimizer = none_throws(self.optimizer) - optimizer.zero_grad(set_to_none=True) + with get_timing_context( + state, f"{self.__class__.__name__}.optimizer_zero_grad" + ): + optimizer.zero_grad(set_to_none=True) def _update_weights(self, state: State) -> Optional[torch.Tensor]: """ @@ -904,11 +918,12 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]: else: optimizer.step() - # sets gradients to zero - with get_timing_context( - state, f"{self.__class__.__name__}.optimizer_zero_grad" - ): - self.zero_grad() + if self.zero_grad_at_train_step_start: + # mark that weights were updated in this step + # so in next step we know to zero the gradients + self._weight_updated_in_prev_step = True + else: + self.zero_grad(state) if self.step_lr_interval == "step": self._update_lr_and_swa(state, self.train_progress.num_steps_completed)