diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index c920f4b069..e287f4b2de 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -315,7 +315,7 @@ def setup(self, cfg: DictConfig) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, + cfg_lr_scheduler=cfg.get("lr_scheduler", None), num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -626,10 +626,15 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, @@ -886,7 +891,8 @@ def train(self) -> None: kd_loss_to_log = running_kd_loss.item() / num_tokens self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() + if self._lr_scheduler is not None: + self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index ef238da44d..844d7a6d89 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -306,7 +306,7 @@ def setup(self, cfg: DictConfig) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, + cfg_lr_scheduler=cfg.get("lr_scheduler", None), num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -495,10 +495,16 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, @@ -727,7 +733,8 @@ def train(self) -> None: ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() + if self._lr_scheduler is not None: + self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index ab37623cc1..ba2146ffbe 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -284,7 +284,7 @@ def setup(self, cfg: DictConfig) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, + cfg_lr_scheduler=cfg.get("lr_scheduler", None), num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -426,10 +426,16 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, @@ -679,7 +685,8 @@ def train(self) -> None: if (idx + 1) % self._gradient_accumulation_steps == 0: self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() + if self._lr_scheduler is not None: + self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 45209814a0..7a9eb97bee 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -334,7 +334,7 @@ def setup(self, cfg: DictConfig) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, + cfg_lr_scheduler=cfg.get("lr_scheduler", None), num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -563,10 +563,16 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, @@ -837,7 +843,8 @@ def train(self) -> None: ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() + if self._lr_scheduler is not None: + self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index fcdb3e4ea5..5e7aaa6c16 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -331,7 +331,7 @@ def setup(self, cfg: DictConfig) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, + cfg_lr_scheduler=cfg.get("lr_scheduler", None), num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -497,10 +497,16 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, @@ -717,7 +723,8 @@ def train(self) -> None: ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() + if self._lr_scheduler is not None: + self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1