Skip to content

Commit c394ce1

Browse files
update callback to use new scheduler methods
1 parent 2e5d78a commit c394ce1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_accelerated/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def on_training_run_start(self, trainer, **kwargs):
834834
def on_train_step_end(self, trainer, step: int, **kwargs):
835835
"""Handle checkpoint saving and progress logging"""
836836

837-
total_steps = trainer.scheduler.get_current_step()
837+
total_steps = trainer.scheduler.get_current_step() + 1
838838

839839
# Skip if we've already saved at this step
840840
if total_steps == self.last_checkpoint_step:
@@ -861,7 +861,7 @@ def on_train_step_end(self, trainer, step: int, **kwargs):
861861
def on_training_run_end(self, trainer, **kwargs):
862862
"""Save final checkpoint if we haven't already"""
863863
# Get the final step number
864-
total_steps = trainer.run_config.max_num_train_steps
864+
total_steps = trainer.scheduler.get_current_step() + 1
865865

866866
# If we haven't saved the final checkpoint yet
867867
if self.last_checkpoint_step != total_steps:

0 commit comments

Comments
 (0)