# 比较与torch.utils.data.DataLoader的差异 [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_zh_cn/note/api_mapping/pytorch_diff/DataLoader.md) ## torch.utils.data.DataLoader ```python class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False) ``` 更多内容详见[torch.utils.data.DataLoader](https://pytorch.org/docs/1.8.1/data.html#torch.utils.data.DataLoader)。 ## mindspore.dataset.GeneratorDataset ```python class mindspore.dataset.GeneratorDataset( source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6) ``` 更多内容详见[mindspore.dataset.GeneratorDataset](https://mindspore.cn/docs/zh-CN/master/api_python/dataset/mindspore.dataset.GeneratorDataset.html#mindspore.dataset.GeneratorDataset)。 ## 差异对比 PyTorch:DataLoader需要接收一个数据加载类、采样器,及批处理、混洗、多进程并行度等参数,以实现一个具有采样、分批、混洗等功能的数据迭代对象。 其中`dataset`参数支持继承自`torch.utils.data.Dataset`的自定义类,或传入由`torchvision.datasets`、`torchtext.datasets`、`torchaudio.datasets`等组件中预定义好的数据集加载类。 MindSpore:GeneratorDataset需要接收一个数据加载类、采样器、混洗、分片和多进程并行性来创建一个用于数据迭代的迭代器。 此API与PyTorch的DataLoader功能定位一样,均是用于加载自定义的数据集,但参数列表差异较大,下面的多个代码示例将演示如何使用2个API实现同样的功能。 | 分类 | 子类 |PyTorch | MindSpore | 差异 | | --- | --- | --- | --- |--- | |参数 | 参数1 | dataset | source | 定义数据集加载逻辑的对象 | | | 参数2 | batch_size | - | MindSpore通过 `mindspore.dataset.batch` 操作支持 | | | 参数3 | shuffle | shuffle | - | | | 参数4 | sampler | sampler | - | | | 参数5 | batch_sampler | - | MindSpore不支持 | | | 参数6 | num_workers | num_parallel_workers | - | | | 参数7 | collate_fn | - | MindSpore通过 `mindspore.dataset.batch` 操作支持 | | | 参数8 | pin_memory | - | MindSpore不支持 | | | 参数9 | drop_last | - | MindSpore通过 `mindspore.dataset.batch` 操作支持 | | | 参数10 | timeout | - | MindSpore不支持 | | | 参数11 | worker_init_fn | - | MindSpore不支持 | | | 参数12 | multiprocessing_context | - | 多进程上下文,MindSpore不支持 | | | 参数13 | generator | - | 自定义索引生成器,MindSpore不支持 | | | 参数14 | prefetch_factor | - | 定义在 `mindspore.dataset.config.set_prefetch_size` 中 | | | 参数15 | persistent_workers | - | 指定遍历完一次数据后是否释放数据集对象, MindSpore通过 `create_tuple_iterator` 的 `num_epoch` 参数支持,如果设置参数大于1则与 `persistent_workers` 为True一致 | | | 参数16 | - | column_names | 指定数据集生成的列名 | | | 参数17 | - | column_types | 指定生成数据集各个数据列的数据类型 | | | 参数18 | - | schema | 数据格式策略,用于指定读取数据列的数据类型、数据维度等信 | | | 参数19 | - | num_samples | 指定从数据集中读取的样本数 | | | 参数20 | - | num_shards | 指定分布式训练时将数据集进行划分的分片数 | | | 参数21 | - | shard_id | 指定分布式训练时使用的分片ID号 | | | 参数22 | - | python_multiprocessing | 指定是否启用Python多进程模式加速运算 | | | 参数23 | - | max_rowsize | 指定在多进程之间复制数据时,共享内存分配的最大空间 | ### 代码示例1 > 定义一个迭代类型的数据集类与一个随机访问类型的数据集类,并通过DataLoader/GeneratorDataset加载。注意DataLoader的shuffle参数默认行为是False,GeneratorDataset的shuffle默认行为是True。 ```python # Torch import torch class MyIterableDataset(torch.utils.data.IterableDataset): def __init__(self, start, end): super(MyIterableDataset).__init__() self.start = start self.end = end def __iter__(self): return iter(range(self.start, self.end)) ds = MyIterableDataset(start=3, end=7) # Single-process loading print(list(torch.utils.data.DataLoader(ds, num_workers=0, shuffle=False))) # Out: [tensor([3]), tensor([4]), tensor([5]), tensor([6])] class MyMapDataset(torch.utils.data.Dataset): def __init__(self): super(MyMapDataset).__init__() self.data = [1, 2, 3, 4] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) ds = MyMapDataset() # Single-process loading print(list(torch.utils.data.DataLoader(ds))) # Out: [tensor([1]), tensor([2]), tensor([3]), tensor([4])] ``` ```python # MindSpore import mindspore as ms class MyIterableDataset(): def __init__(self, start, end): self.start = start self.end = end def __iter__(self): return iter(range(self.start, self.end)) ds = MyIterableDataset(start=3, end=7) # Single-process loading print(list(ms.dataset.GeneratorDataset(ds, column_names=["data"], num_parallel_workers=1, shuffle=False))) # Out: [[Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)], [Tensor(shape=[], dtype=Int64, value= 5)], [Tensor(shape=[], dtype=Int64, value= 6)]] class MyMapDataset(): def __init__(self): super(MyMapDataset).__init__() self.data = [1, 2, 3, 4] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) ds = MyMapDataset() # Single-process loading print(list(ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False))) # Out: [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)]] ``` ### 代码示例2 > 定义一个数据集类,并对数据进行batch为2的批处理。 ```python # Torch import torch class MyMapDataset(torch.utils.data.Dataset): def __init__(self): super(MyMapDataset).__init__() self.data = [1, 2, 3, 4, 5] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) ds = MyMapDataset() dataloader = torch.utils.data.DataLoader(ds, batch_size=2, drop_last=True) print(list(dataloader)) # Out: [tensor([1, 2]), tensor([3, 4])] ``` ```python # MindSpore import mindspore as ms class MyMapDataset(): def __init__(self): super(MyMapDataset).__init__() self.data = [1, 2, 3, 4, 5] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) ds = MyMapDataset() dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False) dataloader = dataloader.batch(2, drop_remainder=True) print(list(dataloader)) # Out: [[Tensor(shape=[2], dtype=Int64, value= [1, 2])], [Tensor(shape=[2], dtype=Int64, value= [3, 4])]] ``` ### 代码示例3 > 定义一个数据集类,进行批处理时引入collate_fn逻辑。 ```python # Torch import torch class MyMapDataset(torch.utils.data.Dataset): def __init__(self): super(MyMapDataset).__init__() self.data = torch.Tensor([1, 2, 3, 4, 5]) def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) def my_collate_fn(batch): for i, _ in enumerate(batch): batch[i] = batch[i] + 2 return torch.stack(batch) ds = MyMapDataset() dataloader = torch.utils.data.DataLoader(ds, batch_size=2, drop_last=True, collate_fn=my_collate_fn) print(list(dataloader)) # Out: [tensor([3., 4.]), tensor([5., 6.])] ``` ```python # MindSpore import mindspore as ms import numpy as np class MyMapDataset(): def __init__(self): super(MyMapDataset).__init__() self.data = [1, 2, 3, 4, 5] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) def my_collate_fn(batch, batchinfo): for i, _ in enumerate(batch): batch[i] = batch[i] + 2 return np.stack(batch), ds = MyMapDataset() dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False) dataloader = dataloader.batch(2, drop_remainder=True, per_batch_map=my_collate_fn) print(list(dataloader)) # Out: [[Tensor(shape=[2], dtype=Int64, value= [3, 4])], [Tensor(shape=[2], dtype=Int64, value= [5, 6])]] ```