Skip to content

Commit 0dac0ff

Browse files
vtantiafacebook-github-bot
authored andcommitted
Integrating SlowMo into fairseq and adding SlowMo to fbcode
Summary: This diff contains the following changes - * Adding SlowMo algorithm to fbcode (this contains the latest implementation - complete with reduced memory usage for slow momentum, faster forward, linting among other things) * Integration of SlowMo algorithm into fairseq (includes changes to the code to integrate as well as arguments for SlowMo) * Scripts for calling SlowMo * Addition of log-dir in addition to save-dir to allow different directories to be used for logging and saving Reviewed By: myleott, mikerabbat Differential Revision: D19184997 fbshipit-source-id: b42b298ac5297fb83a3335fa7ce262c8f48fb2bc
1 parent 91f7cf6 commit 0dac0ff

File tree

3 files changed

+87
-22
lines changed

3 files changed

+87
-22
lines changed

fairseq/models/distributed_fairseq_model.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from fairseq.models import BaseFairseqModel
1212

1313

14+
_GOSSIP_DISABLED = False
15+
try:
16+
import gossip
17+
except ImportError:
18+
_GOSSIP_DISABLED = True
19+
20+
1421
def DistributedFairseqModel(args, model, process_group=None):
1522
"""
1623
Wrap a *model* to support distributed data parallel training.
@@ -26,7 +33,7 @@ def DistributedFairseqModel(args, model, process_group=None):
2633
"""
2734
# determine which DDP class to extend
2835
assert isinstance(model, nn.Module)
29-
if args.ddp_backend == 'c10d':
36+
if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d':
3037
ddp_class = nn.parallel.DistributedDataParallel
3138
init_kwargs = dict(
3239
module=model,
@@ -41,14 +48,44 @@ def DistributedFairseqModel(args, model, process_group=None):
4148
init_kwargs['check_reduction'] = True
4249
if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
4350
init_kwargs['find_unused_parameters'] = args.find_unused_parameters
44-
elif args.ddp_backend == 'no_c10d':
51+
elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d':
4552
ddp_class = LegacyDistributedDataParallel
4653
init_kwargs = dict(
4754
module=model,
4855
world_size=args.distributed_world_size,
4956
buffer_size=2**28,
5057
process_group=process_group,
5158
)
59+
elif args.distributed_wrapper == 'SlowMo':
60+
if _GOSSIP_DISABLED:
61+
raise ImportError(
62+
'Cannot find gossip library. Please install from: '
63+
'github.com/facebookresearch/stochastic_gradient_push'
64+
)
65+
ddp_class = gossip.GossipDataParallel
66+
67+
# The values of slowmo_momentum below were obtained by tuning on the
68+
# En-De 16 dataset by training the transformer_wmt_en_de_large model
69+
if args.slowmo_momentum is None:
70+
if args.distributed_world_size <= 16:
71+
args.slowmo_momentum = 0.0
72+
elif args.distributed_world_size <= 32:
73+
args.slowmo_momentum = 0.2
74+
elif args.distributed_world_size <= 64:
75+
args.slowmo_momentum = 0.5
76+
else:
77+
args.slowmo_momentum = 0.6
78+
79+
init_kwargs = dict(
80+
module=model,
81+
device_ids=[args.device_id],
82+
output_device=args.device_id,
83+
broadcast_buffers=args.broadcast_buffers,
84+
nprocs_per_node=args.nprocs_per_node,
85+
slowmo_momentum=args.slowmo_momentum,
86+
localsgd=(args.slowmo_algorithm == 'LocalSGD'),
87+
localsgd_frequency=args.localsgd_frequency
88+
)
5289
else:
5390
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
5491

fairseq/options.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,23 @@ def add_distributed_training_args(parser):
389389
group.add_argument('--broadcast-buffers', default=False, action='store_true',
390390
help='Copy non-trainable parameters between GPUs, such as '
391391
'batchnorm population statistics')
392+
393+
group.add_argument('--distributed-wrapper', default='DDP', type=str,
394+
choices=['DDP', 'SlowMo'],
395+
help='DistributedDataParallel backend')
396+
# Add arguments for SlowMo - these will be used when SlowMo is enabled via above
397+
group.add_argument('--slowmo-momentum', default=None, type=float,
398+
help='SlowMo momentum term; by default use 0.0 for 16 GPUs, '
399+
'0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs')
400+
group.add_argument('--slowmo-algorithm', default='LocalSGD', choices=['LocalSGD', 'SGP'],
401+
help='whether to use LocalSGD or SGP')
402+
group.add_argument('--localsgd-frequency', default=3, type=int,
403+
help='Local SGD allreduce frequency')
404+
group.add_argument('--nprocs-per-node', type=int, metavar='N',
405+
default=max(1, torch.cuda.device_count()),
406+
help='number of GPUs in each node. An allreduce operation across GPUs in '
407+
'a node is very fast. Hence, we do allreduce across GPUs in a node, '
408+
'and gossip across different nodes')
392409
# fmt: on
393410
return group
394411

fairseq/trainer.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def maybe_no_sync():
412412
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
413413
)
414414

