1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads) X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3])
#@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) 测试: #为了加强多个头的矩阵并行计算 #首先对于X,batch_size=2, 每个batch有3个query,每个query经过转换有num_hidden=8个特征。 #transpose_qky,会针对每个注意力头,有个batch,但是特征数目会除以head数目,比如2两个头,则每个query分出4个特征给头1,和头2. #综上所述, 所以对于一个输入的无论是Q,K,V, 形状为(batch_size, 查询个数&kv对数目,特征数目) #会转为(batch_size*num_head, 查询个数&kv对数目,特征数目/num_head)
X=torch.normal(0,1,(2,3,8)) print(X.shape)
Xt=transpose_qkv(X, 2) print(Xt.shape)
#在进行完多头并行计算后,得到多个头的数据,这个时候需要合并多个头,进行最后的线性映射。 #(batch_size*num_head, 查询个数&kv对数目,特征数目/num_head) 通过transpose_output Output=transpose_output(Xt, 2) print(Output.shape)
结果: torch.Size([2, 3, 8]) torch.Size([4, 3, 4]) torch.Size([2, 3, 8])
|