代码
AI数据框架大横评(4)之采样器

AI数据框架大横评(4)之采样器

AI数据框架大横评(4)之采样器

前言

开始本文的内容前,先简单回顾一下本系列的前几篇文章:

第一篇,我们简要对比了当前主流AI数据框架的架构设计,从中可以略微看出各家框架的主要设计理念和应用场景。

AI数据框架大横评之架构设计

第二篇,我们简要对比了当前主流AI数据框架的数据加载方式。

AI数据框架大横评之数据加载

第三篇,我们简要对比了当前主流AI数据框架的数据处理方式。

AI数据框架大横评之数据处理

建议大家先阅读以上几篇文章,再开始下面的阅读。

采样器

对于可迭代的(Iterable Style)数据集,数据加载顺序完全由用户定义的迭代逻辑控制。

而对于可随机访问的(Map Style)数据集,则可以通过采样器(Sampler)生成自定义顺序的索引/键,再通过随机访问的能力加载数据。

MindSpore

MindSpore提供了丰富的采样器API供用户开箱即用。

  • mindspore.dataset.SequentialSampler:按顺序采样指定数量样本。
  • mindspore.dataset.RandomSampler:按随机顺序采样指定数量样本。
  • mindspore.dataset.DistributedSampler:将样本等分用于分布式训练。
  • mindspore.dataset.PKSampler:在P个类别中各采样K个样本。
  • mindspore.dataset.SubsetSampler:根据指定的索引列表进行采样。
  • mindspore.dataset.SubsetRandomSampler:根据指定的索引列表进行随机采样。
  • mindspore.dataset.WeightedRandomSampler:根据指定的各个类别的概率采样样本。

还是以处理MNIST手写字识别数据集为例,首先根据自己的策略,定义采样器,然后传给数据加载接口即可。

import mindspore.dataset as ds

mnist_dataset_dir = "/path/to/mnist_dataset_directory"
sampler = RandomSampler()
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=sampler)

为了简化编码流程,用户也可直接通过数据加载接口的num_samplesshufflenum_shardsshard_id参数控制采样器的使用,具体如下:

sampler

num_shards/shard_id

shuffle

num_samples

使用的采样器

mindspore.dataset.Sampler 类型

None

None

None

sampler

numpy.ndarray,list,tuple,int 类型

/

/

num_samples

SubsetSampler(indices=sampler, num_samples=num_samples)

iterable 类型

/

/

num_samples

IterSampler(sampler=sampler, num_samples=num_samples)

None

num_shards / shard_id

None / True

num_samples

DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=True, num_samples=num_samples)

None

num_shards / shard_id

False

num_samples

DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=num_samples)

None

None

None / True

None

RandomSampler(num_samples=num_samples)

None

None

None / True

num_samples

RandomSampler(replacement=True, num_samples=num_samples)

None

None

False

num_samples

SequentialSampler(num_samples=num_samples)

例如下列代码指定了num_shardsshard_id参数,则等同于先构造了DistributedSampler,再传给数据加载接口:

import mindspore.dataset as ds

mnist_dataset_dir = "/path/to/mnist_dataset_directory"

# 直接通过数据加载API的参数创建采样器
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_shards=8, shard_id=0)
# 先定义采样器,再传给数据加载API
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=DistributedSampler(num_shards=8, shard_id=0))

用户也可以根据自己的需要编写自定义采样逻辑。

与自定义数据加载一样,采样器也可实现为可随机访问的(Map Style)和可迭代的(Iterable Style)两种。

对于可随机访问的采样器,可编写自定义采样器类,提供 __getitem__ 和 __len__ 方法,例如:

class MySampler():
    def __init__(self):
        self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]

    def __getitem__(self, index):
        return self.index_ids[index]

    def __len__(self):
        return len(self.index_ids)

对于可迭代的采样器,可编写自定义采样器类,提供 __iter__ 方法,例如:

class MySampler:
    def __iter__(self):
        for i in range(0, 100, 2):
            yield i

PyTorch

PyTorch同样提供了丰富的采样器API供用户开箱即用。

  • torch.utils.data.Sampler:所有采样器的基类。
  • torch.utils.data.SequentialSampler:按顺序采样样本。
  • torch.utils.data.RandomSampler:按随机顺序采样指定数量样本。
  • torch.utils.data.SubsetRandomSampler:根据指定的索引列表进行随机采样。
  • torch.utils.data.WeightedRandomSampler:根据指定的各个类别的概率采样样本。
  • torch.utils.data.BatchSampler:每次返回一个batch的样本索引。
  • torch.utils.data.distributed.DistributedSampler:分布式采样器

PyTorch在创建采样器时,需要先传入数据集对象,最后同时将数据集和采样器对象传给DataLoader即可。

from torch.utils.data import Dataset, RandomSampler, DataLoader

class MapStyleDataset(Dataset):
    def __init__(self, dataset_dir):
        self.files = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir)]

    def __getitem__(self, index):
        return np.load(self.files[index])

    def __len__(self):
        return len(self.files)

dataset = MapStyleDataset("/path/to/dataset_directory")
sampler = RandomSampler(dataset)
loader = DataLoader(dataset=dataset, sampler=sampler)

同样,PyTorch的DataLoader接口也提供了shuffle参数,方便用户快速创建RandomSampler,但对于其他数据集,就得手动构造了。

TensorFlow

TensorFlow未提供采样器的功能。