mindspore.dataset.dataloader._utils.collate.collate

查看源文件
mindspore.dataset.dataloader._utils.collate.collate(batch, *, collate_fn_map=None)[源代码]

根据输入批数据元素的类型,从 collate_fn_map 所定义的类型到整理函数映射中,选择相应函数对批数据进行整理。

批数据中的所有元素应该是相同类型。

  • 如果元素的类型在 collate_fn_map 中,或者元素是 collate_fn_map 中类型的子类,则使用相应函数进行数据整理;

  • 如果元素是映射( Mapping )类型,则按键分组整理:对每个键,收集批数据所有映射中该键 对应的值,组成新的批数据,对其递归调用本函数,将结果作为该键新的值。批数据中各映射的键必须相同,各个键对应值的类型必须相同;

  • 如果元素是序列( Sequence )类型,则按位置分组整理:对每个位置,收集批数据所有序列中 该位置对应的元素,组成新的批数据,对其递归调用本函数,将结果作为该位置新的元素。批数据中各序列的长度必须相同;

  • 否则将抛出异常,表明不支持该元素类型。

每个整理函数需要一个 batch 位置参数和一个 collate_fn_map 关键字参数。

参数:
  • batch (list) - 要整理的批数据。

关键字参数:
  • collate_fn_map (Optional[dict[Union[type, tuple[type, …]], Callable]]) - 从元素类型到相应整理函数的映射。 默认值: None

返回:

Any ,整理后的数据。

样例:

>>> import mindspore
>>> from mindspore.dataset.dataloader._utils.collate import collate
>>>
>>> def collate_int_fn(batch, *, collate_fn_map):
...     return mindspore.tensor(batch)
>>>
>>> collate_map = {int: collate_int_fn}
>>>
>>> collate([0, 1, 2], collate_fn_map=collate_map)
Tensor(shape=[3], dtype=Int64, value= [0, 1, 2])
>>>
>>> collate([{"data": 0, "label": 2},
...          {"data": 1, "label": 3}], collate_fn_map=collate_map)
{'data': Tensor(shape=[2], dtype=Int64, value= [0, 1]), 'label': Tensor(shape=[2], dtype=Int64, value= [2, 3])}
>>>
>>> collate([(0, 3), (1, 4), (2, 5)], collate_fn_map=collate_map)
[Tensor(shape=[3], dtype=Int64, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int64, value= [3, 4, 5])]