mindspore.amp.custom_mixed_precision
- mindspore.amp.custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16)[源代码]
- 通过配置白名单或黑名单,对Cell进行自定义混合精度处理。 当提供 white_list 时,网络中包含在 white_list 里的Primitive或Cell会进行精度转换。 当提供 black_list 时,网络中不包含在 black_list 里的Cell会进行精度转换。 需要提供 white_list 和 black_list 中的一个。 - 说明 - 重复调用混合精度接口,如 custom_mixed_precision 和 auto_mixed_precision ,可能导致网络层数增大,性能降低。 
- 如果使用 - mindspore.train.Model和- mindspore.amp.build_train_network()等接口来训练经 过 custom_mixed_precision 和 auto_mixed_precision 等混合精度接口转换后的网络,则需要将 amp_level 配置 为- O0以避免重复的精度转换。
- 当使用黑名单时,Primitive类型还未支持。 
 - 参数:
- network (Cell) - 定义网络结构。 
- white_list (list[Primitive, Cell], optional) - 自定义混合精度的白名单。默认值: - None。
- black_list (list[Cell], optional) - 自定义混合精度的黑名单。默认值: - None。
- dtype (Type) - 低精度计算时使用的数据类型,可以是 - mstype.float16或- mstype.bfloat16。默认值:- mstype.float16。
 
- 返回:
- network (Cell),支持混合精度的网络。 
- 异常:
- TypeError - network 的类型不是Cell。 
- ValueError - white_list 和 black_list 都没提供。 
- ValueError - 同时提供了 white_list 和 black_list 。 
- ValueError - dtype 既不是 - mstype.float16也不是- mstype.bfloat16。
 
 - 样例: - >>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> custom_white_list = amp.get_white_list() >>> custom_white_list.append(nn.Flatten) >>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)