Skip to content

Commit fb76dac

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Switch to Python logging (+ lint) (#1627)
Summary: Pull Request resolved: #1627 Python logging offers a number of benefits, such as logging timestamps, better cross-library compatibility, ability to add multiple output handlers, etc. Pull Request resolved: fairinternal/fairseq-py#646 Reviewed By: spencerp Differential Revision: D15815620 Pulled By: myleott fbshipit-source-id: 5e64e9929b5e4b9dd5bb49bcdf7c510631907134
1 parent 1bb218f commit fb76dac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+533
-237
lines changed

eval_lm.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
Evaluate the perplexity of a trained language model.
99
"""
1010

11+
import logging
1112
import math
1213

13-
import numpy as np
1414
import torch
1515

1616
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
@@ -19,6 +19,14 @@
1919
from fairseq.sequence_scorer import SequenceScorer
2020

2121

22+
logging.basicConfig(
23+
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
24+
datefmt='%Y-%m-%d %H:%M:%S',
25+
level=logging.INFO,
26+
)
27+
logger = logging.getLogger('fairseq_cli.eval_lm')
28+
29+
2230
class WordStat(object):
2331
def __init__(self, word, is_bpe):
2432
self.word = word
@@ -50,14 +58,14 @@ def main(parsed_args):
5058

5159
utils.import_user_module(parsed_args)
5260

53-
print(parsed_args)
61+
logger.info(parsed_args)
5462

5563
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
5664

5765
task = tasks.setup_task(parsed_args)
5866

5967
# Load ensemble
60-
print('| loading model(s) from {}'.format(parsed_args.path))
68+
logger.info('loading model(s) from {}'.format(parsed_args.path))
6169
models, args = checkpoint_utils.load_model_ensemble(
6270
parsed_args.path.split(':'),
6371
arg_overrides=eval(parsed_args.model_overrides),
@@ -85,7 +93,7 @@ def main(parsed_args):
8593
context_window=args.context_window,
8694
pad_idx=task.source_dictionary.pad(),
8795
)
88-
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
96+
logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
8997

9098
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
9199
for model in models:
@@ -97,7 +105,7 @@ def main(parsed_args):
97105

98106
assert len(models) > 0
99107

100-
print('| num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
108+
logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
101109

102110
itr = task.get_batch_iterator(
103111
dataset=dataset,
@@ -123,11 +131,11 @@ def main(parsed_args):
123131
raise NotImplementedError
124132
else:
125133
bpe_cont = args.remove_bpe.rstrip()
126-
bpe_toks = set(
134+
bpe_toks = {
127135
i
128136
for i in range(len(task.source_dictionary))
129137
if task.source_dictionary[i].endswith(bpe_cont)
130-
)
138+
}
131139
bpe_len = len(bpe_cont)
132140
else:
133141
bpe_toks = None
@@ -171,8 +179,10 @@ def main(parsed_args):
171179

172180
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
173181
if inf_scores.any():
174-
print('| Skipping tokens with inf scores:',
175-
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
182+
logger.info(
183+
'skipping tokens with inf scores:',
184+
task.target_dictionary.string(tokens[inf_scores.nonzero()])
185+
)
176186
pos_scores = pos_scores[(~inf_scores).nonzero()]
177187
score_sum += pos_scores.sum().cpu()
178188
count += pos_scores.numel() - skipped_toks
@@ -202,7 +212,7 @@ def main(parsed_args):
202212
is_bpe = False
203213
w = ''
204214
if args.output_word_probs:
205-
print(
215+
logger.info(
206216
str(int(sample_id)) + " "
207217
+ ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
208218
)
@@ -211,12 +221,16 @@ def main(parsed_args):
211221
t.log({'wps': round(wps_meter.avg)})
212222

213223
avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
214-
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
215-
print('| Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, 2**avg_nll_loss))
224+
logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
225+
gen_timer.n, gen_timer.sum, 1. / gen_timer.avg
226+
))
227+
logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
228+
avg_nll_loss, 2**avg_nll_loss
229+
))
216230

217231
if args.output_word_stats:
218232
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
219-
print(ws)
233+
logger.info(ws)
220234

221235

222236
def cli_main():

fairseq/checkpoint_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from torch.serialization import default_restore_location
1818

1919

20+
logger = logging.getLogger(__name__)
21+
22+
2023
def save_checkpoint(args, trainer, epoch_itr, val_loss):
2124
from fairseq import distributed_utils, meters
2225

@@ -77,7 +80,7 @@ def is_better(a, b):
7780
PathManager.copy(checkpoints[0], cp, overwrite=True)
7881

7982
write_timer.stop()
80-
print(
83+
logger.info(
8184
"| saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
8285
checkpoints[0], epoch, updates, val_loss, write_timer.sum
8386
)
@@ -231,7 +234,7 @@ def torch_persistent_save(*args, **kwargs):
231234
return torch.save(*args, **kwargs)
232235
except Exception:
233236
if i == 2:
234-
logging.error(traceback.format_exc())
237+
logger.error(traceback.format_exc())
235238

236239

237240
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
@@ -388,8 +391,8 @@ def prune_state_dict(state_dict, args):
388391
return state_dict
389392

390393
# apply pruning
391-
print(
392-
"| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
394+
logger.info(
395+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
393396
)
394397

395398
def create_pruning_pass(layers_to_keep, layer_name):
@@ -485,7 +488,7 @@ def verify_checkpoint_directory(save_dir: str) -> None:
485488
with open(temp_file_path, "w"):
486489
pass
487490
except OSError as e:
488-
print("| Unable to access checkpoint save directory: {}".format(save_dir))
491+
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
489492
raise e
490493
else:
491494
os.remove(temp_file_path)

fairseq/data/data_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
from collections import Iterable
1010
import contextlib
1111
import itertools
12+
import logging
1213
import os
1314
import sys
1415
import types
1516

1617
import numpy as np
1718

1819

20+
logger = logging.getLogger(__name__)
21+
22+
1923
def infer_language_pair(path):
2024
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
2125
src, dst = None, None
@@ -78,7 +82,7 @@ def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, def
7882
)
7983
if dataset is None:
8084
break
81-
print('| loaded {} examples from: {}'.format(len(dataset), path_k))
85+
logger.info('loaded {} examples from: {}'.format(len(dataset), path_k))
8286
datasets.append(dataset)
8387
if not combine:
8488
break
@@ -187,8 +191,8 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
187191
'skip this example with --skip-invalid-size-inputs-valid-test'
188192
).format(ignored[0], dataset.size(ignored[0]), max_positions))
189193
if len(ignored) > 0:
190-
print((
191-
'| WARNING: {} samples have invalid sizes and will be skipped, '
194+
logger.warn((
195+
'{} samples have invalid sizes and will be skipped, '
192196
'max_positions={}, first few sample ids={}'
193197
).format(len(ignored), max_positions, ignored[:10]))
194198
return indices

fairseq/data/language_pair_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
7+
68
import numpy as np
79
import torch
810

911
from . import data_utils, FairseqDataset
1012

1113

14+
logger = logging.getLogger(__name__)
15+
16+
1217
def collate(
1318
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
1419
input_feeding=True,
@@ -26,7 +31,7 @@ def check_alignment(alignment, src_len, tgt_len):
2631
if alignment is None or len(alignment) == 0:
2732
return False
2833
if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
29-
print("| alignment size mismatch found, skipping alignment!")
34+
logger.warning("alignment size mismatch found, skipping alignment!")
3035
return False
3136
return True
3237

fairseq/data/subsample_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
7+
68
import numpy as np
79

810
from . import BaseWrapperDataset
911

1012

13+
logger = logging.getLogger(__name__)
14+
15+
1116
class SubsampleDataset(BaseWrapperDataset):
1217
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
1318
@@ -23,7 +28,7 @@ def __init__(self, dataset, size_ratio):
2328
self.indices = np.random.choice(
2429
list(range(len(self.dataset))), self.actual_size, replace=False
2530
)
26-
print(
31+
logger.info(
2732
"subsampled dataset from {} to {} (ratio={})".format(
2833
len(self.dataset), self.actual_size, size_ratio
2934
)

fairseq/distributed_utils.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
import os
78
import pickle
89
import socket
@@ -16,6 +17,9 @@
1617
from fairseq import utils
1718

1819

20+
logger = logging.getLogger(__name__)
21+
22+
1923
def is_master(args):
2024
return args.distributed_rank == 0
2125

@@ -76,42 +80,34 @@ def distributed_init(args):
7680
if torch.distributed.is_initialized():
7781
warnings.warn('Distributed is already initialized, cannot initialize twice!')
7882
else:
79-
print('| distributed init (rank {}): {}'.format(
80-
args.distributed_rank, args.distributed_init_method), flush=True)
83+
logger.info('distributed init (rank {}): {}'.format(
84+
args.distributed_rank, args.distributed_init_method,
85+
))
8186
dist.init_process_group(
8287
backend=args.distributed_backend,
8388
init_method=args.distributed_init_method,
8489
world_size=args.distributed_world_size,
8590
rank=args.distributed_rank,
8691
)
87-
print('| initialized host {} as rank {}'.format(
88-
socket.gethostname(), args.distributed_rank), flush=True)
92+
logger.info('initialized host {} as rank {}'.format(
93+
socket.gethostname(), args.distributed_rank,
94+
))
8995

9096
# perform a dummy all-reduce to initialize the NCCL communicator
9197
if torch.cuda.is_available():
9298
dist.all_reduce(torch.zeros(1).cuda())
9399
else:
94100
dist.all_reduce(torch.zeros(1))
95101

96-
suppress_output(is_master(args))
102+
if is_master(args):
103+
logging.getLogger().setLevel(logging.INFO)
104+
else:
105+
logging.getLogger().setLevel(logging.WARNING)
97106

98107
args.distributed_rank = torch.distributed.get_rank()
99108
return args.distributed_rank
100109

101110

102-
def suppress_output(is_master):
103-
"""Suppress printing on the current device. Force printing with `force=True`."""
104-
import builtins as __builtin__
105-
builtin_print = __builtin__.print
106-
107-
def print(*args, **kwargs):
108-
force = kwargs.pop('force', False)
109-
if is_master or force:
110-
builtin_print(*args, **kwargs)
111-
112-
__builtin__.print = print
113-
114-
115111
def get_rank():
116112
return dist.get_rank()
117113

fairseq/file_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def load_archive_file(archive_file):
5454
try:
5555
resolved_archive_file = cached_path(archive_file, cache_dir=None)
5656
except EnvironmentError:
57-
print(
57+
logger.info(
5858
"Archive name '{}' was not found in archive name list. "
5959
"We assumed '{}' was a path or URL but couldn't find any file "
6060
"associated to this path or URL.".format(
@@ -65,16 +65,16 @@ def load_archive_file(archive_file):
6565
return None
6666

6767
if resolved_archive_file == archive_file:
68-
print("loading archive file {}".format(archive_file))
68+
logger.info("loading archive file {}".format(archive_file))
6969
else:
70-
print("loading archive file {} from cache at {}".format(
70+
logger.info("loading archive file {} from cache at {}".format(
7171
archive_file, resolved_archive_file))
7272

7373
# Extract archive to temp dir and replace .tar.bz2 if necessary
7474
tempdir = None
7575
if not os.path.isdir(resolved_archive_file):
7676
tempdir = tempfile.mkdtemp()
77-
print("extracting archive file {} to temp dir {}".format(
77+
logger.info("extracting archive file {} to temp dir {}".format(
7878
resolved_archive_file, tempdir))
7979
ext = os.path.splitext(archive_file)[1][1:]
8080
with tarfile.open(resolved_archive_file, 'r:' + ext) as archive:

fairseq/hub_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import argparse
88
import copy
9+
import logging
910
import os
1011
from typing import List, Dict, Iterator, Tuple, Any
1112

@@ -16,6 +17,9 @@
1617
from fairseq.data import encoders
1718

1819

20+
logger = logging.getLogger(__name__)
21+
22+
1923
def from_pretrained(
2024
model_name_or_path,
2125
checkpoint_file='model.pt',
@@ -172,15 +176,15 @@ def getarg(name, default):
172176

173177
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
174178
src_str_with_unk = self.string(source_tokens)
175-
print('S\t{}'.format(src_str_with_unk))
179+
logger.info('S\t{}'.format(src_str_with_unk))
176180
for hypo in target_hypotheses:
177181
hypo_str = self.decode(hypo['tokens'])
178-
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
179-
print('P\t{}'.format(
182+
logger.info('H\t{}\t{}'.format(hypo['score'], hypo_str))
183+
logger.info('P\t{}'.format(
180184
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
181185
))
182186
if hypo['alignment'] is not None and getarg('print_alignment', False):
183-
print('A\t{}'.format(
187+
logger.info('A\t{}'.format(
184188
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
185189
))
186190
return outputs

0 commit comments

Comments
 (0)