Appearance
推测解码 — 概念
基本原理
自回归推理每步只生成 1 个 token,GPU 利用率很低。推测解码通过"猜"多个 token 来提高利用率:
数学原理
对于每个候选 token,计算接受概率:
accept_prob = min(1, p_target(token) / p_draft(token))- 如果
p_target > p_draft:总是接受 - 如果
p_target < p_draft:以概率p_target/p_draft接受 - 被拒绝的 token 后重新采样,保证分布不变
推测方法
N-gram Proposer
最简单的方法,不需要额外模型:
- 维护历史 n-gram 频率表
- 根据前 n-1 个 token 预测下一个 token
- GPU 加速版本使用 trie 结构高效查找
优点:零额外计算开销 缺点:准确率低,接受率有限
EAGLE
基于模型特征的方法:
- 利用目标模型的中间层特征
- 草稿模型很小(通常只有 1-2 层)
- 准确率高(85-95%),接受率高
Medusa
多头预测方法:
- 在目标模型上添加多个预测头
- 每个头预测不同未来位置的 token
- 不需要单独的草稿模型
- 训练成本低
方法对比
| 方法 | 额外模型 | 准确率 | 加速比 | 适用场景 |
|---|---|---|---|---|
| N-gram | 无 | 低 | 1.2-1.5× | 简单场景 |
| EAGLE | 小模型 | 高 | 2-3× | 通用 |
| Medusa | 多头 | 中高 | 1.5-2.5× | 不想用草稿模型 |
| DFlash | 小模型 | 高 | 2-3× | Flash attention |
| Suffix Decoding | 无 | 中 | 1.3-1.8× | 重复性文本 |
Rejection Sampling
验证阶段的核心算法:
python
def rejection_sample(draft_tokens, draft_probs, target_probs):
accepted = []
for i, token in enumerate(draft_tokens):
p_draft = draft_probs[i][token]
p_target = target_probs[i][token]
# 接受概率
accept_prob = min(1.0, p_target / p_draft)
if random() < accept_prob:
accepted.append(token)
else:
# 从调整后的分布中重新采样
adjusted = max(0, p_target - p_draft)
new_token = sample(adjusted)
accepted.append(new_token)
break # 后续 token 全部丢弃
return accepted推测解码与连续批处理
推测解码需要与 Continuous Batching 协同工作:
- 每个请求可能有不同数量的候选 token
- 调度器需要为推测解码预留足够的计算预算
- 被拒绝的 token 的 KV 缓存需要回滚
混合模型的推测解码
Mamba + Attention 推测解码
混合模型(如 Mamba + Attention)的推测解码面临 CPU-GPU 同步瓶颈:Mamba 状态在验证失败后需要回滚,传统方法需要将状态从 GPU 拷贝到 CPU 再恢复。
vLLM 引入了融合 Triton 内核 postprocess_mamba_fused_kernel,完全在 GPU 上执行 Mamba 状态复制:
关键组件:
- MambaSpecDecodeGPUContext:预计算内存布局元数据(基地址、步幅、元素大小、卷积宽度)
- postprocess_mamba_align_gpu():align 模式缓存回滚
- postprocess_mamba_all():
mamba_cache_mode="all"模式回滚 - MambaBuffers:从
MambaCopyBuffers重构,增加 postprocess 子对象
EAGLE-3 后规范架构
EAGLE-3 推测模型现在支持后规范(post-norm)架构,扩展了可适配的模型范围。
相关概念
- Continuous Batching — 连续批处理
- KV Cache — 推测解码的 KV 缓存回滚
- CUDA Graph — 推测解码对图捕获的影响