From e1a0b23da99dae0e0f1edf43d0b945914a1b5095 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Wed, 8 Jan 2025 14:14:45 +0000 Subject: [PATCH 01/11] Adds validation loss to LoRA fine tune single device --- recipes/lora_finetune_single_device.py | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 9a3f3eacfb..e05beca6bc 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -310,6 +310,14 @@ def setup(self, cfg: DictConfig) -> None: collate_fn=collate_name, ) + if "dataset_validation" in cfg: + self._sampler_val, self._dataloader_val = self._setup_data( + cfg_dataset=cfg.dataset_validation, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + # Finally update the recipe state which can only be correctly set after all of the # other components have been initialized and updated. @@ -670,6 +678,7 @@ def train(self) -> None: with self._profiler as prof: # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): + # TRAINING LOOP # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) @@ -779,6 +788,38 @@ def train(self) -> None: ) ) + # VALIDATION LOOP + self._sampler_val.set_epoch(curr_epoch) + + pbar = tqdm(total=len(self._dataloader_val)) + val_losses = [] + for idx, batch in enumerate(self._dataloader_val): + utils.batch_to_device(batch, self._device) + + current_loss = self._loss_step(batch) + val_losses.append(current_loss.item()) + + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{idx}|Loss: {current_loss.item()}" + ) + + log_dict = { + "val_loss": current_loss.item(), + } + self._metric_logger.log_dict( + log_dict, + step=(curr_epoch + 1) * idx, + ) + + self._metric_logger.log_dict( + { + "avg_val_loss": sum(val_losses) / len(val_losses), + "epoch": curr_epoch + 1, + }, + step=self.global_step, + ) + def cleanup(self) -> None: self._metric_logger.close() From ef8084654ed9188e1d348f152b10df9356c1267c Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Fri, 10 Jan 2025 21:55:36 +0000 Subject: [PATCH 02/11] Fixes execution without validation set --- recipes/lora_finetune_single_device.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index e05beca6bc..f43fcb6b28 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -310,7 +310,8 @@ def setup(self, cfg: DictConfig) -> None: collate_fn=collate_name, ) - if "dataset_validation" in cfg: + self.run_validation = "dataset_validation" in cfg + if self.run_validation: self._sampler_val, self._dataloader_val = self._setup_data( cfg_dataset=cfg.dataset_validation, shuffle=cfg.shuffle, @@ -789,6 +790,8 @@ def train(self) -> None: ) # VALIDATION LOOP + if not self.run_validation: + continue self._sampler_val.set_epoch(curr_epoch) pbar = tqdm(total=len(self._dataloader_val)) @@ -812,13 +815,14 @@ def train(self) -> None: step=(curr_epoch + 1) * idx, ) - self._metric_logger.log_dict( - { - "avg_val_loss": sum(val_losses) / len(val_losses), - "epoch": curr_epoch + 1, - }, - step=self.global_step, - ) + if self.run_validation: + self._metric_logger.log_dict( + { + "avg_val_loss": sum(val_losses) / len(val_losses), + "epoch": curr_epoch + 1, + }, + step=self.global_step, + ) def cleanup(self) -> None: self._metric_logger.close() From df8cd1e4a0dd3020ed4d51ecb9cc7eddf2b838bd Mon Sep 17 00:00:00 2001 From: "Massimo.Frasson" Date: Wed, 15 Jan 2025 13:03:04 +0000 Subject: [PATCH 03/11] Moves validate to separate method; adds call after n steps and early stopping --- recipes/lora_finetune_single_device.py | 109 +++++++++++++++---------- 1 file changed, 65 insertions(+), 44 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index f43fcb6b28..89a45b5d76 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -310,15 +310,6 @@ def setup(self, cfg: DictConfig) -> None: collate_fn=collate_name, ) - self.run_validation = "dataset_validation" in cfg - if self.run_validation: - self._sampler_val, self._dataloader_val = self._setup_data( - cfg_dataset=cfg.dataset_validation, - shuffle=cfg.shuffle, - batch_size=cfg.batch_size, - collate_fn=collate_name, - ) - # Finally update the recipe state which can only be correctly set after all of the # other components have been initialized and updated. @@ -344,6 +335,29 @@ def setup(self, cfg: DictConfig) -> None: last_epoch=self.global_step - 1, ) + # Setup the validation dataset + self.run_validation = "dataset_validation" in cfg + if self.run_validation: + self._sampler_val, self._dataloader_val = self._setup_data( + cfg_dataset=cfg.dataset_validation, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) + if self.run_validation: + if self.run_val_every_n_steps is None: + log.warning( + f"""Validation is enabled but run_val_every_n_steps is not set. + Will be set to steps_per_epoch. ({self._steps_per_epoch})""" + ) + self.run_val_every_n_steps = self._steps_per_epoch + elif self.run_val_every_n_steps < 1: + raise ValueError("run_val_every_n_steps must be greater than 0.") + + self.max_validation_batches = cfg.get("max_validation_batches", -1) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) @@ -661,6 +675,43 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: return loss + def validate(self, curr_epoch) -> None: + pbar = tqdm(total=max(len(self._dataloader_val), self.max_validation_batches)) + val_losses = [] + for idx, batch in enumerate(self._dataloader_val): + if ( + self.max_validation_batches > 0 + and idx == self.max_validation_batches - 1 + ): + break + utils.batch_to_device(batch, self._device) + + current_loss = self._loss_step(batch) + val_losses.append(current_loss.item()) + + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{idx}| Validation Loss: {current_loss.item()}" + ) + + # This bit allows to see the loss for each batch. Not sure about step indexing. + log_dict = { + "val_loss": current_loss.item(), + } + self._metric_logger.log_dict( + log_dict, + step=(curr_epoch + 1) * idx + idx, + ) + + if self.run_validation: + self._metric_logger.log_dict( + { + "avg_val_loss": sum(val_losses) / len(val_losses), + "epoch": curr_epoch + 1, + }, + step=self.global_step, + ) + def train(self) -> None: """ The core training loop. @@ -679,7 +730,6 @@ def train(self) -> None: with self._profiler as prof: # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): - # TRAINING LOOP # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) @@ -789,40 +839,11 @@ def train(self) -> None: ) ) - # VALIDATION LOOP - if not self.run_validation: - continue - self._sampler_val.set_epoch(curr_epoch) - - pbar = tqdm(total=len(self._dataloader_val)) - val_losses = [] - for idx, batch in enumerate(self._dataloader_val): - utils.batch_to_device(batch, self._device) - - current_loss = self._loss_step(batch) - val_losses.append(current_loss.item()) - - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{idx}|Loss: {current_loss.item()}" - ) - - log_dict = { - "val_loss": current_loss.item(), - } - self._metric_logger.log_dict( - log_dict, - step=(curr_epoch + 1) * idx, - ) - - if self.run_validation: - self._metric_logger.log_dict( - { - "avg_val_loss": sum(val_losses) / len(val_losses), - "epoch": curr_epoch + 1, - }, - step=self.global_step, - ) + if ( + self.run_validation + and self.global_step % self.run_val_every_n_steps == 0 + ): + self.validate(curr_epoch=curr_epoch) def cleanup(self) -> None: self._metric_logger.close() From 555b670512625801eb9284500ac9240633116693 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 11:25:43 +0000 Subject: [PATCH 04/11] Removes duplicate if --- recipes/lora_finetune_single_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 89a45b5d76..47ece6f3f5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -336,6 +336,7 @@ def setup(self, cfg: DictConfig) -> None: ) # Setup the validation dataset + self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) self.run_validation = "dataset_validation" in cfg if self.run_validation: self._sampler_val, self._dataloader_val = self._setup_data( @@ -345,8 +346,6 @@ def setup(self, cfg: DictConfig) -> None: collate_fn=collate_name, ) - self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) - if self.run_validation: if self.run_val_every_n_steps is None: log.warning( f"""Validation is enabled but run_val_every_n_steps is not set. From 2ac741e44288d007c45109c40684537823ffff51 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 11:30:09 +0000 Subject: [PATCH 05/11] Removes always true check --- recipes/lora_finetune_single_device.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 47ece6f3f5..0b80e497b7 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -702,14 +702,13 @@ def validate(self, curr_epoch) -> None: step=(curr_epoch + 1) * idx + idx, ) - if self.run_validation: - self._metric_logger.log_dict( - { - "avg_val_loss": sum(val_losses) / len(val_losses), - "epoch": curr_epoch + 1, - }, - step=self.global_step, - ) + self._metric_logger.log_dict( + { + "avg_val_loss": sum(val_losses) / len(val_losses), + "epoch": curr_epoch + 1, + }, + step=self.global_step, + ) def train(self) -> None: """ From 61ac4f67157730ec3b440ccbde16e00394cb2bb1 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 11:38:44 +0000 Subject: [PATCH 06/11] Groups all validation config under the validation key --- recipes/lora_finetune_single_device.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 0b80e497b7..3bf1c86878 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -336,11 +336,12 @@ def setup(self, cfg: DictConfig) -> None: ) # Setup the validation dataset - self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) - self.run_validation = "dataset_validation" in cfg + validation_config = cfg.get("validation") + self.run_validation = validation_config is not None + self.run_val_every_n_steps = validation_config.get("run_every_n_steps") if self.run_validation: self._sampler_val, self._dataloader_val = self._setup_data( - cfg_dataset=cfg.dataset_validation, + cfg_dataset=validation_config.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, collate_fn=collate_name, @@ -355,7 +356,9 @@ def setup(self, cfg: DictConfig) -> None: elif self.run_val_every_n_steps < 1: raise ValueError("run_val_every_n_steps must be greater than 0.") - self.max_validation_batches = cfg.get("max_validation_batches", -1) + self.max_validation_batches = validation_config.get( + "max_validation_batches", -1 + ) # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False From 9ba88328958abf909cb9ab436dbc99ce403d566c Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 11:51:59 +0000 Subject: [PATCH 07/11] Validation run frequency expressed in epochs instead of steps to user --- recipes/lora_finetune_single_device.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 3bf1c86878..24f8198045 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -338,7 +338,9 @@ def setup(self, cfg: DictConfig) -> None: # Setup the validation dataset validation_config = cfg.get("validation") self.run_validation = validation_config is not None - self.run_val_every_n_steps = validation_config.get("run_every_n_steps") + self.run_val_every_n_steps = int( + self._steps_per_epoch * validation_config.get("run_every_n_epochs") + ) if self.run_validation: self._sampler_val, self._dataloader_val = self._setup_data( cfg_dataset=validation_config.dataset, From 1c35995a68e4475f96b23190db6e9b4eb8c4bc2b Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 12:40:51 +0000 Subject: [PATCH 08/11] Fixes wrong arguments behaviours --- recipes/lora_finetune_single_device.py | 28 ++++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 24f8198045..a784af4627 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -338,8 +338,13 @@ def setup(self, cfg: DictConfig) -> None: # Setup the validation dataset validation_config = cfg.get("validation") self.run_validation = validation_config is not None + actual_steps_per_epoch = ( + self._steps_per_epoch + if self.max_steps_per_epoch is None + else min(self._steps_per_epoch, self.max_steps_per_epoch) + ) self.run_val_every_n_steps = int( - self._steps_per_epoch * validation_config.get("run_every_n_epochs") + actual_steps_per_epoch * validation_config.get("run_every_n_epochs", 0) ) if self.run_validation: self._sampler_val, self._dataloader_val = self._setup_data( @@ -358,9 +363,7 @@ def setup(self, cfg: DictConfig) -> None: elif self.run_val_every_n_steps < 1: raise ValueError("run_val_every_n_steps must be greater than 0.") - self.max_validation_batches = validation_config.get( - "max_validation_batches", -1 - ) + self.max_validation_batches = validation_config.get("max_batches", -1) # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False @@ -680,7 +683,7 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: return loss def validate(self, curr_epoch) -> None: - pbar = tqdm(total=max(len(self._dataloader_val), self.max_validation_batches)) + pbar = tqdm(total=min(len(self._dataloader_val), self.max_validation_batches)) val_losses = [] for idx, batch in enumerate(self._dataloader_val): if ( @@ -710,7 +713,6 @@ def validate(self, curr_epoch) -> None: self._metric_logger.log_dict( { "avg_val_loss": sum(val_losses) / len(val_losses), - "epoch": curr_epoch + 1, }, step=self.global_step, ) @@ -737,7 +739,7 @@ def train(self) -> None: # in case shuffle is True self._sampler.set_epoch(curr_epoch) - pbar = tqdm(total=self._steps_per_epoch) + pbar = tqdm(total=self._steps_per_epoch, position=0) for idx, batch in enumerate(self._dataloader): if ( self.max_steps_per_epoch is not None @@ -816,6 +818,12 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + if ( + self.run_validation + and self.global_step % self.run_val_every_n_steps == 0 + ): + self.validate(curr_epoch=curr_epoch) + # Stop tracking CUDA memory now that active steps are complete if ( curr_epoch == 0 @@ -842,12 +850,6 @@ def train(self) -> None: ) ) - if ( - self.run_validation - and self.global_step % self.run_val_every_n_steps == 0 - ): - self.validate(curr_epoch=curr_epoch) - def cleanup(self) -> None: self._metric_logger.close() From a51c0ff3b35059e383558b874138ba863a430294 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 12:52:53 +0000 Subject: [PATCH 09/11] Fixes indexing issues and makes tqdm more readable --- recipes/lora_finetune_single_device.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index a784af4627..712209f310 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -683,7 +683,9 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: return loss def validate(self, curr_epoch) -> None: - pbar = tqdm(total=min(len(self._dataloader_val), self.max_validation_batches)) + pbar = tqdm( + total=min(len(self._dataloader_val), self.max_validation_batches - 1) + ) val_losses = [] for idx, batch in enumerate(self._dataloader_val): if ( @@ -698,7 +700,7 @@ def validate(self, curr_epoch) -> None: pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{idx}| Validation Loss: {current_loss.item()}" + f"{curr_epoch + 1}|{idx}|Validation Loss: {current_loss.item()}" ) # This bit allows to see the loss for each batch. Not sure about step indexing. @@ -788,7 +790,7 @@ def train(self) -> None: loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Training Loss: {loss_to_log}" ) # Log per-step metrics From 05efbcb1cd6dfc4d92d1461da5ce4b02256fc9ed Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 18:26:20 +0000 Subject: [PATCH 10/11] Adds validation set configuration to 1B_lora_single_device --- recipes/configs/llama3_2/1B_lora_single_device.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 72c03f55d9..fff39d461f 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -52,6 +52,13 @@ seed: null shuffle: True batch_size: 4 +validation: + dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False + run_every_n_epochs: .5 + max_batches: 5 + # Optimizer and Scheduler optimizer: _component_: torch.optim.AdamW From 073af720280e120c668b2c5abf651250a0054d61 Mon Sep 17 00:00:00 2001 From: Massimo Frasson Date: Sun, 26 Jan 2025 18:33:09 +0000 Subject: [PATCH 11/11] Removes logging val loss for all batches --- recipes/lora_finetune_single_device.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 712209f310..964f75fbd5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -703,15 +703,6 @@ def validate(self, curr_epoch) -> None: f"{curr_epoch + 1}|{idx}|Validation Loss: {current_loss.item()}" ) - # This bit allows to see the loss for each batch. Not sure about step indexing. - log_dict = { - "val_loss": current_loss.item(), - } - self._metric_logger.log_dict( - log_dict, - step=(curr_epoch + 1) * idx + idx, - ) - self._metric_logger.log_dict( { "avg_val_loss": sum(val_losses) / len(val_losses),