[{"data":1,"prerenderedAt":612},["ShallowReactive",2],{"content-query-L0VW0ajvMQ":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":606,"_id":607,"_source":608,"_file":609,"_stem":610,"_extension":611},"/technology-blogs/zh/3547","zh",false,"","AI数据框架大横评（4）之采样器","开始本文的内容前，先简单回顾一下本系列的前几篇文章：","2024-12-17","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2025/01/08/0982be6f7e1c4c5fa82597922ec88012.png","technology-blogs","基础知识",{"type":15,"children":16,"toc":595},"root",[17,25,31,36,41,53,58,68,73,83,88,93,98,103,114,119,159,164,174,210,215,220,224,228,233,238,243,247,251,255,260,265,269,273,278,283,287,291,295,300,304,309,314,318,323,327,331,336,340,345,349,353,357,361,366,370,374,378,382,387,391,395,399,403,408,434,442,447,452,457,465,470,478,484,489,536,549,557,584,590],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"ai数据框架大横评4之采样器",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":29},"h2",{"id":28},"前言",[30],{"type":24,"value":28},{"type":18,"tag":32,"props":33,"children":34},"p",{},[35],{"type":24,"value":9},{"type":18,"tag":32,"props":37,"children":38},{},[39],{"type":24,"value":40},"第一篇，我们简要对比了当前主流AI数据框架的架构设计，从中可以略微看出各家框架的主要设计理念和应用场景。",{"type":18,"tag":32,"props":42,"children":43},{},[44],{"type":18,"tag":45,"props":46,"children":50},"a",{"href":47,"rel":48},"https://zhuanlan.zhihu.com/p/692425560",[49],"nofollow",[51],{"type":24,"value":52},"AI数据框架大横评之架构设计",{"type":18,"tag":32,"props":54,"children":55},{},[56],{"type":24,"value":57},"第二篇，我们简要对比了当前主流AI数据框架的数据加载方式。",{"type":18,"tag":32,"props":59,"children":60},{},[61],{"type":18,"tag":45,"props":62,"children":65},{"href":63,"rel":64},"https://zhuanlan.zhihu.com/p/707501176",[49],[66],{"type":24,"value":67},"AI数据框架大横评之数据加载",{"type":18,"tag":32,"props":69,"children":70},{},[71],{"type":24,"value":72},"第三篇，我们简要对比了当前主流AI数据框架的数据处理方式。",{"type":18,"tag":32,"props":74,"children":75},{},[76],{"type":18,"tag":45,"props":77,"children":80},{"href":78,"rel":79},"https://zhuanlan.zhihu.com/p/3082846735",[49],[81],{"type":24,"value":82},"AI数据框架大横评之数据处理",{"type":18,"tag":32,"props":84,"children":85},{},[86],{"type":24,"value":87},"建议大家先阅读以上几篇文章，再开始下面的阅读。",{"type":18,"tag":26,"props":89,"children":91},{"id":90},"采样器",[92],{"type":24,"value":90},{"type":18,"tag":32,"props":94,"children":95},{},[96],{"type":24,"value":97},"对于可迭代的（Iterable Style）数据集，数据加载顺序完全由用户定义的迭代逻辑控制。",{"type":18,"tag":32,"props":99,"children":100},{},[101],{"type":24,"value":102},"而对于可随机访问的（Map Style）数据集，则可以通过采样器（Sampler）生成自定义顺序的索引/键，再通过随机访问的能力加载数据。",{"type":18,"tag":104,"props":105,"children":107},"h3",{"id":106},"mindspore",[108],{"type":18,"tag":109,"props":110,"children":111},"strong",{},[112],{"type":24,"value":113},"MindSpore",{"type":18,"tag":32,"props":115,"children":116},{},[117],{"type":24,"value":118},"MindSpore提供了丰富的采样器API供用户开箱即用。",{"type":18,"tag":120,"props":121,"children":122},"ul",{},[123,129,134,139,144,149,154],{"type":18,"tag":124,"props":125,"children":126},"li",{},[127],{"type":24,"value":128},"mindspore.dataset.SequentialSampler：按顺序采样指定数量样本。",{"type":18,"tag":124,"props":130,"children":131},{},[132],{"type":24,"value":133},"mindspore.dataset.RandomSampler：按随机顺序采样指定数量样本。",{"type":18,"tag":124,"props":135,"children":136},{},[137],{"type":24,"value":138},"mindspore.dataset.DistributedSampler：将样本等分用于分布式训练。",{"type":18,"tag":124,"props":140,"children":141},{},[142],{"type":24,"value":143},"mindspore.dataset.PKSampler：在P个类别中各采样K个样本。",{"type":18,"tag":124,"props":145,"children":146},{},[147],{"type":24,"value":148},"mindspore.dataset.SubsetSampler：根据指定的索引列表进行采样。",{"type":18,"tag":124,"props":150,"children":151},{},[152],{"type":24,"value":153},"mindspore.dataset.SubsetRandomSampler：根据指定的索引列表进行随机采样。",{"type":18,"tag":124,"props":155,"children":156},{},[157],{"type":24,"value":158},"mindspore.dataset.WeightedRandomSampler：根据指定的各个类别的概率采样样本。",{"type":18,"tag":32,"props":160,"children":161},{},[162],{"type":24,"value":163},"还是以处理MNIST手写字识别数据集为例，首先根据自己的策略，定义采样器，然后传给数据加载接口即可。",{"type":18,"tag":165,"props":166,"children":168},"pre",{"code":167},"import mindspore.dataset as ds\n\nmnist_dataset_dir = \"/path/to/mnist_dataset_directory\"\nsampler = RandomSampler()\ndataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=sampler)\n",[169],{"type":18,"tag":170,"props":171,"children":172},"code",{"__ignoreMap":7},[173],{"type":24,"value":167},{"type":18,"tag":32,"props":175,"children":176},{},[177,179,185,187,193,194,200,202,208],{"type":24,"value":178},"为了简化编码流程，用户也可直接通过数据加载接口的",{"type":18,"tag":170,"props":180,"children":182},{"className":181},[],[183],{"type":24,"value":184},"num_samples",{"type":24,"value":186},"、",{"type":18,"tag":170,"props":188,"children":190},{"className":189},[],[191],{"type":24,"value":192},"shuffle",{"type":24,"value":186},{"type":18,"tag":170,"props":195,"children":197},{"className":196},[],[198],{"type":24,"value":199},"num_shards",{"type":24,"value":201},"和",{"type":18,"tag":170,"props":203,"children":205},{"className":204},[],[206],{"type":24,"value":207},"shard_id",{"type":24,"value":209},"参数控制采样器的使用，具体如下：",{"type":18,"tag":32,"props":211,"children":212},{},[213],{"type":24,"value":214},"sampler",{"type":18,"tag":32,"props":216,"children":217},{},[218],{"type":24,"value":219},"num_shards/shard_id",{"type":18,"tag":32,"props":221,"children":222},{},[223],{"type":24,"value":192},{"type":18,"tag":32,"props":225,"children":226},{},[227],{"type":24,"value":184},{"type":18,"tag":32,"props":229,"children":230},{},[231],{"type":24,"value":232},"使用的采样器",{"type":18,"tag":32,"props":234,"children":235},{},[236],{"type":24,"value":237},"mindspore.dataset.Sampler 类型",{"type":18,"tag":32,"props":239,"children":240},{},[241],{"type":24,"value":242},"None",{"type":18,"tag":32,"props":244,"children":245},{},[246],{"type":24,"value":242},{"type":18,"tag":32,"props":248,"children":249},{},[250],{"type":24,"value":242},{"type":18,"tag":32,"props":252,"children":253},{},[254],{"type":24,"value":214},{"type":18,"tag":32,"props":256,"children":257},{},[258],{"type":24,"value":259},"numpy.ndarray,list,tuple,int 类型",{"type":18,"tag":32,"props":261,"children":262},{},[263],{"type":24,"value":264},"/",{"type":18,"tag":32,"props":266,"children":267},{},[268],{"type":24,"value":264},{"type":18,"tag":32,"props":270,"children":271},{},[272],{"type":24,"value":184},{"type":18,"tag":32,"props":274,"children":275},{},[276],{"type":24,"value":277},"SubsetSampler(indices=sampler, num_samples=num_samples)",{"type":18,"tag":32,"props":279,"children":280},{},[281],{"type":24,"value":282},"iterable 类型",{"type":18,"tag":32,"props":284,"children":285},{},[286],{"type":24,"value":264},{"type":18,"tag":32,"props":288,"children":289},{},[290],{"type":24,"value":264},{"type":18,"tag":32,"props":292,"children":293},{},[294],{"type":24,"value":184},{"type":18,"tag":32,"props":296,"children":297},{},[298],{"type":24,"value":299},"IterSampler(sampler=sampler, num_samples=num_samples)",{"type":18,"tag":32,"props":301,"children":302},{},[303],{"type":24,"value":242},{"type":18,"tag":32,"props":305,"children":306},{},[307],{"type":24,"value":308},"num_shards / shard_id",{"type":18,"tag":32,"props":310,"children":311},{},[312],{"type":24,"value":313},"None / True",{"type":18,"tag":32,"props":315,"children":316},{},[317],{"type":24,"value":184},{"type":18,"tag":32,"props":319,"children":320},{},[321],{"type":24,"value":322},"DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=True, num_samples=num_samples)",{"type":18,"tag":32,"props":324,"children":325},{},[326],{"type":24,"value":242},{"type":18,"tag":32,"props":328,"children":329},{},[330],{"type":24,"value":308},{"type":18,"tag":32,"props":332,"children":333},{},[334],{"type":24,"value":335},"False",{"type":18,"tag":32,"props":337,"children":338},{},[339],{"type":24,"value":184},{"type":18,"tag":32,"props":341,"children":342},{},[343],{"type":24,"value":344},"DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=num_samples)",{"type":18,"tag":32,"props":346,"children":347},{},[348],{"type":24,"value":242},{"type":18,"tag":32,"props":350,"children":351},{},[352],{"type":24,"value":242},{"type":18,"tag":32,"props":354,"children":355},{},[356],{"type":24,"value":313},{"type":18,"tag":32,"props":358,"children":359},{},[360],{"type":24,"value":242},{"type":18,"tag":32,"props":362,"children":363},{},[364],{"type":24,"value":365},"RandomSampler(num_samples=num_samples)",{"type":18,"tag":32,"props":367,"children":368},{},[369],{"type":24,"value":242},{"type":18,"tag":32,"props":371,"children":372},{},[373],{"type":24,"value":242},{"type":18,"tag":32,"props":375,"children":376},{},[377],{"type":24,"value":313},{"type":18,"tag":32,"props":379,"children":380},{},[381],{"type":24,"value":184},{"type":18,"tag":32,"props":383,"children":384},{},[385],{"type":24,"value":386},"RandomSampler(replacement=True, num_samples=num_samples)",{"type":18,"tag":32,"props":388,"children":389},{},[390],{"type":24,"value":242},{"type":18,"tag":32,"props":392,"children":393},{},[394],{"type":24,"value":242},{"type":18,"tag":32,"props":396,"children":397},{},[398],{"type":24,"value":335},{"type":18,"tag":32,"props":400,"children":401},{},[402],{"type":24,"value":184},{"type":18,"tag":32,"props":404,"children":405},{},[406],{"type":24,"value":407},"SequentialSampler(num_samples=num_samples)",{"type":18,"tag":32,"props":409,"children":410},{},[411,413,418,419,424,426,432],{"type":24,"value":412},"例如下列代码指定了",{"type":18,"tag":170,"props":414,"children":416},{"className":415},[],[417],{"type":24,"value":199},{"type":24,"value":201},{"type":18,"tag":170,"props":420,"children":422},{"className":421},[],[423],{"type":24,"value":207},{"type":24,"value":425},"参数，则等同于先构造了",{"type":18,"tag":170,"props":427,"children":429},{"className":428},[],[430],{"type":24,"value":431},"DistributedSampler",{"type":24,"value":433},"，再传给数据加载接口：",{"type":18,"tag":165,"props":435,"children":437},{"code":436},"import mindspore.dataset as ds\n\nmnist_dataset_dir = \"/path/to/mnist_dataset_directory\"\n\n# 直接通过数据加载API的参数创建采样器\ndataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_shards=8, shard_id=0)\n# 先定义采样器，再传给数据加载API\ndataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, sampler=DistributedSampler(num_shards=8, shard_id=0))\n",[438],{"type":18,"tag":170,"props":439,"children":440},{"__ignoreMap":7},[441],{"type":24,"value":436},{"type":18,"tag":32,"props":443,"children":444},{},[445],{"type":24,"value":446},"用户也可以根据自己的需要编写自定义采样逻辑。",{"type":18,"tag":32,"props":448,"children":449},{},[450],{"type":24,"value":451},"与自定义数据加载一样，采样器也可实现为可随机访问的（Map Style）和可迭代的（Iterable Style）两种。",{"type":18,"tag":32,"props":453,"children":454},{},[455],{"type":24,"value":456},"对于可随机访问的采样器，可编写自定义采样器类，提供 __getitem__ 和 __len__ 方法，例如：",{"type":18,"tag":165,"props":458,"children":460},{"code":459},"class MySampler():\n    def __init__(self):\n        self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]\n\n    def __getitem__(self, index):\n        return self.index_ids[index]\n\n    def __len__(self):\n        return len(self.index_ids)\n",[461],{"type":18,"tag":170,"props":462,"children":463},{"__ignoreMap":7},[464],{"type":24,"value":459},{"type":18,"tag":32,"props":466,"children":467},{},[468],{"type":24,"value":469},"对于可迭代的采样器，可编写自定义采样器类，提供 __iter__ 方法，例如：",{"type":18,"tag":165,"props":471,"children":473},{"code":472},"class MySampler:\n    def __iter__(self):\n        for i in range(0, 100, 2):\n            yield i\n",[474],{"type":18,"tag":170,"props":475,"children":476},{"__ignoreMap":7},[477],{"type":24,"value":472},{"type":18,"tag":104,"props":479,"children":481},{"id":480},"pytorch",[482],{"type":24,"value":483},"PyTorch",{"type":18,"tag":32,"props":485,"children":486},{},[487],{"type":24,"value":488},"PyTorch同样提供了丰富的采样器API供用户开箱即用。",{"type":18,"tag":120,"props":490,"children":491},{},[492,497,502,507,512,517,522],{"type":18,"tag":124,"props":493,"children":494},{},[495],{"type":24,"value":496},"torch.utils.data.Sampler：所有采样器的基类。",{"type":18,"tag":124,"props":498,"children":499},{},[500],{"type":24,"value":501},"torch.utils.data.SequentialSampler：按顺序采样样本。",{"type":18,"tag":124,"props":503,"children":504},{},[505],{"type":24,"value":506},"torch.utils.data.RandomSampler：按随机顺序采样指定数量样本。",{"type":18,"tag":124,"props":508,"children":509},{},[510],{"type":24,"value":511},"torch.utils.data.SubsetRandomSampler：根据指定的索引列表进行随机采样。",{"type":18,"tag":124,"props":513,"children":514},{},[515],{"type":24,"value":516},"torch.utils.data.WeightedRandomSampler：根据指定的各个类别的概率采样样本。",{"type":18,"tag":124,"props":518,"children":519},{},[520],{"type":24,"value":521},"torch.utils.data.BatchSampler：每次返回一个batch的样本索引。",{"type":18,"tag":124,"props":523,"children":524},{},[525,527,534],{"type":24,"value":526},"torch.utils.data.distributed.DistributedSampler：",{"type":18,"tag":45,"props":528,"children":531},{"href":529,"rel":530},"https://zhida.zhihu.com/search?content_id=251078289&content_type=Article&match_order=1&q=%E5%88%86%E5%B8%83%E5%BC%8F%E9%87%87%E6%A0%B7%E5%99%A8&zhida_source=entity",[49],[532],{"type":24,"value":533},"分布式采样器",{"type":24,"value":535},"。",{"type":18,"tag":32,"props":537,"children":538},{},[539,541,547],{"type":24,"value":540},"PyTorch在创建采样器时，需要先传入数据集对象，最后同时将数据集和采样器对象传给",{"type":18,"tag":170,"props":542,"children":544},{"className":543},[],[545],{"type":24,"value":546},"DataLoader",{"type":24,"value":548},"即可。",{"type":18,"tag":165,"props":550,"children":552},{"code":551},"from torch.utils.data import Dataset, RandomSampler, DataLoader\n\nclass MapStyleDataset(Dataset):\n    def __init__(self, dataset_dir):\n        self.files = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir)]\n\n    def __getitem__(self, index):\n        return np.load(self.files[index])\n\n    def __len__(self):\n        return len(self.files)\n\ndataset = MapStyleDataset(\"/path/to/dataset_directory\")\nsampler = RandomSampler(dataset)\nloader = DataLoader(dataset=dataset, sampler=sampler)\n",[553],{"type":18,"tag":170,"props":554,"children":555},{"__ignoreMap":7},[556],{"type":24,"value":551},{"type":18,"tag":32,"props":558,"children":559},{},[560,562,567,569,574,576,582],{"type":24,"value":561},"同样，PyTorch的",{"type":18,"tag":170,"props":563,"children":565},{"className":564},[],[566],{"type":24,"value":546},{"type":24,"value":568},"接口也提供了",{"type":18,"tag":170,"props":570,"children":572},{"className":571},[],[573],{"type":24,"value":192},{"type":24,"value":575},"参数，方便用户快速创建",{"type":18,"tag":170,"props":577,"children":579},{"className":578},[],[580],{"type":24,"value":581},"RandomSampler",{"type":24,"value":583},"，但对于其他数据集，就得手动构造了。",{"type":18,"tag":104,"props":585,"children":587},{"id":586},"tensorflow",[588],{"type":24,"value":589},"TensorFlow",{"type":18,"tag":32,"props":591,"children":592},{},[593],{"type":24,"value":594},"TensorFlow未提供采样器的功能。",{"title":7,"searchDepth":596,"depth":596,"links":597},4,[598,600],{"id":28,"depth":599,"text":28},2,{"id":90,"depth":599,"text":90,"children":601},[602,604,605],{"id":106,"depth":603,"text":113},3,{"id":480,"depth":603,"text":483},{"id":586,"depth":603,"text":589},"markdown","content:technology-blogs:zh:3547.md","content","technology-blogs/zh/3547.md","technology-blogs/zh/3547","md",1776506130809]