diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index a6699e5..bc96950 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -200,15 +200,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_cycles: float, min_lr_rate: float = 0.0, ): - if ( - warmup_outerstep is not None - and current_step > num_warmup_steps - and current_step % num_inner_steps < warmup_outerstep - ): - return 0 - if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) + + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate