比较与torch.utils.data.Dataset的功能差异

查看源文件

torch.utils.data.Dataset

class torch.utils.data.Dataset(*args, **kwds)

更多内容详见torch.utils.data.Dataset

mindspore.dataset.GeneratorDataset

class mindspore.dataset.GeneratorDataset(
    source,
    column_names=None,
    column_types=None,
    schema=None,
    num_samples=None,
    num_parallel_workers=1,
    shuffle=None,
    sampler=None,
    num_shards=None,
    shard_id=None,
    python_multiprocessing=True,
    max_rowsize=6
    )

更多内容详见mindspore.dataset.GeneratorDataset

使用方式

PyTorch:自定义数据集的抽象类,自定义数据子类可以通过调用__len__()__getitem__()这两个方法继承这个抽象类。

MindSpore:通过每次调用Python层自定义的Dataset以生成数据集。

代码示例

import numpy as np
import mindspore.dataset as ds
from torch.utils.data import Dataset

# In MindSpore, GeneratorDataset generates data from Python by invoking Python data source each epoch. The column names and column types of generated dataset depend on Python data defined by users.

class GetDatasetGenerator:

    def __init__(self):
        np.random.seed(58)
        self.__data = np.random.sample((5, 2))
        self.__label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return (self.__data[index], self.__label[index])

    def __len__(self):
        return len(self.__data)


dataset_generator = GetDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

for data in dataset.create_dict_iterator():
    print(data["data"], data["label"])
# Out:
# [0.36510558 0.45120592] [0.78888122]
# [0.49606035 0.07562207] [0.38068183]
# [0.57176158 0.28963401] [0.16271622]
# [0.30880446 0.37487617] [0.54738768]
# [0.81585667 0.96883469] [0.77994068]


# In torch, the subclass of torch.utils.data.Dataset should overwrite `__getitem__()`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite `__len__()`, which is expected to return the size of the dataset.

class GetDatasetGenerator1(Dataset):

    def __init__(self):
        np.random.seed(58)
        self.__data = np.random.sample((5, 2))
        self.__label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return (self.__data[index], self.__label[index])

    def __len__(self):
        return len(self.__data)


dataset = GetDatasetGenerator1()
for item in dataset:
    print("item:", item)

# Out:
# item: (array([0.36510558, 0.45120592]), array([0.78888122]))
# item: (array([0.49606035, 0.07562207]), array([0.38068183]))
# item: (array([0.57176158, 0.28963401]), array([0.16271622]))
# item: (array([0.30880446, 0.37487617]), array([0.54738768]))
# item: (array([0.81585667, 0.96883469]), array([0.77994068]))