Skip to content

Commit a541b19

Browse files
Mandeep Bainesfacebook-github-bot
authored andcommitted
Add dummy task for translation benchmarking (#1212)
Summary: Pull Request resolved: fairinternal/fairseq-py#1212 Test Plan: python train.py \ -a transformer \ --clip-norm 0.4 --optimizer adam --lr 0.001 \ --dropout 0.0 \ --decoder-layers 7 \ --encoder-layers 7 \ --encoder-ffn-embed-dim 2048 \ --decoder-ffn-embed-dim 2048 \ --encoder-embed-dim 1024 \ --decoder-embed-dim 1024 \ --max-tokens 8192 \ --criterion cross_entropy --max-update 50 \ --attention-dropout 0.0 \ --adam-betas '(0.9, 0.98)' \ --disable-validation --no-save \ --task dummy_mt # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22484873 Pulled By: msbaines fbshipit-source-id: bc61165ab91290d0b6aa2077c968ab537bce8a6a
1 parent ffecb4e commit a541b19

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

fairseq/benchmark/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
dummy_lm,
99
dummy_masked_lm,
1010
dummy_model,
11+
dummy_mt,
1112
)

fairseq/benchmark/dummy_mt.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import numpy as np
9+
import torch
10+
11+
from fairseq.data import Dictionary, FairseqDataset
12+
from fairseq.tasks import FairseqTask, register_task
13+
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
@register_task('dummy_mt')
19+
class DummyMTTask(FairseqTask):
20+
21+
@staticmethod
22+
def add_args(parser):
23+
"""Add task-specific arguments to the parser."""
24+
parser.add_argument('--dict-size', default=49996, type=int)
25+
parser.add_argument('--dataset-size', default=100000, type=int)
26+
parser.add_argument('--tokens-per-sample', default=512, type=int,
27+
help='max number of total tokens over all segments '
28+
'per sample for BERT dataset')
29+
30+
def __init__(self, args, dictionary):
31+
super().__init__(args)
32+
self.dictionary = dictionary
33+
self.seed = args.seed
34+
35+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
36+
37+
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
38+
39+
self.dummy_src = seq[:-1]
40+
self.dummy_tgt = seq[1:]
41+
42+
@classmethod
43+
def setup_task(cls, args, **kwargs):
44+
"""Setup the task. """
45+
dictionary = Dictionary()
46+
for i in range(args.dict_size):
47+
dictionary.add_symbol('word{}'.format(i))
48+
logger.info('dictionary: {} types'.format(len(dictionary)))
49+
return cls(args, dictionary)
50+
51+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52+
"""Load a given dataset split.
53+
Args:
54+
split (str): name of the split (e.g., train, valid, test)
55+
"""
56+
if self.args.max_sentences is not None:
57+
bsz = self.args.max_sentences
58+
else:
59+
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
60+
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
61+
self.datasets[split] = DummyDataset(
62+
{
63+
'id': 1,
64+
'net_input': {
65+
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
66+
'src_lengths': torch.full(
67+
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
68+
),
69+
'prev_output_tokens': tgt.clone(),
70+
},
71+
'target': tgt,
72+
'nsentences': bsz,
73+
'ntokens': bsz * self.args.tokens_per_sample,
74+
},
75+
num_items=self.args.dataset_size,
76+
item_size=self.args.tokens_per_sample,
77+
)
78+
79+
@property
80+
def source_dictionary(self):
81+
return self.dictionary
82+
83+
@property
84+
def target_dictionary(self):
85+
return self.dictionary
86+
87+
88+
class DummyDataset(FairseqDataset):
89+
90+
def __init__(self, batch, num_items, item_size):
91+
super().__init__()
92+
self.batch = batch
93+
self.num_items = num_items
94+
self.item_size = item_size
95+
96+
def __getitem__(self, index):
97+
return index
98+
99+
def __len__(self):
100+
return self.num_items
101+
102+
def collater(self, samples):
103+
return self.batch
104+
105+
@property
106+
def sizes(self):
107+
return np.array([self.item_size] * self.num_items)
108+
109+
def num_tokens(self, index):
110+
return self.item_size
111+
112+
def size(self, index):
113+
return self.item_size
114+
115+
def ordered_indices(self):
116+
return np.arange(self.num_items)
117+
118+
@property
119+
def supports_prefetch(self):
120+
return False

0 commit comments

Comments
 (0)