Skip to content

Commit 9f4256e

Browse files
joshim5facebook-github-bot
authored andcommitted
Standalone LSTM decoder language model (#934)
Summary: Currently, the LSTM models in Fairseq master can only be used in an encoder/decoder setting, for example, in `class LSTMModel(FairseqEncoderDecoderModel)`. This PR adds a standalone LSTM decoder language model. Changes: - adds support for `LSTMDecoder` in cases where an encoder is not present, for instance, where `encoder_output_units=0`. - fixes bugs in `LSTMDecoder` that only become apparent when using it in a standalone fashion, for example, not handling `src_lengths` as an optional argument. - adds `class LSTMLanguageModel(FairseqLanguageModel)` for training LSTM language models. - tests for the `LSTMLanguageModel`. Changes to the `LSTMDecoder` are handled by existing test cases. Pull Request resolved: fairinternal/fairseq-py#934 Reviewed By: myleott Differential Revision: D18816310 Pulled By: joshim5 fbshipit-source-id: 4773695a7f5d36aa773da8a45db2e02f76c968a9
1 parent 1da061f commit 9f4256e

File tree

3 files changed

+195
-16
lines changed

3 files changed

+195
-16
lines changed

fairseq/models/lstm.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from fairseq.modules import AdaptiveSoftmax
1919

20+
DEFAULT_MAX_SOURCE_POSITIONS = 1e5
21+
DEFAULT_MAX_TARGET_POSITIONS = 1e5
2022

2123
@register_model('lstm')
2224
class LSTMModel(FairseqEncoderDecoderModel):
@@ -85,6 +87,9 @@ def build_model(cls, args, task):
8587
if args.encoder_layers != args.decoder_layers:
8688
raise ValueError('--encoder-layers must match --decoder-layers')
8789

90+
max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS)
91+
max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS)
92+
8893
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
8994
num_embeddings = len(dictionary)
9095
padding_idx = dictionary.pad()
@@ -149,6 +154,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
149154
dropout_out=args.encoder_dropout_out,
150155
bidirectional=args.encoder_bidirectional,
151156
pretrained_embed=pretrained_encoder_embed,
157+
max_source_positions=max_source_positions
152158
)
153159
decoder = LSTMDecoder(
154160
dictionary=task.target_dictionary,
@@ -166,6 +172,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
166172
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
167173
if args.criterion == 'adaptive_loss' else None
168174
),
175+
max_target_positions=max_target_positions
169176
)
170177
return cls(encoder, decoder)
171178

@@ -176,13 +183,15 @@ def __init__(
176183
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
177184
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
178185
left_pad=True, pretrained_embed=None, padding_value=0.,
186+
max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS
179187
):
180188
super().__init__(dictionary)
181189
self.num_layers = num_layers
182190
self.dropout_in = dropout_in
183191
self.dropout_out = dropout_out
184192
self.bidirectional = bidirectional
185193
self.hidden_size = hidden_size
194+
self.max_source_positions = max_source_positions
186195

187196
num_embeddings = len(dictionary)
188197
self.padding_idx = dictionary.pad()
@@ -269,7 +278,7 @@ def reorder_encoder_out(self, encoder_out, new_order):
269278

270279
def max_positions(self):
271280
"""Maximum input length supported by the encoder."""
272-
return int(1e5) # an arbitrary large number
281+
return self.max_source_positions
273282

274283

275284
class AttentionLayer(nn.Module):
@@ -312,13 +321,15 @@ def __init__(
312321
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
313322
encoder_output_units=512, pretrained_embed=None,
314323
share_input_output_embed=False, adaptive_softmax_cutoff=None,
324+
max_target_positions=DEFAULT_MAX_TARGET_POSITIONS
315325
):
316326
super().__init__(dictionary)
317327
self.dropout_in = dropout_in
318328
self.dropout_out = dropout_out
319329
self.hidden_size = hidden_size
320330
self.share_input_output_embed = share_input_output_embed
321331
self.need_attn = True
332+
self.max_target_positions = max_target_positions
322333

