@@ -425,7 +425,9 @@ def __str__(self):
425425class ELMoEmbeddings (TokenEmbeddings ):
426426 """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018."""
427427
428- def __init__ (self , model : str = "original" ):
428+ def __init__ (
429+ self , model : str = "original" , options_file : str = None , weight_file : str = None
430+ ):
429431 super ().__init__ ()
430432
431433 try :
@@ -442,22 +444,23 @@ def __init__(self, model: str = "original"):
442444 self .name = "elmo-" + model
443445 self .static_embeddings = True
444446
445- # the default model for ELMo is the 'original' model, which is very large
446- options_file = allennlp .commands .elmo .DEFAULT_OPTIONS_FILE
447- weight_file = allennlp .commands .elmo .DEFAULT_WEIGHT_FILE
448- # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
449- if model == "small" :
450- options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
451- weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
452- if model == "medium" :
453- options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
454- weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
455- if model == "pt" or model == "portuguese" :
456- options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
457- weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
458- if model == "pubmed" :
459- options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
460- weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
447+ if not options_file or not weight_file :
448+ # the default model for ELMo is the 'original' model, which is very large
449+ options_file = allennlp .commands .elmo .DEFAULT_OPTIONS_FILE
450+ weight_file = allennlp .commands .elmo .DEFAULT_WEIGHT_FILE
451+ # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
452+ if model == "small" :
453+ options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
454+ weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
455+ if model == "medium" :
456+ options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
457+ weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
458+ if model == "pt" or model == "portuguese" :
459+ options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
460+ weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
461+ if model == "pubmed" :
462+ options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
463+ weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
461464
462465 # put on Cuda if available
463466 from flair import device
@@ -492,7 +495,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
492495 sentence_embeddings = embeddings [i ]
493496
494497 for token , token_idx in zip (sentence .tokens , range (len (sentence .tokens ))):
495-
496498 word_embedding = torch .cat (
497499 [
498500 torch .FloatTensor (sentence_embeddings [0 , token_idx , :]),
0 commit comments