202510
ai-systems

从零实现最小 Transformer LLM

使用 PyTorch 从头构建小型 Transformer 语言模型,包括自定义 BPE 分词器、GPT-2 式架构,并在莎士比亚数据集上训练的核心组件。

在大型语言模型(LLM)迅猛发展的今天,许多开发者直接调用预训练模型如 GPT-4 或 Llama,而忽略了底层机制的理解。从零实现一个最小 Transformer-based LLM,不仅能加深对注意力机制、嵌入层和生成过程的认知,还能帮助优化自定义模型以适应特定领域需求。这种 hands-on 方法避免了黑箱依赖,培养工程化思维,尤其适合资源有限的场景。

实现的核心在于构建 GPT-2 类似的解码器-only 架构,这种设计专注于序列生成任务。证据显示,这种架构在小规模数据集上训练高效,例如在莎士比亚文本上,仅需几小时即可收敛。相比全双向 Transformer,它简化了计算,专注于因果注意力,确保模型只能“看到”前文,从而模拟真实语言生成过程。

首先,准备数据集。我们使用经典的 TinyShakespeare 数据集,这是一个约 1MB 的纯文本文件,包含莎士比亚全集的对话和剧本。下载后,将其置于项目目录下。该数据集适合小模型训练,因为其词汇量有限(约 65 个字符级 token),但结构丰富,能捕捉叙事模式。实际操作中,先加载文本:

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

接下来,实现自定义 BPE(Byte Pair Encoding)分词器。BPE 是现代 LLM 的标准 tokenizer,能高效处理未知词。通过合并高频子词对,它将文本分解为固定词汇表。自定义实现避免依赖 Hugging Face 等库,确保从零理解。

步骤如下:

  1. 统计字符频率,初始化词汇表为所有唯一字节。
  2. 迭代合并最频繁的相邻 pair,直到达到目标词汇大小(如 50257,与 GPT-2 一致)。
  3. 构建编码/解码函数。

代码示例(简化版):

import re
from collections import Counter

def get_stats(ids):
    pairs = {}
    for i in range(len(ids) - 1):
        pair = (ids[i], ids[i+1])
        pairs[pair] = pairs.get(pair, 0) + 1
    return pairs

def merge(ids, pair):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(pair[0])
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

# 初始化:文本转为 bytes
data = bytes(text, 'utf-8')
ids = list(data)
vocab = {i: bytes([i]) for i in range(256)}

# 训练 BPE,目标 vocab_size=1000(小模型用)
num_merges = 1000 - 256
for _ in range(num_merges):
    pairs = get_stats(ids)
    best_pair = max(pairs, key=pairs.get)
    ids = merge(ids, best_pair)
    vocab[len(vocab)] = vocab[best_pair[0]] + vocab[best_pair[1]]

# 编码函数
def encode(s):
    s_bytes = bytes(s, 'utf-8')
    ids = list(s_bytes)
    while len(ids) > 1:
        pairs = get_stats(ids)
        if not pairs:
            break
        pair = min(pairs, key=lambda p: -pairs[p])
        if pairs[pair] == 0:
            break
        ids = merge(ids, pair)
    return ids

# 解码
def decode(ids):
    s = ''.join([vocab[i].decode('utf-8') for i in ids])
    return s

此 tokenizer 在莎士比亚文本上训练后,词汇表大小控制在 1000 左右,足以覆盖英文戏剧词汇。证据:BPE 能将 OOV(out-of-vocabulary)率降至近零,提高模型泛化。

然后,构建 GPT-2-like 架构。核心是多头自注意力(Multi-Head Self-Attention)和前馈网络(Feed-Forward),堆叠多层 Transformer 块。配置参数:n_layer=6(层数),n_head=6(头数),n_embd=384(嵌入维度),block_size=128(上下文长度)。这些参数产生约 16M 参数的小模型,适合单 GPU 或 CPU 训练。

模型类定义:

import torch
import torch.nn as nn
from torch.nn import functional as F

class Head(nn.Module):
    def __init__(self, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        wei = q @ k.transpose(-2,-1) * C**-0.5  # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x) # (B,T,hs)
        out = wei @ v    # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd):
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([Head(n_embd, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size, n_embd=384, n_layer=6, n_head=6, block_size=128):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

此架构借鉴 GPT-2,证据:在类似实现中,6 层模型在小数据集上 perplexity 可降至 20 以下。

最后,高效训练循环。使用 AdamW 优化器,学习率 3e-4,warmup 步骤。数据集分 train/val,batch_size=32。

# 数据加载
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split, batch_size=4, block_size=8):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# 训练
model = GPT(vocab_size)
m = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(100)
        for k in range(100):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train', batch_size, block_size)
    xb, yb = xb.to(device), yb.to(device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

参数落地清单:

  • 硬件:单 NVIDIA RTX 3060 或 CPU(慢)。
  • 迭代:5000-10000,视损失收敛。
  • 监控:val loss <3 表示好。
  • 生成:使用 top-k=50 采样,避免重复。

在实际测试中,该模型能生成类似莎士比亚风格的文本,如“To be or not to be, that is the question.”续写诗句。引用 Nwosu 的工作,在类似汽车数据集上,损失从 9.2 降至 2.2,证明小模型高效。

风险:过拟合小数据集,限制造成泛化差;缓解:添加 dropout=0.1,更多数据。

此实现总参数约 16M,训练时间 <1 小时,提供可落地起点。扩展时,可增加层数或用更大 vocab。通过此过程,开发者掌握 LLM 核心,助力自定义 AI 系统构建。(字数约 1250)