99
1010class WSDLrScheduler (StatefulSchedulerBase ):
1111 """
12- Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in
12+ Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in
1313 'Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective'.
1414
1515 The schedule has three phases:
@@ -20,22 +20,16 @@ class WSDLrScheduler(StatefulSchedulerBase):
2020 This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint
2121 involves decaying the learning rate to get better model performance.
2222
23- Selecting num_checkpoints:
24- - Use 1 checkpoint if you only want the final trained model
25- - Use multiple checkpoints (typically 2-3) if:
26- * Training on large datasets (>100B tokens) where intermediate models
27- are useful for development/testing
28- * You want to evaluate model performance vs training data size
29- (e.g., does your model need full training?)
30- * You might need to continue training later but want flexibility
31- about when to stop
32-
33-
23+ Use multiple checkpoints (typically 2-3) if:
24+ - Training on large datasets (>100B tokens) where intermediate models are useful for development/testing
25+ - You want to evaluate model performance vs training data size (e.g., does your model need full training?)
26+ - You might need to continue training later but want flexibility about when to stop training
27+
3428 The scheduler uses geometric progression to space checkpoints evenly on a log scale:
3529 - First checkpoint is placed at 25% of total steps
36- - Each subsequent checkpoint is ~2x steps from previous
30+ - Each subsequent checkpoint is ~2x steps from previous checkpoint
3731
38- Examples:
32+ Examples:
3933 - 2 checkpoints for 100K steps: [50K, 100K]
4034 - 3 checkpoints for 200K steps: [50K, 100K, 200K]
4135 - 4 checkpoints for 200K steps: [25K, 50K, 100K, 200K]
@@ -51,15 +45,15 @@ class WSDLrScheduler(StatefulSchedulerBase):
5145 - Derived from theoretical analysis on quadratic functions
5246 - Steeper initial decay, more gradual approach to lr_min
5347 - Optimal for quadratic loss landscapes
54-
48+
5549 2. Sqrt Decay:
5650 lr = lr_min + (1 - lr_min) * (1 - sqrt(t))
5751 - Similar to traditional cosine decay patterns
5852 - More gradual initial decay, consistent decay rate
5953 - May be more robust across different architectures
6054
6155 Continuation Behavior:
62- - Training can be continued from a pre-decay checkpoint
56+ - Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint
6357 - When continuing, scheduler starts a fresh stable phase with new total_steps
6458 - Decay phase ratio applies to new training length
6559 - No warmup is applied during continuation
@@ -69,14 +63,15 @@ class WSDLrScheduler(StatefulSchedulerBase):
6963 Initial run (1000 steps, 0.1 decay ratio):
7064 - Steps 0-50: Optional warmup
7165 - Steps 50-900: Stable high learning rate
72- - Steps 900-1000: Decay
66+ - Steps 900-1000: Decay to lr_min
7367
7468 Continuation (500 new steps, 0.1 decay ratio):
7569 - Steps 0-450: Stable high learning rate
76- - Steps 450-500: Decay
70+ - Steps 450-500: Decay to lr_min
7771
7872 .. Note:: This scheduler is designed to be used with the :class:`~pytorch_accelerated.callbacks.WSDCheckpointCallback` class,
79- which handles saving and loading checkpoints.
73+ which handles saving and loading checkpoints.
74+
8075 """
8176
8277 def __init__ (
@@ -371,10 +366,30 @@ def create_scheduler_fn(
371366 num_checkpoints : int = 1 ,
372367 is_continuation_from_checkpoint : bool = False ,
373368 ) -> Callable :
374- """Creates a scheduler function that the trainer can use.
369+ """
370+ An alternative constructor which returns a function that accepts an optimizer and creates an instance of
371+ ``WSDLrScheduler``. This is primarily intended to be used with the :class:`~pytorch_accelerated.trainer.Trainer`
372+ as illustrated below::
373+
374+
375+ trainer = Trainer(
376+ ...,
377+ callbacks=[
378+ WSDCheckpointCallback(
379+ save_dir="checkpoints",
380+ initial_checkpoint="checkpoint_45000_pre_decay.pt",
381+ )
382+ ],)
383+
384+ trainer.train(
385+ train_dataset=train_dataset,
386+ num_epochs=num_epochs,
387+ per_device_batch_size=batch_size,
388+ create_scheduler_fn=CosineLrScheduler.WSDLrScheduler(is_continuation_from_checkpoint=True),
389+ )
375390
376- Returns a function that accepts an optimizer and creates an instance of WSDScheduler.
377- The trainer will replace TrainerPlaceholderValues with actual values at runtime.
391+ By default, the ``total_num_epochs`` and ``num_iterations_per_epoch`` arguments will be set by the
392+ :class:`~pytorch_accelerated. trainer.Trainer` with the correct values at runtime.
378393 """
379394
380395 return partial (
0 commit comments