本文主要对MultiHeadAttention,给出一个Pytroch版本下的实现
回顾一下MultiHeadAttention的公式。
scaled dot product attention:
$$
\text{Attention}(Q,K,V) = softmax\left(\frac{Q K^{\top}}{\sqrt{d}}\right)V \tag{1}
$$
Head output:
$$
head_i = \text{Attention}(QW^{Q}_i, KW^{K}_i, VW^{V}_i) \tag{2}
$$
Multihead attention:
$$
\text{MultiHead}(Q,K,V) = \text{Concat}(head_1, \ldots, head_h) \tag{3}
$$
第一种实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| import torch import torch.nn.functional as F from math import sqrt import torch.nn as nn
def scaled_dot_prod_attention(query, key, value, mask): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)
if mask is not None: scores = scores.masked_fill(mask==0, -float("inf")) weights = F.softmax(scores, dim=-1)
return torch.bmm(weights, value)
class AttentionHead(nn.Module): def __init__(self, head_dim, embed_dim): super().__init__()
self.q = nn.Linear(embed_dim, head_dim) self.k = nn.Linear(embed_dim, head_dim) self.v = nn.Linear(embed_dim, head_dim)
def forward(self, query, key, value, mask=None):
return scaled_dot_prod_attention( self.q(query), self.k(key), self.v(value), mask )
class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() head_dim = embed_dim // num_heads
self.heads = torch.ModuleList([ AttentionHead(embed_dim, head_dim) for _ in range(num_heads) ])
self.output = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask): out = torch.cat([ h(query, key, value, mask) for h in self.heads ], dim=-1)
return self.output(out)
|
上面这种实现其实是以$Q$, $K$, $V$ 作为输入的。但其实网络真正的输入是$X$, 但其实$Q$, $K$, $V$ 也是通过$X$得到的。
第二种实现就是以$X$作为输入的
- $Q$, $K$, $V$:
$$
Q = XW_Q, \quad
K = XW_K, \quad
V = XW_V \tag{4}
$$
第二种实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| import torch import torch.nn.Functional as F from math import sqrt import torch.nn as nn
def scaled_dot_prod_attention(query, key, value, mask=None): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k) if mask is not None: scores = scores.masked_fill(mask==0, -float("inf")) weights = F.softmax(scores, dim=-1)
return torch.bmm(weights, value)
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.dim_head = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model)
self.out_linear = nn.Linear(d_model, d_model) def forward(self, x, mask=None): bs = x.size(0) seq_len = x.size(1) q = self.q_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)
k = self.k_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)
v = self.v_linear(x).view(bs, seq_len, self.num_heads, self.dim_head)
q = q.transpose(1,2) k = k.transpose(1,2) v = v.transpose(1,2)
output = scaled_dot_prod_attention(q, k, v, mask)
output = output.transpose(1,2).contiguous().view(bs, seq_len, d_model)
return self.out_linear(output)
|