Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions examples/speech_recognition/criterions/cross_entropy_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

@classmethod
def from_args(cls, task, args):
return cls(task, args.sentence_avg)

def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
Expand Down Expand Up @@ -50,7 +55,7 @@ def get_logging_output(self, sample, target, lprobs, loss):
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)

logging_output = {
Expand Down
10 changes: 7 additions & 3 deletions fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ class AdaptiveLoss(FairseqCriterion):
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

@classmethod
def from_args(cls, task, args):
if args.ddp_backend == 'c10d':
raise Exception(
'AdaptiveLoss is not compatible with the c10d '
'version of DistributedDataParallel. Please use '
'`--ddp-backend=no_c10d` instead.'
)
return cls(task, args.sentence_avg)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down Expand Up @@ -64,7 +68,7 @@ def forward(self, model, sample, reduce=True):

orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens
sample_size = sample['target'].size(0) if self.sentence_avg else ntokens
logging_output = {
'loss': loss.data,
'ntokens': ntokens,
Expand Down
17 changes: 12 additions & 5 deletions fairseq/criterions/binary_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
@register_criterion('binary_cross_entropy')
class BinaryCrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.infonce = getattr(args, "infonce", False)
self.loss_weights = None if getattr(args, 'loss_weights', None) is None else eval(args.loss_weights)
self.log_keys = [] if getattr(args, 'log_keys', None) is None else eval(args.log_keys)
def __init__(self, task, infonce, loss_weights, log_keys):
super().__init__(task)
self.infonce = infonce
self.loss_weights = loss_weights
self.log_keys = log_keys

@classmethod
def from_args(cls, task, args):
infonce = getattr(args, "infonce", False)
loss_weights = None if getattr(args, 'loss_weights', None) is None else eval(args.loss_weights)
log_keys = [] if getattr(args, 'log_keys', None) is None else eval(args.log_keys)
return cls(task, infonce, loss_weights, log_keys)

@staticmethod
def add_args(parser):
Expand Down
15 changes: 12 additions & 3 deletions fairseq/criterions/composite_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class CompositeLoss(FairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""

def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
Expand All @@ -22,6 +26,11 @@ def add_args(parser):
help='underlying criterion to use for the composite loss')
# fmt: on

@classmethod
def from_args(cls, task, args):
underlying_criterion = task.build_criterion(args)
return cls(task, underlying_criterion)

@staticmethod
def build_underlying_criterion(args, task):
saved_criterion = args.criterion
Expand Down Expand Up @@ -58,8 +67,8 @@ def decoder(self):

class _CompositeLoss(FairseqCriterion):

def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion

def forward(self, model, sample, reduce=True):
Expand Down Expand Up @@ -92,4 +101,4 @@ def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)

return _CompositeLoss(args, task, underlying_criterion)
return _CompositeLoss(task, underlying_criterion)
11 changes: 8 additions & 3 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

@classmethod
def from_args(cls, task, args):
return cls(task, args.sentence_avg)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand All @@ -27,7 +32,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data,
'ntokens': sample['ntokens'],
Expand Down
11 changes: 8 additions & 3 deletions fairseq/criterions/fairseq_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@

class FairseqCriterion(_Loss):

def __init__(self, args, task):
def __init__(self, task):
super().__init__()
self.args = args
self.task = task
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100

Expand All @@ -23,9 +22,15 @@ def add_args(parser):
"""Add criterion-specific arguments to the parser."""
pass

@classmethod
def from_args(cls, task, args):
"""Construct a criterion from command-line args
"""
raise NotImplementedError

@classmethod
def build_criterion(cls, args, task):
return cls(args, task)
return cls.from_args(task, args)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down
13 changes: 9 additions & 4 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
def __init__(self, task, sentence_avg, label_smoothing):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing

@classmethod
def from_args(cls, task, args):
return cls(task, args.sentence_avg, args.label_smoothing)

@staticmethod
def add_args(parser):
Expand All @@ -55,7 +60,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data,
'nll_loss': nll_loss.data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.alignment_lambda = args.alignment_lambda
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda

@classmethod
def from_args(cls, task, args):
return cls(task, args.sentence_avg, args.label_smoothing, args.alignment_lambda)

@staticmethod
def add_args(parser):
Expand All @@ -36,7 +40,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
Expand Down
14 changes: 10 additions & 4 deletions fairseq/criterions/legacy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class LegacyMaskedLmLoss(FairseqCriterion):
an argument.
"""

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight

@staticmethod
def add_args(parser):
Expand All @@ -61,6 +63,10 @@ def add_args(parser):
help='weight for next sentence prediction'
' loss (default 1)')

@classmethod
def from_args(cls, task, args):
return cls(task, args.masked_lm_only, args.nsp_loss_weight)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
Expand All @@ -85,7 +91,7 @@ def forward(self, model, sample, reduce=True):

# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.args.masked_lm_only:
if not self.masked_lm_only:
sentence_logits = output_metadata['sentence_logits']
sentence_targets = sample['sentence_target'].view(-1)
# This needs to be recomputed due to some differences between
Expand All @@ -102,7 +108,7 @@ def forward(self, model, sample, reduce=True):
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets)

loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)

# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
Expand Down
4 changes: 4 additions & 0 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class MaskedLmLoss(FairseqCriterion):
Implementation for the loss used in masked language model (MLM) training.
"""

@classmethod
def from_args(cls, task, args):
return cls(task)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

Expand Down
8 changes: 8 additions & 0 deletions fairseq/criterions/nat_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

@register_criterion("nat_loss")
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
Expand All @@ -27,6 +31,10 @@ def add_args(parser):
help='epsilon for label smoothing, 0 means no label smoothing')
# fmt: on

