Skip to content

Commit c9b8da4

Browse files
update docs
1 parent ca5b792 commit c9b8da4

File tree

3 files changed

+39
-31
lines changed

3 files changed

+39
-31
lines changed

docs/callbacks.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ Implemented Callbacks
6464
.. autoclass:: LimitEvalStepsCallback
6565
:show-inheritance:
6666

67+
.. autoclass:: WSDCheckpointCallback
68+
:show-inheritance:
69+
6770

6871
Creating New Callbacks
6972
========================

pytorch_accelerated/callbacks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,12 +719,12 @@ class WSDCheckpointCallback(TrainerCallback):
719719
schedules and automatically syncs with :class:`~pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler` for checkpoint timing.
720720
721721
For single checkpoint configurations:
722-
- Pre-decay checkpoint is saved just before learning rate decay starts
723-
- Post-decay checkpoint is saved at the end of training
722+
- Pre-decay checkpoint is saved just before learning rate decay starts
723+
- Post-decay checkpoint is saved at the end of training
724724
725725
For multiple checkpoints:
726-
- Pre-decay checkpoint saved before each decay phase
727-
- Post-decay checkpoint saved after each decay phase
726+
- Pre-decay checkpoint saved before each decay phase
727+
- Post-decay checkpoint saved after each decay phase
728728
729729
For WSD vs WSD-S:
730730
- WSD resumes from pre-decay checkpoints (discarding decay progress)
@@ -743,6 +743,11 @@ class WSDCheckpointCallback(TrainerCallback):
743743
:raises ValueError: If trainer's scheduler doesn't implement get_checkpoint_steps()
744744
745745
Example:
746+
No Checkpoint:
747+
>>> callback = WSDCheckpointCallback(
748+
... save_dir="checkpoints",
749+
... )
750+
746751
WSD-S usage:
747752
>>> callback = WSDCheckpointCallback(
748753
... save_dir="checkpoints",

pytorch_accelerated/schedulers/wsd_scheduler.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,64 @@
1010
class WSDLrScheduler(StatefulSchedulerBase):
1111
"""
1212
Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in
13-
'Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective'.
13+
`Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective <https://arxiv.org/abs/2410.05192>`_.
1414
1515
The schedule has three phases:
16-
1. Warmup: Linear warmup from warmup_starting_lr to base learning rate
17-
2. Stable: Maintains constant high learning rate
18-
3. Decay: Rapidly decays learning rate before each checkpoint
16+
1. Warmup: Linear warmup from warmup_starting_lr to base learning rate
17+
2. Stable: Maintains constant high learning rate
18+
3. Decay: Rapidly decays learning rate before each checkpoint
1919
2020
This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint
2121
involves decaying the learning rate to get better model performance.
2222
2323
Use multiple checkpoints (typically 2-3) if:
24-
- Training on large datasets (>100B tokens) where intermediate models are useful for development/testing
25-
- You want to evaluate model performance vs training data size (e.g., does your model need full training?)
26-
- You might need to continue training later but want flexibility about when to stop training
24+
- Training on large datasets (>100B tokens) where intermediate models are useful for development/testing
25+
- You want to evaluate model performance vs training data size (e.g., does your model need full training?)
26+
- You might need to continue training later but want flexibility about when to stop training
2727
2828
The scheduler uses geometric progression to space checkpoints evenly on a log scale:
29-
- First checkpoint is placed at 25% of total steps
30-
- Each subsequent checkpoint is ~2x steps from previous checkpoint
29+
- First checkpoint is placed at 25% of total steps
30+
- Each subsequent checkpoint is ~2x steps from previous checkpoint
3131
3232
Examples:
3333
- 2 checkpoints for 100K steps: [50K, 100K]
3434
- 3 checkpoints for 200K steps: [50K, 100K, 200K]
3535
- 4 checkpoints for 200K steps: [25K, 50K, 100K, 200K]
3636
3737
For each checkpoint:
38-
- The stable phase continues until decay_phase_ratio portion of steps remain
39-
- Then learning rate decays to lr_min * base_lr using selected decay formula
38+
- The stable phase continues until decay_phase_ratio portion of steps remain
39+
- Then learning rate decays to lr_min * base_lr using selected decay formula
4040
4141
Two decay formulas are provided:
4242
4343
1. Inverse Proportional Decay (paper's formula):
4444
lr = 1 / (t * (1/lr_min - 1) + 1)
45-
- Derived from theoretical analysis on quadratic functions
46-
- Steeper initial decay, more gradual approach to lr_min
47-
- Optimal for quadratic loss landscapes
45+
- Derived from theoretical analysis on quadratic functions
46+
- Steeper initial decay, more gradual approach to lr_min
47+
- Optimal for quadratic loss landscapes
4848
4949
2. Sqrt Decay:
5050
lr = lr_min + (1 - lr_min) * (1 - sqrt(t))
51-
- Similar to traditional cosine decay patterns
52-
- More gradual initial decay, consistent decay rate
53-
- May be more robust across different architectures
51+
- Similar to traditional cosine decay patterns
52+
- More gradual initial decay, consistent decay rate
53+
- May be more robust across different architectures
5454
5555
Continuation Behavior:
56-
- Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint
57-
- When continuing, scheduler starts a fresh stable phase with new total_steps
58-
- Decay phase ratio applies to new training length
59-
- No warmup is applied during continuation
60-
- State must be loaded via load_state_dict for continuation to work
56+
- Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint
57+
- When continuing, scheduler starts a fresh stable phase with new total_steps
58+
- Decay phase ratio applies to new training length
59+
- No warmup is applied during continuation
60+
- State must be loaded via load_state_dict for continuation to work
6161
6262
Example:
6363
Initial run (1000 steps, 0.1 decay ratio):
64-
- Steps 0-50: Optional warmup
65-
- Steps 50-900: Stable high learning rate
66-
- Steps 900-1000: Decay to lr_min
64+
- Steps 0-50: Optional warmup
65+
- Steps 50-900: Stable high learning rate
66+
- Steps 900-1000: Decay to lr_min
6767
6868
Continuation (500 new steps, 0.1 decay ratio):
69-
- Steps 0-450: Stable high learning rate
70-
- Steps 450-500: Decay to lr_min
69+
- Steps 0-450: Stable high learning rate
70+
- Steps 450-500: Decay to lr_min
7171
7272
.. Note:: This scheduler is designed to be used with the :class:`~pytorch_accelerated.callbacks.WSDCheckpointCallback` class,
7373
which handles saving and loading checkpoints.

0 commit comments

Comments
 (0)