Appearance
通过 tiling 和 kernel fusion 减少 HBM 访问次数的 IO 感知注意力算法,是 vLLM 注意力计算的核心后端。
为什么需要 FlashAttention
标准注意力实现需将完整的 S = QK^T 和 P = softmax(S)V 写入 HBM,内存访问量为 O(N^2 d),成为计算瓶颈。FlashAttention 通过分块计算(tiling)将 Q/K/V 分块送入 SRAM,在 on-chip 上完成 softmax 和 matmul,仅将最终输出写回 HBM,内存访问量降至 O(N^2 d^2 / M),其中 M 是 SRAM 大小。
核心原理
- Tiling:将 Q/K/V 按块大小切分,每块在 SRAM 中计算局部注意力后累积到输出。
- Online Softmax:通过维护 running max 和 running sum 实现分块 softmax 的数值稳定计算。
- IO 复杂度:相比标准注意力的 O(N^2 d) HBM 访问,FlashAttention 降至 O(N^2 d^2 / M),在长序列上提速 2-4x。
- FlashAttention-2/3:进一步优化 parallelism(2D grid)和 async(与 Tensor Core 流水),性能持续提升。
在源码中的实现
vllm/attention/backends/flash_attn.py— FlashAttention 后端实现,封装 flash-attn 库调用。vllm/attention/backends/rocm_flash_attn.py— AMD ROCm 版本的 FlashAttention。vllm/attention/layer.py— Attention 层根据配置选择后端(FlashAttention、XFormers 等)。vllm/attention/ops/— 底层注意力 kernel 的自定义实现与封装。
相关概念
- paged-attention — FlashAttention kernel 需适配 PagedAttention 的非连续 KV 寻址
- kv-cache — FlashAttention 读取 KV Cache 进行注意力计算
- tensor-parallelism — FlashAttention kernel 需感知 TP 的 QKV 切分
- chunked-prefill — Chunked prefill 依赖 FlashAttention 的分块计算能力