mindspore.train.Metric

查看源文件
class mindspore.train.Metric[源代码]

用于计算评估指标的基类。

在计算评估指标时需要调用 clearupdateeval 三个方法,在继承该类自定义评估指标时,也需要实现这三个方法。其中,update 用于计算中间过程的内部结果,eval 用于计算最终评估结果,clear 用于重置中间结果。 请勿直接使用该类,需使用子类如 mindspore.train.MAEmindspore.train.Recall 等。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>>
>>> class MyMAE(ms.train.Metric):
...     def __init__(self):
...         super(MyMAE, self).__init__()
...         self.clear()
...
...     def clear(self):
...         self._abs_error_sum = 0
...         self._samples_num = 0
...
...     def update(self, *inputs):
...         y_pred = inputs[0].asnumpy()
...         y = inputs[1].asnumpy()
...         abs_error_sum = np.abs(y - y_pred)
...         self._abs_error_sum += abs_error_sum.sum()
...         self._samples_num += y.shape[0]
...
...     def eval(self):
...         return self._abs_error_sum / self._samples_num
>>>
>>> x = ms.Tensor(np.array([[0.1, 0.2, 0.6, 0.9], [0.1, 0.2, 0.6, 0.9]]), ms.float32)
>>> y = ms.Tensor(np.array([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]]), ms.float32)
>>> y2 = ms.Tensor(np.array([[0.1, 0.25, 0.7, 0.9], [0.1, 0.25, 0.7, 0.9]]), ms.float32)
>>> metric = MyMAE().set_indexes([0, 2])
>>> metric.clear()
>>> # indexes is [0, 2], using x as logits, y2 as label.
>>> metric.update(x, y, y2)
>>> accuracy = metric.eval()
>>> print(accuracy)
1.399999976158142
>>> print(metric.indexes)
[0, 2]
abstract clear()[源代码]

清除内部评估结果。

说明

所有子类都必须重写此接口。

教程样例:
abstract eval()[源代码]

计算最终评估结果。

说明

所有子类都必须重写此接口。

教程样例:
property indexes

获取当前的 indexes 值。默认为None,调用 set_indexes 方法可修改 indexes 值。

set_indexes(indexes)[源代码]

该接口用于重排 update 的输入。

给定(label0, label1, logits)作为 update 的输入,将 indexes 设置为[2, 1],则最终使用(logits, label1)作为 update 的真实输入。

说明

在继承该类自定义评估函数时,需要用装饰器 mindspore.train.rearrange_inputs 修饰 update 方法,否则配置的 indexes 值不生效。

参数:
  • indexes (List(int)) - logits和标签的目标顺序。

输出:

Metric ,类实例本身。

异常:
  • ValueError - 如果输入的index类型不是list或其元素类型不全为int。

abstract update(*inputs)[源代码]

更新内部评估结果。

说明

所有子类都必须重写此接口。

参数:
  • inputs - 可变长度输入参数列表。通常是预测值和对应的真实标签。

教程样例: