Skip to content

Commit dea50fa

Browse files
authored
Merge pull request #661 from GiliGoldin/master
add option to support customized elmo embeddings
2 parents f01683d + 524ac08 commit dea50fa

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

flair/embeddings.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,9 @@ def __str__(self):
425425
class ELMoEmbeddings(TokenEmbeddings):
426426
"""Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018."""
427427

428-
def __init__(self, model: str = "original"):
428+
def __init__(
429+
self, model: str = "original", options_file: str = None, weight_file: str = None
430+
):
429431
super().__init__()
430432

431433
try:
@@ -442,22 +444,23 @@ def __init__(self, model: str = "original"):
442444
self.name = "elmo-" + model
443445
self.static_embeddings = True
444446

445-
# the default model for ELMo is the 'original' model, which is very large
446-
options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
447-
weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
448-
# alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
449-
if model == "small":
450-
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
451-
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
452-
if model == "medium":
453-
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
454-
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
455-
if model == "pt" or model == "portuguese":
456-
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
457-
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
458-
if model == "pubmed":
459-
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
460-
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
447+
if not options_file or not weight_file:
448+
# the default model for ELMo is the 'original' model, which is very large
449+
options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
450+
weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
451+
# alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
452+
if model == "small":
453+
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
454+
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
455+
if model == "medium":
456+
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
457+
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
458+
if model == "pt" or model == "portuguese":
459+
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
460+
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
461+
if model == "pubmed":
462+
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
463+
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
461464

462465
# put on Cuda if available
463466
from flair import device
@@ -492,7 +495,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
492495
sentence_embeddings = embeddings[i]
493496

494497
for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
495-
496498
word_embedding = torch.cat(
497499
[
498500
torch.FloatTensor(sentence_embeddings[0, token_idx, :]),

0 commit comments

Comments
 (0)