本文主要对MultiHeadAttention,给出一个Pytroch版本下的实现

回顾一下MultiHeadAttention的公式。

  1. scaled dot product attention:
    $$
    \text{Attention}(Q,K,V) = softmax\left(\frac{Q K^{\top}}{\sqrt{d}}\right)V \tag{1}
    $$

  2. Head output:
    $$
    head_i = \text{Attention}(QW^{Q}_i, KW^{K}_i, VW^{V}_i) \tag{2}
    $$

  3. Multihead attention:
    $$
    \text{MultiHead}(Q,K,V) = \text{Concat}(head_1, \ldots, head_h) \tag{3}
    $$

第一种实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn.functional as F
from math import sqrt
import torch.nn as nn

def scaled_dot_prod_attention(query, key, value, mask):
dim_k = query.size(-1)
scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)

if mask is not None:
scores = scores.masked_fill(mask==0, -float("inf"))

weights = F.softmax(scores, dim=-1) # 按行进行归一化

return torch.bmm(weights, value)

class AttentionHead(nn.Module):
def __init__(self, head_dim, embed_dim):
super().__init__()

self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)

def forward(self, query, key, value, mask=None):

return scaled_dot_prod_attention(
self.q(query), self.k(key), self.v(value), mask
)

class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
head_dim = embed_dim // num_heads

self.heads = torch.ModuleList([
AttentionHead(embed_dim, head_dim) for _ in range(num_heads)
])

self.output = nn.Linear(embed_dim, embed_dim)

def forward(self, query, key, value, mask):
out = torch.cat([
h(query, key, value, mask) for h in self.heads
], dim=-1)

return self.output(out)

上面这种实现其实是以$Q$, $K$, $V$ 作为输入的。但其实网络真正的输入是$X$, 但其实$Q$, $K$, $V$ 也是通过$X$得到的。

第二种实现就是以$X$作为输入的

  1. $Q$, $K$, $V$:
    $$
    Q = XW_Q, \quad
    K = XW_K, \quad
    V = XW_V \tag{4}
    $$

第二种实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import torch
import torch.nn.Functional as F
from math import sqrt
import torch.nn as nn

def scaled_dot_prod_attention(query, key, value, mask=None):
dim_k = query.size(-1)
scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)
if mask is not None:
scores = scores.masked_fill(mask==0, -float("inf"))
weights = F.softmax(scores, dim=-1)

return torch.bmm(weights, value)

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.dim_head = d_model // num_heads

self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)

self.out_linear = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):
# x 的 维度是 batch_size, seq_len, embed_dim
# 这里实际上假设了embed_dim = d_model
bs = x.size(0)
seq_len = x.size(1)
q = self.q_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

k = self.k_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

v = self.v_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)

output = scaled_dot_prod_attention(q, k, v, mask)

output = output.transpose(1,2).contiguous().view(bs, seq_len, d_model)

return self.out_linear(output)