Hybrid LLM 之 Gated Attention

Qwen3-Next[1] 发布后,算是真正开启了 hybrid 序幕,原本还想着后面再慢慢补这块,现在看来是不行了,得提前了。好在东西也不多,我们就借着这次机会过一轮吧。

这是第一篇,我们简单点,从 Gated Attention 开始,来自 Paper:Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free[2],5 月份的一篇论文了,官方 GitHub[3] 关注的人不多,没想到这就成了 Qwen 新版本的标准配置了。

论文本身是比较简单的,就是在标准的 attention 后面加一个 sigmoid 激活门,就能始终提升效果(还能提升训练稳定性)。经过尝试各种位置和变体后,将其有效性归结为两个因素:在 softmax 注意力的低秩映射上引入非线性;使用依赖于 query 的稀疏门控分数来调节 SDPA 输出。另外,这种稀疏门控机制能够缓解“attention sink”问题,并提升长上下文外推能力。

关于 attention sink 可以查看:https://yam.gift/2025/08/06/NLP/2025-08-06-gpt-oss/

代码非常简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 来自transformer library: transformers/models/qwen3_next/modeling_qwen3_next.py
class Qwen3NextAttention(nn.Module):
def __init__(self, config: Qwen3NextConfig, layer_idx: int):
...
# 注意这里*2
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
)
...
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# 分出来gate和query_state
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
...
# gate sigmod后乘以attn分数
attn_output = attn_output * torch.sigmoid(gate)

好了,这篇文章到这里其实可以结束了,后面的可以不看了,主要是分析具体机理了。

实验验证

了解 NLP 历史的同学应该知道,门控机制最早出现在 RNN 的变体:LSTM[4] 中,它包括三个门:遗忘门、输入门和输出门。大模型时代,SSM 和一些 attention 机制也应用门机制,但对其功能和影响研究的并不深入。

文章举了 Switch Heads 的例子,它通过 sigmoid 选择前 K 个注意力头专家。实验结果表明,即便简化为仅包含单一专家,且门控仅作用于调节 value 输出时,仍能获得显著的性能提升。这说明门控机制本身提供了重要的内在价值。

本文仔细探讨了门控机制,在不同位置引入并进行评测,如下图所示:

结果表明:G1 能带来显著提升。最右边的图是 loss。

除了位置,还对比研究了:

  • 粒度
    • 逐 head:单个标量门控分数调节整个注意力 head 的输出。
    • 逐元素 (Elementwise):门控分数为与 attention 维度相同的向量,实现逐维度的精细调节。
  • 专用或共享 head
    • 专用 head:每个注意力头具有独立的门控分数,实现对各头的独立调节。
    • 共享 head:权重和门控分数在所有注意力 head 之间共享。
  • 乘性或加性
    • 乘性:Y′=Y⋅σ(Xθ)
    • 加性:Y′=Y + σ(Xθ)
  • 激活函数
    • SiLU 用于加性。
    • Sigmoid 用于乘性。

结果如下:

我们先看参数情况。模型基本配置如下:

  • 128 experts,激活 8 门控细粒度 experts
  • head_dim=128, q=32, k=4, kv_groups=32/4=8
  • hidden_dim = 128×32=4096
  • 层数:24

一个 Qwen3-30B-A3B 的 24 层变体,参数估计如下:

1
2
3
4
5
6
7
8
9
(24*(
1*(2048*128) + \ # 专家激活
128*3*(2048*768) + \ # 专家
2*2048 + \ # layer_norm
2*128 + \ # qk_norm
2*(2048*32*128) + \ # q, o
2*(4*128*2048) # k, v
) + \
1*2048+2048*151936*1)/1e9 = 15.26606233615B

不过按此配置,激活值仅有 1.7B(而不是论文里提的 2.54B):

1
2
3
4
5
6
7
8
9
(24*(
1*(2048*128) + \
8*3*(2048*768) + \
2*2048 + \
2*128 + \
2*(2048*32*128) + \
2*(4*128*2048)
) + \
1*2048+2048*151936*1)/1e9 = 1.6765173761.7B

