mindspore.nn.TrainOneStepWithLossScaleCell

class mindspore.nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]

Network training with loss scaling.

This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as args. The loss scale value can be updated in both host side or device side. If you want to update it on host side, using a value of Tensor type as scale_sense, otherwise, using a Cell instance for updating loss scale as scale_sense.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Cell) – Optimizer for updating the network parameters.

  • scale_sense (Union[Tensor, Cell]) – If this value is a Cell, it will be called by TrainOneStepWithLossScaleCell to update loss scale. If this value is a Tensor, the loss scale can be modified by set_sense_scale, the shape should be \(()\) or \((1,)\).

Inputs:
  • (*inputs) (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.

  • loss (Tensor) - Tensor with shape \(()\).

  • overflow (Tensor) - Tensor with shape \(()\), type is bool.

  • loss scale (Tensor) - Tensor with shape \(()\)

Raises
  • TypeError – If scale_sense is neither Cell nor Tensor.

  • ValueError – If shape of scale_sense is neither (1,) nor ().

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter, nn, ops
>>> from mindspore import dtype as mstype
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = nn.WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
>>> output = train_network(input, labels)
>>>
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = nn.WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor([1024], dtype=mstype.float32)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
>>>
>>> # update scaling sens and train the network
>>> scaling_sens = Tensor([1], dtype=mstype.float32)
>>> train_network.set_sense_scale(scaling_sens)
>>> output = train_network(inputs, label)
get_overflow_status(status, compute_output)[source]

Get floating-point overflow status.

Get overflow results after executing the target process for overflow detection. User-defined training network based on this class can also call this interface to process the overflow.

Parameters
  • status (object) – A status instance used to detect the overflow.

  • compute_output – Overflow detection should be performed on a certain computation. Set compute_output as the output of the computation, to ensure overflow status is acquired before executing the computation.

Returns

bool, whether the overflow occurs or not.

process_loss_scale(overflow)[source]

Calculate loss scale according to the overflow.

User-defined training network based on this class can also call this interface to process the overflow.

Parameters

overflow (bool) – Whether the overflow occurs or not.

Returns

bool, the input overflow value.

set_sense_scale(sens)[source]

If the user has set the scale_sense of Tensor type, he can call this function to reassign the value.

Parameters

sens (Tensor) – The new sense whose shape and type are the same with original scale_sense.

start_overflow_check(pre_cond, compute_input)[source]

Start floating-point overflow detection. Create and clear the overflow detection state.

Specify the argument ‘pre_cond’ and ‘compute_input’ to make sure overflow status is cleared at the right time. Taking this situation as an example, we need to execute state clearing after loss calculation and then detect overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss function, and compute_input should be the input of gradients-computing function. User-defined training network based on this class can also call this interface to process the overflow.

Parameters
  • pre_cond (Tensor) – A precondition for starting overflow detection. It determines the executing order of overflow state clearing and prior processions. It makes sure that the function ‘start_overflow’ clears status after finishing the process of precondition.

  • compute_input (object) – The input of subsequent process. Overflow detection should be performed on a certain computation. Set compute_input as the input of the computation, to ensure overflow status is cleared before executing the computation.

Returns

Tuple[object, object], the first value is False for GPU backend, while it is an instance of NPUAllocFloatStatus for other backend. The status is used to detect overflow during get_overflow_status. The second value is the same as the input of compute_input, but contains some information about the execution order.