Skip to content

Commit 8845dcf

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Move MoE files into examples (#1040)
Summary: Pull Request resolved: fairinternal/fairseq-py#1040 Differential Revision: D20030279 Pulled By: myleott fbshipit-source-id: 76b48a62409020039225cf98e8fcf7a494d0b7f8
1 parent e1de989 commit 8845dcf

File tree

11 files changed

+43
-23
lines changed

11 files changed

+43
-23
lines changed

examples/speech_recognition/criterions/ASG_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from fairseq.criterions import FairseqCriterion, register_criterion
1414
from examples.speech_recognition.data.replabels import pack_replabels
1515

16-
from wav2letter.criterion import ASGLoss, CriterionScaleMode
17-
1816

1917
@register_criterion("asg_loss")
2018
class ASGCriterion(FairseqCriterion):
@@ -43,6 +41,8 @@ def add_args(parser):
4341
)
4442

4543
def __init__(self, args, task):
44+
from wav2letter.criterion import ASGLoss, CriterionScaleMode
45+
4646
super().__init__(args, task)
4747
self.tgt_dict = task.target_dictionary
4848
self.eos = self.tgt_dict.eos()

examples/speech_recognition/datasets/asr_prep_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import json
1515
import sentencepiece as spm
1616
import multiprocessing
17-
import torchaudio
1817

1918
from fairseq.data import Dictionary
2019

2120
MILLISECONDS_TO_SECONDS = 0.001
2221

2322

2423
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
24+
import torchaudio
2525
input = {}
2626
output = {}
2727
si, ei = torchaudio.info(aud_path)

examples/speech_recognition/w2l_decoder.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,22 @@
1313
import torch
1414
from fairseq import utils
1515
from examples.speech_recognition.data.replabels import unpack_replabels
16-
from wav2letter.common import create_word_dict, load_words
17-
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
18-
from wav2letter.decoder import (
19-
CriterionType,
20-
DecoderOptions,
21-
KenLM,
22-
SmearingMode,
23-
Trie,
24-
WordLMDecoder,
25-
)
16+
17+
try:
18+
from wav2letter.common import create_word_dict, load_words
19+
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
20+
from wav2letter.decoder import (
21+
CriterionType,
22+
DecoderOptions,
23+
KenLM,
24+
SmearingMode,
25+
Trie,
26+
WordLMDecoder,
27+
)
28+
except ImportError:
29+
# wav2letter is a required dependency for the speech_recognition
30+
# example, but don't break on import
31+
pass
2632

2733

2834
class W2lDecoder(object):

examples/translation_moe/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The following command will train a `hMoElp` model with `3` experts:
1818
fairseq-train --ddp-backend='no_c10d' \
1919
data-bin/wmt17_en_de \
2020
--max-update 100000 \
21-
--task translation_moe \
21+
--task translation_moe --user-dir examples/translation_moe/src \
2222
--method hMoElp --mean-pool-gating-network \
2323
--num-experts 3 \
2424
--arch transformer_wmt_en_de --share-all-embeddings \
@@ -37,7 +37,7 @@ For example, to generate from expert 0:
3737
fairseq-generate data-bin/wmt17_en_de \
3838
--path checkpoints/checkpoint_best.pt \
3939
--beam 1 --remove-bpe \
40-
--task translation_moe \
40+
--task translation_moe --user-dir examples/translation_moe/src \
4141
--method hMoElp --mean-pool-gating-network \
4242
--num-experts 3 \
4343
--gen-expert 0
@@ -61,7 +61,7 @@ for EXPERT in $(seq 0 2); do \
6161
--beam 1 \
6262
--bpe subword_nmt --bpe-codes $BPE_CODE \
6363
--buffer-size 500 --max-tokens 6000 \
64-
--task translation_moe \
64+
--task translation_moe --user-dir examples/translation_moe/src \
6565
--method hMoElp --mean-pool-gating-network \
6666
--num-experts 3 \
6767
--gen-expert $EXPERT ; \
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from . import translation_moe # noqa
File renamed without changes.

fairseq/tasks/translation_moe.py renamed to examples/translation_moe/src/translation_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
import torch
77

8-
from fairseq import metrics, modules, utils
8+
from fairseq import metrics, utils
99
from fairseq.tasks import register_task
1010
from fairseq.tasks.translation import TranslationTask
1111

12+
from .logsumexp_moe import LogSumExpMoE
13+
from .mean_pool_gating_network import MeanPoolGatingNetwork
14+
1215

1316
@register_task('translation_moe')
1417
class TranslationMoETask(TranslationTask):
@@ -100,7 +103,7 @@ def build_model(self, args):
100103
else:
101104
raise ValueError('Must specify --mean-pool-gating-network-dropout')
102105

103-
model.gating_network = modules.MeanPoolGatingNetwork(
106+
model.gating_network = MeanPoolGatingNetwork(
104107
encoder_dim, args.num_experts, dropout,
105108
)
106109
else:
@@ -171,7 +174,7 @@ def get_lprob_yz(winners=None):
171174
loss = -get_lprob_yz(winners)
172175
else:
173176
lprob_yz = get_lprob_yz() # B x K
174-
loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
177+
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
175178

176179
loss = loss.sum()
177180
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']

fairseq/modules/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from .learned_positional_embedding import LearnedPositionalEmbedding
1818
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
1919
from .linearized_convolution import LinearizedConvolution
20-
from .logsumexp_moe import LogSumExpMoE
21-
from .mean_pool_gating_network import MeanPoolGatingNetwork
2220
from .multihead_attention import MultiheadAttention
2321
from .positional_embedding import PositionalEmbedding
2422
from .scalar_bias import ScalarBias
@@ -47,8 +45,6 @@
4745
'LightweightConv1dTBC',
4846
'LightweightConv',
4947
'LinearizedConvolution',
50-
'LogSumExpMoE',
51-
'MeanPoolGatingNetwork',
5248
'MultiheadAttention',
5349
'PositionalEmbedding',
5450
'ScalarBias',

fairseq/options.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def parse_args_and_arch(
113113

114114
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
115115

116+
# Before creating the true parser, we need to import optional user module
117+
# in order to eagerly import custom tasks, optimizers, architectures, etc.
118+
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
119+
usr_parser.add_argument("--user-dir", default=None)
120+
usr_args, _ = usr_parser.parse_known_args(input_args)
121+
utils.import_user_module(usr_args)
122+
116123
if modify_parser is not None:
117124
modify_parser(parser)
118125

0 commit comments

Comments
 (0)