Skip to content

量化系统 — 代码走读

量化配置 — vllm/config/quantization.py

python
@dataclass
class QuantizationConfig:
    quant_method: str  # "gptq", "awq", "fp8", etc.

    @staticmethod
    def from_config(model_config):
        """从 HuggingFace config 中解析量化配置"""
        quant_cfg = model_config.quantization_config
        method = quant_cfg.get("quant_method", "")

        if method == "gptq":
            return GPTQConfig(**quant_cfg)
        elif method == "awq":
            return AWQConfig(**quant_cfg)
        elif method == "fp8":
            return FP8Config(**quant_cfg)
        ...

量化线性层 — model_executor/layers/quantization/

GPTQ 实现

python
class GPTQMarlinLinearMethod(QuantizeMethodBase):
    def create_quantized_linear(self, linear):
        # 替换权重为量化格式
        linear.weight = Parameter(
            self._pack_qweight(linear.weight),
            requires_grad=False,
        )
        linear.scales = Parameter(self._compute_scales())
        return linear

    def apply(self, linear, x):
        # 调用 Marlin kernel
        return gptq_marlin_gemm(
            x, linear.weight, linear.scales,
            linear.workspace, linear.w_qzeros,
        )

FP8 实现

python
class Fp8LinearMethod(QuantizeMethodBase):
    def process_weights_after_loading(self, model):
        for layer in model.modules():
            if isinstance(layer, Linear):
                # 将 FP16 权重转换为 FP8
                layer.weight = Parameter(
                    torch._scaled_mm(
                        layer.weight, scale=1.0,
                        output_dtype=torch.float8_e4m3fn,
                    ),
                    requires_grad=False,
                )

    def apply(self, linear, x):
        # FP8 矩阵乘法
        return torch._scaled_mm(
            x, linear.weight,
            scale_a=linear.input_scale,
            scale_b=linear.weight_scale,
            output_dtype=torch.float16,
        )

量化框架集成

权重加载时自动量化

python
# 在模型加载时
class DefaultModelLoader:
    def _load_module(self, module, weights):
        for name, param in module.named_parameters():
            if hasattr(module, 'quant_method'):
                # 量化方法处理权重
                module.quant_method.create_quantized_linear(module)
            else:
                # 正常加载
                module.load_weights(weights)

关键函数索引

函数/类文件职责
QuantizationConfig.from_config()config/quantization.py解析量化配置
GPTQMarlinLinearMethod.apply()layers/quantization/gptq/GPTQ 矩阵乘法
Fp8LinearMethod.apply()layers/quantization/fp8/FP8 矩阵乘法
AWQMarlinLinearMethod.apply()layers/quantization/awq/AWQ 矩阵乘法
process_weights_after_loading()各量化方法权重后处理