mindspore.dataset.dataloader._utils.collate.collate

mindspore.dataset.dataloader._utils.collate.collate(batch, *, collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None)[源代码]

General collate function that handles collection type of element within each batch.

The function also opens function registry to deal with specific element types. default_collate_fn_map provides default collate functions for tensors, numpy arrays, numbers and strings.

Parameters

batch – a single batch to be collated

关键字参数

collate_fn_map – Optional dictionary mapping from element type to the corresponding collate function. If the element type isn't present in this dictionary, this function will go through each key of the dictionary in the insertion order to invoke the corresponding collate function if the element type is a subclass of the key.

样例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return mindspore.ops.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {mindspore.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({mindspore.Tensor: collate_tensor_fn})

说明

Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as collate_fn_map.