Optimizer Parallel

View Source On Gitee

Overview

When performing data parallel training, the parameter update part of the model is computed redundantly across cards. Optimizer parallelism can effectively reduce memory consumption and improve network performance on large-scale networks (e.g., Bert, GPT) by spreading the computation of the optimizer to the cards of the data parallel dimension.

In data parallel mode to enable optimizer parallelism, the framework will spread the parameters to be updated to different cards, and then do weight sharing among clusters by Broadcast operator after each update. It should be noted that the number of parameters should be greater than the number of machines, and currently only Lamb and AdamWeightDecay optimizers are supported.

In auto_parallel or semi_auto_parallel mode to enable optimizer parallelism, if the parameters after slicing strategy have duplicate slices between machines and the highest dimension of the shape is divisible by the cardinality of the duplicate slices, the framework saves the parameters as minimal slices and updates them in the optimizer. All optimizers are supported in this mode.

Parallel mode

Parameter update mode

Optimizer support

Backend support

Data parallelism

The parameter groups are updated, then are broadcasted to all cards

Lamb, AdamWeightDecay和AdaFactor

Ascend

Full/semi-automatic parallel

The parameters are sliced into N copies according to data parallelism, and each card updates the parameters on the current card

all optimizers

Ascend, GPU

In either mode, the optimizer parallelism does not affect the compute graph of the original forward and backward network, but only the compute volume and compute logic of the parameter updates.

Basic Principles

The traditional data parallel model keeps copies of the model parameters on each device, slices the training data, synchronizes the gradient information after each iteration by using communication operators, and finally updates the parameters through optimizer calculations. Data parallelism, while effective in improving training throughput, does not maximize the use of machine resources. The optimizer introduces redundant memory and computation, eliminating these redundancies is an optimization point to focus on.

In a training iteration, the data parallelism introduces a communication operation to synchronize the gradients across multiple cards to collect the parameter gradients generated by the different samples on each card. Because the model parallelism is not involved, the optimizer operations on each card are actually updated based on the same parameters and in the same direction. The fundamental idea of eliminating optimizer redundancy is to spread this memory and computation across the cards to achieve memory and performance gains.

If you want to implement parallel computing for the optimizer, there are two implementation ideas, weights grouping and weights sharding. One of the weights grouping is to do inter-layer division of the parameters and gradients within the optimizer, and the general training flow is shown in Figure 1. The parameters and gradients are grouped onto different cards to be updated, and then the updated weights are shared among devices through a communication broadcast operation. The memory and performance gains of the solution depend on the group with the largest proportion of parameters. When the parameters are divided evenly, the theoretical positive gains are N-1/N of optimizer runtime and dynamic memory, and N-1/N of memory size for optimizer state parameters, where N denotes the number of devices. And the negative gain introduced is the communication time that comes when sharing network weights.

images

Figure 1: Schematic diagram of the parameter grouping training process

Another way to implement parameter slicing is to do intra-layer division of parameters, and take the corresponding slice for each parameter and gradient according to the device number. After updating the parameters and gradients, the communication aggregation operation is called to share the parameters among devices. The advantage of this scheme is that it naturally supports load balancing, i.e., the number of parameters and computations are consistent on each card, and the disadvantage is that the shape of the parameter requires to be divisible by the number of devices. The theoretical gains of this scheme are consistent with the parameter grouping, and the following improvements are made to the framework in order to extend the advantages.

First, slice the weights in the network can further reduce static memory. However, this also requires performing the shared weight operation at the end of the iteration before the forward start of the next iteration, ensuring that the original tensor shape remains the same after going into the forward and backward operations. In addition, the main negative gain from the parallel operation of the optimizer is the communication time of the shared weights, which can bring a performance gain if we can reduce or hide it. One advantage of communication cross-iteration execution is that communication operations can be executed interleaved with the forward network by fusing the communication operators in appropriate groups, thus hiding the communication time consumption as much as possible. The communication time consumption is also related to the communication volume. For the network involving mixed precision, if we can use fp16 communication, the communication volume will be reduced by half compared to fp32. Combining the above characteristics, the implementation scheme of parameter slicing is shown in Figure 2.

image

Figure 2: Schematic diagram of the parameter slicing training process

In the test validation of the actual network training, we found that the memory gain from parameter slicing is significant. In particular, for large-scale network models, the popular Adaptive Moment estimation (Adam) and Layer-wise Adaptive Moments optimizer for Batching training (LAMB) are usually chosen to train the network, and the number of parameters and computations of the optimizer itself should not be neglected. After parameter grouping, the weight parameters in the network and the two copies of state parameters in the optimizer are reduced by a factor of N-1/N, which greatly saves the static memory. This provides the possibility to increase the number of samples in a single iteration and improve the overall training throughput, which effectively solves the memory pressure of large-scale network training.

