11# Copyright © 2021 Chris Hughes
2+ from datetime import datetime
23import inspect
34import logging
45import sys
56import time
67from abc import ABC
8+ from pathlib import Path
9+ from typing import Optional , Union
10+
711
812import numpy as np
913import torch
@@ -341,7 +345,7 @@ def __init__(
341345 reset_on_train : bool = True ,
342346 save_optimizer : bool = True ,
343347 save_scheduler : bool = True ,
344- load_saved_checkpoint : bool = True
348+ load_saved_checkpoint : bool = True ,
345349 ):
346350 """
347351
@@ -522,6 +526,7 @@ def on_training_run_start(self, trainer, **kwargs):
522526 def on_evaluation_run_start (self , trainer , ** kwargs ):
523527 self ._move_modules_to_device (trainer )
524528
529+
525530class LimitBatchesCallback (TrainerCallback ):
526531 """
527532 A callback that that limits the number of batches used during training and evaluation.
@@ -679,9 +684,9 @@ class LimitEvalStepsCallback(TrainerCallback):
679684 :param limit_intermediate_only: whether to limit the number of intermediate evaluations only
680685
681686 .. Note::
682- When used together this callback should be placed before :class:`StepBasedEvaluationCallback` and
687+ When used together this callback should be placed before :class:`StepBasedEvaluationCallback` and
683688 :class:`ProgressBarCallback` in the list of callbacks.
684-
689+
685690 """
686691
687692 def __init__ (self , num_eval_steps : int , limit_intermediate_only = True ):
@@ -705,3 +710,166 @@ def on_eval_epoch_start(self, trainer, is_intermediate=False, **kwargs):
705710 def on_eval_epoch_end (self , trainer , is_intermediate = False , ** kwargs ):
706711 if is_intermediate or not self .limit_intermediate_only :
707712 trainer ._eval_dataloader = self ._original_eval_dataloader
713+
714+
715+ class WSDCheckpointCallback (TrainerCallback ):
716+ """Manages checkpointing for WSD and WSD-S learning rate schedules.
717+
718+ This callback saves both pre-decay and post-decay checkpoints during training with WSD-style
719+ schedules and automatically syncs with :class:`~pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler` for checkpoint timing.
720+
721+ For single checkpoint configurations:
722+ - Pre-decay checkpoint is saved just before learning rate decay starts
723+ - Post-decay checkpoint is saved at the end of training
724+
725+ For multiple checkpoints:
726+ - Pre-decay checkpoint saved before each decay phase
727+ - Post-decay checkpoint saved after each decay phase
728+
729+ For WSD vs WSD-S:
730+ - WSD resumes from pre-decay checkpoints (discarding decay progress)
731+ - WSD-S resumes from post-decay checkpoints (preserving decay progress)
732+
733+ :param save_dir: Directory to save checkpoints
734+ :type save_dir: str
735+ :param save_optimizer: Whether to save optimizer state
736+ :type save_optimizer: bool
737+ :param save_scheduler: Whether to save scheduler state
738+ :type save_scheduler: bool
739+ :param initial_checkpoint: Path to checkpoint to load at start of training. For WSD-S,
740+ use post-decay checkpoint. For WSD, use pre-decay checkpoint.
741+ :type initial_checkpoint: Union[str, Path], optional
742+
743+ :raises ValueError: If trainer's scheduler doesn't implement get_checkpoint_steps()
744+
745+ Example:
746+ WSD-S usage:
747+ >>> callback = WSDCheckpointCallback(
748+ ... save_dir="checkpoints",
749+ ... initial_checkpoint="checkpoint_50000_post_decay.pt"
750+ ... )
751+
752+ WSD usage:
753+ >>> callback = WSDCheckpointCallback(
754+ ... save_dir="checkpoints",
755+ ... initial_checkpoint="checkpoint_45000_pre_decay.pt"
756+ ... )
757+ """
758+
759+ def __init__ (
760+ self ,
761+ save_dir : str = "checkpoints" ,
762+ save_optimizer : bool = True ,
763+ save_scheduler : bool = True ,
764+ initial_checkpoint : Optional [Union [str , Path ]] = None ,
765+ ):
766+ """
767+ :param save_dir: Directory to save checkpoints
768+ :param save_optimizer: Whether to save optimizer state
769+ :param save_scheduler: Whether to save scheduler state
770+ :param initial_checkpoint: Optional path to checkpoint to load at start of training
771+ """
772+ self .save_dir = Path (save_dir )
773+ self .save_optimizer = save_optimizer
774+ self .save_scheduler = save_scheduler
775+ self .initial_checkpoint = (
776+ Path (initial_checkpoint ) if initial_checkpoint else None
777+ )
778+
779+ # Tracking state
780+ self .last_checkpoint_step = None
781+ self .checkpoint_steps = None
782+ self .decay_fraction = None
783+ self .decay_info = None
784+
785+ # Create save directory if it doesn't exist
786+ self .save_dir .mkdir (parents = True , exist_ok = True )
787+
788+ def _get_checkpoint_path (self , step : int , checkpoint_type : str ) -> Path :
789+ return self .save_dir / f"checkpoint_{ step } _{ checkpoint_type } .pt"
790+
791+ def _save_checkpoint (self , trainer , step : int , checkpoint_type : str ):
792+ checkpoint_path = self ._get_checkpoint_path (step , checkpoint_type )
793+ trainer .save_checkpoint (
794+ checkpoint_path ,
795+ save_optimizer = self .save_optimizer ,
796+ save_scheduler = self .save_scheduler ,
797+ checkpoint_kwargs = {
798+ "step" : step ,
799+ "checkpoint_type" : checkpoint_type ,
800+ "total_steps" : trainer .run_config .max_num_train_steps ,
801+ "decay_fraction" : self .decay_fraction ,
802+ "timestamp" : datetime .now ().isoformat (),
803+ },
804+ )
805+ trainer .print (f"\n Saved { checkpoint_type } checkpoint at step { step } " )
806+
807+ def on_training_run_start (self , trainer , ** kwargs ):
808+ """Initialize checkpoint tracking state and load initial checkpoint if specified."""
809+ if not hasattr (trainer .scheduler , "get_checkpoint_steps" ):
810+ raise ValueError (
811+ "Scheduler must implement get_checkpoint_steps(). "
812+ "Are you using WSDLrScheduler?"
813+ )
814+
815+ # Get checkpoint steps and decay info from scheduler
816+ self .checkpoint_steps = set (trainer .scheduler .get_checkpoint_steps ())
817+ self .decay_fraction = trainer .scheduler .decay_phase_ratio
818+ self .decay_info = trainer .scheduler .get_decay_info ()
819+
820+ # Load initial checkpoint if specified
821+ if self .initial_checkpoint and self .initial_checkpoint .exists ():
822+ trainer .print (f"\n Loading checkpoint from { self .initial_checkpoint } " )
823+ checkpoint = trainer .load_checkpoint (self .initial_checkpoint )
824+ self .last_checkpoint_step = checkpoint .get ("step" )
825+ trainer .print (
826+ f"Loaded { checkpoint ['checkpoint_type' ]} checkpoint from step { self .last_checkpoint_step } "
827+ )
828+
829+ def on_train_step_end (self , trainer , step : int , ** kwargs ):
830+ """Handle checkpoint saving and progress logging"""
831+
832+ # Calculate global step accounting for distributed training and gradient accumulation
833+ total_steps = (
834+ (trainer .run_history .current_epoch - 1 )
835+ * trainer .run_config .num_update_steps_per_epoch
836+ + step // trainer .run_config .gradient_accumulation_steps
837+ )
838+
839+ # Skip if we've already saved at this step
840+ if total_steps == self .last_checkpoint_step :
841+ return
842+
843+ # Get current phase info from scheduler
844+ phase_info = trainer .scheduler .get_phase_info (total_steps )
845+ pre_decay_step = phase_info ["pre_decay_step" ]
846+ period_end = phase_info ["period_end" ]
847+
848+ # Save pre-decay checkpoint when entering decay phase
849+ if total_steps == pre_decay_step :
850+ trainer .print (
851+ f"\n WSD Lr Scheduler entering decay phase at step { total_steps } "
852+ )
853+ self ._save_checkpoint (trainer , total_steps , "wsd_pre_decay" )
854+ self .last_checkpoint_step = total_steps
855+
856+ # If we've completed the decay phase
857+ elif total_steps == period_end :
858+ self ._save_checkpoint (trainer , total_steps , "wsd_post_decay" )
859+ self .last_checkpoint_step = total_steps
860+
861+ def on_training_run_end (self , trainer , ** kwargs ):
862+ """Save final checkpoint if we haven't already"""
863+ # Get the final step number
864+ total_steps = trainer .run_config .max_num_train_steps
865+
866+ # If we haven't saved the final checkpoint yet
867+ if self .last_checkpoint_step != total_steps :
868+ # Get final phase info
869+ phase_info = trainer .scheduler .get_phase_info (total_steps )
870+ period_end = phase_info ["period_end" ]
871+
872+ # Verify this is actually the end of a period
873+ if total_steps == period_end :
874+ self ._save_checkpoint (trainer , total_steps , "wsd_post_decay" )
875+ self .last_checkpoint_step = total_steps
0 commit comments