mindspore.train.History

查看源文件
class mindspore.train.History[源代码]

将网络输出和评估指标的相关信息记录到 History 对象中。

用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 Tensornumpy.ndarray,则记录此返回值均值,如果返回 tuplelist,则记录第一个元素。

说明

通常使用在 mindspore.train.Model.trainmindspore.train.Model.fit 中。

样例:

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> from mindspore.train import Model, History
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> history_cb = History()
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[history_cb])
>>> print(history_cb.epoch)
{'epoch': [1, 2]}
>>> print(history_cb.history)
{'net_output': [1.607877, 1.6033841]}
begin(run_context)[源代码]

训练开始时初始化History对象的epoch属性。

参数:
epoch_end(run_context)[源代码]

epoch结束时记录网络输出和评估指标的相关信息。

参数: