Skip to content

分布式计算 — 代码走读

Parallel State — vllm/distributed/parallel_state.py

python
class ParallelState:
    """管理所有并行相关的进程组"""

    def __init__(self):
        self.tp_group = None    # Tensor Parallel group
        self.pp_group = None    # Pipeline Parallel group
        self.dp_group = None    # Data Parallel group
        self.ep_group = None    # Expert Parallel group

    def initialize(self, parallel_config):
        world_size = parallel_config.world_size
        tp_size = parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size

        # 创建 TP 进程组
        for i in range(world_size // tp_size):
            ranks = list(range(i * tp_size, (i + 1) * tp_size))
            group = dist.new_group(ranks)
            for rank in ranks:
                self.tp_group[rank] = group

通信操作

Tensor Parallel 通信

python
def tensor_model_parallel_all_reduce(input_):
    """TP 组内的 All-Reduce"""
    group = get_tp_group()
    dist.all_reduce(input_, group=group)
    return input_

def tensor_model_parallel_all_gather(input_, dim=-1):
    """TP 组内的 All-Gather"""
    group = get_tp_group()
    world_size = get_tp_world_size()
    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    dist.all_gather(tensor_list, input_, group=group)
    return torch.cat(tensor_list, dim=dim)

KV Transfer — vllm/distributed/kv_transfer/

python
class KVTransferAgent:
    """跨 GPU 的 KV 缓存传输"""

    def send_kv_cache(self, kv_cache, request_id):
        """发送 KV 缓存到目标 GPU"""
        data = self._serialize_kv_cache(kv_cache)
        self.transport.send(data, request_id)

    def recv_kv_cache(self, request_id):
        """接收 KV 缓存"""
        data = self.transport.recv(request_id)
        return self._deserialize_kv_cache(data)

关键函数索引

函数/类文件职责
ParallelState.initialize()distributed/parallel_state.py初始化进程组
tensor_model_parallel_all_reduce()distributed/parallel_state.pyTP All-Reduce
tensor_model_parallel_all_gather()distributed/parallel_state.pyTP All-Gather
get_tp_group()distributed/parallel_state.py获取 TP 进程组
KVTransferAgentdistributed/kv_transfer/KV 缓存传输