本文主要对MultiHeadAttention,给出一个Pytroch版本下的实现

回顾一下MultiHeadAttention的公式。

  1. scaled dot product attention:
    (1)Attention(Q,K,V)=softmax(QKd)V

  2. Head output:
    (2)headi=Attention(QWiQ,KWiK,VWiV)

  3. Multihead attention:
    (3)MultiHead(Q,K,V)=Concat(head1,,headh)

第一种实现

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn.functional as F
from math import sqrt
import torch.nn as nn

def scaled_dot_prod_attention(query, key, value, mask):
dim_k = query.size(-1)
scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)

if mask is not None:
scores = scores.masked_fill(mask==0, -float("inf"))

weights = F.softmax(scores, dim=-1) # 按行进行归一化

return torch.bmm(weights, value)

class AttentionHead(nn.Module):
def __init__(self, head_dim, embed_dim):
super().__init__()

self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)

def forward(self, query, key, value, mask=None):

return scaled_dot_prod_attention(
self.q(query), self.k(key), self.v(value), mask
)

class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
head_dim = embed_dim // num_heads

self.heads = torch.ModuleList([
AttentionHead(embed_dim, head_dim) for _ in range(num_heads)
])

self.output = nn.Linear(embed_dim, embed_dim)

def forward(self, query, key, value, mask):
out = torch.cat([
h(query, key, value, mask) for h in self.heads
], dim=-1)

return self.output(out)

上面这种实现其实是以Q, K, V 作为输入的。但其实网络真正的输入是X, 但其实Q, K, V 也是通过X得到的。

第二种实现就是以X作为输入的

  1. Q, K, V:
    (4)Q=XWQ,K=XWK,V=XWV

第二种实现

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import torch
import torch.nn.Functional as F
from math import sqrt
import torch.nn as nn

def scaled_dot_prod_attention(query, key, value, mask=None):
dim_k = query.size(-1)
scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)
if mask is not None:
scores = scores.masked_fill(mask==0, -float("inf"))
weights = F.softmax(scores, dim=-1)

return torch.bmm(weights, value)

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.dim_head = d_model // num_heads

self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)

self.out_linear = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):
# x 的 维度是 batch_size, seq_len, embed_dim
# 这里实际上假设了embed_dim = d_model
bs = x.size(0)
seq_len = x.size(1)
q = self.q_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

k = self.k_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

v = self.v_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)

q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)

output = scaled_dot_prod_attention(q, k, v, mask)

output = output.transpose(1,2).contiguous().view(bs, seq_len, d_model)

return self.out_linear(output)