Skip to content

Commit e81cd9c

Browse files
updates
1 parent 2a80d6e commit e81cd9c

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

pytorch_accelerated/schedulers/wsd_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def load_state_dict(self, state_dict: dict):
359359
def create_scheduler_fn(
360360
cls,
361361
total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS,
362-
num_update_steps_per_epoch: int = TrainerPlaceholderValues.NUM_LOCAL_UPDATE_STEPS_PER_EPOCH,
362+
num_update_steps_per_epoch: int = TrainerPlaceholderValues.PER_PROCESS_NUM_UPDATE_STEPS_PER_EPOCH,
363363
num_warmup_epochs: int = None,
364364
decay_phase_ratio: float = 0.1,
365365
lr_min: float = 1e-6,

pytorch_accelerated/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ def _run_train_epoch(self, train_dl):
808808
self,
809809
)
810810

811+
# updates across all processes
811812
max_total_update_steps = self.run_config.max_num_train_steps
812813

813814
updates_completed = (
@@ -849,9 +850,11 @@ def _run_train_epoch(self, train_dl):
849850
+ (step + 1) // self.run_config.gradient_accumulation_steps
850851
)
851852

853+
global_process_updates = process_updates * self._accelerator.num_processes
854+
852855
if (
853856
self.run_config.max_num_train_steps is not None
854-
and process_updates >= max_total_update_steps
857+
and global_process_updates >= max_total_update_steps
855858
):
856859
reached_max_steps = True
857860
# Synchronize reached_max_steps across processes
@@ -860,7 +863,7 @@ def _run_train_epoch(self, train_dl):
860863
[reached_max_steps], device=self.device
861864
)
862865
reached_max_steps = (
863-
self.gather(reached_max_steps_tensor).all().item()
866+
self.gather(reached_max_steps_tensor).any().item()
864867
)
865868
break
866869

test/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from pathlib import Path
2-
from tempfile import TemporaryFile, TemporaryDirectory
31
from unittest.mock import MagicMock, Mock, call
42
import pytest
53

@@ -402,6 +400,8 @@ def test_max_steps_stops_training_correctly():
402400
)
403401

404402

403+
404+
405405
def test_max_steps_overrides_num_epochs():
406406
"""Test that max_num_train_steps takes precedence over num_epochs"""
407407

0 commit comments

Comments
 (0)