mindspore.train.Callback

查看源文件
class mindspore.train.Callback[源代码]

用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。 可以使用此机制进行一些自定义操作。

Callback类的每个方法对应了训练或推理过程的不同阶段,这些方法有相同的入参 run_context,用于保存训练或推理过程中模型的相关信息。定义Callback子类或自定义Callback时,请根据需要重写名称前缀为”on_train”或”on_eval”的方法,否则自定义的Callback在 model.fit 中使用时会产生错误。

自定义Callback场景下,在类方法中通过 RunContext.original_args() 方法可以获取模型训练或推理过程中已有 的上下文信息,此信息为一个存储了已有属性的字典型变量。用户也可以在此信息中添加其他的自定义属性。此外, 通过调用 request_stop 方法来停止训练过程。有关自定义Callback的具体用法,请查看 回调机制Callback

样例:

>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>> from mindspore.train import Model, Callback
>>> class Print_info(Callback):
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         print("step_num: ", cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model.train(1, dataset, callbacks=print_cb)
step_num: 1
step_num: 2
begin(run_context)[源代码]

在网络执行之前被调用一次。与 on_train_beginon_eval_begin 方法具有兼容性。

说明

begin 接口将在后续版本中废弃,请使用 on_train_beginon_eval_begin

参数:
end(run_context)[源代码]

网络执行后被调用一次。与 on_train_endon_eval_end 方法具有兼容性。

说明

end 接口将在后续版本中废弃,请使用 on_train_endon_eval_end

参数:
epoch_begin(run_context)[源代码]

在每个epoch开始之前被调用。与 on_train_epoch_beginon_eval_epoch_begin 方法具有兼容性。

说明

epoch_begin 接口将在后续版本中废弃,请使用 on_train_epoch_beginon_eval_epoch_begin

参数:
epoch_end(run_context)[源代码]

在每个epoch结束后被调用。与 on_train_epoch_endon_eval_epoch_end 方法具有兼容性。

说明

epoch_end 接口将在后续版本中废弃,请使用 on_train_epoch_endon_eval_epoch_end

参数:
on_eval_begin(run_context)[源代码]

在网络执行推理之前调用。

参数:
on_eval_end(run_context)[源代码]

网络执行推理之后调用。

参数:
on_eval_epoch_begin(run_context)[源代码]

在推理的epoch开始之前被调用。

参数:
on_eval_epoch_end(run_context)[源代码]

在推理的epoch结束后被调用。

参数:
on_eval_step_begin(run_context)[源代码]

在推理的每个step开始之前被调用。

参数:
on_eval_step_end(run_context)[源代码]

在推理的每个step完成后被调用。

参数:
on_train_begin(run_context)[源代码]

在网络执行训练之前调用。

参数:
on_train_end(run_context)[源代码]

网络训练执行结束时调用。

参数:
on_train_epoch_begin(run_context)[源代码]

在训练的每个epoch开始之前被调用。

参数:
on_train_epoch_end(run_context)[源代码]

在训练的每个epoch结束后被调用。

参数:
on_train_step_begin(run_context)[源代码]

在训练的每个step开始之前被调用。

参数:
on_train_step_end(run_context)[源代码]

在训练的每个step完成后被调用。

参数:
step_begin(run_context)[源代码]

在每个step开始之前被调用。与 on_train_step_beginon_eval_step_begin 方法具有兼容性。

说明

step_begin 接口将在后续版本中废弃,请使用 on_train_step_beginon_eval_step_begin

参数:
step_end(run_context)[源代码]

在每个step完成后被调用。与 on_train_step_endon_eval_step_end 方法具有兼容性。

说明

step_end 接口将在后续版本中废弃,请使用 on_train_step_endon_eval_step_end

参数: