|
10 | 10 | class WSDLrScheduler(StatefulSchedulerBase): |
11 | 11 | """ |
12 | 12 | Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in |
13 | | - 'Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective'. |
| 13 | + `Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective <https://arxiv.org/abs/2410.05192>`_. |
14 | 14 |
|
15 | 15 | The schedule has three phases: |
16 | | - 1. Warmup: Linear warmup from warmup_starting_lr to base learning rate |
17 | | - 2. Stable: Maintains constant high learning rate |
18 | | - 3. Decay: Rapidly decays learning rate before each checkpoint |
| 16 | + 1. Warmup: Linear warmup from warmup_starting_lr to base learning rate |
| 17 | + 2. Stable: Maintains constant high learning rate |
| 18 | + 3. Decay: Rapidly decays learning rate before each checkpoint |
19 | 19 |
|
20 | 20 | This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint |
21 | 21 | involves decaying the learning rate to get better model performance. |
22 | 22 |
|
23 | 23 | Use multiple checkpoints (typically 2-3) if: |
24 | | - - Training on large datasets (>100B tokens) where intermediate models are useful for development/testing |
25 | | - - You want to evaluate model performance vs training data size (e.g., does your model need full training?) |
26 | | - - You might need to continue training later but want flexibility about when to stop training |
| 24 | + - Training on large datasets (>100B tokens) where intermediate models are useful for development/testing |
| 25 | + - You want to evaluate model performance vs training data size (e.g., does your model need full training?) |
| 26 | + - You might need to continue training later but want flexibility about when to stop training |
27 | 27 | |
28 | 28 | The scheduler uses geometric progression to space checkpoints evenly on a log scale: |
29 | | - - First checkpoint is placed at 25% of total steps |
30 | | - - Each subsequent checkpoint is ~2x steps from previous checkpoint |
| 29 | + - First checkpoint is placed at 25% of total steps |
| 30 | + - Each subsequent checkpoint is ~2x steps from previous checkpoint |
31 | 31 |
|
32 | 32 | Examples: |
33 | 33 | - 2 checkpoints for 100K steps: [50K, 100K] |
34 | 34 | - 3 checkpoints for 200K steps: [50K, 100K, 200K] |
35 | 35 | - 4 checkpoints for 200K steps: [25K, 50K, 100K, 200K] |
36 | 36 |
|
37 | 37 | For each checkpoint: |
38 | | - - The stable phase continues until decay_phase_ratio portion of steps remain |
39 | | - - Then learning rate decays to lr_min * base_lr using selected decay formula |
| 38 | + - The stable phase continues until decay_phase_ratio portion of steps remain |
| 39 | + - Then learning rate decays to lr_min * base_lr using selected decay formula |
40 | 40 |
|
41 | 41 | Two decay formulas are provided: |
42 | 42 |
|
43 | 43 | 1. Inverse Proportional Decay (paper's formula): |
44 | 44 | lr = 1 / (t * (1/lr_min - 1) + 1) |
45 | | - - Derived from theoretical analysis on quadratic functions |
46 | | - - Steeper initial decay, more gradual approach to lr_min |
47 | | - - Optimal for quadratic loss landscapes |
| 45 | + - Derived from theoretical analysis on quadratic functions |
| 46 | + - Steeper initial decay, more gradual approach to lr_min |
| 47 | + - Optimal for quadratic loss landscapes |
48 | 48 | |
49 | 49 | 2. Sqrt Decay: |
50 | 50 | lr = lr_min + (1 - lr_min) * (1 - sqrt(t)) |
51 | | - - Similar to traditional cosine decay patterns |
52 | | - - More gradual initial decay, consistent decay rate |
53 | | - - May be more robust across different architectures |
| 51 | + - Similar to traditional cosine decay patterns |
| 52 | + - More gradual initial decay, consistent decay rate |
| 53 | + - May be more robust across different architectures |
54 | 54 |
|
55 | 55 | Continuation Behavior: |
56 | | - - Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint |
57 | | - - When continuing, scheduler starts a fresh stable phase with new total_steps |
58 | | - - Decay phase ratio applies to new training length |
59 | | - - No warmup is applied during continuation |
60 | | - - State must be loaded via load_state_dict for continuation to work |
| 56 | + - Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint |
| 57 | + - When continuing, scheduler starts a fresh stable phase with new total_steps |
| 58 | + - Decay phase ratio applies to new training length |
| 59 | + - No warmup is applied during continuation |
| 60 | + - State must be loaded via load_state_dict for continuation to work |
61 | 61 |
|
62 | 62 | Example: |
63 | 63 | Initial run (1000 steps, 0.1 decay ratio): |
64 | | - - Steps 0-50: Optional warmup |
65 | | - - Steps 50-900: Stable high learning rate |
66 | | - - Steps 900-1000: Decay to lr_min |
| 64 | + - Steps 0-50: Optional warmup |
| 65 | + - Steps 50-900: Stable high learning rate |
| 66 | + - Steps 900-1000: Decay to lr_min |
67 | 67 |
|
68 | 68 | Continuation (500 new steps, 0.1 decay ratio): |
69 | | - - Steps 0-450: Stable high learning rate |
70 | | - - Steps 450-500: Decay to lr_min |
| 69 | + - Steps 0-450: Stable high learning rate |
| 70 | + - Steps 450-500: Decay to lr_min |
71 | 71 |
|
72 | 72 | .. Note:: This scheduler is designed to be used with the :class:`~pytorch_accelerated.callbacks.WSDCheckpointCallback` class, |
73 | 73 | which handles saving and loading checkpoints. |
|
0 commit comments