[ "MindSpore Made Easy" ]
MindSpore Made Easy Summary of Training with Mixed Precision
August 12, 2022
1. Overview
The training with mixed precision method accelerates the deep neural network training process by mixing the single-precision floating-point data format and the half-precision floating-point data format without compromising the network accuracy. It can also accelerate the computing process, reduce memory usage and retrieval, and enable a larger model or batch size to be trained on specific hardware. The computation process in MindSpore requires:
1. Parameters are stored in FP32 format.
2. During forward propagation, inputs and parameters of FP16 operators are cast from FP32 to FP16.
3. Set the loss layer to FP32 for computation.
4. During backward propagation, the value of Loss Scale is multiplied first to avoid underflow caused by a small gradient descent.
5. The FP16 parameters are used in gradient calculation, and the result is cast back to FP32.
6. The result is divided by Loss Scale to restore the enlarged gradient.
7. The optimizer checks whether the gradient overflows. If yes, the optimizer skips the update. If no, the optimizer updates the original parameters using FP32.
2. Application Scenarios
The mixed precision can accelerate computing and reduce memory usage. Therefore, you can use it in the following scenarios:
1. Memory resources are insufficient.
2. The training speed is low.
3. Usage Rules
This blog is intended for users who:
1. Have basic understanding of MindSpore and are about to start MindSpore training code migration tasks.
2. Have completed MindSpore training code migration tasks, that is, have obtained the MindSpore training code.
4. Usage Samples
(1) MindSpore High-Level APIs with Mixed Precision
MindSpore encapsulates mixed precision in the mindspore.Model interface for users to call. The specific implementation procedure is the same as that of writing common training code. You only need to set parameters related to mixed precision in mindspore.Model, for example, amp_level, loss_scale_manager and keep_batchnorm_fp32.
Modify the mindspore.Model interface in the high-level API code to set amp_level to O3. Then, the network uses mixed precision for training.
net = Model(net, loss, opt, metrics=metrics, amp_level="O3")
(2) MindSpore Low-Level APIs with Mixed Precision
To enable MindSpore low-level APIs to use mixed precision, you only need to enable the mixed precision training of the network in the step of constructing a model using MindSpore low-level API code. The following compares the two model construction modes.
Construct a model using the MindSpore low-level API code:
class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, my_network, my_criterion, train_batch_size, class_num):
super(BuildTrainNetwork, self).__init__()
self.network = my_network
self.criterion = my_criterion
self.print = P.Print()
# Initialize self.output
self.output = mindspore.Parameter(Tensor(np.ones((train_batch_size,
class_num)), mindspore.float32), requires_grad=False)
def construct(self, input_data, label):
output = self.network(input_data)
# Get the network output and assign it to self.output
self.output = output
loss0 = self.criterion(output, label)
return loss0
class TrainOneStepCellV2(TrainOneStepCell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
# Obtain self.network from BuildTrainNetwork
output = self.network.output
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
# Get the gradient of the network parameters
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
# Optimize model parameters
loss = F.depend(loss, self.optimizer(grads))
return loss, output
model_constructed = BuildTrainNetwork(net, loss_function,
TRAIN_BATCH_SIZE, CLASS_NUM)
model_constructed = TrainOneStepCellV2(model_constructed, opt)
Construct a model using the MindSpore low-level API code with mixed precision:
class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, my_network, my_criterion, train_batch_size, class_num):
super(BuildTrainNetwork, self).__init__()
self.network = my_network
self.criterion = my_criterion
self.print = P.Print()
# Initialize self.output
self.output = mindspore.Parameter(Tensor(np.ones((train_batch_size,
class_num)), mindspore.float32), requires_grad=False)
def construct(self, input_data, label):
output = self.network(input_data)
# Get the network output and assign it to self.output
self.output = output
loss0 = self.criterion(output, label)
return loss0
class TrainOneStepCellV2(TrainOneStepCell):
'''Build train network.'''
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
# Obtain self.network from BuildTrainNetwork
output = self.network.output
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
# Get the gradient of the network parameters
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
# Optimize model parameters
loss = F.depend(loss, self.optimizer(grads))
return loss, output
def build_train_network_step2(network, optimizer,
loss_fn=None, level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
"""
amp.validator.check_value_type('network', network, nn.Cell)
amp.validator.check_value_type('optimizer', optimizer, nn.Optimizer)
amp.validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"],
amp.Rel.IN)
if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError(
"Level `auto` only support when `device_target` is GPU or Ascend.")
amp._check_kwargs(kwargs)
config = dict(amp._config_level[level], **kwargs)
config = amp.edict(config)
if config.cast_model_type == mstype.float16:
network.to_float(mstype.float16)
if config.keep_batchnorm_fp32:
amp._do_keep_batchnorm_fp32(network)
if loss_fn:
network = amp._add_loss_network(network, loss_fn,
config.cast_model_type)
if amp._get_parallel_mode() in (amp.ParallelMode.SEMI_AUTO_PARALLEL,
amp.ParallelMode.AUTO_PARALLEL):
network = amp._VirtualDatasetCell(network)
loss_scale = 1.0
if config.loss_scale_manager is not None:
loss_scale_manager = config.loss_scale_manager
loss_scale = loss_scale_manager.get_loss_scale()
update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None:
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
if not context.get_context("enable_ge")
and context.get_context("device_target") == "CPU":
raise ValueError("Only `loss_scale_manager=None` and "
"`loss_scale_manager=FixedLossScaleManager`"
"are supported in current version. If you use `O2` option,"
"use `loss_scale_manager=None` or `FixedLossScaleManager`")
network = TrainOneStepCellV2(network, optimizer)
return network
network = TrainOneStepCellV2(network, optimizer)
return network
model_constructed = BuildTrainNetwork(net, loss_function, TRAIN_BATCH_SIZE, CLASS_NUM)
model_constructed = build_train_network_step2(model_constructed, opt, level="O3")
5. Performance Comparison
Compared with full-precision training, the performance is greatly improved after the mixed precision is used.
Low-level APIs: 2000 imgs/sec; Low-level APIs with mixed precision: 3200 imgs/sec
High-level APIs: 2200 imgs/sec; High-level APIs with mixed precision: 3300 imgs/sec