By strint, Depeng Liang and Luyang Zhao
S : 总参数个数,等于 ZeRO文章中的符号 ψ,当下文描述通信量的时候,只统计总的通信参数个数与 Zero 文章对应。 g : fp32梯度,4S字节 p : fp32参数,4S 字节 m : adam momentum, 4S 字节 v : adam variance, 4S 字节 n : 数据并行设备数也就是进程数
对于普通的混合精度的数据并行,假设模型总参数个数是 S,则显存占用量是 (K + 2 + 2)* S 字节。其中 K=12 (fp32类型的 adam states m 和 v 和 p,每个都是4个字节,加起来等于12),接下来两个2 分别表示是 ,fp16 的参数和梯度。
普通混合精度数据并行的通信量, 反向之后做一次 all-reduce(reduce-scatter + allgather),同步fp16 梯度 。总通信量 2S。
概括下: g_fp16[2S Bytes] → Communication(all_reduce, 2S) → g_fp16[2S Bytes,inplace] → g[4S Bytes] → m[4S Bytes] + v[4S Bytes] → p[4S Bytes, 释放g] → p_fp16[2S Bytes]
K = 4 * 3 = 12 total_mem_size = 2S + K*S + 2S= 16S Bytes total_comm_volume = 2S
首先把一个group的 fp16 的所有参数 flatten 成一维的向量,然后按顺序拼接在一起,变成一个完整的一维向量, 对应 代码。之后对这个一维度向量做切分,形成子分片为计算和通信做准备。 stage1 的切分的方式是先按通信次数切成多次通信的通信分片,再对通信分片切分出进程分片。切分通信分片的依据是构建了一个最大单次总通信数据量 max_elems_per_comm,则每个进程单次通信量是 sub_partition_size = max_elems_per_comm / n。
不同设备所负责更新的参数部分配方式如下图,对应代码:
首先把拼成一维的 fp16 参数 按照 sub_partition_size 分段,接着按顺序给不同的进程分配其负责更新的参数分段:
接着每个 rank 将自己负责的 fp16 参数分段拷贝一份转换成 fp32 ,这样子就得到了本地需要负责更新的 fp32 参数分片。
接着把本地的 fp32 参数分片替换掉 optimizer 中的参数 ,对应代码。然后初始化 optimizer states,Adam 的 m和 v,由于optimizer 中的参数已经是本地分片大小,所以创建的m和v就是分片大小。到这里就完成了 fp32 的 参数, m 和 v 的按设备数均分。而fp16的参数和梯度还是每张卡都有一份完整的。
这时候显存开销降为 (2 + 2 + K / n ) * S Bytes。
每个进程在反向完成得到所有 fp16 梯度之后,再进行 reduce-scatter 同步并分发梯度,每个 rank 拿到本地负责的梯度,转为 fp32 在参数更新之后就释放掉,而 fp16 的梯度也会释放掉但是因为是后向完成之后才释放的,所以峰值显存占用还是需要考虑 fp16 的梯度。接着做一次 all-gather 收集更新之后的 fp16 参数,所以总通信量与普通混合精度数据并行一致 2S 。
概括下: g_fp16[2S Bytes] → Communication(reduce_scatter, S) → part_g_fp16[2S Bytes, inplace] → part_g[4S/n Bytes,释放无关的和有关的part_g_fp16,这里是整个后向完成后才释放] → part_m[4S/n Bytes] + part_v[4S/n Bytes] → part_p[4S/n Bytes,释放part_g] → part_p_fp16[2S/n Bytes]→ Communication(all_gather, S) → p_fp16[2S Bytes]
tatal_mem_size(os) = 2S + 12S/n + 2S = 4S + 12S/n Bytes total_comm_volume = 2S