mindspore.nn.FocalLoss

View Source On Gitee
class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]

It is a loss function to solve the imbalance of categories and the difference of classification difficulty. The loss function proposed by Kaiming team in their paper Focal Loss for Dense Object Detection improves the effect of image object detection. The function is shown as follows:

\[FL(p_t) = -(1-p_t)^\gamma \log(p_t)\]
Parameters
  • gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0 .

  • weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weight is applied. Default: None .

  • reduction (str, optional) –

    Apply specific reduction method to the output: 'none' , 'mean' , 'sum' . Default: 'mean' .

    • 'none': no reduction will be applied.

    • 'mean': compute and return the weighted mean of elements in the output.

    • 'sum': the output elements will be summed.

Inputs:
  • logits (Tensor) - Tensor of shape should be \((N, C)\) or \((N, C, H)\) or \((N, C, H, W)\). Where \(C\) is the number of classes. Its value is greater than 1. If the shape is \((N, C, H, W)\) or \((N, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as labels.

  • labels (Tensor) - Tensor of shape should be \((N, C)\) or \((N, C, H)\) or \((N, C, H, W)\). The value of \(C\) is 1 or it needs to be the same as predict’s \(C\). If \(C\) is not 1, the shape of target should be the same as that of predict, where \(C\) is the number of classes. If the shape is \((N, C, H, W)\) or \((N, C, H)\), the \(H\) or product of \(H\) and \(W\) should be the same as logits. The value of labels is should be in the range [-\(C\), \(C\)). Where \(C\) is the number of classes in logits.

Outputs:

Tensor or Scalar, if reduction is "none", its shape is the same as logits. Otherwise, a scalar value will be returned.

Raises
  • TypeError – If the data type of gamma is not a float.

  • TypeError – If weight is not a Tensor.

  • ValueError – If labels dim is different from logits.

  • ValueError – If labels channel is not 1 and labels shape is different from logits.

  • ValueError – If reduction is not one of 'none', 'mean', 'sum'.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> logits = ms.Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], ms.float32)
>>> labels = ms.Tensor([[1], [1], [0]], ms.int32)
>>> focalloss = nn.FocalLoss(weight=ms.Tensor([1, 2]), gamma=2.0, reduction='mean')
>>> output = focalloss(logits, labels)
>>> print(output)
0.12516622