AI数据框架大横评(4)之采样器
AI数据框架大横评(4)之采样器
前言
开始本文的内容前,先简单回顾一下本系列的前几篇文章:
第一篇,我们简要对比了当前主流AI数据框架的架构设计,从中可以略微看出各家框架的主要设计理念和应用场景。
第二篇,我们简要对比了当前主流AI数据框架的数据加载方式。
第三篇,我们简要对比了当前主流AI数据框架的数据处理方式。
建议大家先阅读以上几篇文章,再开始下面的阅读。
采样器
对于可迭代的(Iterable Style)数据集,数据加载顺序完全由用户定义的迭代逻辑控制。
而对于可随机访问的(Map Style)数据集,则可以通过采样器(Sampler)生成自定义顺序的索引/键,再通过随机访问的能力加载数据。
MindSpore
MindSpore提供了丰富的采样器API供用户开箱即用。
- mindspore.dataset.SequentialSampler:按顺序采样指定数量样本。
- mindspore.dataset.RandomSampler:按随机顺序采样指定数量样本。
- mindspore.dataset.DistributedSampler:将样本等分用于分布式训练。
- mindspore.dataset.PKSampler:在P个类别中各采样K个样本。
- mindspore.dataset.SubsetSampler:根据指定的索引列表进行采样。
- mindspore.dataset.SubsetRandomSampler:根据指定的索引列表进行随机采样。
- mindspore.dataset.WeightedRandomSampler:根据指定的各个类别的概率采样样本。
还是以处理MNIST手写字识别数据集为例,首先根据自己的策略,定义采样器,然后传给数据加载接口即可。
import mindspore.dataset as ds
mnist_dataset_dir = "/path/to/mnist_dataset_directory"
sampler = RandomSampler()
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=sampler)
为了简化编码流程,用户也可直接通过数据加载接口的num_samples、shuffle、num_shards和shard_id参数控制采样器的使用,具体如下:
sampler
num_shards/shard_id
shuffle
num_samples
使用的采样器
mindspore.dataset.Sampler 类型
None
None
None
sampler
numpy.ndarray,list,tuple,int 类型
/
/
num_samples
SubsetSampler(indices=sampler, num_samples=num_samples)
iterable 类型
/
/
num_samples
IterSampler(sampler=sampler, num_samples=num_samples)
None
num_shards / shard_id
None / True
num_samples
DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=True, num_samples=num_samples)
None
num_shards / shard_id
False
num_samples
DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=num_samples)
None
None
None / True
None
RandomSampler(num_samples=num_samples)
None
None
None / True
num_samples
RandomSampler(replacement=True, num_samples=num_samples)
None
None
False
num_samples
SequentialSampler(num_samples=num_samples)
例如下列代码指定了num_shards和shard_id参数,则等同于先构造了DistributedSampler,再传给数据加载接口:
import mindspore.dataset as ds
mnist_dataset_dir = "/path/to/mnist_dataset_directory"
# 直接通过数据加载API的参数创建采样器
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_shards=8, shard_id=0)
# 先定义采样器,再传给数据加载API
dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=DistributedSampler(num_shards=8, shard_id=0))
用户也可以根据自己的需要编写自定义采样逻辑。
与自定义数据加载一样,采样器也可实现为可随机访问的(Map Style)和可迭代的(Iterable Style)两种。
对于可随机访问的采样器,可编写自定义采样器类,提供 __getitem__ 和 __len__ 方法,例如:
class MySampler():
def __init__(self):
self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
def __getitem__(self, index):
return self.index_ids[index]
def __len__(self):
return len(self.index_ids)
对于可迭代的采样器,可编写自定义采样器类,提供 __iter__ 方法,例如:
class MySampler:
def __iter__(self):
for i in range(0, 100, 2):
yield i
PyTorch
PyTorch同样提供了丰富的采样器API供用户开箱即用。
- torch.utils.data.Sampler:所有采样器的基类。
- torch.utils.data.SequentialSampler:按顺序采样样本。
- torch.utils.data.RandomSampler:按随机顺序采样指定数量样本。
- torch.utils.data.SubsetRandomSampler:根据指定的索引列表进行随机采样。
- torch.utils.data.WeightedRandomSampler:根据指定的各个类别的概率采样样本。
- torch.utils.data.BatchSampler:每次返回一个batch的样本索引。
- torch.utils.data.distributed.DistributedSampler:分布式采样器。
PyTorch在创建采样器时,需要先传入数据集对象,最后同时将数据集和采样器对象传给DataLoader即可。
from torch.utils.data import Dataset, RandomSampler, DataLoader
class MapStyleDataset(Dataset):
def __init__(self, dataset_dir):
self.files = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir)]
def __getitem__(self, index):
return np.load(self.files[index])
def __len__(self):
return len(self.files)
dataset = MapStyleDataset("/path/to/dataset_directory")
sampler = RandomSampler(dataset)
loader = DataLoader(dataset=dataset, sampler=sampler)
同样,PyTorch的DataLoader接口也提供了shuffle参数,方便用户快速创建RandomSampler,但对于其他数据集,就得手动构造了。
TensorFlow
TensorFlow未提供采样器的功能。