mindspore.dataset.dataloader._utils.collate.collate
- mindspore.dataset.dataloader._utils.collate.collate(batch, *, collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None)[源代码]
General collate function that handles collection type of element within each batch.
The function also opens function registry to deal with specific element types. default_collate_fn_map provides default collate functions for tensors, numpy arrays, numbers and strings.
- Parameters
batch – a single batch to be collated
- 关键字参数
collate_fn_map – Optional dictionary mapping from element type to the corresponding collate function. If the element type isn't present in this dictionary, this function will go through each key of the dictionary in the insertion order to invoke the corresponding collate function if the element type is a subclass of the key.
样例
>>> def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return mindspore.ops.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {mindspore.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({mindspore.Tensor: collate_tensor_fn})
说明
Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as collate_fn_map.