mindspore.amp.auto_mixed_precision

mindspore.amp.auto_mixed_precision(network, amp_level='O0')[源代码]

对Cell进行自动混合精度处理。

参数:
  • network (Cell) - 定义网络结构。

  • amp_level (str) - 支持[“O0”, “O1”, “O2”, “O3”]。默认值:”O0”。

    • “O0” - 不变化。

    • “O1” - 将白名单内的Cell和运算转为float16精度,其余部分保持float32精度。

    • “O2” - 将黑名单内的Cell和运算保持float32精度,其余部分转为float16精度。

    • “O3” - 将网络全部转为float16精度。

异常:
  • ValueError - amp_level 不在支持范围内。

样例:

>>> from mindspore import amp, nn
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)