一、为何使用 DeviceMesh?
在混合并行(DP/TP/PP/HSDP/…)中,需要管理多个子通信组(ProcessGroup),对应复杂的设备拓扑结构。DeviceMesh
提供了:
- 理论上无缝支持任意维度的多维拓扑;
- 自动拆分进程组(
new_group
/split_group
); - 灵活切片子 Mesh;
- 经历设计周全的高效初始化方案 (docs.pytorch.org, pytorch.org)。
二、初始化流程
init_device_mesh(...)
的作用
一个一行搞定的方法,它会:
- 初始化全局
init_process_group(...)
(若未初始化); - 根据
mesh_shape
自动构造 CPU 上的torch.arange(...).view(...)
; - 创建
DeviceMesh(...)
。内部完成子组拆分原理(见下一节)。
DeviceMesh.__init__()
+ _init_process_groups()
存储:
device_type
、mesh
、mesh_dim_names
;通信组拆分:遍历每个维度
dim
:- 使用
mesh.swapdims(-1, dim).reshape(-1, size(dim))
列出该维所有子组 rank; - 若 NCCL 已绑定 GPU,即可用
split_group
一次拆出全部子组; - 否则使用
new_group()
分 group 拆; - 并将当前 rank 属于的那组信息放入
self._dim_group_infos[dim]
;
- 使用
结果:每个维度对应一个包含当前 rank 的
ProcessGroup
信息列表。
1 | #pp |
三、核心接口与内部实现解析
1. 属性与方法
1 | mesh.shape # tuple(self.mesh.shape) |
用于获取 mesh 元结构和规模,适用于判断维度数量、循环迭代、并行策略配置等场景。
2. Rank 与坐标
get_rank()
:等价于torch.distributed.get_rank()
,返回全局 rank;get_local_rank(mesh_dim)
:内部调用get_rank(self.get_group(mesh_dim))
→ 当前维度的小组内编号;get_coordinate()
:返回self._coordinate_on_dim
,其在初始化中通过(self.mesh==global_rank).nonzero()
获得。
示例:mesh_shape=(4,2)
,rank=5 → local_pp=2、local_tp=1,coordinate [2,1]
。
3. 通信组获取
get_group(mesh_dim)
:- 若 1D 且不传参,直接返回唯一子进程组;
- 多维则根据
mesh_dim
(索引或名字)检索self._dim_group_infos[dim]
,用_find_pg_by_ranks_and_tag()
获取对应ProcessGroup
。
get_all_groups()
:返回所有维度的 group 列表;__getitem__(dims)
:切片接口调用_mesh_resources._get_slice_mesh_dims(...)
,创建新的子 mesh,保留底层 communicator,但维度降。- 支持单维或多维切片,且返回的 submesh 顺序按传入顺序排列 (discuss.ray.io, gemfury.com, pytorch.org)。
4. from_group(...)
方法
- 可接受单 group 或 group 列表;
- 创建新的
DeviceMesh
时不会调用 backend 初始化; - 会复用现有
ProcessGroup
,并填充_dim_group_infos
,因此get_group(...)
将直接返回传入的实例,避免重复创建 group。
四、完整单机 8 卡 Demo:tp=2, pp=4
下面演示如何调用所有接口并输出结果。注意:需在 torchrun --nproc_per_node=8
下运行。
1 | import os, torch, torch.distributed as dist |
💬 预期输出(例如 rank = 5):
rank=5, coord=[2,1], local_pp=2, local_tp=1
ndim=2, shape=(4,2), total=8, pp=4, tp=2
pp_group ranks: [4,5,6,7]
tp_group ranks: [5,7]
all_groups sizes: [4,2]
tp_mesh ndim, shape: 1 (2,)
pp_mesh ndim, shape: 1 (4,)
说明:
- rank=5 位于 pipeline 段 2,tp 内编号 1;
pp_group
包含与其同 segment 的 4 张卡;tp_group
包含同 segment tp 维度的两张卡;- 切片后
tp_mesh
、pp_mesh
成为 1 维结构,用于后续 parallelization。
👏 总结
DeviceMesh
构建自身通过init_device_mesh()
完成初始化与子组拆分;- 接口内部实现逻辑与 Group 管理机制清晰、高效;
__getitem__
为多维并行下子 Mesh 切片关键工具,对集成 parallel APIs 至关重要;- 通过该机制,可以简单地组织复杂的 hybrid-parallel pipelines,同时充分复用 communicator 资源并简化开发流程。