1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| #!/usr/bin/python # -*- coding: UTF-8 -*- import math import torch from torch import nn from d2l import torch as d2l import pandas as pd
# skip def masked_softmax(X, valid_lens): """通过在最后一个轴上掩蔽元素来执行softmax操作""" # X:3D张量,valid_lens:1D或2D张量 if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0 X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1)
# ok,self attention,输入q,k,v(query数目,和kv数目一致,d维度也一样) # 输出得score是q对应每个v的特征加权和 class DotProductAttention(nn.Module): """缩放点积注意力"""
def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # 设置transpose_b=True为了交换keys的最后两个维度 scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) attention_weights_dropout = self.dropout(self.attention_weights) ## dropout 保持均值不变的情况下,把个别特征值设置为0
# 最关键是Q*K_T是什么意义? # 设置1个batch,每个batch2个query,一个query有4个特征值, (1,2,4) # 设置1个batch,每个bacth有2个kv值,每个kv值由4个特征值, (1,2,4)
# query: # q11,q12,q13,q14 # q21,q22,q23,q24
# key: # k11,k12,k13,k14 # k21,k22,k23,k24
# key_T # k11,k21 # k12,k22 # k13,k23 # k14,k24
# Q*K_T/sqrt(d) = (1,2,4) * (1,4,2) = (1,2,2), 也就是每个query对应的key的相似度 # a11=q11*k11 + q12*k12 + q13*k13 + q14*k14 q1和k1 # a12=q11*k21 + q12*k22 + q13*k23 + q14*k24 q1和k2 # a21=q21*k11 + q22*k12 + q23*k13 + q24*k14 q2和k1 # a22=q21*k21 + q22*k22 + q23*k23 + q24*k24 q2和k2
# a11,a12 # a21,a22 # a11是query1 对应 key1的相似度
# AT * Val = (1,2,2) * (1,2,4) = (1, 2, 4) # value: # v11, v12, v13, v14 # v21, v22, v23, v24
# a11*v11+a12*v21, a11*v12+a12*v22, a11*v13+a12*v23, a11*v14+a12*v24 # 解读: # score11 = q1和k1相似度*v1的特征1 + q1和k2相似度*v2的特征1 # score12 = q1和k1相似度*v1的特征2 + q1和k2相似度*v2的特征2 # score13 = q1和k1相似度*v1的特征3 + q1和k2相似度*v2的特征3 # score14 = q1和k1相似度*v1的特征4 + q1和k2相似度*v2的特征4
# a21*v11+a22*v21, a21*v12+a22*v22, a21*v13+a22*v23, a21*v14+a22*v24 # 解读: # score21 = q2和k1相似度*v1的特征1 + q2和k2相似度*v2的特征1 # score22 = q2和k1相似度*v1的特征2 + q2和k2相似度*v2的特征2 # score23 = q2和k1相似度*v1的特征3 + q2和k2相似度*v2的特征3 # score24 = q2和k1相似度*v1的特征4 + q2和k2相似度*v2的特征4 # 也就是一个score(i,j)第一个索引是代表query(i),第二个索引是代表所有的value在j特征的加权和,
res = torch.bmm(attention_weights_dropout, values) return res
|