diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 72c03f55d9..fff39d461f 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -52,6 +52,13 @@ seed: null shuffle: True batch_size: 4 +validation: + dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False + run_every_n_epochs: .5 + max_batches: 5 + # Optimizer and Scheduler optimizer: _component_: torch.optim.AdamW diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 9a3f3eacfb..964f75fbd5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -335,6 +335,36 @@ def setup(self, cfg: DictConfig) -> None: last_epoch=self.global_step - 1, ) + # Setup the validation dataset + validation_config = cfg.get("validation") + self.run_validation = validation_config is not None + actual_steps_per_epoch = ( + self._steps_per_epoch + if self.max_steps_per_epoch is None + else min(self._steps_per_epoch, self.max_steps_per_epoch) + ) + self.run_val_every_n_steps = int( + actual_steps_per_epoch * validation_config.get("run_every_n_epochs", 0) + ) + if self.run_validation: + self._sampler_val, self._dataloader_val = self._setup_data( + cfg_dataset=validation_config.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + if self.run_val_every_n_steps is None: + log.warning( + f"""Validation is enabled but run_val_every_n_steps is not set. + Will be set to steps_per_epoch. ({self._steps_per_epoch})""" + ) + self.run_val_every_n_steps = self._steps_per_epoch + elif self.run_val_every_n_steps < 1: + raise ValueError("run_val_every_n_steps must be greater than 0.") + + self.max_validation_batches = validation_config.get("max_batches", -1) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) @@ -652,6 +682,34 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: return loss + def validate(self, curr_epoch) -> None: + pbar = tqdm( + total=min(len(self._dataloader_val), self.max_validation_batches - 1) + ) + val_losses = [] + for idx, batch in enumerate(self._dataloader_val): + if ( + self.max_validation_batches > 0 + and idx == self.max_validation_batches - 1 + ): + break + utils.batch_to_device(batch, self._device) + + current_loss = self._loss_step(batch) + val_losses.append(current_loss.item()) + + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{idx}|Validation Loss: {current_loss.item()}" + ) + + self._metric_logger.log_dict( + { + "avg_val_loss": sum(val_losses) / len(val_losses), + }, + step=self.global_step, + ) + def train(self) -> None: """ The core training loop. @@ -674,7 +732,7 @@ def train(self) -> None: # in case shuffle is True self._sampler.set_epoch(curr_epoch) - pbar = tqdm(total=self._steps_per_epoch) + pbar = tqdm(total=self._steps_per_epoch, position=0) for idx, batch in enumerate(self._dataloader): if ( self.max_steps_per_epoch is not None @@ -723,7 +781,7 @@ def train(self) -> None: loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Training Loss: {loss_to_log}" ) # Log per-step metrics @@ -753,6 +811,12 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + if ( + self.run_validation + and self.global_step % self.run_val_every_n_steps == 0 + ): + self.validate(curr_epoch=curr_epoch) + # Stop tracking CUDA memory now that active steps are complete if ( curr_epoch == 0