88Evaluate the perplexity of a trained language model.
99"""
1010
11+ import logging
1112import math
1213
13- import numpy as np
1414import torch
1515
1616from fairseq import checkpoint_utils , options , progress_bar , tasks , utils
1919from fairseq .sequence_scorer import SequenceScorer
2020
2121
22+ logging .basicConfig (
23+ format = '%(asctime)s | %(levelname)s | %(name)s | %(message)s' ,
24+ datefmt = '%Y-%m-%d %H:%M:%S' ,
25+ level = logging .INFO ,
26+ )
27+ logger = logging .getLogger ('fairseq_cli.eval_lm' )
28+
29+
2230class WordStat (object ):
2331 def __init__ (self , word , is_bpe ):
2432 self .word = word
@@ -50,14 +58,14 @@ def main(parsed_args):
5058
5159 utils .import_user_module (parsed_args )
5260
53- print (parsed_args )
61+ logger . info (parsed_args )
5462
5563 use_cuda = torch .cuda .is_available () and not parsed_args .cpu
5664
5765 task = tasks .setup_task (parsed_args )
5866
5967 # Load ensemble
60- print ( '| loading model(s) from {}' .format (parsed_args .path ))
68+ logger . info ( ' loading model(s) from {}' .format (parsed_args .path ))
6169 models , args = checkpoint_utils .load_model_ensemble (
6270 parsed_args .path .split (':' ),
6371 arg_overrides = eval (parsed_args .model_overrides ),
@@ -85,7 +93,7 @@ def main(parsed_args):
8593 context_window = args .context_window ,
8694 pad_idx = task .source_dictionary .pad (),
8795 )
88- print ( '| {} {} {} examples' .format (args .data , args .gen_subset , len (dataset )))
96+ logger . info ( ' {} {} {} examples' .format (args .data , args .gen_subset , len (dataset )))
8997
9098 # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
9199 for model in models :
@@ -97,7 +105,7 @@ def main(parsed_args):
97105
98106 assert len (models ) > 0
99107
100- print ( '| num. model params: {}' .format (sum (p .numel () for p in models [0 ].parameters ())))
108+ logger . info ( ' num. model params: {}' .format (sum (p .numel () for p in models [0 ].parameters ())))
101109
102110 itr = task .get_batch_iterator (
103111 dataset = dataset ,
@@ -123,11 +131,11 @@ def main(parsed_args):
123131 raise NotImplementedError
124132 else :
125133 bpe_cont = args .remove_bpe .rstrip ()
126- bpe_toks = set (
134+ bpe_toks = {
127135 i
128136 for i in range (len (task .source_dictionary ))
129137 if task .source_dictionary [i ].endswith (bpe_cont )
130- )
138+ }
131139 bpe_len = len (bpe_cont )
132140 else :
133141 bpe_toks = None
@@ -171,8 +179,10 @@ def main(parsed_args):
171179
172180 inf_scores = pos_scores .eq (float ('inf' )) | pos_scores .eq (float ('-inf' ))
173181 if inf_scores .any ():
174- print ('| Skipping tokens with inf scores:' ,
175- task .target_dictionary .string (tokens [inf_scores .nonzero ()]))
182+ logger .info (
183+ 'skipping tokens with inf scores:' ,
184+ task .target_dictionary .string (tokens [inf_scores .nonzero ()])
185+ )
176186 pos_scores = pos_scores [(~ inf_scores ).nonzero ()]
177187 score_sum += pos_scores .sum ().cpu ()
178188 count += pos_scores .numel () - skipped_toks
@@ -202,7 +212,7 @@ def main(parsed_args):
202212 is_bpe = False
203213 w = ''
204214 if args .output_word_probs :
205- print (
215+ logger . info (
206216 str (int (sample_id )) + " "
207217 + ('\t ' .join ('{} [{:2f}]' .format (x [0 ], x [1 ]) for x in word_prob ))
208218 )
@@ -211,12 +221,16 @@ def main(parsed_args):
211221 t .log ({'wps' : round (wps_meter .avg )})
212222
213223 avg_nll_loss = - score_sum / count / math .log (2 ) # convert to base 2
214- print ('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)' .format (gen_timer .n , gen_timer .sum , 1. / gen_timer .avg ))
215- print ('| Loss (base 2): {:.4f}, Perplexity: {:.2f}' .format (avg_nll_loss , 2 ** avg_nll_loss ))
224+ logger .info ('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)' .format (
225+ gen_timer .n , gen_timer .sum , 1. / gen_timer .avg
226+ ))
227+ logger .info ('Loss (base 2): {:.4f}, Perplexity: {:.2f}' .format (
228+ avg_nll_loss , 2 ** avg_nll_loss
229+ ))
216230
217231 if args .output_word_stats :
218232 for ws in sorted (word_stats .values (), key = lambda x : x .count , reverse = True ):
219- print (ws )
233+ logger . info (ws )
220234
221235
222236def cli_main ():
0 commit comments