Skip to content

Commit 2728f9b

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Add huggingface submodule and GPT2 model (#1019)
Summary: Pull Request resolved: fairinternal/fairseq-py#1019 Differential Revision: D20044785 Pulled By: myleott fbshipit-source-id: 022a49f696c0093d577422af5598f6f326022569
1 parent ed4aa2c commit 2728f9b

File tree

4 files changed

+188
-0
lines changed

4 files changed

+188
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "fairseq/models/huggingface/transformers"]
2+
path = fairseq/models/huggingface/transformers
3+
url = https://github.com/huggingface/transformers.git
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .hf_gpt2 import * # noqa
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import os
8+
import sys
9+
from typing import Dict, List, Optional
10+
11+
import torch
12+
from fairseq.models import (
13+
FairseqIncrementalDecoder,
14+
FairseqLanguageModel,
15+
register_model,
16+
register_model_architecture,
17+
)
18+
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
DEFAULT_MAX_TARGET_POSITIONS = 1024
24+
25+
26+
@register_model('hf_gpt2')
27+
class HuggingFaceGPT2LanguageModel(FairseqLanguageModel):
28+
29+
def __init__(self, decoder):
30+
super().__init__(decoder)
31+
32+
@staticmethod
33+
def add_args(parser):
34+
"""Add model-specific arguments to the parser."""
35+
# fmt: off
36+
parser.add_argument('--embed-dim', type=int, metavar='N',
37+
help='embedding dimension')
38+
parser.add_argument('--num-attention-heads', type=int, metavar='N',
39+
help='num attention heads')
40+
parser.add_argument('--num-layers', type=int, metavar='N',
41+
help='num layers')
42+
parser.add_argument('--dropout', type=float, metavar='D',
43+
help='dropout probability for all fully connected layers '
44+
'in the embeddings, encoder, and pooler')
45+
parser.add_argument('--attention-dropout', type=float, metavar='D',
46+
help='dropout probability for attention weights')
47+
# fmt: on
48+
49+
@classmethod
50+
def build_model(cls, args, task):
51+
"""Build a new model instance."""
52+
default_architecture(args)
53+
return cls(HuggingFaceGPT2Decoder(args, task))
54+
55+
56+
class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
57+
58+
def __init__(self, args, task):
59+
super().__init__(task.target_dictionary)
60+
61+
try:
62+
# Prepend the transformers submodule to the path, so that
63+
# it's prioritized over other installations. This allows
64+
# making local changes in the submodule.
65+
sys.path.insert(
66+
0, os.path.join(os.path.dirname(__file__), 'transformers', 'src')
67+
)
68+
from transformers import GPT2Config, GPT2LMHeadModel
69+
except ImportError:
70+
raise ImportError(
71+
'\n\nPlease install huggingface/transformers with:'
72+
'\n\n pip install transformers'
73+
'\n\nOr to make local edits, install the submodule:'
74+
'\n\n git submodule update --init '
75+
'fairseq/models/huggingface/transformers'
76+
)
77+
78+
config = GPT2Config(
79+
vocab_size=len(task.target_dictionary),
80+
n_positions=args.max_target_positions,
81+
n_ctx=args.max_target_positions,
82+
n_embd=args.embed_dim,
83+
n_layer=args.num_layers,
84+
n_head=args.num_attention_heads,
85+
resid_pdrop=args.dropout,
86+
embd_pdrop=args.dropout,
87+
attn_pdrop=args.attention_dropout,
88+
layer_norm_epsilon=1e-6,
89+
)
90+
self.model = GPT2LMHeadModel(config)
91+
92+
# set zero embedding for padding symbol
93+
self.pad_idx = task.target_dictionary.pad()
94+
self.model.transformer.wte.weight.data[self.pad_idx].zero_()
95+
self.model.transformer.wpe.weight.data[0].zero_()
96+
97+
def forward(
98+
self,
99+
prev_output_tokens,
100+
src_lengths=None,
101+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
102+
):
103+
features = self.extract_features(prev_output_tokens, incremental_state)
104+
lm_logits = self.model.lm_head(features)
105+
return (lm_logits, )
106+
107+
def extract_features(
108+
self,
109+
prev_output_tokens,
110+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
111+
):
112+
if incremental_state is not None:
113+
past = self.get_incremental_state("past")
114+
else:
115+
past = None
116+
117+
# don't attend to padding symbols
118+
attention_mask = prev_output_tokens.ne(self.pad_idx).int()
119+
120+
# set position ids to exclude padding symbols
121+
position_ids = attention_mask * (
122+
torch.arange(1, 1 + prev_output_tokens.size(1))
123+
.to(prev_output_tokens)
124+
.repeat(prev_output_tokens.size(0), 1)
125+
)
126+
127+
outputs = self.model.transformer(
128+
input_ids=prev_output_tokens,
129+
past=past,
130+
attention_mask=attention_mask,
131+
position_ids=position_ids,
132+
)
133+
last_hidden_states = outputs[0]
134+
135+
if incremental_state is not None:
136+
self.set_incremental_state(incremental_state, "past", outputs[1])
137+
138+
return last_hidden_states
139+
140+
def max_positions(self):
141+
return self.model.config.n_positions
142+
143+
144+
@register_model_architecture('hf_gpt2', 'hf_gpt2')
145+
def default_architecture(args):
146+
if getattr(args, 'max_target_positions', None) is None:
147+
args.max_target_positions = getattr(
148+
args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS
149+
)
150+
args.embed_dim = getattr(args, 'embed_dim', 768)
151+
args.num_attention_heads = getattr(args, 'num_attention_heads', 12)
152+
args.num_layers = getattr(args, 'num_layers', 12)
153+
args.dropout = getattr(args, 'dropout', 0.1)
154+
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
155+
156+
157+
@register_model_architecture('hf_gpt2', 'hf_gpt2_medium')
158+
def hf_gpt2_medium(args):
159+
args.embed_dim = getattr(args, 'embed_dim', 1024)
160+
args.num_attention_heads = getattr(args, 'num_attention_heads', 16)
161+
args.num_layers = getattr(args, 'num_layers', 24)
162+
default_architecture(args)
163+
164+
165+
@register_model_architecture('hf_gpt2', 'hf_gpt2_large')
166+
def hf_gpt2_large(args):
167+
args.embed_dim = getattr(args, 'embed_dim', 1280)
168+
args.num_attention_heads = getattr(args, 'num_attention_heads', 20)
169+
args.num_layers = getattr(args, 'num_layers', 36)
170+
default_architecture(args)
171+
172+
173+
@register_model_architecture('hf_gpt2', 'hf_gpt2_xl')
174+
def hf_gpt2_xl(args):
175+
args.embed_dim = getattr(args, 'embed_dim', 1600)
176+
args.num_attention_heads = getattr(args, 'num_attention_heads', 25)
177+
args.num_layers = getattr(args, 'num_layers', 48)
178+
default_architecture(args)
Submodule transformers added at d426b58

0 commit comments

Comments
 (0)