关于这点和作者邮件确认了下,不过还没有回复。

下面逐个分析每种配置增加的参数:

  • (2) k=8:24*2*(8-4)*128*2048 = 50331648 ≈ 50m
  • (3) q=48: 24*2*2048*(48-32)*128 = 201326592 ≈ 201m
  • (4) 4experts: 24*4*3*2048*768 = 452984832 ≈ 450m,和表格有点出入。
  • (5, 8) G1: 24*32*128*2048 = 201326592 ≈ 201m
  • (6, 7) G2=G3: 24*4*128*2048 = 25165824 ≈ 25m
  • (9) G5: 24*2048*2048 = 100663296 ≈ 100m
  • (10) HW G1: 24*32*2048 = 1572864 ≈ 1.6m
  • (11) HW G2: 24*4*2048 = 196608 ≈ 0.2m
  • (12, 13) HS G1, G2: 同 5 6,取了平均。
  • (14, 15) Activation: 同 G5。

有个位置有点出入,不知道是作者笔误还是我的计算有误。根据上面表格里的实验结果:

  • G1 和 G2 其实都不错(第 5、6 行)、G1 的 PPL 更低一些。

  • G1 和 G2 逐头门控仅引入极少额外参数,但带来显著提升(10、11 行)。不同注意力头分配独立门控分数比较重要(12 对比 10,13 比 11)。

  • 乘性优于加性(5 比 14)。

  • Sigmoid 比 SiLU 更好(5 比 15)。

另外在 dense model 上,门控在多种设定下也均有效,而且能够提升稳定性并促进可扩展性。

为什么?

前面说了,两个因素:非线性和稀疏性。

非线性

在 MHA 中,第 i 个 token 在第 k 个 head 的输出如下:

oik=(j=0iSijkXjWVk)WOk=j=0iSijkXj(WVkWOk)(1)o_i^k=\left(\sum_{j=0}^i S_{i j}^k \cdot X_j W_V^k\right) W_O^k=\sum_{j=0}^i S_{i j}^k \cdot X_j\left(W_V^k W_O^k\right) \tag{1}

Wo^k 表示输出层中与第 k 个 注意力头相关的参数。S 是第 k 个 head 上,第 i 个 token 对第 j 个 token 的 attention。X_j 是 token j 的注意力输入,X_j W_v^k 表示 token j 在第 k 个头的值输出。

其中:

  • X_j (d_model) 是第 j 个 token 的输入一维向量。
  • W_v^k (d_model, dk) 是第 k 个注意力头的 value 投影矩阵。
  • X_j W_v^k (dk) 是第 j 个 token 在第 k 个注意力头的 value 向量。
  • S_ij 是第 i 个 token 对第 j 个 token 在第 k 注意力头注意力权重(标量)。
  • Σ S_ij X_j W_v^k (dk) 是加权后的 value 向量,其实就是在 j 个 token 上加权求和。
  • W_o^k (dk, d_model) 是第 k 个头的 Output 投影矩阵。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np

def softmax(x):
e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) # 防止溢出
return e_x / np.sum(e_x, axis=-1, keepdims=True)

def self_attention(Q, K, V):
scores = Q @ K.T # (n,d_k) @ (d_k,n) = (n,n)
d_k = Q.shape[1]
scores /= np.sqrt(d_k)
attention_weights = softmax(scores) # (n, n)
output = attention_weights @ V # (n,n) @ (n, d_k) = (n, d_k)
return output

def multi_head_attention(X, W_Q, W_K, W_V, W_O, k):
"""
X: (n, d)
W_Q, W_K, W_V: (d, d_k)
W_O: (d_k, d)
"""
d = X.shape[1]
# 对第k个头计算 Self-Attention
Q = np.dot(X, W_Q[k]) # (n, d) @ (d, d_k) = (n,dk)
K = np.dot(X, W_K[k]) # 同上
V = np.dot(X, W_V[k]) # 同上
head = self_attention(Q, K, V) # (n, d_k)
output = head @ W_O # (n, d_k) @ (d_k, d) (n, d)
return output

