mindspore.nn.DynamicLossScaleUpdateCell

查看源文件
class mindspore.nn.DynamicLossScaleUpdateCell(loss_scale_value, scale_factor, scale_window)[源代码]

用于动态更新损失缩放系数(loss scale)的神经元。

使用混合精度功能进行训练时,初始损失缩放系数值为 loss_scale_value。在每个训练步骤中,当出现溢出时,通过计算公式 loss_scale/scale_factor 减小损失缩放系数。如果连续 scale_window 步(step)未溢出,则将通过 loss_scale * scale_factor 增大损失缩放系数。

该类是 mindspore.amp.DynamicLossScaleManagerget_update_cell 方法的返回值。训练过程中,类 mindspore.nn.TrainOneStepWithLossScaleCell 会调用该Cell来更新损失缩放系数。

参数:
  • loss_scale_value (float) - 初始的损失缩放系数。

  • scale_factor (int) - 增减系数。

  • scale_window (int) - 未溢出时,增大损失缩放系数的最大连续训练步数。

输入:
  • loss_scale (Tensor) - 训练期间的损失缩放系数,是一个标量,shape为 \(()\)

  • overflow (bool) - 是否发生溢出。

输出:

Bool,即输入 overflow

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter, nn, ops
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> in_features, out_features = 16, 10
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = nn.WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
>>> output = train_network(input, labels)
get_loss_scale()[源代码]

获取当前损失缩放系数。

返回:

float,损失缩放系数。

样例:

>>> from mindspore import nn
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=212, scale_factor=2, scale_window=1000)
>>> output = manager.get_loss_scale()
>>> print(output)
212