|
7 | 7 | Translate pre-processed data with a trained model. |
8 | 8 | """ |
9 | 9 |
|
| 10 | +import ast |
| 11 | +from itertools import chain |
10 | 12 | import logging |
11 | 13 | import math |
12 | 14 | import os |
@@ -78,17 +80,39 @@ def _main(args, output_file): |
78 | 80 | src_dict = None |
79 | 81 | tgt_dict = task.target_dictionary |
80 | 82 |
|
| 83 | + overrides = ast.literal_eval(args.model_overrides) |
| 84 | + |
81 | 85 | # Load ensemble |
82 | 86 | logger.info('loading model(s) from {}'.format(args.path)) |
83 | 87 | models, _model_args = checkpoint_utils.load_model_ensemble( |
84 | 88 | utils.split_paths(args.path), |
85 | | - arg_overrides=eval(args.model_overrides), |
| 89 | + arg_overrides=overrides, |
86 | 90 | task=task, |
87 | 91 | suffix=getattr(args, "checkpoint_suffix", ""), |
88 | 92 | ) |
89 | 93 |
|
| 94 | + if args.lm_path is not None: |
| 95 | + overrides['data'] = args.data |
| 96 | + |
| 97 | + try: |
| 98 | + lms, _ = checkpoint_utils.load_model_ensemble( |
| 99 | + [args.lm_path], |
| 100 | + arg_overrides=overrides, |
| 101 | + task=None, |
| 102 | + ) |
| 103 | + except: |
| 104 | + logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " |
| 105 | + f"as target dict and is located in the data dir ({args.data})") |
| 106 | + raise |
| 107 | + |
| 108 | + assert len(lms) == 1 |
| 109 | + else: |
| 110 | + lms = [None] |
| 111 | + |
90 | 112 | # Optimize ensemble for generation |
91 | | - for model in models: |
| 113 | + for model in chain(models, lms): |
| 114 | + if model is None: |
| 115 | + continue |
92 | 116 | model.prepare_for_inference_(args) |
93 | 117 | if args.fp16: |
94 | 118 | model.half() |
@@ -124,7 +148,12 @@ def _main(args, output_file): |
124 | 148 |
|
125 | 149 | # Initialize generator |
126 | 150 | gen_timer = StopwatchMeter() |
127 | | - generator = task.build_generator(models, args) |
| 151 | + |
| 152 | + extra_gen_cls_kwargs = { |
| 153 | + 'lm_model': lms[0], |
| 154 | + 'lm_weight': args.lm_weight |
| 155 | + } |
| 156 | + generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) |
128 | 157 |
|
129 | 158 | # Handle tokenization and BPE |
130 | 159 | tokenizer = encoders.build_tokenizer(args) |
@@ -269,9 +298,11 @@ def decode_fn(x): |
269 | 298 | if has_target: |
270 | 299 | if args.bpe and not args.sacrebleu: |
271 | 300 | if args.remove_bpe: |
272 | | - logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") |
| 301 | + logger.warning( |
| 302 | + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") |
273 | 303 | else: |
274 | | - logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") |
| 304 | + logger.warning( |
| 305 | + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") |
275 | 306 | # use print to be consistent with other main outputs: S-, H-, T-, D- and so on |
276 | 307 | print( |
277 | 308 | 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), |
|
0 commit comments