Skip to content

推测解码 — 概念

基本原理

自回归推理每步只生成 1 个 token,GPU 利用率很低。推测解码通过"猜"多个 token 来提高利用率:

数学原理

对于每个候选 token,计算接受概率:

accept_prob = min(1, p_target(token) / p_draft(token))
  • 如果 p_target > p_draft:总是接受
  • 如果 p_target < p_draft:以概率 p_target/p_draft 接受
  • 被拒绝的 token 后重新采样,保证分布不变

推测方法

N-gram Proposer

最简单的方法,不需要额外模型:

  • 维护历史 n-gram 频率表
  • 根据前 n-1 个 token 预测下一个 token
  • GPU 加速版本使用 trie 结构高效查找

优点:零额外计算开销 缺点:准确率低,接受率有限

EAGLE

基于模型特征的方法:

  • 利用目标模型的中间层特征
  • 草稿模型很小(通常只有 1-2 层)
  • 准确率高(85-95%),接受率高

Medusa

多头预测方法:

  • 在目标模型上添加多个预测头
  • 每个头预测不同未来位置的 token
  • 不需要单独的草稿模型
  • 训练成本低

方法对比

方法额外模型准确率加速比适用场景
N-gram1.2-1.5×简单场景
EAGLE小模型2-3×通用
Medusa多头中高1.5-2.5×不想用草稿模型
DFlash小模型2-3×Flash attention
Suffix Decoding1.3-1.8×重复性文本

Rejection Sampling

验证阶段的核心算法:

python
def rejection_sample(draft_tokens, draft_probs, target_probs):
    accepted = []
    for i, token in enumerate(draft_tokens):
        p_draft = draft_probs[i][token]
        p_target = target_probs[i][token]

        # 接受概率
        accept_prob = min(1.0, p_target / p_draft)

        if random() < accept_prob:
            accepted.append(token)
        else:
            # 从调整后的分布中重新采样
            adjusted = max(0, p_target - p_draft)
            new_token = sample(adjusted)
            accepted.append(new_token)
            break  # 后续 token 全部丢弃

    return accepted

推测解码与连续批处理

推测解码需要与 Continuous Batching 协同工作:

  • 每个请求可能有不同数量的候选 token
  • 调度器需要为推测解码预留足够的计算预算
  • 被拒绝的 token 的 KV 缓存需要回滚

混合模型的推测解码

Mamba + Attention 推测解码

混合模型(如 Mamba + Attention)的推测解码面临 CPU-GPU 同步瓶颈:Mamba 状态在验证失败后需要回滚,传统方法需要将状态从 GPU 拷贝到 CPU 再恢复。

vLLM 引入了融合 Triton 内核 postprocess_mamba_fused_kernel,完全在 GPU 上执行 Mamba 状态复制:

关键组件:

  • MambaSpecDecodeGPUContext:预计算内存布局元数据(基地址、步幅、元素大小、卷积宽度)
  • postprocess_mamba_align_gpu():align 模式缓存回滚
  • postprocess_mamba_all()mamba_cache_mode="all" 模式回滚
  • MambaBuffers:从 MambaCopyBuffers 重构,增加 postprocess 子对象

EAGLE-3 后规范架构

EAGLE-3 推测模型现在支持后规范(post-norm)架构,扩展了可适配的模型范围。

相关概念