01、注意力机制的引入
前面几篇文章总结了注意力机制、自注意力机制、QKV的原理,本篇继续向下推进,聊一聊多头注意力机制。
这里我想先介绍一下注意力机制的引入缘由,在注意力机制出现之前,序列建模的主要方法是基于循环神经网络(RNN)及其变体(如LSTM长短期记忆网络、GRU门控循环单元)。这些模型在处理序列数据时,通常将序列逐个时间步处理,并将历史信息压缩到一个固定大小的隐状态向量中。
那些输入是序列、输出也是序列的模型就是序列模型,这是从输入输出的角度来定义的,这些模型通常会使用编码器到解码器的架构,编码器用于处理输入序列,解码器用于生成输出序列。由于循环神经网络RNN在处理序列数据时的优势,很多模型都会选择循环神经网络作为编码器和解码器,其原理就是:循环神经网络会逐步遍历序列,并通过隐状态的传递保存序列的历史信息,从而达到理解序列数据的目的。
这种编码方式存在一些局限性:
1、信息压缩瓶颈:将整个序列的信息压缩到一个固定维度的向量中,可能会导致信息丢失,尤其是对于长序列。
2、长距离依赖问题:尽管LSTM和GRU通过门控机制缓解了梯度消失问题,但处理非常长的序列时,捕捉长距离依赖仍然困难。
3、并行化困难:RNN的序列依赖性导致无法并行处理整个序列,训练速度较慢。
但这是循环神经网络能够处理序列数据的关键,同时也是循环神经网络的局限所在。循环神经网络处理序列数据时需要逐步遍历,这是顺序性的,这个过程难以并行。此外,隐状态在长序列间进行传递容易丢失信息,使网络难以对长序列进行建模。
另一方面,研究人员发现,注意力机制在序列到序列模型中的应用对性能的提升有很大帮助。后来就有人提出一个开创性的想法,能不能整个序列到序列模型就用注意力机制,不再基于循环神经网络。当然,我们知道这个尝试是成功了,这就是大名鼎鼎的 Transformer,Attention Is All You Need,不仅使训练过程可以并行,而且还能对长序列进行建模!
关注公众号【阳光宅猿】回复【AIGC】领取最新AI人工智能学习资料,包含RAG、Agent、深度学习、模型微调等多种最新技术文档等你来选!!
关注公众号【阳光宅猿】回复【加群】进入大模型技术交流群一起学习成长!!!
02、注意力机制原理
多头注意力机制是整个Transformer模型的核心,且不说多头是什么操作,我们先来看注意力机制是个啥,之前文章也讲过注意力机制在Seq2Seq中的应用。深度学习之一篇文章带你深度理解注意力机制
注意力机制在本质上是根据当前查询以不同的权重抽取序列中的信息,序列中的每个元素会有一个键向量用于匹配,另有一个值向量表示其蕴含的信息。
当给定查询向量时,会通过以下两个步骤得到最终要抽取的信息向量
1、首先,计算查询向量与每个键向量之间的匹配分数
2、然后,以匹配分数作为权重对所有值向量加权求和

其中最关键的部分就是打分函数的设计,简单来说就是如何衡量查询向量和键向量之间的匹配程度
在 Transformer中所使用的打分函数是缩放点积算法,其计算公式如下:


当查询向量来自序列本身时,此时的注意力机制被称为自注意力机制,序列中的每个元素对应三个向量,一是Q向量,二是K向量,三是V向量,那么如何得到这些向量呢?深度学习之注意力机制中QKV原理解读
下面是自注意力机制的运作架构图(来源于网络)

在自注意力机制中,每个元素都会根据自身查询来获取序列的上下文信息,从而得到对应的向量表示,对比循环神经网络,二者都可用于编码序列,其不同之处在于自注意力机制能够并行地编码序列。
这是因为序列中每个元素的注意力计算之间并没有依赖关系,因此可以通过矩阵运算来并行加速。反观循环神经网络则需要通过在序列中逐步传递隐状态来获取序列的上下文信息,这个过程不能并行。
仅用一个注意力头往往难以同时捕捉多种语义关系(如词法、语义、句法等)。因此,Transformer 提出了多头注意力机制 (Multi-Head Attention, MHA)。
03、多头注意力机制原理
Transformer是基于注意力机制的模型架构,其核心创新点是多头注意力机制(Multi-Head Attention)。在多头注意力机制中,输入数据会被拆分成多个部分映射到多组查询、键、值向量,然后分别进行注意力计算,每个部分使用独立的注意力头进行处理,这样模型可以在多个不同的子空间中关注不同的信息。
多头注意力机制的工作原理可以描述为:将输入向量分别映射为查询(Query)、键(Key)和值(Value)向量,并通过多个注意力头计算每个子空间中的加权和,最后将所有头的输出拼接在一起,形成最终的输出。结构图如下(图片来源于网络):

