
class mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')[source]

The loss function proposed by Kaiming team in their paper Focal Loss for Dense Object Detection improves the effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of classification difficulty. If you want to learn more, please refer to the paper. https://arxiv.org/pdf/1708.02002.pdf. The function is shown as follows:

\[FL(p_t) = -(1-p_t)^\gamma log(p_t)\]
  • gamma (float) – Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.

  • weight (Union[Tensor, None]) – A rescaling weight applied to the loss of each batch element. The dimension of weight should be 1. If None, no weights are applied. Default: None.

  • reduction (str) – Type of reduction to be applied to loss. The optional values are “mean”, “sum”, and “none”. If “none”, do not perform reduction. Default: “mean”.

  • logits (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). Where C is the number of classes. Its value is greater than 1. If the shape is (B, C, H, W) or (B, C, H), the H or product of H and W should be the same as labels.

  • labels (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). The value of C is 1 or it needs to be the same as predict’s C. If C is not 1, the shape of target should be the same as that of predict, where C is the number of classes. If the shape is (B, C, H, W) or (B, C, H), the H or product of H and W should be the same as logits.


Tensor, it’s a tensor with the same shape and type as input logits.

  • TypeError – If the data type of gamma is not float..

  • TypeError – If weight is not a Tensor.

  • ValueError – If labels dim different from logits.

  • ValueError – If labels channel is not 1 and labels shape is different from logits.

  • ValueError – If reduction is not one of ‘none’, ‘mean’, ‘sum’.

Supported Platforms:

Ascend GPU


>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
>>> labels = Tensor([[1], [1], [0]], mstype.int32)
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
>>> output = focalloss(logits, labels)
>>> print(output)