Skip to content

Commit bab411b

Browse files
Counting bug fix (#69)
* add debug configs * update step logic
1 parent 296b9f2 commit bab411b

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

.vscode/launch.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Python Debugger: Current File",
9+
"type": "debugpy",
10+
"request": "launch",
11+
"program": "${file}",
12+
"console": "integratedTerminal",
13+
"justMyCode": false
14+
},
15+
{
16+
"name": "Python Debugger: accelerate",
17+
"type": "debugpy",
18+
"request": "launch",
19+
"module": "accelerate.commands.launch",
20+
"args": [
21+
// "--config_file",
22+
// "PATH/TO/accelerate_config.yaml",
23+
// "PATH/TO/train.py",
24+
25+
],
26+
"console": "integratedTerminal",
27+
"justMyCode": false
28+
},
29+
]
30+
}

pytorch_accelerated/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def train(
428428
:param num_epochs: the number of epochs to train for
429429
:param eval_dataset: the dataset to use during evaluation epochs, if this is not provided, evaluation is skipped.
430430
:param per_device_batch_size: the batch size to use per device
431-
:param max_num_train_steps: the maximum number of steps across all processes to train for. If provided, this will override num_epochs
431+
:param max_num_train_steps: the maximum number of steps across all processes to train for. If both max_num_train_steps and num_epochs are provided, the smaller of the two limits is used.
432432
:param gradient_accumulation_steps: accumulate gradients to the specified number of steps to simulate a bigger batch size. By default, this is set to ``1``
433433
:param gradient_clip_value: if specified, the gradients of the model's parameters will be clipped to the range ``[-gradient_clip_value, gradient_clip_value]``
434434
:param create_scheduler_fn: a function which accepts an optimizer as an argument and returns a learning rate scheduler
@@ -808,9 +808,10 @@ def _run_train_epoch(self, train_dl):
808808
self,
809809
)
810810

811-
# updates across all processes
811+
# max steps across all processes
812812
max_total_update_steps = self.run_config.max_num_train_steps
813-
813+
814+
# updates across all processes
814815
updates_completed = (
815816
self.run_history.current_epoch - 1
816817
) * self.run_config.num_update_steps_per_epoch
@@ -850,11 +851,9 @@ def _run_train_epoch(self, train_dl):
850851
+ (step + 1) // self.run_config.gradient_accumulation_steps
851852
)
852853

853-
global_process_updates = process_updates * self._accelerator.num_processes
854-
855854
if (
856855
self.run_config.max_num_train_steps is not None
857-
and global_process_updates >= max_total_update_steps
856+
and process_updates >= max_total_update_steps
858857
):
859858
reached_max_steps = True
860859
# Synchronize reached_max_steps across processes

0 commit comments

Comments
 (0)