mindspore.ops.all_gather_matmul
- mindspore.ops.all_gather_matmul(input, x2, group, world_size, *, bias=None, gather_index=0, gather_output=True, comm_turn=0, trans_input=False, trans_x2=False) Tensor[源代码]
- TP 切分场景下,实现 allgather 和 matmul 的融合,融合算子内部实现通信和计算流水并行。 \[ \begin{align}\begin{aligned}output = allgather(input)@x2\\gather\_out = allgather(input)\end{aligned}\end{align} \]- 警告 - 这是一个实验性 API,后续可能修改或删除。 - 参数:
- input (Tensor) - matmul 的左矩阵,dtype 支持 float16、bfloat16,shape 支持二维,数据格式支持 ND。 
- x2 (Tensor) - matmul 的右矩阵,dtype 需要和 - input一致,shape 支持二维,数据格式支持 ND。
- group (str) - 通信组名称,可以由 - create_group方法创建,或者使用默认组- mindspore.communication.GlobalComm.WORLD_COMM_GROUP。
- world_size (int) - 通信组的总进程数,要求与实际运行的卡数一致,支持 - 2、- 4、- 8。
 
- 关键字参数:
- bias (Tensor, 可选) - 当前仅支持 - None。默认- None。
- gather_index (int, 可选) - 表示 allgather 操作对象, - 0表示对- input做 gather,- 1表示对- x2做 gather。当前仅支持- 0。默认- 0。
- gather_output (bool, 可选) - 表示是否需要 gather 输出。默认 - True。
- comm_turn (int, 可选) - 表示进程间通信切分粒度。当前仅支持 - 0。默认- 0。
- trans_input (bool, 可选) - 表示 - input是否转置。当前仅支持- False。默认- False。
- trans_x2 (bool, 可选) - 表示 - x2是否转置。默认- False。
 
- 返回:
- output (Tensor) - allgather 和 matmul 融合计算的结果。 
- gather_out (Tensor) - allgather 的结果。如果 gather_output 为 - False,gather_out 返回 shape 为 0 的 tensor。
 
 - 说明 - 使用该接口时,请确保驱动固件包和 CANN 包都为配套的 8.0.RC2 版本或者配套的更高版本,否则将会引发报错,比如 BUS ERROR 等。 
- input的 shape 为 (m, k),- x2的 shape 为 (k, n),要求 k 相等,且 k 的取值范围为 [256, 65535)。- output的 shape 为 (m * world_size, n),- gather_out的 shape 为 (m * world_size, k)。
- 一个模型中的通算融合算子仅支持相同通信组。 
 - 异常:
- TypeError - 参数的类型不对。 
- RuntimeError - - input或- x2的 dtype 不是 float16 或 bfloat16。
- RuntimeError - - input和- x2的 dtype 不一致。
- RuntimeError - - input或- x2的 shape 不是二维。
- RuntimeError - - inputshape 和- x2shape 的 k 不相等。
- RuntimeError - k 小于 - 256或大于等于- 65535。
- RuntimeError - - bias不是- None。
- RuntimeError - - group不存在。
- RuntimeError - - world_size与实际运行的卡数不一致。
- RuntimeError - - world_size不等于- 2、- 4、- 8。
- RuntimeError - - gather_index不是- 0。
- RuntimeError - - trans_input为- True。
 
- 支持平台:
- Ascend
 - 样例: - 说明 - 运行以下样例之前,需要配置好通信环境变量。 - 针对Ascend/GPU/CPU设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动 。 - 该样例需要在 2 卡环境下运行。 - >>> import mindspore as ms >>> import numpy as np >>> from mindspore import ops >>> ms.communication.init() >>> rank = ms.communication.get_rank() >>> np.random.seed(rank) >>> input = ms.Tensor(np.random.randn(128, 256).astype(np.float32), dtype=ms.float16) >>> x2 = ms.Tensor(np.random.randn(256, 512).astype(np.float32), dtype=ms.float16) >>> group = ms.communication.GlobalComm.WORLD_COMM_GROUP >>> world_size = ms.communication.get_group_size() >>> output, gather_out = ops.all_gather_matmul( ... input, ... x2, ... group, ... world_size, ... bias=None, ... gather_index=0, ... gather_output=True, ... comm_turn=0, ... trans_input=False, ... trans_x2=False, ... ) >>> print(output.shape) (256, 512) >>> print(gather_out.shape) (256, 256)