Hotdry.

Article

从 Gemini 蒸馏到 26M 工具调用模型:Simple Attention Network 架构与训练细节

解析 Needle 如何将 Gemini 3.1 Flash-Lite 的工具调用能力蒸馏到 26M 参数的 Simple Attention Network:无 FFN 设计、门控残差、INT4 QAT 正则化与 Token 级损失加权。

2026-05-12ai-systems

在大模型能力向边缘设备迁移的浪潮中,模型压缩始终是工程落地的核心挑战。Cactus Compute 开源的 Needle 项目提供了一个极具参考价值的案例:将从 Gemini 3.1 Flash-Lite 蒸馏出的工具调用能力,封装进一个仅 26M 参数的 Simple Attention Network(SAN)中,在消费级设备上实现本地微调与推理。本文聚焦该模型的核心架构设计、训练配方与工程实践,为需要在资源受限场景下部署工具调用能力的团队提供可落地的参数参考。

架构设计:无 FFN 的工具调用专用网络

传统 Transformer 的参数分布中,前馈网络(FFN)占据约 2/3 的参数量。然而 Needle 的设计哲学建立在这样一个洞察之上:工具调用本质上是一个检索与组装(Retrieval-and-Assembly)任务,而非通用语义建模。模型只需完成三件事 —— 将用户查询与工具名称对齐、提取参数值、组装 JSON 结构 —— 这三者都依赖注意力机制的输入输出对齐能力,而非 FFN 提供的逐位置非线性变换。

基于此判断,Needle 采用了 Encoder-Decoder 架构,Encoder 负责编码工具定义,Decoder 负责生成调用指令。关键设计参数如下:隐层维度 d=512,8 个注意力头配合 4 个 Key-Value 头(BPE 分词器词表 8192),Encoder 堆叠 12 层,Decoder 堆叠 8 层,Embedding 在编码器输入与解码器输出之间共享(即 tied embedding)。最核心的变化在于:Encoder 层内完全移除了 FFN,仅保留 ZCRMSNorm、配备 GQA+RoPE 的自注意力以及门控残差连接;Decoder 侧同样如此,仅在交叉注意力模块中与编码器输出交互。

Encoder-Decoder 架构的选择同样经过权衡。工具定义本身是结构化对象,双向编码使 Encoder 能一次性感知完整定义,无需像因果模型那样从左到右推断结构。此外,固定尺寸的编码器表示直接用于交叉注意力,避免了解码阶段每步重新 attending 完整输入的 KV 缓存开销,这对于边缘设备的内存带宽尤为友好。

门控残差与 ZCRMSNorm:注意力 only 网络的可训练性保障

移除 FFN 带来一个关键问题:每个注意力层只能对输入添加增量(Add),缺乏非线性重写能力。若仅使用标准残差 x = x + Attn(Norm(x)),深层网络的表达能力将严重受限;若完全放弃残差 x = Attn(Norm(x)),12+ 层网络的梯度传播将难以维持。Needle 采用的门控残差设计 x = x + sigmoid(gate) * Attn(Norm(x)) 提供了一个可学习的平衡点:每层有一个标量门控参数,初始化为 0(sigmoid (0) = 0.5),使训练起始阶段残差以半强度通过,模型可自主学习将有效层的门控推向 1、无用层推向 0,既保留了梯度高速公路,又允许灵活的特征精炼。

ZCRMSNorm(Zero-initialized Channel RMSNorm)则是门控残差的配套组件。标准 RMSNorm 为 x * gamma / RMS(x),gamma 初始化为 1;ZCRMSNorm 改为 x * (1 + gamma) / RMS(x),gamma 初始化为 0。在初始化阶段,整个注意力块表现为一个衰减的单位映射加上衰减的归一化注意力输出,没有任何组件以强偏置启动。Needle 还将 ZCRMSNorm 应用到 QK 投影上(即 QK-Norm),以提升训练稳定性。该设计思路来自 nGPT 与 DeepSeek-V3 系列工作,但在 Needle 的无 FFN 架构中扮演了更关键的初始化角色。

对比工具选择头:CLIP 风格的检索增强

工具调用场景常面临候选工具数量众多的情况,逐一进行生成前的交叉注意力成本高昂。Needle 引入了一个 CLIP 风格的对比较量选择头(Contrastive Tool Selection Head)来解决这一问题:编码器输出经过非填充位置均值池化后,依次通过 Dense (d_model/4) -> ReLU -> Dense (128) -> L2 归一化,生成单位向量表示。

