Skip to content

Commit ca5b792

Browse files
fix docs errors
1 parent 5e0ce19 commit ca5b792

File tree

1 file changed

+37
-22
lines changed

1 file changed

+37
-22
lines changed

pytorch_accelerated/schedulers/wsd_scheduler.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class WSDLrScheduler(StatefulSchedulerBase):
1111
"""
12-
Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in
12+
Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in
1313
'Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective'.
1414
1515
The schedule has three phases:
@@ -20,22 +20,16 @@ class WSDLrScheduler(StatefulSchedulerBase):
2020
This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint
2121
involves decaying the learning rate to get better model performance.
2222
23-
Selecting num_checkpoints:
24-
- Use 1 checkpoint if you only want the final trained model
25-
- Use multiple checkpoints (typically 2-3) if:
26-
* Training on large datasets (>100B tokens) where intermediate models
27-
are useful for development/testing
28-
* You want to evaluate model performance vs training data size
29-
(e.g., does your model need full training?)
30-
* You might need to continue training later but want flexibility
31-
about when to stop
32-
33-
23+
Use multiple checkpoints (typically 2-3) if:
24+
- Training on large datasets (>100B tokens) where intermediate models are useful for development/testing
25+
- You want to evaluate model performance vs training data size (e.g., does your model need full training?)
26+
- You might need to continue training later but want flexibility about when to stop training
27+
3428
The scheduler uses geometric progression to space checkpoints evenly on a log scale:
3529
- First checkpoint is placed at 25% of total steps
36-
- Each subsequent checkpoint is ~2x steps from previous
30+
- Each subsequent checkpoint is ~2x steps from previous checkpoint
3731
38-
Examples:
32+
Examples:
3933
- 2 checkpoints for 100K steps: [50K, 100K]
4034
- 3 checkpoints for 200K steps: [50K, 100K, 200K]
4135
- 4 checkpoints for 200K steps: [25K, 50K, 100K, 200K]
@@ -51,15 +45,15 @@ class WSDLrScheduler(StatefulSchedulerBase):
5145
- Derived from theoretical analysis on quadratic functions
5246
- Steeper initial decay, more gradual approach to lr_min
5347
- Optimal for quadratic loss landscapes
54-
48+
5549
2. Sqrt Decay:
5650
lr = lr_min + (1 - lr_min) * (1 - sqrt(t))
5751
- Similar to traditional cosine decay patterns
5852
- More gradual initial decay, consistent decay rate
5953
- May be more robust across different architectures
6054
6155
Continuation Behavior:
62-
- Training can be continued from a pre-decay checkpoint
56+
- Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint
6357
- When continuing, scheduler starts a fresh stable phase with new total_steps
6458
- Decay phase ratio applies to new training length
6559
- No warmup is applied during continuation
@@ -69,14 +63,15 @@ class WSDLrScheduler(StatefulSchedulerBase):
6963
Initial run (1000 steps, 0.1 decay ratio):
7064
- Steps 0-50: Optional warmup
7165
- Steps 50-900: Stable high learning rate
72-
- Steps 900-1000: Decay
66+
- Steps 900-1000: Decay to lr_min
7367
7468
Continuation (500 new steps, 0.1 decay ratio):
7569
- Steps 0-450: Stable high learning rate
76-
- Steps 450-500: Decay
70+
- Steps 450-500: Decay to lr_min
7771
7872
.. Note:: This scheduler is designed to be used with the :class:`~pytorch_accelerated.callbacks.WSDCheckpointCallback` class,
79-
which handles saving and loading checkpoints.
73+
which handles saving and loading checkpoints.
74+
8075
"""
8176

8277
def __init__(
@@ -371,10 +366,30 @@ def create_scheduler_fn(
371366
num_checkpoints: int = 1,
372367
is_continuation_from_checkpoint: bool = False,
373368
) -> Callable:
374-
"""Creates a scheduler function that the trainer can use.
369+
"""
370+
An alternative constructor which returns a function that accepts an optimizer and creates an instance of
371+
``WSDLrScheduler``. This is primarily intended to be used with the :class:`~pytorch_accelerated.trainer.Trainer`
372+
as illustrated below::
373+
374+
375+
trainer = Trainer(
376+
...,
377+
callbacks=[
378+
WSDCheckpointCallback(
379+
save_dir="checkpoints",
380+
initial_checkpoint="checkpoint_45000_pre_decay.pt",
381+
)
382+
],)
383+
384+
trainer.train(
385+
train_dataset=train_dataset,
386+
num_epochs=num_epochs,
387+
per_device_batch_size=batch_size,
388+
create_scheduler_fn=CosineLrScheduler.WSDLrScheduler(is_continuation_from_checkpoint=True),
389+
)
375390
376-
Returns a function that accepts an optimizer and creates an instance of WSDScheduler.
377-
The trainer will replace TrainerPlaceholderValues with actual values at runtime.
391+
By default, the ``total_num_epochs`` and ``num_iterations_per_epoch`` arguments will be set by the
392+
:class:`~pytorch_accelerated.trainer.Trainer` with the correct values at runtime.
378393
"""
379394

380395
return partial(

0 commit comments

Comments
 (0)