Optimizer parameter slicing implemented by MindSpore also has the advantage of being mixed with operator-level parallelism. When the number of sliced parts in the operator-level model parallel parameters are smaller than the number of dimensions, the optimizer parameters can continue to be sliced in the dimension of data parallelism, increasing the utilization of machine resources and thus improving the end-to-end performance.

Operation Practice

Sample Code Description

The directory structure is as follows:

└─sample_code
    ├─distributed_optimizer_parallel
        ├── fusion_example.py
        ├── rank_table_2pcs.json
        ├── rank_table_8pcs.json
        └── run_fusion_example.sh

The role of each file is as follows:

  • fusion_example.py: Sample code fused by optimizer, illustrating how to configure the fusion flag for the optimizer.

  • rank_table_2pcs.json: 2-card configuration file of RANK_TABLE_FILE.

  • rank_table_8pcs.json: 8-card configuration file of RANK_TABLE_FILE.

  • run_fusion_example.sh: Startup script for optimizer fusion code.

Turning on Optimizer Parallel

The enable_parallel_optimizer option is provided in mindspore.set_auto_parallel_context. Configure it to True to enable optimizer parallelism. By default, optimizer slicing is performed for all parameters that take up no less than 64KB of memory.

import mindspore as ms
ms.set_auto_parallel_context(enable_parallel_optimizer=True)

Configurating Parameter Optimizer Parallel

In addition, the user can customize whether optimizer slicing is performed in certain parameters. Parameter provides a parallel_optimizer parameter to configure whether optimizer slicing is performed in current parameters. So the user configures whether to turn on optimizer parallelism for each parameter individually, as follows:

import numpy as np
import mindspore as ms
param = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight1', parallel_optimizer=True)

# Another way to set the parallel_optimizer attribute
param2 = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight2')
param2.parallel_optimizer = False

The optimizer parallel feature also provides the configuration dictionary parallel_optimizer_config. By configuring different key values in the context, different effects can be achieved:

  • gradient_accumulation_shard(bool): If True, the cumulative gradient variables will be sliced on the data parallelism. When accumulating gradients, an additional communication (ReduceScatter) will be introduced in each accumulation iteration to ensure computational consistency, but saves a large amount of compute device memory (e.g. GPU video memory), thus allowing the model to be trained in larger batches. This configuration is valid only if the model is set in pipelined parallel training or gradient accumulation and has a data parallel dimension. The default value is True.

    import mindspore as ms
    ms.set_auto_parallel_context(parallel_optimizer_config={"gradient_accumulation_shard": True}, enable_parallel_optimizer=True)
    
  • parallel_optimizer_threshold(int): This value indicates the minimum value of memory required for the target parameter when slicing the parameter. When the target parameter is smaller than this value, it will not be sliced.

    import numpy as np
    import mindspore as ms
    param = ms.Parameter(ms.Tensor(np.ones((10, 2)), dtype=ms.float32), name='weight1')
    # The float32 type occupies 4 Bytes of memory:
    # param_size = np.prod(list(param.shape)) * 4 = (10 * 2) * 4 = 80B < 24KB, not be sliced
    ms.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 24})
    

Configuring Communication Fusion

In the section Configurating Parameter Optimizer Parallel, we describe how to configure the optimizer parallelism property for each parameter. In full/semi-automatic mode, each parameter generates a corresponding AllGather operation and a ReduceScatter operation. These communication operators are inserted automatically by the auto-parallel framework. However, as the number of parameters increases, the number of corresponding communication operator increases, and the communication operations generate more overhead for both operator scheduling and startup. Therefore, fusion tokens can be configured for the AllGather and ReduceScatter operations corresponding to the parameters within each cell through the set_comm_fusion method provided by the cell.

As shown in the following code, the set_comm_fusion method is called for the instantiated DenseLayer to set the fusion value for each layer.

"""Parallel Optimizer Fusion Example"""
from mindspore.communication import init
from mindspore import nn
import mindspore as ms
init()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, enable_parallel_optimizer=True)

class DenseLayer(nn.Cell):
    """A base layer with two dense layer"""
    def __init__(self):
        super().__init__()
        self.input_mapping = nn.Dense(10, 10)
        self.output_mapping = nn.Dense(10, 10)
    def construct(self, x):
        x = self.input_mapping(x)
        return self.output_mapping(x)

