
View Source On Gitee
class mindspore.train.ConfusionMatrix(num_classes, normalize='no_norm', threshold=0.5)[source]

Computes the confusion matrix, which is commonly used to evaluate the performance of classification models, including binary classification and multiple classification.

If you only need confusion matrix, use this class. If you want to calculate other metrics, such as ‘PPV’, ‘TPR’, ‘TNR’, etc., use class mindspore.train.ConfusionMatrixMetric .

  • num_classes (int) – Number of classes in the dataset.

  • normalize (str) –

    Normalization mode for confusion matrix. Default: "no_norm" . Choose from:

    • "no_norm" : No Normalization is used. Default: None.

    • "target" : Normalization based on target value.

    • "prediction" : Normalization based on predicted value.

    • "all" : Normalization over the whole matrix.

  • threshold (float) – The threshold used to compare with the input tensor. Default: 0.5 .

Supported Platforms:

Ascend GPU CPU


>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.train import ConfusionMatrix
>>> x = Tensor(np.array([1, 0, 1, 0]))
>>> y = Tensor(np.array([1, 0, 0, 1]))
>>> metric = ConfusionMatrix(num_classes=2, normalize='no_norm', threshold=0.5)
>>> metric.clear()
>>> metric.update(x, y)
>>> output = metric.eval()
>>> print(output)
[[1. 1.]
 [1. 1.]]

Clears the internal evaluation result.


Computes confusion matrix.


numpy.ndarray, the computed result.


Update state with y_pred and y.


inputs (tuple) – Input y_pred and y. y_pred and y are a Tensor, list or numpy.ndarray. y_pred is the predicted value, y is the true value. The shape of y_pred is \((N, C, ...)\) or \((N, ...)\). The shape of y is \((N, ...)\).

  • ValueError – If the number of inputs is not 2.

  • ValueError – If the dim of y_pred and y are not equal.