Skip to content

Commit 46b773a

Browse files
eripfacebook-github-bot
authored andcommitted
refactor namespaces in criterion interface (#1729)
Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes #1672 in part (part 1: [context](#1714 (comment))) ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: #1729 Differential Revision: D20049353 Pulled By: myleott fbshipit-source-id: 732077a1cc339c9f7ebe26dae42a7e8d7b5a07b4
1 parent aa79bb9 commit 46b773a

15 files changed

+123
-53
lines changed

examples/speech_recognition/criterions/cross_entropy_acc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
@register_criterion("cross_entropy_acc")
1818
class CrossEntropyWithAccCriterion(FairseqCriterion):
19-
def __init__(self, args, task):
20-
super().__init__(args, task)
19+
def __init__(self, task, sentence_avg):
20+
super().__init__(task)
21+
self.sentence_avg = sentence_avg
2122

2223
def compute_loss(self, model, net_output, target, reduction, log_probs):
2324
# N, T -> N * T
@@ -50,7 +51,7 @@ def get_logging_output(self, sample, target, lprobs, loss):
5051
)
5152
total = torch.sum(mask)
5253
sample_size = (
53-
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
54+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
5455
)
5556

5657
logging_output = {

fairseq/criterions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88

99
from fairseq import registry
10-
from fairseq.criterions.fairseq_criterion import FairseqCriterion
10+
from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion
1111

1212

1313
build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(

fairseq/criterions/adaptive_loss.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,19 @@ class AdaptiveLoss(FairseqCriterion):
1717
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
1818
(http://arxiv.org/abs/1609.04309)."""
1919

20-
def __init__(self, args, task):
21-
super().__init__(args, task)
20+
def __init__(self, task, sentence_avg):
21+
super().__init__(task)
22+
self.sentence_avg = sentence_avg
2223

24+
@classmethod
25+
def build_criterion(cls, args, task):
2326
if args.ddp_backend == 'c10d':
2427
raise Exception(
2528
'AdaptiveLoss is not compatible with the c10d '
2629
'version of DistributedDataParallel. Please use '
2730
'`--ddp-backend=no_c10d` instead.'
2831
)
32+
return cls(task, args.sentence_avg)
2933

3034
def forward(self, model, sample, reduce=True):
3135
"""Compute the loss for the given sample.
@@ -64,7 +68,7 @@ def forward(self, model, sample, reduce=True):
6468

6569
orig = utils.strip_pad(orig_target, self.padding_idx)
6670
ntokens = orig.numel()
67-
sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens
71+
sample_size = sample['target'].size(0) if self.sentence_avg else ntokens
6872
logging_output = {
6973
'loss': loss.data,
7074
'ntokens': ntokens,

fairseq/criterions/binary_cross_entropy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
@register_criterion('binary_cross_entropy')
1616
class BinaryCrossEntropyCriterion(FairseqCriterion):
1717

18-
def __init__(self, args, task):
19-
super().__init__(args, task)
20-
self.infonce = getattr(args, "infonce", False)
21-
self.loss_weights = None if getattr(args, 'loss_weights', None) is None else eval(args.loss_weights)
22-
self.log_keys = [] if getattr(args, 'log_keys', None) is None else eval(args.log_keys)
18+
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
19+
super().__init__(task)
20+
self.infonce = infonce
21+
self.loss_weights = None if loss_weights is None else eval(loss_weights)
22+
self.log_keys = [] if log_keys is None else eval(log_keys)
2323

2424
@staticmethod
2525
def add_args(parser):

fairseq/criterions/composite_loss.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class CompositeLoss(FairseqCriterion):
1414
"""This is a composite loss that, given a list of model outputs and a list of targets,
1515
computes an average of losses for each output-target pair"""
1616

17+
def __init__(self, task, underlying_criterion):
18+
super().__init__(task)
19+
self.underlying_criterion = underlying_criterion
20+
1721
@staticmethod
1822
def add_args(parser):
1923
"""Add criterion-specific arguments to the parser."""
@@ -58,8 +62,8 @@ def decoder(self):
5862

5963
class _CompositeLoss(FairseqCriterion):
6064

61-
def __init__(self, args, task, underlying_criterion):
62-
super().__init__(args, task)
65+
def __init__(self, task, underlying_criterion):
66+
super().__init__(task)
6367
self.underlying_criterion = underlying_criterion
6468

6569
def forward(self, model, sample, reduce=True):
@@ -92,4 +96,4 @@ def aggregate_logging_outputs(logging_outputs):
9296
def reduce_metrics(logging_outputs) -> None:
9397
underlying_criterion.__class__.reduce_metrics(logging_outputs)
9498

95-
return _CompositeLoss(args, task, underlying_criterion)
99+
return _CompositeLoss(task, underlying_criterion)

fairseq/criterions/cross_entropy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
@register_criterion('cross_entropy')
1515
class CrossEntropyCriterion(FairseqCriterion):
1616

17-
def __init__(self, args, task):
18-
super().__init__(args, task)
17+
def __init__(self, task, sentence_avg):
18+
super().__init__(task)
19+
self.sentence_avg = sentence_avg
1920

2021
def forward(self, model, sample, reduce=True):
2122
"""Compute the loss for the given sample.
@@ -27,7 +28,7 @@ def forward(self, model, sample, reduce=True):
2728
"""
2829
net_output = model(**sample['net_input'])
2930
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
30-
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
31+
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
3132
logging_output = {
3233
'loss': loss.data,
3334
'ntokens': sample['ntokens'],

fairseq/criterions/fairseq_criterion.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import inspect
67
from typing import Any, Dict, List
78

89
from torch.nn.modules.loss import _Loss
@@ -12,11 +13,12 @@
1213

1314
class FairseqCriterion(_Loss):
1415

15-
def __init__(self, args, task):
16+
def __init__(self, task):
1617
super().__init__()
17-
self.args = args
1818
self.task = task
19-
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
19+
if hasattr(task, 'target_dictionary'):
20+
tgt_dict = task.target_dictionary
21+
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
2022

2123
@staticmethod
2224
def add_args(parser):
@@ -25,7 +27,35 @@ def add_args(parser):
2527

2628
@classmethod
2729
def build_criterion(cls, args, task):
28-
return cls(args, task)
30+
"""Construct a criterion from command-line args."""
31+
# Criterions can override this, but for convenience we also try
32+
# to automatically map argparse.Namespace keys to corresponding
33+
# arguments in the __init__.
34+
init_args = {}
35+
for p in inspect.signature(cls).parameters.values():
36+
if (
37+
p.kind == p.POSITIONAL_ONLY
38+
or p.kind == p.VAR_POSITIONAL
39+
or p.kind == p.VAR_KEYWORD
40+
):
41+
# we haven't implemented inference for these argument types,
42+
# but PRs welcome :)
43+
raise NotImplementedError('{} not supported'.format(p.kind))
44+
45+
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
46+
47+
if p.name == 'task':
48+
init_args['task'] = task
49+
elif hasattr(args, p.name):
50+
init_args[p.name] = getattr(args, p.name)
51+
elif p.default != p.empty:
52+
pass # we'll use the default value
53+
else:
54+
raise NotImplementedError(
55+
'Unable to infer Criterion arguments, please implement '
56+
'{}.build_criterion'.format(cls.__name__)
57+
)
58+
return cls(**init_args)
2959

3060
def forward(self, model, sample, reduce=True):
3161
"""Compute the loss for the given sample.
@@ -69,3 +99,21 @@ def logging_outputs_can_be_summed() -> bool:
6999
to True will improves distributed training speed.
70100
"""
71101
return False
102+
103+
104+
class LegacyFairseqCriterion(FairseqCriterion):
105+
106+
def __init__(self, args, task):
107+
super().__init__(task=task)
108+
self.args = args
109+
110+
utils.deprecation_warning(
111+
'Criterions should take explicit arguments instead of an '
112+
'argparse.Namespace object, please update your criterion by '
113+
'extending FairseqCriterion instead of LegacyFairseqCriterion.'
114+
)
115+
116+
@classmethod
117+
def build_criterion(cls, args, task):
118+
"""Construct a criterion from command-line args."""
119+
return cls(args, task)

fairseq/criterions/label_smoothed_cross_entropy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
3333
@register_criterion('label_smoothed_cross_entropy')
3434
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3535

36-
def __init__(self, args, task):
37-
super().__init__(args, task)
38-
self.eps = args.label_smoothing
36+
def __init__(self, task, sentence_avg, label_smoothing):
37+
super().__init__(task)
38+
self.sentence_avg = sentence_avg
39+
self.eps = label_smoothing
3940

4041
@staticmethod
4142
def add_args(parser):
@@ -55,7 +56,7 @@ def forward(self, model, sample, reduce=True):
5556
"""
5657
net_output = model(**sample['net_input'])
5758
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
58-
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
59+
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
5960
logging_output = {
6061
'loss': loss.data,
6162
'nll_loss': nll_loss.data,

fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
@register_criterion('label_smoothed_cross_entropy_with_alignment')
1515
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):
1616

17-
def __init__(self, args, task):
18-
super().__init__(args, task)
19-
self.alignment_lambda = args.alignment_lambda
17+
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
18+
super().__init__(task, sentence_avg, label_smoothing)
19+
self.alignment_lambda = alignment_lambda
2020

2121
@staticmethod
2222
def add_args(parser):
2323
"""Add criterion-specific arguments to the parser."""
24-
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
25-
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
24+
LabelSmoothedCrossEntropyCriterion.add_args(parser)
2625
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
2726
help='weight for the alignment loss')
2827

@@ -36,7 +35,7 @@ def forward(self, model, sample, reduce=True):
3635
"""
3736
net_output = model(**sample['net_input'])
3837
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
39-
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
38+
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
4039
logging_output = {
4140
'loss': utils.item(loss.data) if reduce else loss.data,
4241
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,

fairseq/criterions/legacy_masked_lm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ class LegacyMaskedLmLoss(FairseqCriterion):
4848
an argument.
4949
"""
5050

51-
def __init__(self, args, task):
52-
super().__init__(args, task)
51+
def __init__(self, task, masked_lm_only, nsp_loss_weight):
52+
super().__init__(task)
53+
self.masked_lm_only = masked_lm_only
54+
self.nsp_loss_weight = nsp_loss_weight
5355

5456
@staticmethod
5557
def add_args(parser):
@@ -85,7 +87,7 @@ def forward(self, model, sample, reduce=True):
8587

8688
# Compute sentence loss if masked_lm_only is False
8789
sentence_loss = None
88-
if not self.args.masked_lm_only:
90+
if not self.masked_lm_only:
8991
sentence_logits = output_metadata['sentence_logits']
9092
sentence_targets = sample['sentence_target'].view(-1)
9193
# This needs to be recomputed due to some differences between
@@ -102,7 +104,7 @@ def forward(self, model, sample, reduce=True):
102104
sentence_loss = compute_cross_entropy_loss(
103105
sentence_logits, sentence_targets)
104106

105-
loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)
107+
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
106108

107109
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
108110
# we don't need to use sample_size as denominator for the gradient

0 commit comments

Comments
 (0)