mindscience.distributed.modules.RowParallelLinear
- class mindscience.distributed.modules.RowParallelLinear(in_features, out_features, bias=True, input_is_parallel=False, use_sequence_parallel=False, weight_init=None, bias_init=None, param_init_dtype=ms.float32, compute_dtype=ms.bfloat16)[source]
Row-parallel linear layer that shards the input feature dimension across TP ranks.
- Parameters
in_features (int) – Total number of input features across all TP ranks.
out_features (int) – Number of output features.
bias (bool, optional) – Whether create a bias parameter (not sharded along input dim). Default:
True.input_is_parallel (bool, optional) – Whether expect the input already partitioned across TP ranks. Default:
False.use_sequence_parallel (bool, optional) – Whether use reduce scatter for output instead of all reduce. Default:
False.weight_init (Union[Initializer, str], optional) – Weight initialization method. Default:
None.bias_init (Union[Initializer, str], optional) – Bias initialization method. Default:
None.param_init_dtype (mstype.dtype, optional) – Parameter initialization data type. Default:
ms.float32.compute_dtype (mstype.dtype, optional) – Computation data type. Default:
ms.bfloat16.
- Inputs:
x (Tensor): Input tensor of shape (seq_len, in_features // TP) or (seq_len, in_features), depending on whether input_is_parallel is True.
- Outputs:
output (Tensor): Output tensor of shape (seq_len // TP, out_features) or (seq_len, out_features), depending on whether use_sequence_parallel is True.
Examples
>>> import mindspore as ms >>> from mindspore.communication import init >>> from mindscience.distributed import initialize_parallel >>> from mindscience.distributed.modules import RowParallelLinear >>> init() >>> initialize_parallel(tensor_parallel_size=2) >>> linear = RowParallelLinear(in_features=1024, out_features=512, bias=True) >>> input_tensor = ms.mint.randn(32, 1024) >>> output = linear(input_tensor) >>> print(output.shape) (32, 512)