mindspore.train.ConfusionMatrix ================================ .. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg :target: https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/train/mindspore.train.ConfusionMatrix.rst :alt: 查看源文件 .. py:class:: mindspore.train.ConfusionMatrix(num_classes, normalize="no_norm", threshold=0.5) 计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。 如果只想使用混淆矩阵,请使用该类。如果想计算"PPV"、"TPR"、"TNR"等,请使用 :class:`mindspore.train.ConfusionMatrixMetric` 类。 参数: - **num_classes** (int) - 数据集中的类别数量。 - **normalize** (str) - 计算ConfusionMatrix的参数支持四种归一化模式,默认值: ``"no_norm"`` 。 - ``"no_norm"`` :不使用标准化。 - ``"target"`` :基于目标值的标准化。 - ``"prediction"`` :基于预测值的标准化。 - ``"all"`` :整个矩阵的标准化。 - **threshold** (float) - 阈值,用于与输入Tensor进行比较。默认值: ``0.5`` 。 .. py:method:: clear() 重置评估结果。 .. py:method:: eval() 计算混淆矩阵。 返回: numpy.ndarray,计算的结果。 .. py:method:: update(*inputs) 使用y_pred和y更新内部评估结果。 参数: - ***inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。 `y_pred` 是预测值, `y` 是真实值, `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` , `y` 的shape是 :math:`(N, ...)` 。 异常: - **ValueError** - 输入参数的数量不等于2。 - **ValueError** - 如果预测值和标签的维度不一致。