在端侧 AI 场景中,如何用极小的模型完成精准的工具调用(Tool Calling),一直是工程化落地的核心挑战。传统方案往往依赖百兆甚至更大的语言模型,再通过提示工程或检索增强来驱动工具选择。这在服务器端可以接受,但在手机、手表、AR 眼镜等资源受限的设备上,内存占用与推理延迟就成了无法忽视的瓶颈。Needle 项目正是在这一背景下登场:它将 Gemini 3.1 的工具调用能力蒸馏为一个仅有 26M 参数的 Simple Attention Network,在单次调用(Single-shot)场景下超越 FunctionGemma-270m、Qwen-0.6B、Granite-350m 等模型,且能够在消费级 Mac/PC 上完成本地微调。本文将从架构设计、训练策略、工程权衡三个层面,详细拆解这一轻量化方案的技术路径。
移除 FFN:注意力机制足以完成工具路由
理解 Needle 的起点,是它对标准 Transformer 架构的核心假设提出了质疑:对于工具调用这类结构化输出任务,前馈网络(Feed-Forward Network)是否真的必要?
传统 Transformer 中,约三分之二的参数来自 FFN 层。这些层在每个位置执行非线性特征变换,被认为是模型表达能力的核心来源。然而,Needle 团队在实验中观察到:当模型规模缩小到 50M 参数以下时,FFN 参数的边际收益急剧下降。原因在于工具调用本质上是一个检索 - 组装过程 —— 模型需要将用户查询与工具名称对齐、提取参数值、组装 JSON 结构,这三个步骤都可以理解为输入与输出之间的对齐与复制操作,而非对每个位置的语义进行独立重写。注意力机制天然适合处理对齐任务,而 softmax 本身已经提供了数据依赖的非线性混合能力,因此 FFN 的非线性变换在工具调用场景中并非不可替代。
更重要的是,移除 FFN 直接削减了约三分之二的层间参数,这些参数正是内存带宽瓶颈的主要贡献者。在边缘设备上,推理延迟往往由内存带宽而非算力主导,因此减少参数体积可以直接转化为更低的延迟和更小的内存占用。Needle 的实验数据显示,移除 FFN 后,模型在 ARM CPU 上的推理速度提升了数倍,同时保持了与更大模型相当的工具调用准确率。
架构设计:门控残差与 ZCRMSNorm
既然移除了 FFN,模型就需要在残差连接和归一化层上做出更精细的设计,否则 12 层以上的深度堆叠会导致训练不稳定甚至梯度消失。Needle 采用了门控残差(Gated Residual)和 ZCRMSNorm 两大关键设计来解决这一问题。
门控残差的核心思想是为每个子层的输出引入一个可学习的标量门控:x = x + sigmoid (gate) * Attn (Norm (x))。在训练初期,门控被初始化为 0,此时 sigmoid (0) = 0.5,意味着每个注意力层只贡献一半的更新强度。这给了模型一个安全的起点 —— 即使某些层尚未收敛,其输出也不会完全覆盖原始输入的语义。随着训练的进行,模型可以学习将有用层的门控推向 1(完整激活),将无用层的门控推向 0(抑制输出),从而实现一种自适应的层间权重分配。这种机制与 nGPT 和 DeepSeek-V3 系列工作中的做法一脉相承,但 Needle 进一步将它应用到了无 FFN 的注意力 - only 架构中。
ZCRMSNorm 则是对标准 RMSNorm 的改进:标准 RMSNorm 的公式为 x * gamma / RMS (x),其中 gamma 初始化为 1;而 ZCRMSNorm 改为 x * (1 + gamma) / RMS (x),gamma 初始化为 0。在训练初始阶段,ZCRMSNorm 退化为一个仅做尺度归一化的恒等变换,与门控残差配合使用时,整个注意力块起始于一个衰减的恒等映射加衰减的归一化注意力。这种设计确保了没有任何组件在训练开始时就带有强偏置,为后续的梯度优化提供了一个干净的起点。Needle 还将 ZCRMSNorm 应用于 QK 投影(QK-norm),以进一步提升训练稳定性。
编码器 - 解码器架构:工具定义的双向建模
与许多小型语言模型采用纯解码器架构不同,Needle 选择了编码器 - 解码器结构。这一选择的考量在于:工具定义本身是结构化的对象,包含名称、描述、参数模式等多维信息,双向编码器能够一次性看到工具的完整定义,而因果解码器只能从左到右逐步推断结构,需要额外的上下文来弥补信息缺口。
编码器部分包含 12 层,使用分组查询注意力(GQA)和旋转位置编码(RoPE),但同样移除了 FFN 层。编码器的输出不仅驱动解码器的交叉注意力,还同时服务于对比工具选择头(Contrastive Tool Selection Head)—— 这是一个 CLIP 风格的检索头,通过对称对比损失在共享嵌入空间中学习将查询和工具投影到可比较的表示。训练时使用批次内的负样本作为对比信号,损失权重为主交叉熵损失的 0.1 倍。推理时,模型先将用户查询和所有候选工具分别编码为归一化向量,再通过余弦相似度排序选出 top-k 最相关的工具送入生成阶段。这一设计在工具集合较大时尤为有效,可以显著减少生成阶段的输入长度。
解码器部分包含 8 层,交替使用带掩码的自注意力(包含 RoPE)和交叉注意力。所有层共享同一个输出嵌入与线性层(tied embedding),最终输出工具调用结果和文本答案。
训练策略:数据生成与损失设计
Needle 的训练分为预训练和后训练两个阶段。预训练阶段在 16 块 TPU v6e 上消耗 200B tokens,耗时约 27 小时;后训练阶段则聚焦于单次工具调用数据集,仅用 2B tokens 训练 45 分钟即可达到可用状态。这种极快收敛的底气来自于数据合成策略:Needle 使用 Gemini 生成高质量的单次工具调用数据,确保每个样本都是清晰的查询 - 工具 - 参数三元组。
训练中采用了双优化器配置:Muon 优化器专门负责 Q/K/V/O 投影,学习率设为 0.02,权重衰减为 0.01;AdamW 负责其余参数,学习率为 3e-4。Muon 优化器的特点是通过 Newton-Schulz 正交化步骤约束权重更新方向,防止堆叠大量线性层时出现表征崩溃。对于一个纯注意力网络来说,这种正交性约束尤为重要 —— 因为缺少 FFN 提供的非线性重启机制,线性层的权重退化更容易导致表示坍缩。
INT4 量化感知训练(QAT)被用作正则化手段。每 100 步进行一次模拟量化,将权重按组(group_size=32)量化为 INT4 对称格式,直通估计器(STE)让梯度能够流过取整操作。这种做法带来了双重好处:量化噪声本身是一种权重扰动正则化,可以降低小模型的过拟合风险;同时模型在训练时就适应了推理时的量化表示,消除了训练 - 部署之间的量化差距。
损失函数层面,Needle 采用了多级 token 加权策略:JSON 结构 token 权重 1.0x,工具名称 token 权重 2.0x,参数键 token 权重 1.5x,参数值 token 权重 4.0x。加权策略与实际错误分布高度吻合 —— 模型在早期训练即可达到 99% 的 JSON 解析正确率,而真正的难点集中在参数值的生成质量上。因此将更多监督信号投向值 token,是提升整体可用性的关键。除此之外还加入了辅助的 z-loss 用于 logit 稳定性,以及前文提到的 CLIP 对比损失。
推理部署:速度与精度的工程权衡
生产环境中,Needle 运行在 Cactus 推理引擎上,实现了 6000 tokens / 秒的预填充速度和 1200 tokens / 秒的解码速度。这一性能得益于两个因素的叠加:模型体积的极致压缩(26M 参数在 INT4 量化下仅需约 13MB 存储),以及对 FFN 的彻底移除(削减了最大的矩阵乘法维度)。在边缘设备的 CPU 上推理时,内存带宽不再是瓶颈,因为整个模型可以完全容纳在 L2/L3 缓存中。
部署时建议关注以下参数阈值:预填充延迟目标应低于 50ms(针对典型工具描述的输入长度),解码延迟目标应低于 100ms/token(针对 JSON 结构化输出);内存占用在 INT4 量化下应控制在 20MB 以内,留出足够空间给输入 token 和 KV 缓存。对于需要多工具候选的场景,对比检索头的 top-k 值建议设为 4 到 8 之间,过小会遗漏相关工具,过大则增加编码开销。
本地微调方面,Needle 提供了开箱即用的 CLI 工具。通过 needle playground 可以启动 Web UI,在浏览器中上传自己的工具定义并进行交互式测试与微调;通过 needle finetune data.jsonl 可以在本地自动下载权重并完成微调,训练结果以 pkl 格式保存。数据格式采用标准的 JSONL,每行包含 query、tools 和 expected_output 三个字段。由于模型规模小,单卡消费级 GPU 甚至高端 Mac 的 CPU 都能完成完整的微调流程,非常适合需要针对特定业务工具集进行定制的场景。
关键参数速查
以下是在实际使用 Needle 时最需要关注的几个可配置参数及其推荐范围。工具集规模直接影响编码器输入长度和交叉注意力成本,当工具数量超过 20 个时,建议启用对比检索头并设置 top-k 为 8,以平衡召回率与延迟。参数值权重 4.0x 是经过实验验证的最优值,降低该权重会导致 JSON 解析率下降,提升则可能挤压其他 token 的学习空间。门控残差的初始化值固定为 0,不建议修改,因为这直接控制了训练初期的等效深度。量化感知训练间隔 100 步是精度与训练速度的折中,间隔过短会增加计算开销,过长则削弱正则化效果。
资料来源
本文技术细节主要来源于 Needle 官方 GitHub 仓库(https://github.com/cactus-compute/needle)及配套的 Simple Attention Networks 文档(https://github.com/cactus-compute/needle/blob/main/docs/simple_attention_networks.md)。
内容声明:本文无广告投放、无付费植入。
如有事实性问题,欢迎发送勘误至 i@hotdrydog.com。