Skip to content

Commit 738a0c1

Browse files
update step logic
1 parent 0109255 commit 738a0c1

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

pytorch_accelerated/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def train(
428428
:param num_epochs: the number of epochs to train for
429429
:param eval_dataset: the dataset to use during evaluation epochs, if this is not provided, evaluation is skipped.
430430
:param per_device_batch_size: the batch size to use per device
431-
:param max_num_train_steps: the maximum number of steps across all processes to train for. If provided, this will override num_epochs
431+
:param max_num_train_steps: the maximum number of steps across all processes to train for. If both max_num_train_steps and num_epochs are provided, the smaller of the two limits is used.
432432
:param gradient_accumulation_steps: accumulate gradients to the specified number of steps to simulate a bigger batch size. By default, this is set to ``1``
433433
:param gradient_clip_value: if specified, the gradients of the model's parameters will be clipped to the range ``[-gradient_clip_value, gradient_clip_value]``
434434
:param create_scheduler_fn: a function which accepts an optimizer as an argument and returns a learning rate scheduler
@@ -808,9 +808,10 @@ def _run_train_epoch(self, train_dl):
808808
self,
809809
)
810810

811-
# updates across all processes
811+
# max steps across all processes
812812
max_total_update_steps = self.run_config.max_num_train_steps
813-
813+
814+
# updates across all processes
814815
updates_completed = (
815816
self.run_history.current_epoch - 1
816817
) * self.run_config.num_update_steps_per_epoch
@@ -850,11 +851,9 @@ def _run_train_epoch(self, train_dl):
850851
+ (step + 1) // self.run_config.gradient_accumulation_steps
851852
)
852853

853-
global_process_updates = process_updates * self._accelerator.num_processes
854-
855854
if (
856855
self.run_config.max_num_train_steps is not None
857-
and global_process_updates >= max_total_update_steps
856+
and process_updates >= max_total_update_steps
858857
):
859858
reached_max_steps = True
860859
# Synchronize reached_max_steps across processes

0 commit comments

Comments
 (0)