mindspore.nn.Accuracy

class mindspore.nn.Accuracy(eval_type='classification')[source]

Calculates the accuracy for classification and multilabel data.

The accuracy class has two local variables, the correct number and the total number of samples, that are used to compute the frequency with which y_pred matches y. This frequency is ultimately returned as the accuracy: an idempotent operation that simply divides the correct number by the total number.

\[\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}\]
Parameters

eval_type (str) – The metric to calculate the accuracy over a dataset. Supports ‘classification’ and ‘multilabel’. ‘classification’ means the dataset label is single. ‘multilabel’ means the dataset has multiple labels. Default: ‘classification’.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor
>>>
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
>>> y = Tensor(np.array([1, 0, 1]), mindspore.float32)
>>> metric = nn.Accuracy('classification')
>>> metric.clear()
>>> metric.update(x, y)
>>> accuracy = metric.eval()
>>> print(accuracy)
0.6666666666666666
clear()[source]

Clears the internal evaluation result.

eval()[source]

Computes the accuracy.

Returns

Float, the computed result.

Raises

RuntimeError – If the sample size is 0.

update(*inputs)[source]

Updates the local variables. For ‘classification’, if the index of the maximum of the predict value matches the label, the predict result is correct. For ‘multilabel’, the predict value match the label, the predict result is correct.

Parameters

inputs – Logits and labels. y_pred stands for logits, y stands for labels. y_pred and y must be a Tensor, a list or an array. For the ‘classification’ evaluation type, y_pred is a list of floating numbers in range \([0, 1]\) and the shape is \((N, C)\) in most cases (not strictly), where \(N\) is the number of cases and \(C\) is the number of categories. y must be in one-hot format that shape is \((N, C)\), or can be transformed to one-hot format that shape is \((N,)\). For ‘multilabel’ evaluation type, the value of y_pred and y can only be 0 or 1, indices with 1 indicate the positive category. The shape of y_pred and y are both \((N, C)\).

Raises

ValueError – If the number of the inputs is not 2.