mindspore.nn.OptTFTWrapper
- class mindspore.nn.OptTFTWrapper(opt, **kwargs)[源代码]
- 实现TFT优化器封装器。该封装器将在优化器更新前向MindIO TFT上报状态。 - 说明 - 该优化器依赖于MindIO TFT特性。当前只支持Ascend后端的图模式,并且sink_size的配置必须小于等于1。 - 参数:
- opt (Optimizer) - 该参数必须为Optimizer的子类。 
 
- 输入:
- gradients (tuple[Tensor]) - 参数opt的 params 的梯度,shape与opt的 params shape 相同。 
 
- 输出:
- Tensor,优化器opt执行返回的结果。 
- 异常:
- TypeError - 如果opt不是Optimizer的子类。 
- ValueError - 如果不是运行在Ascend后端的图模式,或者用户不开启TFT特性。 
 
- 支持平台:
- Ascend
 - 样例: - >>> import mindspore as ms >>> from mindspore import nn >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.SGD(params=net.trainable_params()) >>> optim_wrapper = nn.OptTFTWrapper(optim) >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)