1717)
1818from fairseq .modules import AdaptiveSoftmax
1919
20+ DEFAULT_MAX_SOURCE_POSITIONS = 1e5
21+ DEFAULT_MAX_TARGET_POSITIONS = 1e5
2022
2123@register_model ('lstm' )
2224class LSTMModel (FairseqEncoderDecoderModel ):
@@ -85,6 +87,9 @@ def build_model(cls, args, task):
8587 if args .encoder_layers != args .decoder_layers :
8688 raise ValueError ('--encoder-layers must match --decoder-layers' )
8789
90+ max_source_positions = getattr (args , 'max_source_positions' , DEFAULT_MAX_SOURCE_POSITIONS )
91+ max_target_positions = getattr (args , 'max_target_positions' , DEFAULT_MAX_TARGET_POSITIONS )
92+
8893 def load_pretrained_embedding_from_file (embed_path , dictionary , embed_dim ):
8994 num_embeddings = len (dictionary )
9095 padding_idx = dictionary .pad ()
@@ -149,6 +154,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
149154 dropout_out = args .encoder_dropout_out ,
150155 bidirectional = args .encoder_bidirectional ,
151156 pretrained_embed = pretrained_encoder_embed ,
157+ max_source_positions = max_source_positions
152158 )
153159 decoder = LSTMDecoder (
154160 dictionary = task .target_dictionary ,
@@ -166,6 +172,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
166172 options .eval_str_list (args .adaptive_softmax_cutoff , type = int )
167173 if args .criterion == 'adaptive_loss' else None
168174 ),
175+ max_target_positions = max_target_positions
169176 )
170177 return cls (encoder , decoder )
171178
@@ -176,13 +183,15 @@ def __init__(
176183 self , dictionary , embed_dim = 512 , hidden_size = 512 , num_layers = 1 ,
177184 dropout_in = 0.1 , dropout_out = 0.1 , bidirectional = False ,
178185 left_pad = True , pretrained_embed = None , padding_value = 0. ,
186+ max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
179187 ):
180188 super ().__init__ (dictionary )
181189 self .num_layers = num_layers
182190 self .dropout_in = dropout_in
183191 self .dropout_out = dropout_out
184192 self .bidirectional = bidirectional
185193 self .hidden_size = hidden_size
194+ self .max_source_positions = max_source_positions
186195
187196 num_embeddings = len (dictionary )
188197 self .padding_idx = dictionary .pad ()
@@ -269,7 +278,7 @@ def reorder_encoder_out(self, encoder_out, new_order):
269278
270279 def max_positions (self ):
271280 """Maximum input length supported by the encoder."""
272- return int ( 1e5 ) # an arbitrary large number
281+ return self . max_source_positions
273282
274283
275284class AttentionLayer (nn .Module ):
@@ -312,13 +321,15 @@ def __init__(
312321 num_layers = 1 , dropout_in = 0.1 , dropout_out = 0.1 , attention = True ,
313322 encoder_output_units = 512 , pretrained_embed = None ,
314323 share_input_output_embed = False , adaptive_softmax_cutoff = None ,
324+ max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
315325 ):
316326 super ().__init__ (dictionary )
317327 self .dropout_in = dropout_in
318328 self .dropout_out = dropout_out
319329 self .hidden_size = hidden_size
320330 self .share_input_output_embed = share_input_output_embed
321331 self .need_attn = True
332+ self .max_target_positions = max_target_positions
322333
323334 self .adaptive_softmax = None
324335 num_embeddings = len (dictionary )
@@ -329,14 +340,18 @@ def __init__(
329340 self .embed_tokens = pretrained_embed
330341
331342 self .encoder_output_units = encoder_output_units
332- if encoder_output_units != hidden_size :
343+ if encoder_output_units != hidden_size and encoder_output_units != 0 :
333344 self .encoder_hidden_proj = Linear (encoder_output_units , hidden_size )
334345 self .encoder_cell_proj = Linear (encoder_output_units , hidden_size )
335346 else :
336347 self .encoder_hidden_proj = self .encoder_cell_proj = None
348+
349+ # disable input feeding if there is no encoder
350+ # input feeding is described in arxiv.org/abs/1508.04025
351+ input_feed_size = 0 if encoder_output_units == 0 else hidden_size
337352 self .layers = nn .ModuleList ([
338353 LSTMCell (
339- input_size = hidden_size + embed_dim if layer == 0 else hidden_size ,
354+ input_size = input_feed_size + embed_dim if layer == 0 else hidden_size ,
340355 hidden_size = hidden_size ,
341356 )
342357 for layer in range (num_layers )
@@ -355,7 +370,7 @@ def __init__(
355370 elif not self .share_input_output_embed :
356371 self .fc_out = Linear (out_embed_dim , num_embeddings , dropout = dropout_out )
357372
358- def forward (self , prev_output_tokens , encoder_out , incremental_state = None ):
373+ def forward (self , prev_output_tokens , encoder_out = None , incremental_state = None , ** kwargs ):
359374 x , attn_scores = self .extract_features (
360375 prev_output_tokens , encoder_out , incremental_state
361376 )
@@ -367,16 +382,23 @@ def extract_features(
367382 """
368383 Similar to *forward* but only return features.
369384 """
370- encoder_padding_mask = encoder_out ['encoder_padding_mask' ]
371- encoder_out = encoder_out ['encoder_out' ]
385+ if encoder_out is not None :
386+ encoder_padding_mask = encoder_out ['encoder_padding_mask' ]
387+ encoder_out = encoder_out ['encoder_out' ]
388+ else :
389+ encoder_padding_mask = None
390+ encoder_out = None
372391
373392 if incremental_state is not None :
374393 prev_output_tokens = prev_output_tokens [:, - 1 :]
375394 bsz , seqlen = prev_output_tokens .size ()
376395
377396 # get outputs from encoder
378- encoder_outs , encoder_hiddens , encoder_cells = encoder_out [:3 ]
379- srclen = encoder_outs .size (0 )
397+ if encoder_out is not None :
398+ encoder_outs , encoder_hiddens , encoder_cells = encoder_out [:3 ]
399+ srclen = encoder_outs .size (0 )
400+ else :
401+ srclen = None
380402
381403 # embed tokens
382404 x = self .embed_tokens (prev_output_tokens )
@@ -389,20 +411,33 @@ def extract_features(
389411 cached_state = utils .get_incremental_state (self , incremental_state , 'cached_state' )
390412 if cached_state is not None :
391413 prev_hiddens , prev_cells , input_feed = cached_state
392- else :
414+ elif encoder_out is not None :
415+ # setup recurrent cells
393416 num_layers = len (self .layers )
394417 prev_hiddens = [encoder_hiddens [i ] for i in range (num_layers )]
395418 prev_cells = [encoder_cells [i ] for i in range (num_layers )]
396419 if self .encoder_hidden_proj is not None :
397420 prev_hiddens = [self .encoder_hidden_proj (x ) for x in prev_hiddens ]
398421 prev_cells = [self .encoder_cell_proj (x ) for x in prev_cells ]
399422 input_feed = x .new_zeros (bsz , self .hidden_size )
400-
401- attn_scores = x .new_zeros (srclen , seqlen , bsz )
423+ else :
424+ # setup zero cells, since there is no encoder
425+ num_layers = len (self .layers )
426+ zero_state = x .new_zeros (bsz , self .hidden_size )
427+ prev_hiddens = [zero_state for i in range (num_layers )]
428+ prev_cells = [zero_state for i in range (num_layers )]
429+ input_feed = None
430+
431+ assert srclen is not None or self .attention is None , \
432+ "attention is not supported if there are no encoder outputs"
433+ attn_scores = x .new_zeros (srclen , seqlen , bsz ) if self .attention is not None else None
402434 outs = []
403435 for j in range (seqlen ):
404436 # input feeding: concatenate context vector from previous time step
405- input = torch .cat ((x [j , :, :], input_feed ), dim = 1 )
437+ if input_feed is not None :
438+ input = torch .cat ((x [j , :, :], input_feed ), dim = 1 )
439+ else :
440+ input = x [j ]
406441
407442 for i , rnn in enumerate (self .layers ):
408443 # recurrent cell
@@ -423,7 +458,8 @@ def extract_features(
423458 out = F .dropout (out , p = self .dropout_out , training = self .training )
424459
425460 # input feeding
426- input_feed = out
461+ if input_feed is not None :
462+ input_feed = out
427463
428464 # save final output
429465 outs .append (out )
@@ -445,7 +481,7 @@ def extract_features(
445481 x = F .dropout (x , p = self .dropout_out , training = self .training )
446482
447483 # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
448- if not self .training and self .need_attn :
484+ if not self .training and self .need_attn and self . attention is not None :
449485 attn_scores = attn_scores .transpose (0 , 2 )
450486 else :
451487 attn_scores = None
@@ -469,14 +505,17 @@ def reorder_incremental_state(self, incremental_state, new_order):
469505 def reorder_state (state ):
470506 if isinstance (state , list ):
471507 return [reorder_state (state_i ) for state_i in state ]
472- return state .index_select (0 , new_order )
508+ elif state is not None :
509+ return state .index_select (0 , new_order )
510+ else :
511+ return None
473512
474513 new_state = tuple (map (reorder_state , cached_state ))
475514 utils .set_incremental_state (self , incremental_state , 'cached_state' , new_state )
476515
477516 def max_positions (self ):
478517 """Maximum output length supported by the decoder."""
479- return int ( 1e5 ) # an arbitrary large number
518+ return self . max_target_positions
480519
481520 def make_generation_fast_ (self , need_attn = False , ** kwargs ):
482521 self .need_attn = need_attn
0 commit comments