mindspore.nn.WithEvalCell

class mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)[source]

Wraps the forward network with the loss function.

It returns loss, forward output and label to calculate the metrics.

Parameters
  • network (Cell) – The forward network.

  • loss_fn (Cell) – The loss function.

  • add_cast_fp32 (bool) – Whether to adjust the data type to float32. Default: False.

Inputs:
  • data (Tensor) - Tensor of shape \((N, \ldots)\).

  • label (Tensor) - Tensor of shape \((N, \ldots)\).

Outputs:

Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape \((N, \ldots)\) and a label Tensor of shape \((N, \ldots)\).

Raises

TypeError – If add_cast_fp32 is not a bool.

Supported Platforms:

Ascend GPU CPU

Examples

>>> # Forward network without loss function
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> eval_net = nn.WithEvalCell(net, loss_fn)