mindscience.distributed.modules.ColumnParallelLinear

class mindscience.distributed.modules.ColumnParallelLinear(in_features, out_features, bias=True, gather_output=True, use_sequence_parallel=False, weight_init=None, bias_init=None, param_init_dtype=ms.float32, compute_dtype=ms.bfloat16)[source]

Column-parallel linear layer that shards the output feature dimension across TP ranks.

Parameters
  • in_features (int) – Number of input features.

  • out_features (int) – Total number of output features across all TP ranks.

  • bias (bool, optional) – Whether create and partition a bias parameter consistent with the output sharding. Default: True.

  • gather_output (bool, optional) – Whether gather the TP-local outputs into a full output via. Default: True.

  • use_sequence_parallel (bool, optional) – Whether use all gather for input instead of broadcast. Default: False.

  • weight_init (Union[Initializer, str], optional) – Weight initialization method. Default: None.

  • bias_init (Union[Initializer, str], optional) – Bias initialization method. Default: None.

  • param_init_dtype (mstype.dtype, optional) – Parameter initialization data type. Default: ms.float32.

  • compute_dtype (mstype.dtype, optional) – Computation data type. Default: ms.bfloat16.

Inputs:
  • x (Tensor): Input tensor of shape (seq_len // TP, in_features) or (seq_len, in_features), depending on whether use_sequence_parallel is True.

Outputs:
  • output (Tensor): Output tensor of shape (seq_len, out_features) or (seq_len, out_features // TP), depending on whether gather_output is True.

Examples

>>> import mindspore as ms
>>> from mindspore.communication import init
>>> from mindscience.distributed import initialize_parallel
>>> from mindscience.distributed.modules import ColumnParallelLinear
>>> init()
>>> initialize_parallel(tensor_parallel_size=2)
>>> linear = ColumnParallelLinear(in_features=512, out_features=1024, bias=True)
>>> input_tensor = ms.mint.randn(32, 512)
>>> output = linear(input_tensor)
>>> print(output.shape)
(32, 1024)