Appearance
分布式计算 — 代码走读
Parallel State — vllm/distributed/parallel_state.py
python
class ParallelState:
"""管理所有并行相关的进程组"""
def __init__(self):
self.tp_group = None # Tensor Parallel group
self.pp_group = None # Pipeline Parallel group
self.dp_group = None # Data Parallel group
self.ep_group = None # Expert Parallel group
def initialize(self, parallel_config):
world_size = parallel_config.world_size
tp_size = parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
# 创建 TP 进程组
for i in range(world_size // tp_size):
ranks = list(range(i * tp_size, (i + 1) * tp_size))
group = dist.new_group(ranks)
for rank in ranks:
self.tp_group[rank] = group通信操作
Tensor Parallel 通信
python
def tensor_model_parallel_all_reduce(input_):
"""TP 组内的 All-Reduce"""
group = get_tp_group()
dist.all_reduce(input_, group=group)
return input_
def tensor_model_parallel_all_gather(input_, dim=-1):
"""TP 组内的 All-Gather"""
group = get_tp_group()
world_size = get_tp_world_size()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
dist.all_gather(tensor_list, input_, group=group)
return torch.cat(tensor_list, dim=dim)KV Transfer — vllm/distributed/kv_transfer/
python
class KVTransferAgent:
"""跨 GPU 的 KV 缓存传输"""
def send_kv_cache(self, kv_cache, request_id):
"""发送 KV 缓存到目标 GPU"""
data = self._serialize_kv_cache(kv_cache)
self.transport.send(data, request_id)
def recv_kv_cache(self, request_id):
"""接收 KV 缓存"""
data = self.transport.recv(request_id)
return self._deserialize_kv_cache(data)关键函数索引
| 函数/类 | 文件 | 职责 |
|---|---|---|
ParallelState.initialize() | distributed/parallel_state.py | 初始化进程组 |
tensor_model_parallel_all_reduce() | distributed/parallel_state.py | TP All-Reduce |
tensor_model_parallel_all_gather() | distributed/parallel_state.py | TP All-Gather |
get_tp_group() | distributed/parallel_state.py | 获取 TP 进程组 |
KVTransferAgent | distributed/kv_transfer/ | KV 缓存传输 |