Skip to content

Commit e9b5c2a

Browse files
authored
Merge pull request #1494 from flairNLP/GH-1492-transformers
GH-1492: added new BERT embeddings implementation
2 parents 222d8a3 + 181ce16 commit e9b5c2a

File tree

7 files changed

+444
-67
lines changed

7 files changed

+444
-67
lines changed

flair/datasets.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ def __init__(
148148
test_file=None,
149149
dev_file=None,
150150
tokenizer: Callable[[str], List[Token]] = space_tokenizer,
151-
max_tokens_per_doc: int = -1,
152-
max_chars_per_doc: int = -1,
151+
truncate_to_max_tokens: int = -1,
152+
truncate_to_max_chars: int = -1,
153+
filter_if_longer_than: int = -1,
153154
in_memory: bool = False,
154155
encoding: str = 'utf-8',
155156
):
@@ -161,8 +162,8 @@ def __init__(
161162
:param test_file: the name of the test file
162163
:param dev_file: the name of the dev file, if None, dev data is sampled from train
163164
:param use_tokenizer: If True, tokenizes the dataset, otherwise uses whitespace tokenization
164-
:param max_tokens_per_doc: If set, truncates each Sentence to a maximum number of Tokens
165-
:param max_chars_per_doc: If set, truncates each Sentence to a maximum number of chars
165+
:param truncate_to_max_tokens: If set, truncates each Sentence to a maximum number of Tokens
166+
:param truncate_to_max_chars: If set, truncates each Sentence to a maximum number of chars
166167
:param in_memory: If True, keeps dataset as Sentences in memory, otherwise only keeps strings
167168
:return: a Corpus with annotated train, dev and test data
168169
"""
@@ -175,8 +176,9 @@ def __init__(
175176
train_file,
176177
label_type=label_type,
177178
tokenizer=tokenizer,
178-
max_tokens_per_doc=max_tokens_per_doc,
179-
max_chars_per_doc=max_chars_per_doc,
179+
truncate_to_max_tokens=truncate_to_max_tokens,
180+
truncate_to_max_chars=truncate_to_max_chars,
181+
filter_if_longer_than=filter_if_longer_than,
180182
in_memory=in_memory,
181183
encoding=encoding,
182184
)
@@ -186,8 +188,9 @@ def __init__(
186188
test_file,
187189
label_type=label_type,
188190
tokenizer=tokenizer,
189-
max_tokens_per_doc=max_tokens_per_doc,
190-
max_chars_per_doc=max_chars_per_doc,
191+
truncate_to_max_tokens=truncate_to_max_tokens,
192+
truncate_to_max_chars=truncate_to_max_chars,
193+
filter_if_longer_than=filter_if_longer_than,
191194
in_memory=in_memory,
192195
encoding=encoding,
193196
) if test_file is not None else None
@@ -197,8 +200,9 @@ def __init__(
197200
dev_file,
198201
label_type=label_type,
199202
tokenizer=tokenizer,
200-
max_tokens_per_doc=max_tokens_per_doc,
201-
max_chars_per_doc=max_chars_per_doc,
203+
truncate_to_max_tokens=truncate_to_max_tokens,
204+
truncate_to_max_chars=truncate_to_max_chars,
205+
filter_if_longer_than=filter_if_longer_than,
202206
in_memory=in_memory,
203207
encoding=encoding,
204208
) if dev_file is not None else None
@@ -930,8 +934,9 @@ def __init__(
930934
self,
931935
path_to_file: Union[str, Path],
932936
label_type: str = 'class',
933-
max_tokens_per_doc=-1,
934-
max_chars_per_doc=-1,
937+
truncate_to_max_tokens=-1,
938+
truncate_to_max_chars=-1,
939+
filter_if_longer_than: int = -1,
935940
tokenizer=segtok_tokenizer,
936941
in_memory: bool = True,
937942
encoding: str = 'utf-8',
@@ -943,9 +948,9 @@ def __init__(
943948
If you have a multi class task, you can have as many labels as you want at the beginning of the line, e.g.,
944949
__label__<class_name_1> __label__<class_name_2> <text>
945950
:param path_to_file: the path to the data file
946-
:param max_tokens_per_doc: Takes at most this amount of tokens per document. If set to -1 all documents are taken as is.
951+
:param truncate_to_max_tokens: Takes at most this amount of tokens per document. If set to -1 all documents are taken as is.
947952
:param max_tokens_per_doc: If set, truncates each Sentence to a maximum number of Tokens
948-
:param max_chars_per_doc: If set, truncates each Sentence to a maximum number of chars
953+
:param truncate_to_max_chars: If set, truncates each Sentence to a maximum number of chars
949954
:param in_memory: If True, keeps dataset as Sentences in memory, otherwise only keeps strings
950955
:return: list of sentences
951956
"""
@@ -966,8 +971,9 @@ def __init__(
966971
self.indices = []
967972

968973
self.total_sentence_count: int = 0
969-
self.max_chars_per_doc = max_chars_per_doc
970-
self.max_tokens_per_doc = max_tokens_per_doc
974+
self.truncate_to_max_chars = truncate_to_max_chars
975+
self.truncate_to_max_tokens = truncate_to_max_tokens
976+
self.filter_if_longer_than = filter_if_longer_than
971977

972978
self.path_to_file = path_to_file
973979

@@ -980,6 +986,11 @@ def __init__(
980986
line = f.readline()
981987
continue
982988

989+
if 0 < self.filter_if_longer_than < len(line.split(' ')):
990+
position = f.tell()
991+
line = f.readline()
992+
continue
993+
983994
if self.in_memory:
984995
sentence = self._parse_line_to_sentence(
985996
line, self.label_prefix, tokenizer
@@ -1012,8 +1023,8 @@ def _parse_line_to_sentence(
10121023

10131024
text = line[l_len:].strip()
10141025

1015-
if self.max_chars_per_doc > 0:
1016-
text = text[: self.max_chars_per_doc]
1026+
if self.truncate_to_max_chars > 0:
1027+
text = text[: self.truncate_to_max_chars]
10171028

10181029
if text and labels:
10191030
sentence = Sentence(text, use_tokenizer=tokenizer)
@@ -1023,9 +1034,9 @@ def _parse_line_to_sentence(
10231034

10241035
if (
10251036
sentence is not None
1026-
and 0 < self.max_tokens_per_doc < len(sentence)
1037+
and 0 < self.truncate_to_max_tokens < len(sentence)
10271038
):
1028-
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
1039+
sentence.tokens = sentence.tokens[: self.truncate_to_max_tokens]
10291040

10301041
return sentence
10311042
return None
@@ -1586,7 +1597,7 @@ def __init__(
15861597
class SENTEVAL_MR(ClassificationCorpus):
15871598
def __init__(
15881599
self,
1589-
in_memory: bool = True,
1600+
**corpusargs
15901601
):
15911602
# this dataset name
15921603
dataset_name = self.__class__.__name__.lower()
@@ -1619,7 +1630,7 @@ def __init__(
16191630
train_file.write(f"__label__NEGATIVE {line}")
16201631

16211632
super(SENTEVAL_MR, self).__init__(
1622-
data_folder, label_type='sentiment', tokenizer=segtok_tokenizer, in_memory=in_memory
1633+
data_folder, label_type='sentiment', tokenizer=segtok_tokenizer, **corpusargs
16231634
)
16241635

16251636

@@ -1703,13 +1714,13 @@ def __init__(
17031714
)
17041715

17051716

1706-
class SENTEVAL_SST_BINARY(CSVClassificationCorpus):
1717+
class SENTEVAL_SST_BINARY(ClassificationCorpus):
17071718
def __init__(
17081719
self,
1709-
in_memory: bool = True,
1720+
**corpusargs
17101721
):
17111722
# this dataset name
1712-
dataset_name = self.__class__.__name__.lower()
1723+
dataset_name = self.__class__.__name__.lower() + '_v2'
17131724

17141725
# default dataset folder is the cache root
17151726
data_folder = Path(flair.cache_root) / "datasets" / dataset_name
@@ -1718,17 +1729,21 @@ def __init__(
17181729
if not (data_folder / "train.txt").is_file():
17191730

17201731
# download senteval datasets if necessary und unzip
1721-
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-train', Path("datasets") / dataset_name)
1722-
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-test', Path("datasets") / dataset_name)
1723-
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-dev', Path("datasets") / dataset_name)
1732+
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-train', Path("datasets") / dataset_name / 'raw')
1733+
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-test', Path("datasets") / dataset_name / 'raw')
1734+
cached_path('https://raw.githubusercontent.com/PrincetonML/SIF/master/data/sentiment-dev', Path("datasets") / dataset_name / 'raw')
1735+
1736+
# create train.txt file by iterating over pos and neg file
1737+
with open(data_folder / "train.txt", "a") as out_file, open(data_folder / 'raw' / "sentiment-train") as in_file:
1738+
for line in in_file:
1739+
fields = line.split('\t')
1740+
label = 'POSITIVE' if fields[1].rstrip() == '1' else 'NEGATIVE'
1741+
out_file.write(f"__label__{label} {fields[0]}\n")
17241742

17251743
super(SENTEVAL_SST_BINARY, self).__init__(
17261744
data_folder,
1727-
column_name_map={0: 'text', 1: 'label'},
17281745
tokenizer=segtok_tokenizer,
1729-
in_memory=in_memory,
1730-
delimiter='\t',
1731-
quotechar=None,
1746+
**corpusargs,
17321747
)
17331748

17341749

@@ -1813,12 +1828,15 @@ def __init__(
18131828

18141829

18151830
class IMDB(ClassificationCorpus):
1816-
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = False):
1831+
def __init__(self, base_path: Union[str, Path] = None, rebalance_corpus: bool = True, **corpusargs):
18171832
if type(base_path) == str:
18181833
base_path: Path = Path(base_path)
18191834

18201835
# this dataset name
1821-
dataset_name = self.__class__.__name__.lower()
1836+
dataset_name = self.__class__.__name__.lower() + '_v2'
1837+
1838+
if rebalance_corpus:
1839+
dataset_name = dataset_name + '-rebalanced'
18221840

18231841
# default dataset folder is the cache root
18241842
if not base_path:
@@ -1853,20 +1871,22 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = False):
18531871
if f"{dataset}/{label}" in m.name
18541872
],
18551873
)
1856-
with open(f"{data_path}/{dataset}.txt", "at") as f_p:
1874+
with open(f"{data_path}/train-all.txt", "at") as f_p:
18571875
current_path = data_path / "aclImdb" / dataset / label
18581876
for file_name in current_path.iterdir():
18591877
if file_name.is_file() and file_name.name.endswith(
18601878
".txt"
18611879
):
1880+
if label == "pos": sentiment_label = 'POSITIVE'
1881+
if label == "neg": sentiment_label = 'NEGATIVE'
18621882
f_p.write(
1863-
f"__label__{label} "
1883+
f"__label__{sentiment_label} "
18641884
+ file_name.open("rt", encoding="utf-8").read()
18651885
+ "\n"
18661886
)
18671887

18681888
super(IMDB, self).__init__(
1869-
data_folder, tokenizer=space_tokenizer, in_memory=in_memory
1889+
data_folder, tokenizer=space_tokenizer, **corpusargs
18701890
)
18711891

18721892

0 commit comments

Comments
 (0)