File tree Expand file tree Collapse file tree 3 files changed +8
-5
lines changed
Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -359,7 +359,7 @@ def load_state_dict(self, state_dict: dict):
359359 def create_scheduler_fn (
360360 cls ,
361361 total_num_epochs : int = TrainerPlaceholderValues .NUM_EPOCHS ,
362- num_update_steps_per_epoch : int = TrainerPlaceholderValues .NUM_LOCAL_UPDATE_STEPS_PER_EPOCH ,
362+ num_update_steps_per_epoch : int = TrainerPlaceholderValues .PER_PROCESS_NUM_UPDATE_STEPS_PER_EPOCH ,
363363 num_warmup_epochs : int = None ,
364364 decay_phase_ratio : float = 0.1 ,
365365 lr_min : float = 1e-6 ,
Original file line number Diff line number Diff line change @@ -808,6 +808,7 @@ def _run_train_epoch(self, train_dl):
808808 self ,
809809 )
810810
811+ # updates across all processes
811812 max_total_update_steps = self .run_config .max_num_train_steps
812813
813814 updates_completed = (
@@ -849,9 +850,11 @@ def _run_train_epoch(self, train_dl):
849850 + (step + 1 ) // self .run_config .gradient_accumulation_steps
850851 )
851852
853+ global_process_updates = process_updates * self ._accelerator .num_processes
854+
852855 if (
853856 self .run_config .max_num_train_steps is not None
854- and process_updates >= max_total_update_steps
857+ and global_process_updates >= max_total_update_steps
855858 ):
856859 reached_max_steps = True
857860 # Synchronize reached_max_steps across processes
@@ -860,7 +863,7 @@ def _run_train_epoch(self, train_dl):
860863 [reached_max_steps ], device = self .device
861864 )
862865 reached_max_steps = (
863- self .gather (reached_max_steps_tensor ).all ().item ()
866+ self .gather (reached_max_steps_tensor ).any ().item ()
864867 )
865868 break
866869
Original file line number Diff line number Diff line change 1- from pathlib import Path
2- from tempfile import TemporaryFile , TemporaryDirectory
31from unittest .mock import MagicMock , Mock , call
42import pytest
53
@@ -402,6 +400,8 @@ def test_max_steps_stops_training_correctly():
402400 )
403401
404402
403+
404+
405405def test_max_steps_overrides_num_epochs ():
406406 """Test that max_num_train_steps takes precedence over num_epochs"""
407407
You can’t perform that action at this time.
0 commit comments