mindscience.distributed.modules.RowParallelLinear

class mindscience.distributed.modules.RowParallelLinear(in_features, out_features, bias=True, input_is_parallel=False, use_sequence_parallel=False, weight_init=None, bias_init=None, param_init_dtype=ms.float32, compute_dtype=ms.bfloat16)[source]

Row-parallel linear layer that shards the input feature dimension across TP ranks.

Parameters
  • in_features (int) – Total number of input features across all TP ranks.

  • out_features (int) – Number of output features.

  • bias (bool, optional) – Whether create a bias parameter (not sharded along input dim). Default: True.

  • input_is_parallel (bool, optional) – Whether expect the input already partitioned across TP ranks. Default: False.

  • use_sequence_parallel (bool, optional) – Whether use reduce scatter for output instead of all reduce. Default: False.

  • weight_init (Union[Initializer, str], optional) – Weight initialization method. Default: None.

  • bias_init (Union[Initializer, str], optional) – Bias initialization method. Default: None.

  • param_init_dtype (mstype.dtype, optional) – Parameter initialization data type. Default: ms.float32.

  • compute_dtype (mstype.dtype, optional) – Computation data type. Default: ms.bfloat16.

Inputs:
  • x (Tensor): Input tensor of shape (seq_len, in_features // TP) or (seq_len, in_features), depending on whether input_is_parallel is True.

Outputs:
  • output (Tensor): Output tensor of shape (seq_len // TP, out_features) or (seq_len, out_features), depending on whether use_sequence_parallel is True.

Examples

>>> import mindspore as ms
>>> from mindspore.communication import init
>>> from mindscience.distributed import initialize_parallel
>>> from mindscience.distributed.modules import RowParallelLinear
>>> init()
>>> initialize_parallel(tensor_parallel_size=2)
>>> linear = RowParallelLinear(in_features=1024, out_features=512, bias=True)
>>> input_tensor = ms.mint.randn(32, 1024)
>>> output = linear(input_tensor)
>>> print(output.shape)
(32, 512)