-
Notifications
You must be signed in to change notification settings - Fork 516
/
multi_head_attention.py
317 lines (281 loc) · 11.1 KB
/
multi_head_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Optional
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from corenet.modeling.layers.base_layer import BaseLayer
from corenet.modeling.layers.dropout import Dropout
from corenet.modeling.layers.linear_layer import LinearLayer
from corenet.utils import logger
class MultiHeadAttention(BaseLayer):
"""
This layer applies a multi-head self- or cross-attention as described in
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
num_heads (int): Number of heads in multi-head attention
attn_dropout (Optional[float]): Attention dropout. Default: 0.0
bias (Optional[bool]): Use bias or not. Default: ``True``
Shape:
- Input:
- Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
and :math:`C_{in}` is input embedding dim
- Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
- Output: same shape as the input
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
attn_dropout: Optional[float] = 0.0,
bias: Optional[bool] = True,
output_dim: Optional[int] = None,
coreml_compatible: Optional[bool] = False,
*args,
**kwargs
) -> None:
if output_dim is None:
output_dim = embed_dim
super().__init__()
if embed_dim % num_heads != 0:
logger.error(
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
self.__class__.__name__, embed_dim, num_heads
)
)
self.qkv_proj = LinearLayer(
in_features=embed_dim, out_features=3 * embed_dim, bias=bias
)
self.attn_dropout = Dropout(p=attn_dropout)
self.out_proj = LinearLayer(
in_features=embed_dim, out_features=output_dim, bias=bias
)
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5
self.softmax = nn.Softmax(dim=-1)
self.num_heads = num_heads
self.embed_dim = embed_dim
self.coreml_compatible = coreml_compatible
self.use_separate_proj_weight = embed_dim != output_dim
def __repr__(self):
return "{}(head_dim={}, num_heads={}, attn_dropout={})".format(
self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p
)
def forward_tracing(
self,
x_q: Tensor,
x_kv: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tensor:
if x_kv is None:
# [N, S, C] --> # [N, S, 3C] Here, T=S
qkv = self.qkv_proj(x_q)
# # [N, S, 3C] --> # [N, S, C] x 3
query, key, value = torch.chunk(qkv, chunks=3, dim=-1)
else:
# [N, S, C]
query = F.linear(
x_q,
weight=self.qkv_proj.weight[: self.embed_dim, ...],
bias=(
self.qkv_proj.bias[: self.embed_dim]
if self.qkv_proj.bias is not None
else None
),
)
# [N, T, C] --> [N, T, 2C]
kv = F.linear(
x_kv,
weight=self.qkv_proj.weight[self.embed_dim :, ...],
bias=(
self.qkv_proj.bias[self.embed_dim :]
if self.qkv_proj.bias is not None
else None
),
)
key, value = torch.chunk(kv, chunks=2, dim=-1)
query = query * self.scaling
# [N, S, C] --> [N, S, c] x h, where C = c * h
query = torch.chunk(query, chunks=self.num_heads, dim=-1)
# [N, T, C] --> [N, T, c] x h, where C = c * h
value = torch.chunk(value, chunks=self.num_heads, dim=-1)
# [N, T, C] --> [N, T, c] x h, where C = c * h
key = torch.chunk(key, chunks=self.num_heads, dim=-1)
wt_out = []
for h in range(self.num_heads):
attn_h = torch.matmul(query[h], key[h].transpose(-1, -2))
attn_h = self.softmax(attn_h)
attn_h = self.attn_dropout(attn_h)
out_h = torch.matmul(attn_h, value[h])
wt_out.append(out_h)
wt_out = torch.cat(wt_out, dim=-1)
wt_out = self.out_proj(wt_out)
return wt_out
def forward_default(
self,
x_q: Tensor,
x_kv: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tensor:
# [N, S, C]
b_sz, S_len, in_channels = x_q.shape
if x_kv is None:
# self-attention
# [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
# [N, S, 3, h, c] --> [N, h, 3, S, C]
qkv = qkv.transpose(1, 3).contiguous()
# [N, h, 3, S, C] --> [N, h, S, C] x 3
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
else:
T_len = x_kv.shape[1]
# cross-attention
# [N, S, C]
query = F.linear(
x_q,
weight=self.qkv_proj.weight[: self.embed_dim, ...],
bias=(
self.qkv_proj.bias[: self.embed_dim]
if self.qkv_proj.bias is not None
else None
),
)
# [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
query = (
query.reshape(b_sz, S_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
# [N, T, C] --> [N, T, 2C]
kv = F.linear(
x_kv,
weight=self.qkv_proj.weight[self.embed_dim :, ...],
bias=(
self.qkv_proj.bias[self.embed_dim :]
if self.qkv_proj.bias is not None
else None
),
)
# [N, T, 2C] --> [N, T, 2, h, c]
kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
# [N, T, 2, h, c] --> [N, h, 2, T, c]
kv = kv.transpose(1, 3).contiguous()
key, value = kv[:, :, 0], kv[:, :, 1]
query = query * self.scaling
# [N h, T, c] --> [N, h, c, T]
key = key.transpose(-1, -2)
# QK^T
# [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
attn = torch.matmul(query, key)
batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape
if attn_mask is not None:
# attn_mask shape should be the same as attn
assert list(attn_mask.shape) == [
batch_size,
num_src_tokens,
num_tgt_tokens,
], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format(
batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape
)
# [N, S, T] --> [N, 1, S, T]
attn_mask = attn_mask.unsqueeze(1)
attn = attn + attn_mask
if key_padding_mask is not None:
# Do not attend to padding positions
# key padding mask size is [N, T]
assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
batch_size,
num_tgt_tokens,
], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format(
batch_size, num_tgt_tokens, key_padding_mask.shape
)
attn = attn.masked_fill(
key_padding_mask.unsqueeze(1)
.unsqueeze(2)
.to(torch.bool), # [N, T] --> [N, 1, 1, T]
float("-inf"),
)
attn_dtype = attn.dtype
attn_as_float = self.softmax(attn.float())
attn = attn_as_float.to(attn_dtype)
attn = self.attn_dropout(attn)
# weighted sum
# [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
out = torch.matmul(attn, value)
# [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
out = self.out_proj(out)
return out
def forward_pytorch(
self,
x_q: Tensor,
x_kv: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tensor:
out, _ = F.multi_head_attention_forward(
query=x_q,
key=x_kv if x_kv is not None else x_q,
value=x_kv if x_kv is not None else x_q,
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=torch.empty([0]),
in_proj_bias=self.qkv_proj.bias,
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=self.attn_dropout.p,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=False,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.qkv_proj.weight[: self.embed_dim, ...],
k_proj_weight=self.qkv_proj.weight[
self.embed_dim : 2 * self.embed_dim, ...
],
v_proj_weight=self.qkv_proj.weight[2 * self.embed_dim :, ...],
)
return out
def forward(
self,
x_q: Tensor,
x_kv: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
*args,
**kwargs
) -> Tensor:
if self.coreml_compatible:
# For CoreML, we follow batch-first format. Make sure the input is of the form
# [Batch , Sequence, Hidden_dim]
return self.forward_tracing(
x_q=x_q,
x_kv=x_kv,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
elif kwargs.get("use_pytorch_mha", False):
# pytorch uses sequence-first format. Make sure that input is of the form [Sequence, Batch, Hidden dim]
return self.forward_pytorch(
x_q=x_q,
x_kv=x_kv,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
else:
# our default implementation format follows batch-first format. Make sure the input is of the form
# [Batch , Sequence, Hidden_dim]
return self.forward_default(
x_q=x_q,
x_kv=x_kv,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)