Skip to content

Commit 2e5d78a

Browse files
add extra methods to wsd scheduler
1 parent bab411b commit 2e5d78a

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

pytorch_accelerated/callbacks.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,7 @@ def on_training_run_start(self, trainer, **kwargs):
834834
def on_train_step_end(self, trainer, step: int, **kwargs):
835835
"""Handle checkpoint saving and progress logging"""
836836

837-
# Calculate global step accounting for distributed training and gradient accumulation
838-
total_steps = (
839-
(trainer.run_history.current_epoch - 1)
840-
* trainer.run_config.num_update_steps_per_epoch
841-
+ step // trainer.run_config.gradient_accumulation_steps
842-
)
837+
total_steps = trainer.scheduler.get_current_step()
843838

844839
# Skip if we've already saved at this step
845840
if total_steps == self.last_checkpoint_step:

pytorch_accelerated/schedulers/wsd_scheduler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,23 @@ def get_decay_info(self) -> List[dict]:
217217
"""
218218
return self.checkpoint_decay_info
219219

220+
def get_current_step(self) -> int:
221+
"""Get the current step count of the scheduler.
222+
223+
Returns:
224+
int: The current number of optimizer updates completed
225+
"""
226+
return self._num_updates
227+
228+
def get_current_phase_info(self) -> dict:
229+
"""Get phase information for the current step.
230+
231+
Returns:
232+
dict: Phase information containing period_start, period_end,
233+
decay_steps, and pre_decay_step for current position
234+
"""
235+
return self.get_phase_info(self._num_updates)
236+
220237
@lru_cache(maxsize=1)
221238
def _get_checkpoint_info(self, num_updates):
222239
"""Get information about the current checkpoint period."""

0 commit comments

Comments
 (0)