mindscience.distributed.mappings.gather_from_sequence
- mindscience.distributed.mappings.gather_from_sequence(x, group, tensor_parallel_output_grad=True)[source]
Gathers sequence partitions along the first dimension.
- Parameters
x (Tensor) – Input tensor with sequence partitions along the first dimension.
group (Union[CommGroup, CommGroupBase]) – Communication group for the operation.
tensor_parallel_output_grad (bool, optional) – Flag to determine whether to use reduce-scatter (True) or scatter (False) in backward pass. Default:
True.
- Returns
Tensor with all sequence partitions gathered along the first dimension.