mindspore.dataset.WaitedDSCallback

class mindspore.dataset.WaitedDSCallback(step_size=1)[source]

Abstract base class used to build dataset callback classes that are synchronized with the training callback class mindspore.train.Callback .

It can be used to execute a custom callback method before a step or an epoch, such as updating the parameters of operations according to the loss of the previous training epoch in auto augmentation.

Users can obtain the network training context through train_run_context , such as network , train_network , epoch_num , batch_num , loss_fn , optimizer , parallel_mode , device_number , list_callback , cur_epoch_num , cur_step_num , dataset_sink_mode , net_outputs , etc., see mindspore.train.Callback .

Users can obtain the dataset pipeline context through ds_run_context , including cur_epoch_num , cur_step_num_in_epoch and cur_step_num .

Note

Note that the call is triggered only at the beginning of the second step or epoch.

Parameters

step_size (int, optional) – The number of rows in each step, usually set equal to the batch size. Default: 1.

Examples

>>> import mindspore.nn as nn
>>> import mindspore as ms
>>> from mindspore.dataset import WaitedDSCallback
>>> import mindspore.dataset as ds
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
>>>
>>> # custom callback class for data synchronization in data pipeline
>>> class MyWaitedCallback(WaitedDSCallback):
...     def __init__(self, events, step_size=1):
...         super().__init__(step_size)
...         self.events = events
...
...     # callback method to be executed by data pipeline before the epoch starts
...     def sync_epoch_begin(self, train_run_context, ds_run_context):
...         event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by data pipeline before the step starts
...     def sync_step_begin(self, train_run_context, ds_run_context):
...         event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom callback class for data synchronization in network training
>>> class MyMSCallback(ms.Callback):
...     def __init__(self, events):
...         self.events = events
...
...     # callback method to be executed by network training after the epoch ends
...     def epoch_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by network training after the step ends
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom network
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x
>>>
>>> # define a parameter that needs to be synchronized between data pipeline and network training
>>> events = []
>>>
>>> # define callback classes of data pipeline and netwok training
>>> my_cb1 = MyWaitedCallback(events, 1)
>>> my_cb2 = MyMSCallback(events)
>>> arr = [1, 2, 3, 4]
>>>
>>> # construct data pipeline
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
>>> # map the data callback object into the pipeline
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
>>>
>>> net = Net()
>>> model = ms.Model(net)
>>>
>>> # add the data and network callback objects to the model training callback list
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
sync_epoch_begin(train_run_context, ds_run_context)[source]

Called before a new dataset epoch is started and after the previous training epoch is ended.

Parameters
  • train_run_context – Include some information of the model with feedback from the previous epoch.

  • ds_run_context – Include some information of the data pipeline.

sync_step_begin(train_run_context, ds_run_context)[source]

Called before a new dataset step is started and after the previous training step is ended.

Parameters
  • train_run_context – Include some information of the model with feedback from the previous step.

  • ds_run_context – Include some information of the data pipeline.