多头,顾名思义,就是同时通过多个“自注意力机制”进行特征提取。图中每一个输入序列被分为了多组QKV,这就是多头,捕捉输入序列x,中不同子空间的特征。其实你可以将“多头”当做“单头”来理解“注意力机制”的计算,不影响你对模型整体的认识。下面我们一步一步的来描述“多头自注意力机制”的计算过程。
第1步:对Q、K、V线性变换
在每个“头”中,都有三组Linear线性层。
它们用于对输入张量Q、K、V进行线性变换。之前文章中已经详细讲过这块,这里就不过多赘述了,可以看这篇文章:深度学习之注意力机制中QKV原理解读
这里思考:为什么需要这些线性变换呢?为什么不直接用输入向量作为Q、K、V呢?原因如下:
1、提升模型的表达能力:线性变换相当于给模型增加了可学习的参数,使得模型能够学习如何将输入映射到不同的子空间,从而捕捉更丰富的特征。如果没有线性变换,那么每个头的Q、K、V就是原始的输入,这样模型的表达能力会受到限制。
2、降维或升维:线性变换可以改变Q、K、V的维度。在多头注意力中,我们通常会将输入的维度拆分成多个头,每个头的维度是原始维度除以头数。通过线性变换,我们可以将输入投影到适合每个头的维度。
3、多头注意力的需要:多头注意力的核心思想是将注意力分散到多个不同的子空间,让每个头关注不同的方面。如果不进行线性变换,那么每个头的Q、K、V都是一样的,这样就失去了多头的意义。通过不同的线性变换(不同的权重矩阵),每个头可以将输入投影到不同的子空间,从而学习到不同的特征表示。
4、增加模型的非线性:虽然线性变换本身是线性的,但注意力机制中会接softmax等非线性函数。而且,整个注意力机制是放在一个带有非线性激活函数的前馈网络中的。但是,单独看注意力机制,线性变换为模型提供了可学习的线性变换,这些变换可以与其他部分一起通过梯度下降进行优化。
第2步:计算缩放点积注意力得分
完成对Q、K、V的线性变换后,用新的Q、K、V这三组张量,计算缩放点积注意力纷得分。

自注意力机制中,查询向量Q与所有键向量K之间的点积被用来计算注意力得分。为了避免点积结果过大导致梯度问题,其实就是把我们通过一系列计算出来的结果,大的变得小一点,小的变得大一点,让它们距离变得更近。其中引入了一个缩放因子1/√dk,其中dk是键向量的维度。
关注公众号【阳光宅猿】回复【AIGC】领取最新AI人工智能学习资料,包含RAG、Agent、深度学习、模型微调等多种最新技术文档等你来选!!
关注公众号【阳光宅猿】回复【加群】进入大模型技术交流群一起学习成长!!!
第3步:融合多头计算的结果
每个头都会基于一组Q、K、V ,计算缩放点积注意力分数,我们需要将这些不同头的计算结果拼接(concat)起来,我们会使用一个线性层,将这些结果进行“Merge”,输出最终的一个与输入张量维度相同的新张量矩阵,变换后输出的张量就带有了全局信息。下面我们基于一个具体例子,说明多头注意力的计算过程。

首先,输入的数据是,经过位置编码后的“You are welcome PAD”;它对应最下方的黄色词向量矩阵。
第1步:对Q、K、V线性变换
4×6的输入张量,首先被复制为Query、Key和Value。

然后分别和3个Linear线性层与权重矩阵进行线性变换,生成三组新的Q、K和V,这里使用浅橙色、红色和深橙色来表示,尺寸都是4×6,这个过程就是最基础的线性层计算。
第2步:计算缩放点积注意力
接着将q、k和v,三个计算结果,先split为多头的形式,然后带入到Attention的计算公式中计算注意力分数。

计算注意力机制的输出。经过Attention的计算后,就会融合输入序列中的全局信息。
第3步:融合多头计算的结果
在示意图的最上方,Attention Score的输出是多组结果。