@classmethod
def from_args(cls, task, args):
return cls(task, args.label_smoothing)

def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
Expand Down
8 changes: 8 additions & 0 deletions fairseq/criterions/sentence_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
@register_criterion('sentence_prediction')
class SentencePredictionCriterion(FairseqCriterion):

def __init__(self, task, classification_head_name):
super().__init__(task)
self.classification_head_name = classification_head_name

@staticmethod
def add_args(parser):
# fmt: off
Expand All @@ -23,6 +27,10 @@ def add_args(parser):
help='name of the classification head to use')
# fmt: on

@classmethod
def from_args(cls, task, args):
return cls(task, args.classification_head_name)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

Expand Down
20 changes: 13 additions & 7 deletions fairseq/criterions/sentence_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
@register_criterion('sentence_ranking')
class SentenceRankingCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w')
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, 'w')
else:
self.prediction_h = None
self.num_classes = num_classes

@classmethod
def from_args(cls, task, args):
return cls(task, args.ranking_head_name, args.save_predictions, args.num_classes)

def __del__(self):
if self.prediction_h is not None:
Expand All @@ -46,14 +52,14 @@ def forward(self, model, sample, reduce=True):
"""
assert (
hasattr(model, 'classification_heads')
and self.args.ranking_head_name in model.classification_heads
and self.ranking_head_name in model.classification_heads
), 'model must provide sentence ranking head for --criterion=sentence_ranking'

scores = []
for idx in range(self.args.num_classes):
for idx in range(self.num_classes):
score, _ = model(
**sample['net_input{idx}'.format(idx=idx+1)],
classification_head_name=self.args.ranking_head_name,
classification_head_name=self.ranking_head_name,
)
scores.append(score)

Expand Down
2 changes: 1 addition & 1 deletion tests/speech_recognition/asr_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def setUpArgs(self):
def setUp(self):
args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder())
self.criterion = self.criterion_cls(args=args, task=DummyTask(args))
self.criterion = self.criterion_cls.build_criterion(args=args, task=DummyTask(args))

def get_src_tokens(self, correct_prediction, aggregate):
"""
Expand Down
Loading