class Net(nn.Cell):
    """An network with many dense layers"""
    def __init__(self):
        super().__init__()
        self.layer1 = DenseLayer()
        self.layer2 = DenseLayer()
        self.layer3 = DenseLayer()
        self.layer1.set_comm_fusion(0)
        self.layer2.set_comm_fusion(1)
        self.layer3.set_comm_fusion(2)
    def construct(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

net = Net()
for item in net.trainable_params():
    print(f"The parameter {item.name}'s fusion id is {item.comm_fusion}")

The corresponding output is as follows, indicating the funsion value for each layer of the particular dense:

The parameter layer1.input_mapping.weight's fusion id is 0
The parameter layer1.input_mapping.bias's fusion id is 0
The parameter layer1.output_mapping.weight's fusion id is 0
The parameter layer1.output_mapping.bias's fusion id is 0
The parameter layer2.input_mapping.weight's fusion id is 1
The parameter layer2.input_mapping.bias's fusion id is 1
The parameter layer2.output_mapping.weight's fusion id is 1
The parameter layer2.output_mapping.bias's fusion id is 1
The parameter layer3.input_mapping.weight's fusion id is 2
The parameter layer3.input_mapping.bias's fusion id is 2
The parameter layer3.output_mapping.weight's fusion id is 2
The parameter layer3.output_mapping.bias's fusion id is 2

In the flow of the compilation diagram, the same fusion tags and are the same communication operations are fused into one communication operation. Thus, the number of communication operations is reduced. For communication operators with fusion mark as 0, they are not fused in the optimization process.

When the optimizer slicing is turned on, a corresponding communication operator is generated for each parameter in the network. However, frequent calls to the communication operators will cause more operator startup consumption. MindSpore provides the most efficient way to reduce the number of communication operators by fusing them into a single communication operator, but this leads to a waste of computational resources. For example, after fusing all the communication operators into one operator, in the current training iteration, after the convergence of the slicing parameters is completed, the NPU performs the forward computation of the network, which will cause the device to wait.

To avoid these problems, the network parameters can be grouped and fused: the communication of the next group of parameters is performed at the same time as the computation performed by the previous group of parameters, allowing the computation and communication to be hidden from each other. This is the reason why the above code sets different fusion values for layer2 and layer3.

Running the Code

The above code needs to be configured with distributed variables before it can run. The Ascend environment needs to be configured with RANK_TABLE_FILE, RANK_ID and DEVICE_ID. For the configuration process, refer to here. GPU environment needs to configure OpenMPI、NCCL和HOST_FILE. For the configuration process, refer to here.

Environment variables related to ascend distributed are:

  • RANK_TABLE_FILE: Path to the network information file. The rank_table_file file can be generated by using hccl_tools.py in the models code repository, which can be obtained from here.

  • DEVICE_ID: The actual serial number of the current card on the machine.

  • RANK_ID: The logical serial number of the current card.

The environment variable related to GPU distributed is:

  • HOST_FILE: Describes the IP and number of devices for multi-card training. Each line of the file has the format [hostname] slots=[slotnum], and hostname can be an ip or hostname. Note that the username needs to be the same on different machines, but the hostname cannot be the same.

The user can access the above script in this document via here. Execute the following bash script to run the program and output the log in the device0/train.log0 file.

#!/bin/bash
set -e
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_fusion_example.sh DATA_PATH RANK_SIZE"
echo "For example: bash run_fusion_example.sh 8"
echo "It is better to use the absolute path."
echo "This example is expected to run on the Ascend environment."
echo "=============================================================================================================="
RANK_SIZE=$1

EXEC_PATH=$(pwd)

test_dist_8pcs()
{
    export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
    export RANK_SIZE=8
}

test_dist_2pcs()
{
    export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
    export RANK_SIZE=2
}

test_dist_${RANK_SIZE}pcs

for((i=0;i<${RANK_SIZE};i++))
do
    rm -rf device$i
    mkdir device$i
    cp ./fusion_example.py ./device$i
    cd ./device$i
    export DEVICE_ID=$i
    export RANK_ID=$i
    echo "start training for device $i"
    env > env$i.log
    pytest -s -v ./fusion_example.py > train.log$i 2>&1 &
    cd ../
done
echo "The program launch succeed, the log is under device0/train.log0."

After configuring RANK_TABLE_FILE in the current directory, the following command requires the user to have 8 Ascend 910 devices. Run the command as follows:

bash run_fusion_example.sh 8