Skip to content

通过 tiling 和 kernel fusion 减少 HBM 访问次数的 IO 感知注意力算法,是 vLLM 注意力计算的核心后端。

为什么需要 FlashAttention

标准注意力实现需将完整的 S = QK^T 和 P = softmax(S)V 写入 HBM,内存访问量为 O(N^2 d),成为计算瓶颈。FlashAttention 通过分块计算(tiling)将 Q/K/V 分块送入 SRAM,在 on-chip 上完成 softmax 和 matmul,仅将最终输出写回 HBM,内存访问量降至 O(N^2 d^2 / M),其中 M 是 SRAM 大小。

核心原理

  • Tiling:将 Q/K/V 按块大小切分,每块在 SRAM 中计算局部注意力后累积到输出。
  • Online Softmax:通过维护 running max 和 running sum 实现分块 softmax 的数值稳定计算。
  • IO 复杂度:相比标准注意力的 O(N^2 d) HBM 访问,FlashAttention 降至 O(N^2 d^2 / M),在长序列上提速 2-4x。
  • FlashAttention-2/3:进一步优化 parallelism(2D grid)和 async(与 Tensor Core 流水),性能持续提升。

在源码中的实现

  • vllm/attention/backends/flash_attn.py — FlashAttention 后端实现,封装 flash-attn 库调用。
  • vllm/attention/backends/rocm_flash_attn.py — AMD ROCm 版本的 FlashAttention。
  • vllm/attention/layer.py — Attention 层根据配置选择后端(FlashAttention、XFormers 等)。
  • vllm/attention/ops/ — 底层注意力 kernel 的自定义实现与封装。

相关概念

  • paged-attention — FlashAttention kernel 需适配 PagedAttention 的非连续 KV 寻址
  • kv-cache — FlashAttention 读取 KV Cache 进行注意力计算
  • tensor-parallelism — FlashAttention kernel 需感知 TP 的 QKV 切分
  • chunked-prefill — Chunked prefill 依赖 FlashAttention 的分块计算能力