mindspore.dataset.dataloader._utils.collate.collate

View Source On Gitee
mindspore.dataset.dataloader._utils.collate.collate(batch, *, collate_fn_map=None)[source]

Collate the input batch of data by the appropriate function for each element type selected from the type to collate function mapping defined in collate_fn_map.

All the elements in the batch should be of the same type.

  • If the element type is in collate_fn_map or the element is a subclass of the type in collate_fn_map, use the corresponding collate function to collate the batch;

  • If the element is a Mapping, collate by key: for each key, collect the values corresponding to that key from all mappings in the batch to form a new batch, recursively call this function on that batch, and use the result as the new value for that key. All mappings in the batch must have the same keys, and the types of values corresponding to each key must be the same;

  • If the element is a Sequence, collate by position: for each position, collect the elements at that position from all sequences in the batch to form a new batch, recursively call this function on that batch, and use the result as the new element at that position. All sequences in the batch must have the same length;

  • Otherwise, raise an exception to indicate that the element type is not supported.

Each collate function requires a positional argument for batch and a keyword argument for collate_fn_map.

Parameters

batch (list) – A batch of data to be collated.

Keyword Arguments

collate_fn_map (Optional[dict[Union[type, tuple[type, ...]], Callable]]) – Mapping from element type to the corresponding collate function. Default: None .

Returns

Any, the collated data.

Examples

>>> import mindspore
>>> from mindspore.dataset.dataloader._utils.collate import collate
>>>
>>> def collate_int_fn(batch, *, collate_fn_map):
...     return mindspore.tensor(batch)
>>>
>>> collate_map = {int: collate_int_fn}
>>>
>>> collate([0, 1, 2], collate_fn_map=collate_map)
Tensor(shape=[3], dtype=Int64, value= [0, 1, 2])
>>>
>>> collate([{"data": 0, "label": 2},
...          {"data": 1, "label": 3}], collate_fn_map=collate_map)
{'data': Tensor(shape=[2], dtype=Int64, value= [0, 1]), 'label': Tensor(shape=[2], dtype=Int64, value= [2, 3])}
>>>
>>> collate([(0, 3), (1, 4), (2, 5)], collate_fn_map=collate_map)
[Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int64, value= [3, 4, 5])]