该头的训练采用对称 CLIP 损失函数,每个 batch 中查询与对应的正例工具构成配对,批次内其他样本作为负例,通过余弦相似度对比学习。学习到的温度参数 log_temp 可调。推理阶段,查询与各工具候选分别编码后在共享嵌入空间计算余弦相似度,取 top-k 最相关工具送入主生成流程。该对比头与主交叉熵损失联合训练,权重系数为 0.1x,两者共享编码器参数,兼顾检索效率与生成精度。

训练配方:预训练 + 后训练的两阶段策略

Needle 的完整训练分为预训练与后训练两个阶段。预训练消耗 200B tokens,在 16 块 TPU v6e 上耗时约 27 小时;后训练使用 2B tokens 的单次工具调用数据集,训练时长仅 45 分钟。这一配比说明:通用语言建模能力通过大规模预训练建立,工具调用专项能力通过短时后训练快速注入。

优化器采用双优化器策略:注意力投影(Q/K/V/O)使用 Muon 优化器,学习率 0.02,权重衰减 0.01;其余参数使用 AdamW,学习率 3e-4。Muon 通过 Newton-Schulz 正交化约束权重更新,防止堆叠多层线性层(无 FFN 提供非线性间隔)时出现表征坍缩。该优化策略在注意力 only 架构中尤为重要。

INT4 QAT 作为正则化手段

Needle 采取了一个非常规但极具工程价值的设计:每隔 100 步进行一次伪量化(Fake Quantization)。权重按 group_size=32 进行 INT4 对称量化,通过直通估计器(STE)将梯度流经取整操作,前向传播使用量化后的权重计算。该设计的核心价值并非单纯压缩,而在于其正则化效果:量化噪声等价于在权重上注入噪声,防止小容量模型(本身已去除 FFN,参数更少)过度拟合训练分布。模型在训练全程始终以推理时的量化状态运行,消除了训练 - 推理量化精度落差(PTQ gap)。

Token 级损失加权:匹配误差分布的动态学习

标准交叉熵对所有 token 一视同仁,但工具调用任务的误差分布高度不均衡:JSON 结构错误最先被消除(模型在训练早期即达到约 99% 的 JSON 可解析率),真正的难点在于参数值的准确提取,其次是工具名称,再次是参数键名。Needle 的 token 级损失加权方案直接针对这一分布设计:基础 JSON 结构 token 权重 1.0x,工具名称 token 加权 2.0x,参数键名加权 1.5x,参数值 token 加权最高达 4.0x。此外,辅助损失包括 z-loss(用于 logit 稳定性)与 CLIP 对比损失(权重 0.1x),协同保障训练收敛。

推理性能与部署参数

在生产环境中,Needle 运行于 Cactus 推理引擎,Prefill 阶段吞吐达到 6000 tokens/sec,Decode 阶段 1200 tokens/decoding-step。模型权重完全开源于 HuggingFace(Cactus-Compute/needle),支持本地微调。通过 needle playground 命令可启动 Web UI 界面进行交互测试与一键微调;needle finetune data.jsonl 支持 CLI 方式的本地数据微调。测试表明,该模型在单次工具调用任务上优于 FunctionGemma-270m、Qwen-0.6B、Granite-350m 与 LFM2.5-350m 等更大规模基线,但适用场景限于单次调用导向,不适合长程对话设置。

工程落地参考要点

对于需要在自有工具集上部署类似蒸馏模型的团队,以下参数值得关注:隐层维度 512 是小模型的效率 - 能力平衡点;无 FFN 架构可将每层参数量削减约 2/3,显著缓解边缘设备的内存带宽瓶颈;门控残差的门控参数建议初始化为 0(即 sigmoid (0) = 0.5),配合 ZCRMSNorm 的 gamma=0 初始化;INT4 QAT 的 group_size 设为 32、伪量化频率 100 步是经过验证的配置;Token 级损失加权中参数值 4.0x、工具名称 2.0x 的比例可作为初始超参基准,根据实际错误分布动态调整。

资料来源:GitHub - cactus-compute/needle (https://github.com/cactus-compute/needle) 及 Simple Attention Networks 文档 (https://github.com/cactus-compute/needle/blob/main/docs/simple_attention_networks.md)。

ai-systems

内容声明:本文无广告投放、无付费植入。

如有事实性问题,欢迎发送勘误至 i@hotdrydog.com