如图 (a) 所示,以一般矩阵乘法为例,假设我们有 $Y=XW$ ,其中 $X\in\mathbb R^{2\times2}$ 为输入数据,$W\in\mathbb R^{2\times2}$ 为参数矩阵,在两块 GPU 上进行并行计算,输入数据 $X$ 与权重向量 $W$ 进行矩阵相乘时,计算行列对之间的点积是相互独立的。列并行就是将权重参数 $W$ 沿列分割成 $W=[W_0\ W_1]$,每块 GPU 持有一列参数 $W_0, W_1 \in \mathbb R^{2\times1}$,如图 (b) 所示,我们将 $X$ 分别输入 rank 0 和 rank 1 的 GPU 上,然后与 GPU 上的参数进行矩阵相乘,我们将得到 $Y_0, Y_1$ ,然后通过一次拼接操作就可以得到和 (a) 等价的 $Y$。而行并行则是将 $W$ 沿行进行切分 $W=[W_0 \ W_1]^{\mathrm T}$ 放置,每块 GPU 持有一行参数 $W_0, W_1 \in \mathbb R^{1\times2}$ ,然后将输入 $X$ 也沿列切分为 $X_0, X_1 \in \mathbb R^{2\times1}$ 并输入到两块 GPU 上分别进行矩阵乘法运算得到 $Y^\prime,Y^{\prime\prime}\in\mathbb R^{2\times 2}$ ,然后按元素位置相加 $Y^\prime+Y^{\prime\prime}$ 也可以得到等价的 $Y$。
0 commit comments