# n是序列长度!X是hidden state
np.random.seed(42)
n = 1
d = 8
X = np.random.rand(n, d)
num_heads = 2
d_k = d // num_heads

W_Q = np.random.rand(num_heads, d, d_k) # (num_heads, d, d_k)
W_K = np.random.rand(num_heads, d, d_k)
W_V = np.random.rand(num_heads, d, d_k)
W_O = np.random.rand(d, d) # (d, d)

output0 = multi_head_attention(X, W_Q, W_K, W_V, W_O[:d_k, :], 0) # (L, d), head0
output1 = multi_head_attention(X, W_Q, W_K, W_V, W_O[d_k:, :], 1) # (L, d), head1

output = output0 + output1 # 分块求和,这和标准attention结果是一样的

这里其实类似分块矩阵,和标准attention的结果是一样的,即:

Concat(head1,...,headh)WO=1hheadiWiO(2)\text{Concat}(\text{head}_1, ..., \text{head}_h)W^O = \sum_1^h \text{head}_i W_i^O \tag{2}

好了,言归正传,由于 head_dim < hidden_dim,可以将 Wv^kWo^k 合并为一个作用在所有 Xj 上的低秩线性映射(如果使用 GQA,Wv 在分组内共享,进一步压缩了参数,会加剧“低秩”问题)。鉴于在两个线性映射之间引入非线性能够提升其表达能力(来自 paper:On the Number of Linear Regions of Deep Neural Networks[5]),有 2 个可以缓解低秩问题的改动:

oik=(j=0iSijk Non-Linearity-Map (XjWVk))WOk(3)o_i^k=\left(\sum_{j=0}^i S_{i j}^k \cdot \text { Non-Linearity-Map }\left(X_j W_V^k\right)\right) W_O^k \tag{3}

oik= Non-Linearity-Map (j=0iSijkXjWVk)WOk(4)o_i^k=\text { Non-Linearity-Map }\left(\sum_{j=0}^i S_{i j}^k \cdot X_j W_V^k\right) W_O^k \tag{4}

式(3)对应 G2,式(4)则对应 G1。这同时解释了为啥 G5 没效果,因为它没有解决 WvWo 之间缺乏非线性的问题。总之,有效门控变体所带来的性能提升很可能归因于在 WvWo 之间引入了非线性。

稀疏性

在检查了 G1 和 G2 的门控分数后发现:

  • 有效的门控分数是稀疏的。G1 最低,并且在接近 0 的区域高度集中。
  • 注意力头特定的稀疏性很重要。强制多个注意力头共享门控分数会导致分数上升,同时性能提升减弱。
  • Query 依赖性很重要。G2 分数普遍高于 G1,性能更差。说明门控分数依赖 Query 时稀疏更有效,而不是由 key 和 value 决定的。意味着门控分数的稀疏性可能会过滤掉与当前 Query 无关的上下文信息。引入输入无关门控(上表第 6 行),结果有所提升(可能是非线性导致),但门控分数比较高。进一步表明,有效的稀疏性应当是 Query 依赖的。
  • 稀疏性不足的门控更差。将 sigmoid 换成非稀疏版本(上表第 7 行):NS-sigmoid(x) = 0.5+0.5⋅σ(x),引入非线性但没有稀疏性,能增益劣于 SDPA 输出的 sigmoid 门控(上表第 1 行)。

总的来说,sigmoid 引入的稀疏性是有效的另一个原因。

正外部性

如论文介绍,Gated Attention 还有两个正外部性。

缓解Attention-Sink

我们开头就提到了,在关于gpt-oss那些值得关注的点 | Yam[6] 中也做了分析。本文主要分析了注意力分数的分布(在所有注意力头上取平均),以及分配给首个 token 的注意力分数(上表 F-Attn 列),以及各层最大隐藏状态激活值的均值(上表 M-Act 列,已取整)。结果如下:

  • G1 显著降低首 token 注意力分数,同时减少大规模激活。
  • G2 同样会降低大激活值,但无法降低首 token 的注意力分数。另外,headwise 门控很重要,同时即便没有大激活值(上表第 4 行),依然有 attention sink,说明大激活值并不是其前提条件。
  • 当降低门控的输入(Query)依赖性(上表第 6 行),或采用 NS-sigmoid 减少稀疏性(上表第 7 行)时,会增加大激活值与 attention sink。