因此总结来说,整个多头注意力机制的作用,就是对输入张量进行特征变换,变换后的输出张量,就带有了全局信息。相比于自注意力机制,主要有以下几点优势:
1、增强表达能力
通过并行地学习多个子空间表示,多头注意力机制能够捕捉到更丰富的上下文信息和特征交互模式。
2、提高鲁棒性
由于每个头都独立地进行计算,因此即使某些头受到噪声或错误信息的干扰,整个模型仍然能够保持稳定的性能。
3、易于优化
多头注意力机制通过将问题分解为多个较小的子问题来处理,降低了优化的难度和复杂度。
多头注意力机制同样广泛应用于各种NLP任务中,并且已经成为Transformer及其变体等先进模型的重要组成部分。它不仅提高了模型的性能,还促进了NLP领域的快速发展和创新。
以下是一个多头注意力机制的Python代码实现(详细注释版本):
import torch
import torch.nn as nn
import torch.nn.functional as F
classMultiHeadAttention(nn.Module):
def__init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0# ✅ 确保模型维度能被头数整除,这是多头的必要条件
self.d_model = d_model # ✅ 模型总维度,如512
self.num_heads = num_heads # ✅ 多头注意力的头数,如8
self.head_dim = d_model // num_heads # ✅ 每个头的维度,如512÷8=64
# ✅ 1. 创建线性变换层(多头注意力的核心组件)
# 将输入投影到不同的子空间,每个头有自己的权重
self.w_q = nn.Linear(d_model, d_model) # ✅ Q的投影:每个头学习不同的查询模式
self.w_k = nn.Linear(d_model, d_model) # ✅ K的投影:每个头学习不同的键模式
self.w_v = nn.Linear(d_model, d_model) # ✅ V的投影:每个头学习不同的值模式
self.w_o = nn.Linear(d_model, d_model) # ✅ 输出投影:整合所有头的信息
self.dropout = nn.Dropout(dropout) # ✅ 注意力dropout,防止过拟合
defforward(self, x, attn_mask=None):
"""
x: [B, L, d_model]
B: batch_size, L: sequence_length
"""
B, L, _ = x.size() # ✅ 获取batch_size和序列长度
# ✅ 1. 线性投影 - 多头的第一步:为不同头创建不同的表示
Q = self.w_q(x) # ✅ [B, L, d_model] -> [B, L, d_model] 每个头有不同的Q投影
K = self.w_k(x) # ✅ 同上,创建K
V = self.w_v(x) # ✅ 同上,创建V
# ✅ 此时Q、K、V还是合并的维度,下一步会分割成多个头
# ✅ 2. reshape 为 [B, H, L, Dh] - 多头的核心实现
defreshape_heads(t):
# ✅ 第一步:将t从[B, L, d_model]变为[B, L, H, Dh]
# ✅ 其中H=num_heads, Dh=head_dim
# ✅ 第二步:transpose(1,2)将H维度放到第2维,得到[B, H, L, Dh]
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
Q = reshape_heads(Q) # ✅ 将Q分割成多个头: [B, H, L, Dh]
K = reshape_heads(K) # ✅ 将K分割成多个头: [B, H, L, Dh]
V = reshape_heads(V) # ✅ 将V分割成多个头: [B, H, L, Dh]
# ✅ 现在Q、K、V都分成了num_heads个头,每个头有自己的维度head_dim
# ✅ 3. 缩放点积注意力 - 每个头独立计算注意力
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5) # ✅ [B, H, L, L]
# ✅ @是矩阵乘法:每个头的Q和K转置相乘
# ✅ K.transpose(-2, -1)将K的最后两维转置,从[B,H,L,Dh]变为[B,H,Dh,L]
# ✅ 除法是缩放因子,防止softmax梯度消失
if attn_mask isnotNone:
# ✅ 应用注意力掩码(如因果掩码)
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
# ✅ 将被掩码的位置设为负无穷,softmax后会接近0
attn = F.softmax(scores, dim=-1) # ✅ 在最后一个维度(L)上做softmax
attn = self.dropout(attn) # ✅ 对注意力权重应用dropout
out = attn @ V # ✅ [B, H, L, L] @ [B, H, L, Dh] = [B, H, L, Dh]
# ✅ 每个头独立计算加权和,得到每个头的输出
# ✅ 4. 合并头 - 将多个头的输出合并回原始维度
# ✅ 首先将头维度移回原位:[B, H, L, Dh] -> [B, L, H, Dh]
out = out.transpose(1, 2).contiguous()
# ✅ contiguous()确保内存连续,为view做准备
# ✅ 然后合并所有头:[B, L, H, Dh] -> [B, L, d_model] 其中d_model = H × Dh
out = out.view(B, L, self.d_model)
returnself.w_o(out) # ✅ 通过最后的线性层整合多头信息