mindscience.distributed.modules.ColumnParallelLinear

class mindscience.distributed.modules.ColumnParallelLinear(in_features, out_features, bias=True, gather_output=True, use_sequence_parallel=False, weight_init=None, bias_init=None, param_init_dtype=ms.float32, compute_dtype=ms.bfloat16)[源代码]

列并行线性层,将输出特征维度在TP通信组中进行分片。

参数:
  • in_features (int) - 输入特征数量。

  • out_features (int) - TP通信组内所有卡的总输出特征数量。

  • bias (bool, 可选) - 是否创建并分割与输出分片一致的偏置参数。默认值:True

  • gather_output (bool, 可选) - 是否将TP本地输出收集到完整输出中。默认值:True

  • use_sequence_parallel (bool, 可选) - 是否对输入使用all gather而不是广播。默认值:False

  • weight_init (Union[Initializer, str], 可选) - 权重初始化方法。默认值:None

  • bias_init (Union[Initializer, str], 可选) - 偏置初始化方法。默认值:None

  • param_init_dtype (mstype.dtype, 可选) - 参数初始化数据类型。默认值:ms.float32

  • compute_dtype (mstype.dtype, 可选) - 计算数据类型。默认值:ms.bfloat16

输入:
  • x (Tensor) - 形状为 (seq_len // TP, in_features) 或 (seq_len, in_features) 的输入张量,具体取决于 use_sequence_parallel 是否为 True

输出:
  • output (Tensor) - 形状为 (seq_len, out_features) 或 (seq_len, out_features // TP) 的输出张量,具体取决于 gather_output 是否为 True

样例:

>>> import mindspore as ms
>>> from mindspore.communication import init
>>> from mindscience.distributed import initialize_parallel
>>> from mindscience.distributed.modules import ColumnParallelLinear
>>> init()
>>> initialize_parallel(tensor_parallel_size=2)
>>> linear = ColumnParallelLinear(in_features=512, out_features=1024, bias=True)
>>> input_tensor = ms.mint.randn(32, 512)
>>> output = linear(input_tensor)
>>> print(output.shape)
(32, 1024)