mindscience.distributed.mappings.gather_from_sequence

mindscience.distributed.mappings.gather_from_sequence(x, group, tensor_parallel_output_grad=True)[source]

Gathers sequence partitions along the first dimension.

Parameters
  • x (Tensor) – Input tensor with sequence partitions along the first dimension.

  • group (Union[CommGroup, CommGroupBase]) – Communication group for the operation.

  • tensor_parallel_output_grad (bool, optional) – Flag to determine whether to use reduce-scatter (True) or scatter (False) in backward pass. Default: True.

Returns

Tensor with all sequence partitions gathered along the first dimension.