Skip to content

Conversation

@alanakbik
Copy link
Collaborator

@alanakbik alanakbik commented Dec 16, 2020

This PR adds FLERT as presented in our recent paper (closes #2015).

A number of changes are made:

  1. Sentence objects now have next_sentence() and previous_sentence() methods that are set automatically if loaded through ColumnCorpus. This is a pointer system to navigate through sentences in a corpus:
# load corpus
corpus = MIT_MOVIE_NER_SIMPLE(in_memory=False)

# get a sentence
sentence = corpus.test[123]
print(sentence)
# get the previous sentence
print(sentence.previous_sentence())
# get the sentence after that
print(sentence.next_sentence())
# get the sentence after the next sentence
print(sentence.next_sentence().next_sentence())

This allows dynamic computation of contexts in the embedding classes.

  1. Sentence objects now have the is_document_boundary field which is set through the ColumnCorpus. In some datasets, there are sentences like "-DOCSTART-" that just indicate document boundaries. This is now recorded as a boolean in the object.

  2. TransformerWordEmbeddings refactored for dynamic context, robustness to long sentences and readability. The names of some constructor arguments have changed for clarity: pooling_operation is now subtoken_pooling (to make clear that we pool subtokens), use_scalar_mean is now layer_mean (we only do a simple layer mean) and use_context can now optionally take an integer to indicate the length of the context. Default arguments are also changed.

For instance, to create embeddings with a document-level context of 64 subtokens, init like this:

embeddings = TransformerWordEmbeddings(
    model='bert-base-uncased',
    layers="-1",
    subtoken_pooling="first",
    fine_tune=True,
    use_context=64,
)

From my testing, it also seems that the new implementation is a bit faster.

  1. You can train FLERT like this:
import torch

from flair.data import Sentence
from flair.datasets import CONLL_03, WNUT_17
from flair.embeddings import TransformerWordEmbeddings, DocumentPoolEmbeddings, WordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer


corpus = CONLL_03()

use_context = 64
hf_model = 'xlm-roberta-large'

embeddings = TransformerWordEmbeddings(
    model=hf_model,
    layers="-1",
    subtoken_pooling="first",
    fine_tune=True,
    use_context=use_context,
)

tag_dictionary = corpus.make_tag_dictionary('ner')

# init bare-bones tagger (no reprojection, LSTM or CRF)
tagger: SequenceTagger = SequenceTagger(
    hidden_size=256,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_type='ner',
    use_crf=False,
    use_rnn=False,
    reproject_embeddings=False,
)

# train with XLM parameters (AdamW, 20 epochs, small LR)
trainer = ModelTrainer(tagger, corpus, optimizer=torch.optim.AdamW)
from torch.optim.lr_scheduler import OneCycleLR

context_string = '+context' if use_context else ''

trainer.train(f"resources/flert",
              learning_rate=5.0e-6,
              mini_batch_size=4,
              mini_batch_chunk_size=1,
              max_epochs=20,
              scheduler=OneCycleLR,
              embeddings_storage_mode='none',
              weight_decay=0.,
              )

@alanakbik alanakbik merged commit e66baf4 into master Dec 16, 2020
@alanakbik alanakbik deleted the GH-2015-flert branch December 16, 2020 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add FLERT approach to Flair

3 participants