mindscience.distributed.modules.ColumnParallelLinear
- class mindscience.distributed.modules.ColumnParallelLinear(in_features, out_features, bias=True, gather_output=True, use_sequence_parallel=False, weight_init=None, bias_init=None, param_init_dtype=ms.float32, compute_dtype=ms.bfloat16)[source]
Column-parallel linear layer that shards the output feature dimension across TP ranks.
- Parameters
in_features (int) – Number of input features.
out_features (int) – Total number of output features across all TP ranks.
bias (bool, optional) – Whether create and partition a bias parameter consistent with the output sharding. Default:
True.gather_output (bool, optional) – Whether gather the TP-local outputs into a full output via. Default:
True.use_sequence_parallel (bool, optional) – Whether use all gather for input instead of broadcast. Default:
False.weight_init (Union[Initializer, str], optional) – Weight initialization method. Default:
None.bias_init (Union[Initializer, str], optional) – Bias initialization method. Default:
None.param_init_dtype (mstype.dtype, optional) – Parameter initialization data type. Default:
ms.float32.compute_dtype (mstype.dtype, optional) – Computation data type. Default:
ms.bfloat16.
- Inputs:
x (Tensor): Input tensor of shape (seq_len // TP, in_features) or (seq_len, in_features), depending on whether use_sequence_parallel is True.
- Outputs:
output (Tensor): Output tensor of shape (seq_len, out_features) or (seq_len, out_features // TP), depending on whether gather_output is True.
Examples
>>> import mindspore as ms >>> from mindspore.communication import init >>> from mindscience.distributed import initialize_parallel >>> from mindscience.distributed.modules import ColumnParallelLinear >>> init() >>> initialize_parallel(tensor_parallel_size=2) >>> linear = ColumnParallelLinear(in_features=512, out_features=1024, bias=True) >>> input_tensor = ms.mint.randn(32, 512) >>> output = linear(input_tensor) >>> print(output.shape) (32, 1024)