为什么需要位置编码
在之前介绍的:
Transformer之Token的通俗理解
Transformer之Attention的通俗理解
两篇文章中,我们介绍了Token被作为一个整体送入Attention中进行计算,这样才能得到各个Token之间的关联。
在NLP中,词语的顺序至关重要,比如说"爱做"和"做爱",相同的词语所表达的意思却天差地别,所以编码器会把带有顺序信息的向量一同送入Attention中;在CV中,图像被nn.Conv2d切成一个个小块,然后把小块变成 [ B , 1 , 1 , C ] [B, 1, 1, C] [B,1,1,C]的点,这些点共同构成送入Attention的patch_embedding,虽然对顺序的要求没有那么高,但是也是有一定要求的。
所以就需要体现顺序的位置编码,融合进要送入Attention的Token之中。
位置编码的本质
位置编码本身是一种偏置,这样就能看出来了
Q u e r y = W q × ( X T o k e n + P E ) = W q × X T o k e n + W q × P E \begin{array}{ccl} Query &= &W_q \times \Big(X_{Token} + PE \Big) \\ &&\\ & = & W_q \times X_{Token} + W_q \times PE \end{array} Query==Wq×(XToken+PE)Wq×XToken+Wq×PE
也就是说 X T o k e n X_{Token} XToken是在做一维平面运动,那么我们就可以这么说,位置编码除了提供序列的时空信息,还具有提供偏置信息的功能。
位置编码是如何起作用的?
这里是重点,建议仔细看看
这里有两句话
“我爱你”和“你爱我”
我(0.5) | 爱(0.6) | 你(0.7) | |
---|---|---|---|
我(0.5) | 0.25 | 0.3 | 0.35 |
爱(0.6) | 0.3 | 0.36 | 0.42 |
你(0.7) | 0.35 | 0.42 | 0.49 |
通过注意力计算,只能计算出“我”、“爱”、“你”之间的关联,例如,如果"我"的嵌入向量是 V e c 我 Vec_{我} Vec我,不管在“我爱你”还是“你爱我”中,其嵌入向量和其他向量所计算的值都是相同的,正如上表所示“爱我” 和 “我爱”的注意力值都是0.3,这就无法区分出语义了;
但是,但是啊,如果加上位置编码,那么:
在“我爱你”中,“我”的嵌入向量就变成了 V 我 + P E 1 V_{我}+PE_1 V我+PE1,
在“你爱我”中,“我”的嵌入向量就变成了 V 我 + P E 3 V_{我}+PE_3 V我+PE3,
这时,相同的Token所计算到的注意力值就是不同的,因为位置变了,假设 P E 0 = 0 , P E 1 = 1 , P E 2 = 2 PE_{0}=0, PE_1=1, PE_2=2 PE0=0,PE1=1,PE2=2,那么上表就会变成下面这张表:
我(0.5+0=0.5) | 爱(0.6+1=1.6) | 你(0.7+2=2.7) | |
---|---|---|---|
我(0.5) | 0.25 | 0.8 | 1.35 |
爱(0.6) | 0.3 | 0.96 | 1.62 |
你(0.7) | 0.35 | 1.12 | 1.89 |
从上表上可以看出,当加上位置编码之后,“爱我”的注意力值是0.8,而“我爱”的注意力值就变成了0.3,这样就很能区分出来是哪种含义了。
Token如何与Position-Embeding融合
通常来说是有两种方法,一种是把Position-Embeding(以后都称之为PE)和Token直接相加,另一种是PE和Token做阿达玛积(对应位置一一相乘),如图所示,其中PE需要具有与Token相同的维度
位置编码有哪些?
1.绝对位置编码
三角式绝对位置编码
P E = { s i n ( n 1000 0 2 × i D T o k e n ) , d = 2 i c o s ( n 1000 0 2 × i D T o k e n ) , d = 2 i + 1 PE= \begin{cases} sin\Big(\frac{n}{10000^{2 \times \frac{i}{D_{Token}}}} \Big), & d=2i \\ & \\ cos \Big(\frac{n}{10000^{2 \times \frac{i}{D_{Token}}}} \Big), & d=2i+1 \end{cases} PE=⎩
⎨
⎧sin(100002×DTokenin),cos(100002×DTokenin),d=2id=2i+1
具体形式如图所示,Token的维度是 [ B , N , D i m ] [B, N, Dim] [B,N,Dim],对应的PE也是 [ B , N , D i m ] [B, N, Dim] [B,N,Dim]
学习式位置编码
这时最简单的一种位置编码,例如Token的维度是 [ B , N , D i m ] [B, N, Dim] [B,N,Dim],那么就在__init__()函数中用nn.Parameter()生成一个维度为 [ B , N , D i m ] [B, N, Dim] [B,N,Dim]的初始位置编码,然后在训练中参与更新,最后学习到一组位置编码。
简单如是:
import torch
import timm
# 为什么没有用[B, N, Dim]呢,因为加上Token的时候,PE会因广播机制而复制
# 所以,本质上还是[B, N, Dim]
self.absolute_position_embedding = nn.Parameter(torch.zeros(1, N, Dim))
timm.models.layers.trunc_normal_(self.absolute_position_embedding, std=0.02)
两种绝对编码的对比
- 三角式相对于学习式具有良好的外推性:
三角式的位置编码具有三角函数的周期性,所以当文本或者patch等Token在Inference的长度要比在Train中要长上数倍时,位置编码可以周期性增长,也会有相对良好的效果;
- 三角式位置编码每词向量对应的位置编码( [ 1 , D i m ] [1 , Dim] [1,Dim])之间是正交的,这就意味着他们之间是相互独立的,不会相互干扰;
- 三角式位置编码因为是周期函数,而且频率很高(周期很长),可以容纳相当多的Tokens;
- 三角式位置编码除了关注了相对关系的距离,还有相对关系的角度信息,包含得更丰富;
- 两者都是关注单个位置信息,在输入层之上,简单地和输入向量(Token)相加,区别于相对位置模型,往往是信息对(增加了位置信息的维度);
相对位置编码
A t t = s o f t m a x ( Q × K T D i m + r e l a t i v e _ p o s i t i o n _ b i a s ) × V Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V Att=softmax(DimQ×KT+relative_position_bias)×V
相对位置编码主要是对序列中元素的相对位置关系处理得会更好,但是处理方式也就和绝对位置编码不同了,上面的relative_position_bias就是所谓的相对位置编码
- 绝对位置编码:是一个矩阵,加在Token上的
- 相对位置编码:是一个矩阵,加在注意力得分上的
Q E + P E × K E + P E T = X E + P E × W q × [ X E + P E × W k ] T = X E + P E × W q × W k T × X E + P E T = ( X q + P E q ) × W q × W k T × ( X k + P E k ) T = X q × W q ⏞ Q u e r y × W k T × X k T ⏞ K e y ⏟ 第一项 + P E q × W q ⏞ a × W k T × X k T ⏞ K e y ⏟ 第二项 + X q × W q ⏞ Q u e r y × W k T × P E k T ⏞ b ⏟ 第三项 + P E q × W q ⏞ a × W k T × P E k T ⏞ b ⏟ 第四项 \begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array} QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项
Xq×Wq
Query×WkT×XkT
Key+第二项
PEq×Wq
a×WkT×XkT
Key+第三项
Xq×Wq
Query×WkT×PEkT
b+第四项
PEq×Wq
a×WkT×PEkT
b
从这里可以看出,除了第一项是一个二次型,其余三项都是跟位置偏置相关的一次变换,也就是一个数据和数据本身有关,其他的都是跟位置编码有关,我们想办法把两种PE: P E q PE_q PEq和 P E k PE_k PEk通过换元法,换成一个:
P E q × W q × W k T × P E k T = P E q × W × P E k T = P E q × W × [ P E q − ( P E q − P E k ) ] T \begin{array}{ccl} PE_q \times W_q \times W_k^T \times PE^T_k &= &PE_q \times W \times PE^T_k \\ && \\ & = &PE_q \times W \times [PE_q - (PE_{q}-PE_{k} )]^T \\ \end{array} PEq×Wq×WkT×PEkT==PEq×W×PEkTPEq×W×[PEq−(PEq−PEk)]T
我们可以发现,只需要 Q u e r y Query Query的PE和 Q u e r y Query Query与 K e y Key Key的PE的相对位置,就可以计算了
那么是不是可以这样:
- 首先计算维度为 [ B , N q , C ] [B, N_q, C] [B,Nq,C]的 Q Q Q和 [ B , N k , C ] [B, N_k, C] [B,Nk,C] 的 K K K
- 然后计算维度为 [ B , N q , N k ] [B, N_q, N_k] [B,Nq,Nk]的 Q × K T Q \times K^T Q×KT
- 最后加上相对位置编码
具体过程如下:
- 生成index矩阵对,分别对应下面两张图
coords_h = torch.arange(h)
coords_w = torch.arange(w)
coords = torch.meshgrid([coord_h, coord_w])
2. 拼接在一起,形成一张图,如下图所示
coords = torch.stack(coords)
3. 拉平,维度变成 [ 2 , W × H ] [2, W\times H] [2,W×H]
coords_flatten = torch.flatten(coords, 1)
4. 转置相减,利用BroadCast机制去复制维度
# 这里的None就是添加了一个维度
relative_coords = coords_flatten[:, :,None] - coords_flatten[:, None, :]
并得到结果如图:
5. 增加偏置
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += W -1
relative_coords[:, :, 1] += H -1
relative_coords[:, :, 0] *= 2 * H -1
- 求和,把上图括号里面的东西求和
relative_position_index = relative_coords.sun(-1)
以上都是索引,那么位置编码如何生成呢?
self.relative_position_bias_table = nn.Parameter(torch.zeros(2*W-1, 2*H-1, num_heads))
具体就是利用上面生成的Index来索引这里的relative_position_bias_table
relative_position_bias = self.relative_position_bias_table(relative_coords)