mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[源代码]
- 返回一个经过自动混合精度处理的网络。 - 该接口会对输入网络进行自动混合精度处理,处理后的网络里的Cell和算子增加了精度转换操作,以低精度进行计算,如 - mstype.float16或- mstype.bfloat16。 Cell和算子的输入和参数被转换成低精度浮点数,计算结果被转换回全精度浮点数,即- mstype.float32。- 框架内置了一组黑名单和白名单, amp_level 决定了具体对哪些Cell和算子进行精度转换。 - 当前的内置白名单内容为: - [ - mindspore.nn.Conv1d,- mindspore.nn.Conv2d,- mindspore.nn.Conv3d,- mindspore.nn.Conv1dTranspose,- mindspore.nn.Conv2dTranspose,- mindspore.nn.Conv3dTranspose,- mindspore.nn.Dense,- mindspore.nn.LSTMCell,- mindspore.nn.RNNCell,- mindspore.nn.GRUCell,- mindspore.ops.Conv2D,- mindspore.ops.Conv3D,- mindspore.ops.Conv2DTranspose,- mindspore.ops.Conv3DTranspose,- mindspore.ops.MatMul,- mindspore.ops.BatchMatMul,- mindspore.ops.PReLU,- mindspore.ops.ReLU,- mindspore.ops.Ger]- 当前的内置黑名单内容为: - [ - mindspore.nn.BatchNorm1d,- mindspore.nn.BatchNorm2d,- mindspore.nn.BatchNorm3d,- mindspore.nn.LayerNorm]- 关于自动混合精度的详细介绍,请参考 自动混合精度 。 - 说明 - 重复调用混合精度接口,如 custom_mixed_precision 和 auto_mixed_precision ,可能导致网络层数增大,性能降低。 
- 如果使用 - mindspore.train.Model和- mindspore.amp.build_train_network()等接口来训练经 过 custom_mixed_precision 和 auto_mixed_precision 等混合精度接口转换后的网络,则需要将 amp_level 配置 为- O0以避免重复的精度转换。
 - 参数:
- network (Cell) - 定义网络结构。 
- amp_level (str) - 支持[“O0”, “O1”, “O2”, “O3”]。默认值: - "O0"。- “O0” - 不变化。 
- “O1” - 仅将白名单内的Cell和算子转换为低精度运算,其余部分保持全精度运算。 
- “O2” - 黑名单内的Cell和算子保持全精度运算,其余部分都转换为低精度运算。 
- “O3” - 将网络全部转为低精度运算。 
 
- dtype (Type) - 低精度计算时使用的数据类型,可以是 - mstype.float16或- mstype.bfloat16。默认值:- mstype.float16。
 
- 异常:
- TypeError - network 不是Cell。 
- ValueError - amp_level 不在支持范围内。 
- ValueError - dtype 既不是 - mstype.float16也不是- mstype.bfloat16。
 
 - 样例: - >>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)