mindspore.dataset.Dataset.sync_wait

Dataset.sync_wait(condition_name, num_batch=1, callback=None)[source]

Add a blocking condition to the input Dataset and a synchronize action will be applied.

Parameters
  • condition_name (str) – The condition name that is used to toggle sending next row.

  • num_batch (int) – the number of batches without blocking at the start of each epoch. Default: 1.

  • callback (function) – The callback function that will be invoked when sync_update is called. Default: None.

Returns

SyncWaitDataset, dataset added a blocking condition.

Raises

RuntimeError – If condition name already exists.

Examples

>>> import numpy as np
>>> def gen():
...     for i in range(100):
...         yield (np.array(i),)
>>>
>>> class Augment:
...     def __init__(self, loss):
...         self.loss = loss
...
...     def preprocess(self, input_):
...         return input_
...
...     def update(self, data):
...         self.loss = data["loss"]
>>>
>>> batch_size = 4
>>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
>>>
>>> aug = Augment(0)
>>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
>>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
>>> dataset = dataset.batch(batch_size)
>>> count = 0
>>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     assert data["input"][0] == count
...     count += batch_size
...     data = {"loss": count}
...     dataset.sync_update(condition_name="policy", data=data)