202509
ai-systems

从零用 PyTorch 实现 Transformer 解码器:自注意力和前馈层优化与自定义位置嵌入

本文从零实现 Transformer 解码器块,聚焦自注意力与前馈层的 PyTorch 优化,并引入自定义位置嵌入以支持可扩展 LLM 训练。

在大型语言模型 (LLM) 的训练中,Transformer 解码器块是核心组件,它通过自注意力机制捕捉序列依赖,并结合前馈网络处理非线性变换。本文基于从零实现的原则,使用 PyTorch 构建高效的解码器块,强调自注意力和前馈层的优化策略,以及自定义位置嵌入的应用。这种实现不仅有助于理解 LLM 的内部工作原理,还能为可扩展训练提供实用参数和监控清单,避免依赖外部库如 Hugging Face Transformers。

Transformer 解码器块的核心在于 masked multi-head self-attention (MHA),它确保模型在生成时仅关注前文,避免未来信息泄露。在 PyTorch 中,我们使用 nn.Linear 层分别投影输入到查询 (Q)、键 (K) 和值 (V) 空间。投影维度 d_out 需可被头数 num_heads 整除,例如 emb_dim=768 时,num_heads=12,每头 head_dim=64。注意力分数计算为 Q @ K^T / sqrt(d_k),其中 d_k = head_dim,然后应用 softmax 并乘以 V。关键优化是 causal mask:使用上三角矩阵填充 -inf 到未来位置的注意力分数,确保 masked_fill_ 操作高效执行。

证据显示,这种实现能有效处理长序列依赖。在 Sebastian Raschka 的《Build a Large Language Model (From Scratch)》中,MHA 类采用视图重塑 (view) 和转置 (transpose) 来并行计算多头注意力,避免循环,提高 GPU 利用率 [1]。为提升可扩展性,建议在训练时使用混合精度 (torch.amp),将注意力计算置于 autocast 上下文中,减少内存占用 50% 以上。同时,dropout=0.1 应用于注意力权重,防止过拟合。对于 KV 缓存,在推理阶段预计算 K 和 V,避免重复计算,支持更长上下文如 1024 令牌。

可落地参数包括:d_model=768, num_heads=12, dropout=0.1, qkv_bias=False(减少参数)。监控点:注意力分数分布(应接近均匀,避免梯度爆炸),使用 torch.histogram 检查 softmax 输出。回滚策略:若 NaN 出现,降低学习率至 1e-4 或启用梯度裁剪 (clip_norm=1.0)。

前馈网络 (FFN) 是解码器块的另一关键部分,它扩展维度以增强表达力。标准实现使用两个线性层:第一层将 emb_dim 映射到 4*emb_dim,激活 GELU,然后第二层回缩至 emb_dim。GELU 公式为 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))),比 ReLU 更平滑,促进梯度流动。在 PyTorch 中,nn.Sequential 封装此过程,确保顺序执行。

优化 FFN 时,关注计算效率:4x 扩展因子平衡性能与参数量(约占模型 2/3 参数)。Raschka 的实现中,FFN 置于残差连接后 LayerNorm 前,采用 pre-norm 架构,提升训练稳定性 [1]。为 scalable LLM 训练,建议使用 fused operations 如 torch.nn.functional.gelu,但从零实现保持纯线性层。参数清单:intermediate_size=3072 (4*768), bias=True 以捕捉偏移。

监控 FFN 输出范数(应在 1 左右),若偏差大,调整初始化(如 Xavier uniform)。风险:高维扩展易导致梯度消失,使用 AdamW 优化器 (betas=(0.9, 0.95), weight_decay=0.1) 缓解。

位置嵌入是赋予序列顺序的关键,自定义实现允许灵活适应任务。传统 sinusoidal embeddings 固定,但 learned embeddings (nn.Embedding(context_length, emb_dim)) 更适应数据分布。在 GPT-like 模型中,将 pos_embeds 添加至 tok_embeds:x = tok_embeds + pos_embeds.unsqueeze(0).expand(batch_size, -1, -1)。这支持动态上下文长度至 1024。

自定义优化:初始化 pos_emb 为小高斯噪声 (std=0.02),与 token emb 一致。证据:learned pos 在预训练中收敛更快,尤其结合 RoPE (Rotary Position Embedding) 扩展,但从零实现优先 learned 以简化。对于长序列训练,限制 max_position_embeddings=2048,避免 OOM;使用梯度检查点 (torch.utils.checkpoint) 节省内存 30%。

整体解码器块组装多个 TransformerBlock:每个块包括 norm1 -> MHA -> residual -> norm2 -> FFN -> residual。最终模型添加 token_emb, pos_emb, dropout, blocks 序列, final_norm 和 lm_head (Linear(emb_dim, vocab_size))。配置示例:vocab_size=50257 (GPT-2), n_layers=12, context_length=1024。

训练清单:

  1. 数据:使用 TikToken 分词,batch_size=4, seq_len=256, stride=128 创建 DataLoader。
  2. 优化器:AdamW(lr=6e-4, betas=(0.9, 0.95)), scheduler=cosine decay。
  3. 损失:CrossEntropyLoss(ignore_index=pad_token),监控 perplexity < 10 为收敛信号。
  4. 硬件:GPU with >=8GB VRAM,支持 distributed training via DDP。
  5. 评估:生成文本质量,使用 BLEU 或人工检查。

风险限制:O(n^2) 注意力复杂度限序列长 <4096;解决方案:引入 sparse attention 或 flash-attention 库,但从零保持 vanilla。引用不超过两处,确保观点基于代码证据。

通过此实现,开发者可构建 scalable LLM 基础,参数如 emb_dim=768 支持 124M 参数模型,训练于标准硬件。未来扩展可集成 LoRA 微调,保持高效。

(字数约 950)

[1] Raschka, S. Build a Large Language Model (From Scratch). Manning, 2024.