mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[source]
Returns a network processed with auto mixed precision.
This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision:
mstype.float16ormstype.bfloat16. Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e.mstype.float32.The framework has a set of built-in blacklists and whitelists, and the amp_level determines which cells and operators are specifically converted.
The current built-in whitelist contents are:
[
mindspore.nn.Conv1d,mindspore.nn.Conv2d,mindspore.nn.Conv3d,mindspore.nn.Conv1dTranspose,mindspore.nn.Conv2dTranspose,mindspore.nn.Conv3dTranspose,mindspore.nn.Dense,mindspore.nn.LSTMCell,mindspore.nn.RNNCell,mindspore.nn.GRUCell,mindspore.ops.Conv2D,mindspore.ops.Conv3D,mindspore.ops.Conv2DTranspose,mindspore.ops.Conv3DTranspose,mindspore.ops.MatMul,mindspore.ops.BatchMatMul,mindspore.ops.PReLU,mindspore.ops.ReLU,mindspore.ops.Ger]The current built-in blacklist contents are:
[
mindspore.nn.BatchNorm1d,mindspore.nn.BatchNorm2d,mindspore.nn.BatchNorm3d,mindspore.nn.LayerNorm]For details on automatic mixed precision, refer to Automatic Mix Precision .
Note
Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.
If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to
O0to avoid the duplicated accuracy conversion.
- Parameters
network (Cell) – Definition of the network.
amp_level (str) –
Supports [“O0”, “O1”, “O2”, “O3”]. Default:
"O0".”O0”: Do not change.
”O1”: Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest.
”O2”: Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations.
”O3”: Cast network to lower precision.
dtype (Type) – The type used in lower precision calculations, can be
mstype.float16ormstype.bfloat16, default:mstype.float16.
- Raises
TypeError – If network is not a Cell.
ValueError – If dtype is not one of
mstype.float16,mstype.bfloat16.ValueError – If amp_level is not within the supported range.
Examples
>>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)