mindspore.dataset.Dataset.padded_batch

View Source On Gitee
Dataset.padded_batch(batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None)[source]

Combine batch_size number of consecutive rows into batch which apply pad_info to the samples first.

Refer to the following figure for the execution process:

../../../../_images/padded_batch_en.png

Note

The order of using repeat and padded_batch reflects the number of batches. It is recommended that the repeat operation applied after the padded_batch operation finished.

Parameters
  • batch_size (Union[int, Callable]) – The number of rows each batch is created with. An int or callable object which takes exactly 1 parameter, BatchInfo.

  • drop_remainder (bool, optional) – Determines whether or not to drop the last block whose data row number is less than batch size. Default: False. If True, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node.

  • num_parallel_workers (int, optional) – Number of workers(threads) to process the dataset in parallel. Default: None.

  • pad_info (dict, optional) – The pad information about how to batch each column. The key corresponds to the column name, and the value must be a tuple of 2 elements. The first element corresponds to the shape to pad to, and the second element corresponds to the value to pad with. If a column is not specified, then that column will be padded to the longest in the current batch, and 0 will be used as the padding value. If pad_info={"col1": ([224, 224], 0)}, expand the data column named col1 to shape (224, 224), and fill in the missing values with 0. If pad_info={}, all samples in the batch will be filled to the shape with the largest sample in the current batch. If pad_info={"col1": (None, 100)}, all samples in the batch will be filled to the shape with the largest sample in the current batch, and fill in the missing values with 100. If no padding is wanted, set pad_info to None. Default: None.

Returns

Dataset, a new dataset with the above operation applied.

Examples

>>> # 1) Pad every sample to the largest sample's shape and batch the samples
>>> import mindspore.dataset as ds
>>> dataset = ds.NumpySlicesDataset([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], "column1")
>>> dataset = dataset.padded_batch(2, True, pad_info={})
>>>
>>> # 2) Create a dataset where every 3 rows are combined into a batch
>>> # and drops the last incomplete batch if there is one.
>>> dataset = ds.NumpySlicesDataset([i for i in range(10)], "column1")
>>> dataset = dataset.padded_batch(3, True)
>>>
>>> # 3) Create a dataset where its batch size is dynamic
>>> # Define a callable batch size function and let batch size increase 1 each time.
>>> def add_one(BatchInfo):
...     return BatchInfo.get_batch_num() + 1
>>> dataset = dataset.padded_batch(batch_size=add_one, drop_remainder=True)