mindspore.nn.MultiLabelSoftMarginLoss

class mindspore.nn.MultiLabelSoftMarginLoss(weight=None, reduction='mean')[source]

Calculates the MultiLabelSoftMarginLoss. The multi-label soft margin loss is a commonly used loss function in multi-label classification tasks where an input sample can belong to multiple classes. Given an input \(x\) and binary labels \(y\) of size \((N,C)\), where \(N\) denotes the number of samples and \(C\) denotes the number of classes.

\[\mathcal{loss\left( x , y \right)} = - \frac{1}{N}\frac{1}{C}\sum_{i = 1}^{N} \sum_{j = 1}^{C}\left(y_{ij}\log\frac{1}{1 + e^{- x_{ij}}} + \left( 1 - y_{ij} \right)\log\frac{e^{-x_{ij}}}{1 + e^{-x_{ij}}} \right)\]

where \(x{ij}\) represents the predicted score of sample \(i\) for class \(j\). \(y{ij}\) represents the binary label of sample \(i\) for class \(j\), where sample \(i\) belongs to class \(j\) if \(y{ij}=1\) , and sample \(i\) does not belong to class \(j\) if \(y{ij}=0\). For a multi-label classification task, each sample may have multiple labels with a value of 1 in the binary label \(y\). weight will multiply to the loss of each class if given.

Parameters
  • weight (Union[Tensor, int, float]) – The manual rescaling weight given to each class. Default: None.

  • reduction (str) – Specifies which reduction to be applied to the output. It must be one of ‘none’, ‘mean’, and ‘sum’, meaning no reduction, reduce mean and sum on output, respectively. Default: ‘mean’.

Inputs:
  • x (Tensor) - A tensor of shape (N, C), where N is batch size and C is number of classes.

  • target (Tensor) - The label target Tensor which has the same shape as x.

Outputs:

Tensor, the data type is the same as x, if the reduction is ‘none’, its shape is (N), otherwise it is zero.

Raises

ValueError – If the rank of x or target is not 2.

Supported Platforms:

Ascend GPU CPU

Examples

>>> x = Tensor([[0.3, 0.6, 0.6], [0.9, 0.4, 0.2]])
>>> target = Tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])
>>> loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
>>> out = loss(x, target)
>>> print(out.asnumpy())
0.84693956