415+
overflow = False
415416
try:
416417
# multiply gradients by (# GPUs / sample_size) since DDP
417418
# already normalizes by the number of GPUs. Thus we get
@@ -429,29 +430,11 @@ def maybe_no_sync():
429430
grad_norm = self.clip_grad_norm(self.args.clip_norm)
430431

431432
# check that grad norms are consistent across workers
432-
if not self.args.use_bmuf:
433+
if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo':
433434
self._check_grad_norms(grad_norm)
434435

435436
# take an optimization step
436437
self.optimizer.step()
437-
self.set_num_updates(self.get_num_updates() + 1)
438-
439-
# log stats
440-
logging_output = self._reduce_and_log_stats(
441-
logging_outputs, sample_size, grad_norm,
442-
)
443-
444-
# clear CUDA cache to reduce memory fragmentation
445-
if (
446-
self.args.empty_cache_freq > 0
447-
and (
448-
(self.get_num_updates() + self.args.empty_cache_freq - 1)
449-
% self.args.empty_cache_freq
450-
) == 0
451-
and torch.cuda.is_available()
452-
and not self.args.cpu
453-
):
454-
torch.cuda.empty_cache()
455438
except FloatingPointError:
456439
# re-run the forward and backward pass with hooks attached to print out where it fails
457440
with NanDetector(self.model):
@@ -461,15 +444,43 @@ def maybe_no_sync():
461444
)
462445
raise
463446
except OverflowError as e:
447+
overflow = True
464448
logger.info("NOTE: overflow detected, " + str(e))
449+
grad_norm = torch.tensor(0.).cuda()
465450
self.zero_grad()
466-
logging_output = None
467451
except RuntimeError as e:
468452
if "out of memory" in str(e):
469453
self._log_oom(e)
470454
logger.error("OOM during optimization, irrecoverable")
471455
raise e
472456

457+
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
458+
if hasattr(self.model, 'perform_additional_optimizer_actions'):
459+
if hasattr(self.optimizer, 'fp32_params'):
460+
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
461+
else:
462+
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)
463+
464+
if not overflow or self.args.distributed_wrapper == 'SlowMo':
465+
self.set_num_updates(self.get_num_updates() + 1)
466+
467+
# log stats
468+
logging_output = self._reduce_and_log_stats(
469+
logging_outputs, sample_size, grad_norm,
470+
)
471+
472+
# clear CUDA cache to reduce memory fragmentation
473+
if (
474+
self.args.empty_cache_freq > 0
475+
and (
476+
(self.get_num_updates() + self.args.empty_cache_freq - 1)
477+
% self.args.empty_cache_freq
478+
) == 0
479+
and torch.cuda.is_available()
480+
and not self.args.cpu
481+
):
482+
torch.cuda.empty_cache()
483+
473484
if self.args.fp16:
474485
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)
475486

0 commit comments

Comments
 (0)