File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed
Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -640,14 +640,17 @@ def _create_run_config(
640640
641641 if self ._train_dataloader is not None :
642642 local_batches = len (self ._train_dataloader )
643+ total_batches = local_batches * self ._accelerator .num_processes
643644 num_update_steps_per_epoch = math .ceil (
644- local_batches / gradient_accumulation_steps
645+ total_batches / gradient_accumulation_steps
645646 )
647+
646648 else :
647649 num_update_steps_per_epoch = 0
648650
649651 if max_num_train_steps is None :
650- max_num_train_steps = num_epochs * num_update_steps_per_epoch
652+ # Add 1 to ensure we don't stop early due to rounding
653+ max_num_train_steps = (num_epochs * num_update_steps_per_epoch ) + 1
651654 else :
652655 num_epochs = math .ceil (max_num_train_steps / num_update_steps_per_epoch )
653656
@@ -749,11 +752,11 @@ def _run_training(self):
749752 )
750753 break
751754
752- # if reached_max_steps:
753- # self.print(
754- # f"Reached max number of training steps {self.run_config.max_num_train_steps} in epoch {epoch + 1}"
755- # )
756- # break
755+ if reached_max_steps :
756+ self .print (
757+ f"Reached max number of training steps { self .run_config .max_num_train_steps } in epoch { epoch + 1 } "
758+ )
759+ break
757760
758761 self .training_run_end ()
759762 self .callback_handler .call_event (
You can’t perform that action at this time.
0 commit comments