Skip to content

Commit 5e0ce19

Browse files
Wsd scheduler (#67)
* add scheduler and wip tests * wip changes * fix tests and format * wip callback changes * fixes and updates * format * wip changes * fix failing tests * refactoring * updates * updates * final modifications
1 parent 6630924 commit 5e0ce19

File tree

7 files changed

+1510
-6
lines changed

7 files changed

+1510
-6
lines changed

docs/schedulers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ Implemented Schedulers
1616

1717
.. automethod:: __init__
1818

19+
.. autoclass:: pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler
20+
:show-inheritance:
21+
:members:
22+
23+
.. automethod:: __init__
24+
1925

2026
Base Schedulers
2127
=======================

pytorch_accelerated/callbacks.py

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Copyright © 2021 Chris Hughes
2+
from datetime import datetime
23
import inspect
34
import logging
45
import sys
56
import time
67
from abc import ABC
8+
from pathlib import Path
9+
from typing import Optional, Union
10+
711

812
import numpy as np
913
import torch
@@ -341,7 +345,7 @@ def __init__(
341345
reset_on_train: bool = True,
342346
save_optimizer: bool = True,
343347
save_scheduler: bool = True,
344-
load_saved_checkpoint: bool = True
348+
load_saved_checkpoint: bool = True,
345349
):
346350
"""
347351
@@ -522,6 +526,7 @@ def on_training_run_start(self, trainer, **kwargs):
522526
def on_evaluation_run_start(self, trainer, **kwargs):
523527
self._move_modules_to_device(trainer)
524528

529+
525530
class LimitBatchesCallback(TrainerCallback):
526531
"""
527532
A callback that that limits the number of batches used during training and evaluation.
@@ -679,9 +684,9 @@ class LimitEvalStepsCallback(TrainerCallback):
679684
:param limit_intermediate_only: whether to limit the number of intermediate evaluations only
680685
681686
.. Note::
682-
When used together this callback should be placed before :class:`StepBasedEvaluationCallback` and
687+
When used together this callback should be placed before :class:`StepBasedEvaluationCallback` and
683688
:class:`ProgressBarCallback` in the list of callbacks.
684-
689+
685690
"""
686691

687692
def __init__(self, num_eval_steps: int, limit_intermediate_only=True):
@@ -705,3 +710,166 @@ def on_eval_epoch_start(self, trainer, is_intermediate=False, **kwargs):
705710
def on_eval_epoch_end(self, trainer, is_intermediate=False, **kwargs):
706711
if is_intermediate or not self.limit_intermediate_only:
707712
trainer._eval_dataloader = self._original_eval_dataloader
713+
714+
715+
class WSDCheckpointCallback(TrainerCallback):
716+
"""Manages checkpointing for WSD and WSD-S learning rate schedules.
717+
718+
This callback saves both pre-decay and post-decay checkpoints during training with WSD-style
719+
schedules and automatically syncs with :class:`~pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler` for checkpoint timing.
720+
721+
For single checkpoint configurations:
722+
- Pre-decay checkpoint is saved just before learning rate decay starts
723+
- Post-decay checkpoint is saved at the end of training
724+
725+
For multiple checkpoints:
726+
- Pre-decay checkpoint saved before each decay phase
727+
- Post-decay checkpoint saved after each decay phase
728+
729+
For WSD vs WSD-S:
730+
- WSD resumes from pre-decay checkpoints (discarding decay progress)
731+
- WSD-S resumes from post-decay checkpoints (preserving decay progress)
732+
733+
:param save_dir: Directory to save checkpoints
734+
:type save_dir: str
735+
:param save_optimizer: Whether to save optimizer state
736+
:type save_optimizer: bool
737+
:param save_scheduler: Whether to save scheduler state
738+
:type save_scheduler: bool
739+
:param initial_checkpoint: Path to checkpoint to load at start of training. For WSD-S,
740+
use post-decay checkpoint. For WSD, use pre-decay checkpoint.
741+
:type initial_checkpoint: Union[str, Path], optional
742+
743+
:raises ValueError: If trainer's scheduler doesn't implement get_checkpoint_steps()
744+
745+
Example:
746+
WSD-S usage:
747+
>>> callback = WSDCheckpointCallback(
748+
... save_dir="checkpoints",
749+
... initial_checkpoint="checkpoint_50000_post_decay.pt"
750+
... )
751+
752+
WSD usage:
753+
>>> callback = WSDCheckpointCallback(
754+
... save_dir="checkpoints",
755+
... initial_checkpoint="checkpoint_45000_pre_decay.pt"
756+
... )
757+
"""
758+
759+
def __init__(
760+
self,
761+
save_dir: str = "checkpoints",
762+
save_optimizer: bool = True,
763+
save_scheduler: bool = True,
764+
initial_checkpoint: Optional[Union[str, Path]] = None,
765+
):
766+
"""
767+
:param save_dir: Directory to save checkpoints
768+
:param save_optimizer: Whether to save optimizer state
769+
:param save_scheduler: Whether to save scheduler state
770+
:param initial_checkpoint: Optional path to checkpoint to load at start of training
771+
"""
772+
self.save_dir = Path(save_dir)
773+
self.save_optimizer = save_optimizer
774+
self.save_scheduler = save_scheduler
775+
self.initial_checkpoint = (
776+
Path(initial_checkpoint) if initial_checkpoint else None
777+
)
778+
779+
# Tracking state
780+
self.last_checkpoint_step = None
781+
self.checkpoint_steps = None
782+
self.decay_fraction = None
783+
self.decay_info = None
784+
785+
# Create save directory if it doesn't exist
786+
self.save_dir.mkdir(parents=True, exist_ok=True)
787+
788+
def _get_checkpoint_path(self, step: int, checkpoint_type: str) -> Path:
789+
return self.save_dir / f"checkpoint_{step}_{checkpoint_type}.pt"
790+
791+
def _save_checkpoint(self, trainer, step: int, checkpoint_type: str):
792+
checkpoint_path = self._get_checkpoint_path(step, checkpoint_type)
793+
trainer.save_checkpoint(
794+
checkpoint_path,
795+
save_optimizer=self.save_optimizer,
796+
save_scheduler=self.save_scheduler,
797+
checkpoint_kwargs={
798+
"step": step,
799+
"checkpoint_type": checkpoint_type,
800+
"total_steps": trainer.run_config.max_num_train_steps,
801+
"decay_fraction": self.decay_fraction,
802+
"timestamp": datetime.now().isoformat(),
803+
},
804+
)
805+
trainer.print(f"\nSaved {checkpoint_type} checkpoint at step {step}")
806+
807+
def on_training_run_start(self, trainer, **kwargs):
808+
"""Initialize checkpoint tracking state and load initial checkpoint if specified."""
809+
if not hasattr(trainer.scheduler, "get_checkpoint_steps"):
810+
raise ValueError(
811+
"Scheduler must implement get_checkpoint_steps(). "
812+
"Are you using WSDLrScheduler?"
813+
)
814+
815+
# Get checkpoint steps and decay info from scheduler
816+
self.checkpoint_steps = set(trainer.scheduler.get_checkpoint_steps())
817+
self.decay_fraction = trainer.scheduler.decay_phase_ratio
818+
self.decay_info = trainer.scheduler.get_decay_info()
819+
820+
# Load initial checkpoint if specified
821+
if self.initial_checkpoint and self.initial_checkpoint.exists():
822+
trainer.print(f"\nLoading checkpoint from {self.initial_checkpoint}")
823+
checkpoint = trainer.load_checkpoint(self.initial_checkpoint)
824+
self.last_checkpoint_step = checkpoint.get("step")
825+
trainer.print(
826+
f"Loaded {checkpoint['checkpoint_type']} checkpoint from step {self.last_checkpoint_step}"
827+
)
828+
829+
def on_train_step_end(self, trainer, step: int, **kwargs):
830+
"""Handle checkpoint saving and progress logging"""
831+
832+
# Calculate global step accounting for distributed training and gradient accumulation
833+
total_steps = (
834+
(trainer.run_history.current_epoch - 1)
835+
* trainer.run_config.num_update_steps_per_epoch
836+
+ step // trainer.run_config.gradient_accumulation_steps
837+
)
838+
839+
# Skip if we've already saved at this step
840+
if total_steps == self.last_checkpoint_step:
841+
return
842+
843+
# Get current phase info from scheduler
844+
phase_info = trainer.scheduler.get_phase_info(total_steps)
845+
pre_decay_step = phase_info["pre_decay_step"]
846+
period_end = phase_info["period_end"]
847+
848+
# Save pre-decay checkpoint when entering decay phase
849+
if total_steps == pre_decay_step:
850+
trainer.print(
851+
f"\nWSD Lr Scheduler entering decay phase at step {total_steps}"
852+
)
853+
self._save_checkpoint(trainer, total_steps, "wsd_pre_decay")
854+
self.last_checkpoint_step = total_steps
855+
856+
# If we've completed the decay phase
857+
elif total_steps == period_end:
858+
self._save_checkpoint(trainer, total_steps, "wsd_post_decay")
859+
self.last_checkpoint_step = total_steps
860+
861+
def on_training_run_end(self, trainer, **kwargs):
862+
"""Save final checkpoint if we haven't already"""
863+
# Get the final step number
864+
total_steps = trainer.run_config.max_num_train_steps
865+
866+
# If we haven't saved the final checkpoint yet
867+
if self.last_checkpoint_step != total_steps:
868+
# Get final phase info
869+
phase_info = trainer.scheduler.get_phase_info(total_steps)
870+
period_end = phase_info["period_end"]
871+
872+
# Verify this is actually the end of a period
873+
if total_steps == period_end:
874+
self._save_checkpoint(trainer, total_steps, "wsd_post_decay")
875+
self.last_checkpoint_step = total_steps

0 commit comments

Comments
 (0)