其实这点是很容易理解的,这种解决方案与 gpt-oss 的那个 bias 作用其实类似,让某些 token 的注意力可以为 0,而不是强行分配最终导致 sink。

关于 attention sink 现象,我们直接取来自 关于gpt-oss那些值得关注的点 | Yam[6] 中的原话:模型将极大比例的注意力权重集中在序列的第一个 token(通常是 <bos> 开始 token),即使它对下文语义并不重要。其原因当然也和softmax有关,即使当前 token 与前面的其他 token 并无关联,模型仍需要分配“冗余”的注意力分数。此时,最容易被“牺牲”的位置就是第一个 token,因为它对所有后续 token 都可见,容易被训练为 attention sink。关于这点,相信只要大家做过Attention分析就应该深有体会,无论哪一层,attention分数最大的几乎都是首Token。

有助于上下文长度扩展

实验观察结果如下:

  • 32k 设置下,带门控的模型表现略优于基线模型。说明在训练长度范围内,attention sink 现象可能并未显著损害模型的长上下文性能。
  • 使用 YaRN 将上下文长度扩展到 128k 时,基线模型与门控模型在原有的 32k 范围内性能均有所下降。不过,对于带门控的模型,这种下降趋势较不明显。而且在 64k128k 上下文长度下,带门控的注意力模型显著优于基线模型。

文章推测:添加门控有助于模型更好地适应上下文长度的扩展。一种可能的解释是:基线模型依赖 attention sink 来调节注意力分数的分布,当通过 YaRN 等方法修改 RoPE 基数时,attention sink 模式可能难以在“免训练”的情况下自适应,导致性能显著下降。相反,带门控的模型主要依赖输入依赖的门控分数来控制信息流,因此对这类变化表现出更强的鲁棒性。

个人觉得这个解释有一定道理。

扩展一下

非LLM时代的Gated Attention

其实,第一次看到 gated attention 去搜索后,除了这篇文章,还有另外一篇 19 年针对普通神经网络(非 LLM)的文章:Not All Attention Is Needed: Gated Attention Network for Sequence Data[7],对应的代码:GitHub [8]

它的观点如下:传统的注意力机制会关注输入句子的整个隐藏状态序列,但在大多数情况下,并不需要关注所有隐藏状态,尤其是对于长序列。文章提出了一种名为门控注意力网络(GA-Net)的新方法,该方法使用辅助网络动态地选择需要关注的元素子集,并计算注意力权重来聚合所选元素。这种方法避免了对未关注元素进行大量不必要的计算,并使模型能够关注序列中的重要部分。

如图所示,结合前面的观点,可以看到两篇文章的思想其实是一样的,只不过用在不同网络上。另外,这里用的是 sigmoid 后的 0 或 1,而不是连续值,如下式所示。不过,前文的参考文献中并没有看到这篇文章;)

ht=LSTM(ht1,xt)pt=sigmoid(Uht)(5)\begin{array}{r} h_t^{\prime}=L S T M\left(h_{t-1}^{\prime}, x_t\right) \\ p_t=\operatorname{sigmoid}\left(U h_t^{\prime}\right) \end{array} \tag{5}

看这个做法,是不是换成一个 dropout 也是有效果的,说到 dropout 就想起了曾经的 简单的对比学习框架:SimCSE | Yam[9],感觉好像都是上古时期的工作了……

非Attention的Gated

其实,除了上面这篇,还很容易让人想到另外几篇文章,没错,你可能已经猜到了,就是 GLU[10]Swish[11]。GLU 由两个线性投影的逐元素乘积构成,其中一个投影会先经过 sigmoid 函数。Swish 则是输入直接乘自己的 sigmoid。

GLU(x)=(xW+b)σ(xV+c)Swish(x)=xσ(βx)(6)\begin{array}{r} \operatorname{GLU}(x)=(x W+b) \odot \sigma(x V+c) \\ \operatorname{Swish}(x) = x \cdot \sigma(\beta x) \end{array} \tag{6}

除此之外,还有来自 Noam Shazeer 的 GLU 变体:GLU Variants Improve Transformer[12],作用的就是 Transformer 的 FFN,各种变体如下所示:

在 GLUE 多个任务实验结果(按平均)显示,ReGLU 最好,SwiGLU 紧随其后,然后是 GLU。如果按 TOP1 的任务数,排在第一位的是 SwiGLU(5/12),然后是 Bilinear(3/12)。而且,在 SuperGLUE 上 SwiGLU 最好。可能这也是为什么后面用 SwiGLU 的原因吧。

最有意思的是论文的结论部分:我们扩展了 GLU 家族的层结构,并提出了它们在 Transformer 中的应用。在迁移学习的设置下,这些新变体在预训练所用的去噪目标上表现出更低的困惑度,同时在许多下游语言理解任务上也取得了更好的结果。这些架构实现简单,且没有明显的计算代价。至于它们为何有效,我们并没有给出解释,只能像其他一切成功一样,将其归因于“上天的恩赐”。最后这句英文原话是:We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence.

哈哈,这个作者太逗了,后面搜了一下,发现他居然也是 MQA 的作者,论文在这里:Fast Transformer Decoding: One Write-Head is All You Need[13]。如果你不知道什么是 MQA,那肯定知道 GQA,这可是现在 LLM 的标配。GQA 就是在 MHA 和 MQA 之间取了个折中,如下图(来自GQA[14]论文)所示:

值得一提的是,Noam Shazeer 的这两篇论文都是独立作者,看起来更像是随手记了个 note……

小结

考虑到 Qwen 目前的全球影响力,Qwen3-Next 算是正式大规模开启了混合架构。我们也从 Qwen3-Next 的 Gated Attention 出发,开启混合架构的学习。Gated Attention 虽然思想和实现都非常简单,但其实分析起来还是有不少细节的。论文将其有效的原因归结为引入的非线性和稀疏性,根据实验结果显示,是比较有说服力的。除此之外,论文还提到两个正外部性:缓解 attention sink 和助力上下文长度扩展,根据实验结果看,确实是一个不错的机制。可以目测,Gated Attention 应该会成为标配。

在正式介绍为论文内容后,我们又稍微扩展了一下,一个是思想一致的非 LLM 时代的 Gated Attention,不过它把 sigmoid 后的值做了二值化操作。另一个是 FFN 上的 Gated 变体,实验结果也是显示有效,作者幽默地将其归为 “上天的恩赐”,他也是 MQA 的作者。

没想到又巴拉巴拉说了这么多,就到这里吧。

References

[1] Qwen3-Next: https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list
[2] Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free: https://arxiv.org/abs/2505.06708
[3] GitHub: https://github.com/qiuzh20/gated_attention
[4] LSTM: https://yam.gift/2019/06/17/NLP/SLP/2019-06-17-Ch09-Senquence-Processing-with-Recurrent-Networks/
[5] On the Number of Linear Regions of Deep Neural Networks: https://arxiv.org/abs/1402.1869
[6] 关于gpt-oss那些值得关注的点 | Yam: https://yam.gift/2025/08/06/NLP/2025-08-06-gpt-oss/
[7] Not All Attention Is Needed: Gated Attention Network for Sequence Data: https://arxiv.org/abs/1912.00349
[8] GitHub : https://github.com/keya-desai/Gated-Attention
[9] 简单的对比学习框架:SimCSE | Yam: https://yam.gift/2021/07/10/Paper/2021-07-10-SimCSE/
[10] GLU: https://arxiv.org/abs/1612.08083
[11] Swish: https://arxiv.org/abs/1710.05941
[12] GLU Variants Improve Transformer: https://arxiv.org/abs/2002.05202
[13] Fast Transformer Decoding: One Write-Head is All You Need: https://arxiv.org/abs/1911.02150
[14] GQA: https://arxiv.org/abs/2305.13245