比较与torch.utils.data.DataLoader的功能差异

查看源文件

torch.utils.data.DataLoader

class torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
    num_workers=0, collate_fn=None, pin_memory=False, drop_last=False,
    timeout=0, worker_init_fn=None, multiprocessing_context=None)

更多内容详见torch.utils.data.DataLoader

MindSpore

MindSpore中的数据集无需加载器。

使用方式

PyTorch:需要将数据集和采样器对象,及分批、混洗、并行等参数传入此接口,以实现具有采样、分批、混洗等功能的并行数据迭代。

MindSpore:采样器对象及混洗、并行等参数可以在创建数据集对象时传入,也可以通过数据集类内方法定义,则数据集对象直接具有采样、分批、混洗、并行数据迭代等功能,无需额外的加载器。

代码示例

import torch
import mindspore.dataset as ds
import numpy as np

np.random.seed(0)
data = np.random.random((6, 2, 2))

# The following implements data iteration with PyTorch.
dataset = torch.utils.data.TensorDataset(torch.Tensor(data))
sampler = torch.utils.data.SequentialSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=3)
print(next(iter(loader)))
# Out:
# [tensor([[[0.5488, 0.7152],
#          [0.6028, 0.5449]],
#
#         [[0.4237, 0.6459],
#          [0.4376, 0.8918]],
#
#         [[0.9637, 0.3834],
#          [0.7917, 0.5289]]])]

# The following implements data iteration with MindSpore.
sampler = ds.SequentialSampler()
dataset = ds.NumpySlicesDataset(data, sampler=sampler)
dataset = dataset.batch(batch_size=3)
iterator = dataset.create_dict_iterator()
print(next(iter(iterator)))
# Out:
# {'column_0': Tensor(shape=[3, 2, 2], dtype=Float64, value=
# [[[ 5.48813504e-01,  7.15189366e-01],
#   [ 6.02763376e-01,  5.44883183e-01]],
#  [[ 4.23654799e-01,  6.45894113e-01],
#   [ 4.37587211e-01,  8.91773001e-01]],
#  [[ 9.63662761e-01,  3.83441519e-01],
#   [ 7.91725038e-01,  5.28894920e-01]]])}