Skip to content

Commit f840564

Browse files
nng555facebook-github-bot
authored andcommitted
initial light and dynamic convolution kernels (#547)
Summary: CUDA code for light/dynamicconv kernels, including pytorch modules. Modules can be built by running setup.py in each respective folder, and can then be imported and used like any other module. Pull Request resolved: fairinternal/fairseq-py#547 Reviewed By: myleott, shubho Differential Revision: D15703660 Pulled By: nng555 fbshipit-source-id: e9c913753be3a1cd571965f7200df6678b644520
1 parent b870468 commit f840564

23 files changed

+1958
-27
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ ENV/
111111

112112
# Generated files
113113
fairseq/temporal_convolution_tbc
114+
fairseq/modules/*_layer/*_forward.cu
115+
fairseq/modules/*_layer/*_backward.cu
114116

115117
# data
116118
data-bin/

examples/pay_less_attention_paper/README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)
2-
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://openreview.net/pdf?id=SkVhlh09tX)
2+
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430)
33

44
## Citation:
55
```bibtex
@@ -8,7 +8,7 @@ This page contains pointers to pre-trained models as well as instructions on how
88
author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli},
99
booktitle = {International Conference on Learning Representations},
1010
year = {2019},
11-
url = {https://openreview.net/forum?id=SkVhlh09tX},
11+
url = {https://arxiv.org/abs/1901.10430},
1212
}
1313
```
1414

@@ -39,6 +39,18 @@ To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`.
3939
For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv.
4040
For best BLEU results, lenpen may need to be manually tuned.
4141

42+
To use the CUDA kernels, first install the PyTorch modules using the commands below
43+
```sh
44+
# to install lightconv
45+
python fairseq/modules/lightconv_layer/cuda_function_gen.py
46+
python fairseq/modules/lightconv_layer/setup.py install
47+
48+
# to install dynamicconv
49+
python fairseq/modules/dynamicconv_layer/cuda_function_gen.py
50+
python fairseq/modules/dynamicconv_layer/setup.py install
51+
```
52+
Once the CUDA modules are installed, they will automatically be used instead of the PyTorch modules.
53+
4254
### IWSLT14 De-En
4355
Training and evaluating DynamicConv (without GLU) on a GPU:
4456
```sh

fairseq/models/lightconv.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7+
import sys
78

89
import torch
910
import torch.nn as nn
@@ -19,10 +20,10 @@
1920
)
2021
from fairseq.modules import (
2122
AdaptiveSoftmax,
22-
DynamicConv1dTBC,
23+
DynamicConv,
2324
LayerNorm,
2425
PositionalEmbedding,
25-
LightweightConv1dTBC,
26+
LightweightConv,
2627
MultiheadAttention,
2728
)
2829

@@ -173,7 +174,6 @@ def build_embedding(dictionary, embed_dim, path=None):
173174
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
174175
return LightConvModel(encoder, decoder)
175176

176-
177177
class LightConvEncoder(FairseqEncoder):
178178
"""
179179
LightConv encoder consisting of *args.encoder_layers* layers. Each layer
@@ -447,15 +447,15 @@ def __init__(self, args, kernel_size=0):
447447
self.linear1 = Linear(self.embed_dim, self.conv_dim)
448448
self.act = None
449449
if args.encoder_conv_type == 'lightweight':
450-
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
451-
weight_softmax=args.weight_softmax,
452-
num_heads=args.encoder_attention_heads,
453-
weight_dropout=args.weight_dropout)
450+
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l,
451+
weight_softmax=args.weight_softmax,
452+
num_heads=args.encoder_attention_heads,
453+
weight_dropout=args.weight_dropout)
454454
elif args.encoder_conv_type == 'dynamic':
455-
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
456-
weight_softmax=args.weight_softmax,
457-
num_heads=args.encoder_attention_heads,
458-
weight_dropout=args.weight_dropout)
455+
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l,
456+
weight_softmax=args.weight_softmax,
457+
num_heads=args.encoder_attention_heads,
458+
weight_dropout=args.weight_dropout)
459459
else:
460460
raise NotImplementedError
461461
self.linear2 = Linear(self.conv_dim, self.embed_dim)
@@ -535,15 +535,15 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0):
535535
self.linear1 = Linear(self.embed_dim, self.conv_dim)
536536
self.act = None
537537
if args.decoder_conv_type == 'lightweight':
538-
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
539-
weight_softmax=args.weight_softmax,
540-
num_heads=args.decoder_attention_heads,
541-
weight_dropout=args.weight_dropout)
538+
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
539+
weight_softmax=args.weight_softmax,
540+
num_heads=args.decoder_attention_heads,
541+
weight_dropout=args.weight_dropout)
542542
elif args.decoder_conv_type == 'dynamic':
543-
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
544-
weight_softmax=args.weight_softmax,
545-
num_heads=args.decoder_attention_heads,
546-
weight_dropout=args.weight_dropout)
543+
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
544+
weight_softmax=args.weight_softmax,
545+
num_heads=args.decoder_attention_heads,
546+
weight_dropout=args.weight_dropout)
547547
else:
548548
raise NotImplementedError
549549
self.linear2 = Linear(self.conv_dim, self.embed_dim)

fairseq/modules/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from .character_token_embedder import CharacterTokenEmbedder
1010
from .conv_tbc import ConvTBC
1111
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
12-
from .dynamic_convolution import DynamicConv1dTBC
12+
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
13+
#from .dynamicconv_layer import DynamicconvLayer
1314
from .gelu import gelu, gelu_accurate
1415
from .grad_multiply import GradMultiply
1516
from .highway import Highway
1617
from .layer_norm import LayerNorm
1718
from .learned_positional_embedding import LearnedPositionalEmbedding
18-
from .lightweight_convolution import LightweightConv1dTBC
19+
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
20+
#from .lightconv_layer import LightconvLayer
1921
from .linearized_convolution import LinearizedConvolution
2022
from .logsumexp_moe import LogSumExpMoE
2123
from .mean_pool_gating_network import MeanPoolGatingNetwork
@@ -36,14 +38,18 @@
3638
'CharacterTokenEmbedder',
3739
'ConvTBC',
3840
'DownsampledMultiHeadAttention',
41+
# 'DyamicconvLayer',
3942
'DynamicConv1dTBC',
43+
'DynamicConv',
4044
'gelu',
4145
'gelu_accurate',
4246
'GradMultiply',
4347
'Highway',
4448
'LayerNorm',
4549
'LearnedPositionalEmbedding',
50+
# 'LightconvLayer',
4651
'LightweightConv1dTBC',
52+
'LightweightConv',
4753
'LinearizedConvolution',
4854
'LogSumExpMoE',
4955
'MeanPoolGatingNetwork',

fairseq/modules/cuda_utils.cu

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
*/
6+
7+
8+
template <typename U, typename V>
9+
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
10+
return (a + b - 1) / b;
11+
}
12+
13+
14+
template<int FS, int SB, int padding_l, typename scalar_t>
15+
__inline__ __device__
16+
void zeroSharedMem(scalar_t* data) {
17+
/*
18+
Given an array of length FS + SB, zero out the first padding_l and last
19+
(FS - padding_l) values in the array
20+
*/
21+
22+
int tid = threadIdx.x;
23+
24+
if (FS < SB) {
25+
26+
// zero all if we have enough threads in a block to do all of them
27+
if (tid < padding_l || tid > SB - FS + padding_l - 1) {
28+
data[tid] = scalar_t(0.0);
29+
}
30+
} else {
31+
32+
// otherwise zero out one block at a time
33+
const int numIterations = divUp<int, int>(FS, SB);
34+
for (int i = 0; i < numIterations; i++) {
35+
int offset = i * SB;
36+
if (tid + offset < padding_l) {
37+
data[tid + offset] = scalar_t(0.0);
38+
} else if (tid + offset < FS) {
39+
data[SB + tid + offset] = scalar_t(0.0);
40+
}
41+
}
42+
}
43+
}
44+
45+
template<typename scalar_t>
46+
__inline__ __device__
47+
scalar_t warpReduce(scalar_t data) {
48+
/*
49+
Reduce an array within each warp. After processing all values in warp will
50+
caontain the sum of all original values in that warp.
51+
52+
data - pointer to data to reduce
53+
*/
54+
data += __shfl_xor_sync(SHFL_MASK, data, 16);
55+
data += __shfl_xor_sync(SHFL_MASK, data, 8);
56+
data += __shfl_xor_sync(SHFL_MASK, data, 4);
57+
data += __shfl_xor_sync(SHFL_MASK, data, 2);
58+
data += __shfl_xor_sync(SHFL_MASK, data, 1);
59+
return data;
60+
}
61+
62+
template<typename scalar_t>
63+
__inline__ __device__
64+
scalar_t blockReduce(scalar_t data) {
65+
/*
66+
Reduce an entire array on the block level. After processing, the
67+
first value in the array will contain the reduced sum.
68+
69+
data - pointer to data to reduce
70+
*/
71+
72+
static __shared__ scalar_t warpSum[32];
73+
const int tid = threadIdx.x;
74+
int wid = tid / 32;
75+
int lane = tid % 32;
76+
77+
__syncthreads();
78+
79+
// reduce each warp then write to shared memory
80+
scalar_t sum = warpReduce(data);
81+
if (lane == 0) {
82+
warpSum[wid] = sum;
83+
}
84+
85+
__syncthreads();
86+
87+
scalar_t v;
88+
// perform final sum of partial warp sums
89+
if (tid < blockDim.x / 32) {
90+
v = warpSum[lane];
91+
} else {
92+
v = scalar_t(0.0);
93+
}
94+
95+
if (wid == 0) {
96+
v = warpReduce(v);
97+
}
98+
__syncthreads();
99+
100+
return v;
101+
}
102+
103+
void checkCudaStatus(cudaError_t status, int lineNumber = -1) {
104+
105+
if (status != cudaSuccess) {
106+
std::cout << cudaGetErrorString(status)
107+
<< " at line " << lineNumber << std::endl;
108+
std::cout << "Exiting" << std::endl;
109+
exit(1);
110+
}
111+
}
112+
113+
template<int FS, int SB, int padding_l, typename scalar_t>
114+
__device__
115+
void load_input_to_shared(const scalar_t* input, // global memory
116+
int inputOffset, int sequenceLength,
117+
int iteration, int numIterations,
118+
bool no_prev, scalar_t* output /* shared memory */) {
119+
/*
120+
Load a block size of input into shared memory with
121+
right and left overhang of total size FS. If previously
122+
loaded memory, overlap will be shifted over to reduce
123+
global memory access
124+
125+
input - pointer to start of channel sequence
126+
inputOffset - how far in the sequence to start loading
127+
sequenceLength - total length of sequence
128+
iteration - which block of sequence we are loading
129+
numIterations - total number of blocks to load
130+
no_prev - whether to load the whole block if the previous block
131+
wasn't loaded
132+
output - shared memory to write input to
133+
*/
134+
135+
const int tid = threadIdx.x;
136+
137+
// Load the left "overhang" of input
138+
if (iteration > 0) {
139+
if (padding_l < SB) {
140+
141+
// load all at once
142+
if (tid < padding_l) {
143+
output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB];
144+
}
145+
} else {
146+
147+
// load in chunks of size SB
148+
int numIterations = divUp<int, int>(padding_l, SB);
149+
for (int i = 0; i < numIterations; i++) {
150+
int offset = i * SB;
151+
if ((tid + offset) < padding_l) {
152+
output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB];
153+
}
154+
}
155+
}
156+
}
157+
158+
// Load the right "overhang" of input
159+
if (iteration < (numIterations - 1)) {
160+
const int elementsLeft = sequenceLength - (iteration+1) * SB;
161+
162+
if ((FS - padding_l) < SB) {
163+
164+
// load all at once
165+
if (tid < (FS - padding_l)) {
166+
output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0);
167+
}
168+
} else {
169+
170+
// load in chunks of size SB
171+
int numIterations = divUp<int, int>(FS - padding_l, SB);
172+
for (int i = 0; i < numIterations; i++) {
173+
int offset = i * SB;
174+
if ((tid + offset) < (FS - padding_l)) {
175+
output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0);
176+
}
177+
}
178+
}
179+
}
180+
181+
// We should also clear out the right "overhang"
182+
if (iteration == (numIterations - 1)) {
183+
if ((FS - padding_l) < SB) {
184+
185+
// clear out all at once
186+
if (tid < (FS - padding_l)) {
187+
output[padding_l + SB + tid] = scalar_t(0.0);
188+
}
189+
} else {
190+
191+
// clear in chunks of size SB
192+
int numIterations = divUp<int, int>(FS - padding_l, SB);
193+
for (int i = 0; i < numIterations; i++) {
194+
int offset = i * SB;
195+
if ((tid + offset) < (FS - padding_l)) {
196+
output[padding_l + SB + tid + offset] = scalar_t(0.0);
197+
}
198+
}
199+
}
200+
}
201+
output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0);
202+
}

0 commit comments

Comments
 (0)