Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds validation loss to LoRA fine tune single device #2238

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ seed: null
shuffle: True
batch_size: 4

validation:
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False
run_every_n_epochs: .5
max_batches: 5

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
Expand Down
68 changes: 66 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,36 @@ def setup(self, cfg: DictConfig) -> None:
last_epoch=self.global_step - 1,
)

# Setup the validation dataset
validation_config = cfg.get("validation")
self.run_validation = validation_config is not None
actual_steps_per_epoch = (
self._steps_per_epoch
if self.max_steps_per_epoch is None
else min(self._steps_per_epoch, self.max_steps_per_epoch)
)
self.run_val_every_n_steps = int(
actual_steps_per_epoch * validation_config.get("run_every_n_epochs", 0)
)
if self.run_validation:
self._sampler_val, self._dataloader_val = self._setup_data(
cfg_dataset=validation_config.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
collate_fn=collate_name,
)

if self.run_val_every_n_steps is None:
log.warning(
f"""Validation is enabled but run_val_every_n_steps is not set.
Will be set to steps_per_epoch. ({self._steps_per_epoch})"""
)
self.run_val_every_n_steps = self._steps_per_epoch
elif self.run_val_every_n_steps < 1:
raise ValueError("run_val_every_n_steps must be greater than 0.")

self.max_validation_batches = validation_config.get("max_batches", -1)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand Down Expand Up @@ -652,6 +682,34 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:

return loss

def validate(self, curr_epoch) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Do we have model.eval() somewhere?

usually we want to set the model to .eval mode, because some layers have different behavior, like dropout.

By doing that, we then require less memory, because we only need the forward pass, which allows us to have a higher batch_size --> faster validation step.

I am not sure about the implications it may have to compile/FSDP. For example, compile will have to create a new graph that doesnt require grad, so compile time will have to increase. If the number of graph breaks increase, we may have to manually change the threshold of maximum number of graph breaks. (there is an example of that in one of our RL recipes)

  1. not all recipes have self._loss_step. We would have to standardize and make sure that they all do, but this requires a different PR,.

IMO, if you have access to >1 GPU, I would encourage you to implement it in lora_distributed with QLoRA config, add .eval(), run it:

  • with eval + compile + opt_in_bwd + activation ckpt + activation offloading
  • without eval + compile + opt_in_bwd + activation ckpt + activation offloading

If nothing breaks, I would feel more confident in approving it

Ps: we would also have to add mode.train() in the training loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipemello1 Thanks for the detailed breakdown and suggestions. Should we also unload the model being trained before loading the eval one? Having just one in memory would allow for bigger batch sizes.

That said, I’m currently constrained on time and not very familiar with the implementation details for this. If I were to take this on, it would likely take me a significant amount of time to get it done properly.

Would you be able to take the lead on this?

Copy link
Contributor

@felipemello1 felipemello1 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @MaxFrax , completely understandable. Thanks for sharing it.

I dont think that I will have bandwidth soon, but if i do, this PR is a good start.

@Ankur-singh , cc'ing you in case you are looking for more issues to contribute to! :D

Thank you guys!

pbar = tqdm(
total=min(len(self._dataloader_val), self.max_validation_batches - 1)
)
val_losses = []
for idx, batch in enumerate(self._dataloader_val):
if (
self.max_validation_batches > 0
and idx == self.max_validation_batches - 1
):
break
utils.batch_to_device(batch, self._device)

current_loss = self._loss_step(batch)
val_losses.append(current_loss.item())

pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{idx}|Validation Loss: {current_loss.item()}"
)

self._metric_logger.log_dict(
{
"avg_val_loss": sum(val_losses) / len(val_losses),
},
step=self.global_step,
)

def train(self) -> None:
"""
The core training loop.
Expand All @@ -674,7 +732,7 @@ def train(self) -> None:
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch)
pbar = tqdm(total=self._steps_per_epoch, position=0)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
Expand Down Expand Up @@ -723,7 +781,7 @@ def train(self) -> None:
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
f"{curr_epoch + 1}|{self.global_step}|Training Loss: {loss_to_log}"
)

# Log per-step metrics
Expand Down Expand Up @@ -753,6 +811,12 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

if (
self.run_validation
and self.global_step % self.run_val_every_n_steps == 0
):
self.validate(curr_epoch=curr_epoch)

# Stop tracking CUDA memory now that active steps are complete
if (
curr_epoch == 0
Expand Down