323334
self.adaptive_softmax = None
324335
num_embeddings = len(dictionary)
@@ -329,14 +340,18 @@ def __init__(
329340
self.embed_tokens = pretrained_embed
330341

331342
self.encoder_output_units = encoder_output_units
332-
if encoder_output_units != hidden_size:
343+
if encoder_output_units != hidden_size and encoder_output_units != 0:
333344
self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size)
334345
self.encoder_cell_proj = Linear(encoder_output_units, hidden_size)
335346
else:
336347
self.encoder_hidden_proj = self.encoder_cell_proj = None
348+
349+
# disable input feeding if there is no encoder
350+
# input feeding is described in arxiv.org/abs/1508.04025
351+
input_feed_size = 0 if encoder_output_units == 0 else hidden_size
337352
self.layers = nn.ModuleList([
338353
LSTMCell(
339-
input_size=hidden_size + embed_dim if layer == 0 else hidden_size,
354+
input_size=input_feed_size + embed_dim if layer == 0 else hidden_size,
340355
hidden_size=hidden_size,
341356
)
342357
for layer in range(num_layers)
@@ -355,7 +370,7 @@ def __init__(
355370
elif not self.share_input_output_embed:
356371
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
357372

358-
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
373+
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
359374
x, attn_scores = self.extract_features(
360375
prev_output_tokens, encoder_out, incremental_state
361376
)
@@ -367,16 +382,23 @@ def extract_features(
367382
"""
368383
Similar to *forward* but only return features.
369384
"""
370-
encoder_padding_mask = encoder_out['encoder_padding_mask']
371-
encoder_out = encoder_out['encoder_out']
385+
if encoder_out is not None:
386+
encoder_padding_mask = encoder_out['encoder_padding_mask']
387+
encoder_out = encoder_out['encoder_out']
388+
else:
389+
encoder_padding_mask = None
390+
encoder_out = None
372391

373392
if incremental_state is not None:
374393
prev_output_tokens = prev_output_tokens[:, -1:]
375394
bsz, seqlen = prev_output_tokens.size()
376395

377396
# get outputs from encoder
378-
encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
379-
srclen = encoder_outs.size(0)
397+
if encoder_out is not None:
398+
encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
399+
srclen = encoder_outs.size(0)
400+
else:
401+
srclen = None
380402

381403
# embed tokens
382404
x = self.embed_tokens(prev_output_tokens)
@@ -389,20 +411,33 @@ def extract_features(
389411
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
390412
if cached_state is not None:
391413
prev_hiddens, prev_cells, input_feed = cached_state
392-
else:
414+
elif encoder_out is not None:
415+
# setup recurrent cells
393416
num_layers = len(self.layers)
394417
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
395418
prev_cells = [encoder_cells[i] for i in range(num_layers)]
396419
if self.encoder_hidden_proj is not None:
397420
prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
398421
prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
399422
input_feed = x.new_zeros(bsz, self.hidden_size)
400-
401-
attn_scores = x.new_zeros(srclen, seqlen, bsz)
423+
else:
424+
# setup zero cells, since there is no encoder
425+
num_layers = len(self.layers)
426+
zero_state = x.new_zeros(bsz, self.hidden_size)
427+
prev_hiddens = [zero_state for i in range(num_layers)]
428+
prev_cells = [zero_state for i in range(num_layers)]
429+
input_feed = None
430+
431+
assert srclen is not None or self.attention is None, \
432+
"attention is not supported if there are no encoder outputs"
433+
attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None
402434
outs = []
403435
for j in range(seqlen):
404436
# input feeding: concatenate context vector from previous time step
405-
input = torch.cat((x[j, :, :], input_feed), dim=1)
437+
if input_feed is not None:
438+
input = torch.cat((x[j, :, :], input_feed), dim=1)
439+
else:
440+
input = x[j]
406441

407442
for i, rnn in enumerate(self.layers):
408443
# recurrent cell
@@ -423,7 +458,8 @@ def extract_features(
423458
out = F.dropout(out, p=self.dropout_out, training=self.training)
424459

425460
# input feeding
426-
input_feed = out
461+
if input_feed is not None:
462+
input_feed = out
427463

428464
# save final output
429465
outs.append(out)
@@ -445,7 +481,7 @@ def extract_features(
445481
x = F.dropout(x, p=self.dropout_out, training=self.training)
446482

447483
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
448-
if not self.training and self.need_attn:
484+
if not self.training and self.need_attn and self.attention is not None:
449485
attn_scores = attn_scores.transpose(0, 2)
450486
else:
451487
attn_scores = None
@@ -469,14 +505,17 @@ def reorder_incremental_state(self, incremental_state, new_order):
469505
def reorder_state(state):
470506
if isinstance(state, list):
471507
return [reorder_state(state_i) for state_i in state]
472-
return state.index_select(0, new_order)
508+
elif state is not None:
509+
return state.index_select(0, new_order)
510+
else:
511+
return None
473512

474513
new_state = tuple(map(reorder_state, cached_state))
475514
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
476515

477516
def max_positions(self):
478517
"""Maximum output length supported by the decoder."""
479-
return int(1e5) # an arbitrary large number
518+
return self.max_target_positions
480519

481520
def make_generation_fast_(self, need_attn=False, **kwargs):
482521
self.need_attn = need_attn

fairseq/models/lstm_lm.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from fairseq import options, utils
7+
from fairseq.models import (
8+
FairseqLanguageModel, register_model, register_model_architecture
9+
)
10+
from fairseq.models.lstm import (
11+
LSTMDecoder, Embedding
12+
)
13+
14+
DEFAULT_MAX_TARGET_POSITIONS = 1e5
15+
16+
@register_model('lstm_lm')
17+
class LSTMLanguageModel(FairseqLanguageModel):
18+
def __init__(self, decoder):
19+
super().__init__(decoder)
20+
21+
@staticmethod
22+
def add_args(parser):
23+
"""Add model-specific arguments to the parser."""
24+
# fmt: off
25+
parser.add_argument('--dropout', type=float, metavar='D',
26+
help='dropout probability')
27+
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
28+
help='decoder embedding dimension')
29+
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
30+
help='path to pre-trained decoder embedding')
31+
parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
32+
help='decoder hidden size')
33+
parser.add_argument('--decoder-layers', type=int, metavar='N',
34+
help='number of decoder layers')
35+
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
36+
help='decoder output embedding dimension')
37+
parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
38+
help='decoder attention')
39+
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
40+
help='comma separated list of adaptive softmax cutoff points. '
41+
'Must be used with adaptive_loss criterion')
42+
43+
# Granular dropout settings (if not specified these default to --dropout)
44+
parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
45+
help='dropout probability for decoder input embedding')
46+
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
47+
help='dropout probability for decoder output')
48+
parser.add_argument('--share-decoder-input-output-embed', default=False,
49+
action='store_true',
50+
help='share decoder input and output embeddings')
51+
52+
@classmethod
53+
def build_model(cls, args, task):
54+
"""Build a new model instance."""
55+
56+
# make sure all arguments are present in older models
57+
base_architecture(args)
58+
59+
if getattr(args, 'max_target_positions', None) is not None:
60+
max_target_positions = args.max_target_positions
61+
else:
62+
max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
63+
64+
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
65+
num_embeddings = len(dictionary)
66+
padding_idx = dictionary.pad()
67+
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
68+
embed_dict = utils.parse_embedding(embed_path)
69+
utils.print_embed_overlap(embed_dict, dictionary)
70+
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
71+
72+
pretrained_decoder_embed = None
73+
if args.decoder_embed_path:
74+
pretrained_decoder_embed = load_pretrained_embedding_from_file(
75+
args.decoder_embed_path,
76+
task.target_dictionary,
77+
args.decoder_embed_dim
78+
)
79+
80+
if args.share_decoder_input_output_embed:
81+
# double check all parameters combinations are valid
82+
if task.source_dictionary != task.target_dictionary:
83+
raise ValueError('--share-decoder-input-output-embeddings requires a joint dictionary')
84+
85+
if args.decoder_embed_dim != args.decoder_out_embed_dim:
86+
raise ValueError(
87+
'--share-decoder-input-output-embeddings requires '
88+
'--decoder-embed-dim to match --decoder-out-embed-dim'
89+
)
90+
91+
decoder = LSTMDecoder(
92+
dictionary=task.dictionary,
93+
embed_dim=args.decoder_embed_dim,
94+
hidden_size=args.decoder_hidden_size,
95+
out_embed_dim=args.decoder_out_embed_dim,
96+
num_layers=args.decoder_layers,
97+
dropout_in=args.decoder_dropout_in,
98+
dropout_out=args.decoder_dropout_out,
99+
attention=options.eval_bool(args.decoder_attention),
100+
encoder_output_units=0,
101+
pretrained_embed=pretrained_decoder_embed,
102+
share_input_output_embed=args.share_decoder_input_output_embed,
103+
adaptive_softmax_cutoff=(
104+
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
105+
if args.criterion == 'adaptive_loss' else None
106+
),
107+
max_target_positions=max_target_positions
108+
)
109+
110+
return cls(decoder)
111+
112+
113+
@register_model_architecture('lstm_lm', 'lstm_lm')
114+
def base_architecture(args):
115+
args.dropout = getattr(args, 'dropout', 0.1)
116+
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
117+
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
118+
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
119+
args.decoder_layers = getattr(args, 'decoder_layers', 1)
120+
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
121+
args.decoder_attention = getattr(args, 'decoder_attention', '0')
122+
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
123+
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
124+
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
125+
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')

tests/test_binaries.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,21 @@ def test_lightconv_lm(self):
503503
'--tokens-per-sample', '500',
504504
])
505505

506+
def test_lstm_lm(self):
507+
with contextlib.redirect_stdout(StringIO()):
508+
with tempfile.TemporaryDirectory('test_lstm_lm') as data_dir:
509+
create_dummy_data(data_dir)
510+
preprocess_lm_data(data_dir)
511+
train_language_model(
512+
data_dir, 'lstm_lm', ['--add-bos-token'], run_validation=True,
513+
)
514+
eval_lm_main(data_dir)
515+
generate_main(data_dir, [
516+
'--task', 'language_modeling',
517+
'--sample-break-mode', 'eos',
518+
'--tokens-per-sample', '500',
519+
])
520+
506521

507522
class TestMaskedLanguageModel(unittest.TestCase):
508523

0 commit comments

Comments
 (0)