mindspore.dataset.WaitedDSCallback

class mindspore.dataset.WaitedDSCallback(step_size=1)[source]

Abstract base class used to build a dataset callback class that is synchronized with the training callback.

This class can be used to execute a user defined logic right after the previous step or epoch. For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.

Parameters

step_size (int, optional) – The number of rows in each step. Usually the step size will be equal to the batch size (Default=1).

Examples

>>> import mindspore.nn as nn
>>> from mindspore.dataset import WaitedDSCallback
>>> from mindspore import context
>>> from mindspore.train import Model
>>> from mindspore.train.callback import Callback
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>>
>>> # custom callback class for data synchronization in data pipeline
>>> class MyWaitedCallback(WaitedDSCallback):
...     def __init__(self, events, step_size=1):
...         super().__init__(step_size)
...         self.events = events
...
...     # callback method to be executed by data pipeline before the epoch starts
...     def sync_epoch_begin(self, train_run_context, ds_run_context):
...         event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by data pipeline before the step starts
...     def sync_step_begin(self, train_run_context, ds_run_context):
...         event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom callback class for data synchronization in network training
>>> class MyMSCallback(Callback):
...     def __init__(self, events):
...         self.events = events
...
...     # callback method to be executed by network training after the epoch ends
...     def epoch_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by network training after the step ends
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom network
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x
>>>
>>> # define a parameter that needs to be synchronized between data pipeline and network training
>>> events = []
>>>
>>> # define callback classes of data pipeline and netwok training
>>> my_cb1 = MyWaitedCallback(events, 1)
>>> my_cb2 = MyMSCallback(events)
>>> arr = [1, 2, 3, 4]
>>>
>>> # construct data pipeline
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
>>> # map the data callback object into the pipeline
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
>>>
>>> net = Net()
>>> model = Model(net)
>>>
>>> # add the data and network callback objects to the model training callback list
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
sync_epoch_begin(train_run_context, ds_run_context)[source]

Called before a new dataset epoch is started and after the previous training epoch is ended.

Parameters
  • train_run_context – Include some information of the model with feedback from the previous epoch.

  • ds_run_context – Include some information of the dataset pipeline.

sync_step_begin(train_run_context, ds_run_context)[source]

Called before a new dataset step is started and after the previous training step is ended.

Parameters
  • train_run_context – Include some information of the model with feedback from the previous step.

  • ds_run_context – Include some information of the dataset pipeline.