Attention Residuals (AttnRes)

notes/research/arxiv/2026-03-16-1946-attention-residuals.md

Port 8777
path
notes/research/arxiv/2026-03-16-1946-attention-residuals.md
# Attention Residuals (AttnRes) ## 元数据 - **title**: Attention Residuals - **category**: research - **source_type**: paper - **created_by**: 小美虾 - **created_at**: 2026-03-16 19:46 - **status**: raw - **tags**: [Transformer, Attention, Residual, MoonshotAI, 架构优化] ## 来源 - **GitHub**: https://github.com/MoonshotAI/Attention-Residuals - **PDF**: https://github.com/MoonshotAI/Attention-Residuals/blob/master/Attention_Residuals.pdf - **机构**: MoonshotAI(月之暗面) ## 核心贡献 ### 问题背景 标准 Transformer 残差连接的问题: - 固定单位权重累加所有层输出 - 随着深度增加,均匀聚合稀释了每层的贡献 - 导致隐藏状态幅度无界增长(PreNorm 的已知问题) ### AttnRes 解决方案 **核心思想**: 用 softmax 注意力机制替代固定累加,使每层能够选择性聚合之前的表示 **公式**: ``` h_l = Σ_{i=0}^{l-1} α_{i→l} · v_i ``` 其中权重 α_{i→l} 通过每层单个学习到的伪查询 w_l ∈ ℝ^d 计算 **优势**: - 每层获得选择性、内容感知的早期表示访问 - 可学习的深度注意力机制 - 即插即用,可替代标准残差连接 ### 两种实现 | 类型 | 描述 | 内存复杂度 | |------|------|-----------| | **Full AttnRes** | 每层关注所有先前输出 | O(Ld) | | **Block AttnRes** | 层分组为块,块间注意力 | O(Nd) | **Block AttnRes**(实用方案): - 将层划分为 N 个块(约 8 个块) - 块内使用标准残差累加 - 块间应用注意力机制 - 恢复 Full AttnRes 大部分收益,开销极小 ## 实验结果 ### 计算效率 - Block AttnRes 在相同计算预算下优于基线 - 匹配基线 1.25x 计算量的损失表现 ### 基准测试对比 | 类别 | 基准 | Baseline | AttnRes | 提升 | |------|------|----------|---------|------| | **General** | MMLU | 73.5 | 74.6 | +1.1 | | | GPQA-Diamond | 36.9 | 44.4 | **+7.5** | | | BBH | 76.3 | 78.0 | +1.7 | | | TriviaQA | 69.9 | 71.8 | +1.9 | | **Math & Code** | Math | 53.5 | 57.1 | +3.6 | | | HumanEval | 59.1 | 62.2 | **+3.1** | | | MBPP | 72.0 | 73.9 | +1.9 | | **Chinese** | CMMLU | 82.0 | 82.9 | +0.9 | | | C-Eval | 79.6 | 82.5 | +2.9 | **最大提升领域**: - 多步推理:GPQA-Diamond +7.5 - 代码生成:HumanEval +3.1 ### 训练动态改善 - 缓解 PreNorm 稀释问题 - 输出幅度在深度上保持有界 - 梯度范数在各层间分布更均匀 ## 核心代码结构 ```python def block_attn_res(blocks: list[Tensor], partial_block: Tensor, proj: Linear, norm: RMSNorm) -> Tensor: """ 块间注意力:在块表示 + 部分和上进行注意力 blocks: N 个形状 [B, T, D] 的张量 partial_block: [B, T, D] 块内部分和 """ V = torch.stack(blocks + [partial_block]) # [N+1, B, T, D] K = norm(V) logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K) h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V) return h ``` ## 作者团队 Chen, Guangyu; Zhang, Yu; Su, Jianlin; Xu, Weixin; Pan, Siyuan; Wang, Yaoyu; Wang, Yucheng; Chen, Guanduo; Yin, Bohong; Chen, Yutian; Yan, Junjie; Wei, Ming; Zhang, Y.; Meng, Fanqing; Hong, Chao; Xie, Xiaotong; Liu, Shaowei; Lu, Enzhe; Tai, Yunpeng; Chen, Yanru; Men, Xin; Guo, Haiqing; Charles, Y.; Lu, Haoyu; Sui, Lin; Zhu, Jinguo; Zhou, Zaida; He, Weiran; Huang, Weixiao; Xu, Xinran; Wang, Yuzhi; Lai, Guokun; Du, Yulun; Wu, Yuxin; Yang, Zhilin; Zhou, Xinyu ## 发布年份 2026 ## 引用 ```bibtex @article{attnres2026, title = {Attention Residuals}, author = {Chen, Guangyu and Zhang, Yu and Su, Jianlin and ... and Zhou, Xinyu}, year = {2026}, url = {https://github.com/MoonshotAI/Attention-Residuals} } ``` --- *笔记由小美虾自动创建,待进一步调研和整理*