[{"data":1,"prerenderedAt":890},["ShallowReactive",2],{"content-query-t5VQfTlEnv":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":884,"_id":885,"_source":886,"_file":887,"_stem":888,"_extension":889},"/technology-blogs/zh/3459","zh",false,"","AI数据框架大横评（3）","开始本文的内容前，先简单回顾一下本系列的前两篇文章：","2024-11-08","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2024/11/28/218db861562e41038ee9f037b538dc59.png","technology-blogs","基础知识",{"type":15,"children":16,"toc":867},"root",[17,25,41,46,51,65,76,81,95,100,111,116,121,126,142,147,152,157,162,167,172,177,182,187,192,197,202,207,212,217,222,227,232,237,242,247,252,257,262,267,272,277,282,287,292,307,312,327,332,337,342,347,352,357,362,366,370,375,380,385,390,395,400,405,410,415,420,425,430,435,439,443,448,453,458,463,468,483,488,493,498,503,508,513,518,523,528,532,537,542,547,552,556,561,566,571,576,581,586,591,595,599,604,609,614,619,623,628,633,638,643,648,653,658,663,678,683,698,703,708,713,718,723,728,733,738,743,757,762,767,771,776,781,786,791,795,800,805,810,815,820,825,829,834,839,844,847,852,857,862],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"ai数据框架大横评3",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":29},"h2",{"id":28},"前言",[30],{"type":18,"tag":31,"props":32,"children":33},"strong",{},[34],{"type":18,"tag":31,"props":35,"children":36},{},[37],{"type":18,"tag":31,"props":38,"children":39},{},[40],{"type":24,"value":28},{"type":18,"tag":42,"props":43,"children":44},"p",{},[45],{"type":24,"value":9},{"type":18,"tag":42,"props":47,"children":48},{},[49],{"type":24,"value":50},"首先，我们简要对比了当前主流AI数据框架的架构设计，从中可以略微看出各家框架的主要设计理念和应用场景。",{"type":18,"tag":42,"props":52,"children":53},{},[54,56],{"type":24,"value":55},"AI数据框架大横评：",{"type":18,"tag":57,"props":58,"children":62},"a",{"href":59,"rel":60},"https://www.hiascend.com/developer/blog/details/0222150949518615069%E7%84%B6%E5%90%8E%EF%BC%8C%E6%88%91%E4%BB%AC%E7%AE%80%E8%A6%81%E5%AF%B9%E6%AF%94%E4%BA%86%E5%BD%93%E5%89%8D%E4%B8%BB%E6%B5%81AI%E6%95%B0%E6%8D%AE%E6%A1%86%E6%9E%B6%E7%9A%84%E6%95%B0%E6%8D%AE%E5%8A%A0%E8%BD%BD%E6%96%B9%E5%BC%8F%E3%80%82",[61],"nofollow",[63],{"type":24,"value":64},"https://www.hiascend.com/developer/blog/details/0222150949518615069然后，我们简要对比了当前主流AI数据框架的数据加载方式。",{"type":18,"tag":42,"props":66,"children":67},{},[68,70],{"type":24,"value":69},"AI数据框架大横评（2）：",{"type":18,"tag":57,"props":71,"children":74},{"href":72,"rel":73},"https://www.hiascend.com/developer/blog/details/0286156408632279318",[61],[75],{"type":24,"value":72},{"type":18,"tag":42,"props":77,"children":78},{},[79],{"type":24,"value":80},"建议大家先阅读以上两篇文章，再开始下面的阅读。",{"type":18,"tag":26,"props":82,"children":84},{"id":83},"数据处理",[85],{"type":18,"tag":31,"props":86,"children":87},{},[88],{"type":18,"tag":31,"props":89,"children":90},{},[91],{"type":18,"tag":31,"props":92,"children":93},{},[94],{"type":24,"value":83},{"type":18,"tag":42,"props":96,"children":97},{},[98],{"type":24,"value":99},"当原始数据从存储设备加载到内存后，往往还需要进行一系列的处理，才能传递到网络进行训练。",{"type":18,"tag":42,"props":101,"children":102},{},[103],{"type":18,"tag":31,"props":104,"children":105},{},[106],{"type":18,"tag":31,"props":107,"children":108},{},[109],{"type":24,"value":110},"数据处理是对已加载的数据进行加工整理，形成适合网络学习的样式，它是模型训练前必不可少的阶段。",{"type":18,"tag":42,"props":112,"children":113},{},[114],{"type":24,"value":115},"针对不同的目的，数据处理的手段各不一样。当原始数据杂乱无章时，需要对其进行去重、清洗；当原始数据量偏少，或是分布不均匀时，需要对齐进行增广；当原始数据特征难以学习时，需要对其进行特征变换、归一化。",{"type":18,"tag":42,"props":117,"children":118},{},[119],{"type":24,"value":120},"除此之外，不同领域的任务，如CV、NLP、Audio等，也都具有各自不同的经典处理方法，如图像的放缩、裁切，文本的分词、向量化，音频的频谱变换、滤波等。",{"type":18,"tag":42,"props":122,"children":123},{},[124],{"type":24,"value":125},"作为AI数据框架，很难全部满足用户差异化的处理需求，所以往往只提供了部分基础的数据处理接口，并开放自定义的能力，方便用户根据自己的需要编写个性化的数据处理方法。",{"type":18,"tag":127,"props":128,"children":130},"h3",{"id":129},"mindspore",[131],{"type":18,"tag":31,"props":132,"children":133},{},[134],{"type":18,"tag":31,"props":135,"children":136},{},[137],{"type":18,"tag":31,"props":138,"children":139},{},[140],{"type":24,"value":141},"MindSpore",{"type":18,"tag":42,"props":143,"children":144},{},[145],{"type":24,"value":146},"MindSpore提供了一系列通用的数据处理API，本文将主要介绍较为常用的Map和Batch操作，其余API可自行参考官方文档。",{"type":18,"tag":42,"props":148,"children":149},{},[150],{"type":24,"value":151},"1）Map",{"type":18,"tag":42,"props":153,"children":154},{},[155],{"type":24,"value":156},"Map操作用于将指定的一系列数据变换（Transforms）依次作用于每个样本。",{"type":18,"tag":42,"props":158,"children":159},{},[160],{"type":24,"value":161},"MindSpore提供了一些常用的数据变换API，方便用户一键式使用，避免书写冗长的处理逻辑。",{"type":18,"tag":42,"props":163,"children":164},{},[165],{"type":24,"value":166},"以处理MNIST手写字识别数据集为例，首先按照上一篇文章的方法加载数据集。",{"type":18,"tag":42,"props":168,"children":169},{},[170],{"type":24,"value":171},"import mindspore.dataset as ds",{"type":18,"tag":42,"props":173,"children":174},{},[175],{"type":24,"value":176},"mnist_dataset_dir = \"/path/to/mnist_dataset_directory\"",{"type":18,"tag":42,"props":178,"children":179},{},[180],{"type":24,"value":181},"dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)",{"type":18,"tag":42,"props":183,"children":184},{},[185],{"type":24,"value":186},"然后即可使用Map指定想要对数据集中图像执行的数据变换。",{"type":18,"tag":42,"props":188,"children":189},{},[190],{"type":24,"value":191},"下列代码首先定义了一个随机裁切变换，并指定输出的图像大小为（5，5）。然后使用Map操作将该变换作用于数据集中的“image”列。这样一来，MNIST数据集中的每张图像都将在随机位置被裁切出一个（5，5）的子图像，用于后续处理。",{"type":18,"tag":42,"props":193,"children":194},{},[195],{"type":24,"value":196},"import mindspore.dataset.vision as vision",{"type":18,"tag":42,"props":198,"children":199},{},[200],{"type":24,"value":201},"random_crop = vision.RandomCrop(5)",{"type":18,"tag":42,"props":203,"children":204},{},[205],{"type":24,"value":206},"dataset = dataset.map(random_crop, input_columns=[\"image\"])",{"type":18,"tag":42,"props":208,"children":209},{},[210],{"type":24,"value":211},"如果官方提供的API无法满足需要，用户也可编写自定义处理逻辑，并通过Map执行。",{"type":18,"tag":42,"props":213,"children":214},{},[215],{"type":24,"value":216},"下列代码自定义了一个归一化变换，用于将MNIST数据集图像的像素值从 [0, 255) 归一化到 [0, 1) ，并通过Map执行。",{"type":18,"tag":42,"props":218,"children":219},{},[220],{"type":24,"value":221},"def normalize(image):",{"type":18,"tag":42,"props":223,"children":224},{},[225],{"type":24,"value":226},"return image / 255",{"type":18,"tag":42,"props":228,"children":229},{},[230],{"type":24,"value":231},"dataset = dataset.map(normalize, input_columns=[\"image\"])",{"type":18,"tag":42,"props":233,"children":234},{},[235],{"type":24,"value":236},"2）Batch",{"type":18,"tag":42,"props":238,"children":239},{},[240],{"type":24,"value":241},"Batch操作用于将连续batch_size个样本组合为1个批数据，有利于更好地利用Device的并行运算能力。",{"type":18,"tag":42,"props":243,"children":244},{},[245],{"type":24,"value":246},"下列使用Batch操作将MNIST数据集中连续的32个样本组合为1个批数据。若原始输入数据的shape为（28，28，3），则Batch后的数据的shape将变为（32，28，28，3）。",{"type":18,"tag":42,"props":248,"children":249},{},[250],{"type":24,"value":251},"dataset = dataset.batch(32)",{"type":18,"tag":42,"props":253,"children":254},{},[255],{"type":24,"value":256},"需要注意的是，若输入数据的shape不固定，则无法直接拼接，此时需要用户自行定义逻辑进行填充或截断，使得同一批内的样本的shape保持一致。",{"type":18,"tag":42,"props":258,"children":259},{},[260],{"type":24,"value":261},"下列自定义了一个截断函数，遍历获取到的连续batch_size个样本组成的列表，将其中的每个图片截取为（5，5，3）大小，再将样本的列表组装成numpy类型并返回。",{"type":18,"tag":42,"props":263,"children":264},{},[265],{"type":24,"value":266},"def truncate(images, labels, batch_info):",{"type":18,"tag":42,"props":268,"children":269},{},[270],{"type":24,"value":271},"truncate_images = []",{"type":18,"tag":42,"props":273,"children":274},{},[275],{"type":24,"value":276},"for image in images:",{"type":18,"tag":42,"props":278,"children":279},{},[280],{"type":24,"value":281},"truncate_images.append(image[:5, :5, :])",{"type":18,"tag":42,"props":283,"children":284},{},[285],{"type":24,"value":286},"return np.array(truncate_images), np.array(labels)",{"type":18,"tag":42,"props":288,"children":289},{},[290],{"type":24,"value":291},"dataset = dataset.batch(32, per_batch_map=truncate)",{"type":18,"tag":26,"props":293,"children":295},{"id":294},"pytorch",[296],{"type":18,"tag":31,"props":297,"children":298},{},[299],{"type":18,"tag":31,"props":300,"children":301},{},[302],{"type":18,"tag":31,"props":303,"children":304},{},[305],{"type":24,"value":306},"PyTorch",{"type":18,"tag":42,"props":308,"children":309},{},[310],{"type":24,"value":311},"PyTorch将所有数据处理功能都集成到DataLoader接口之上，对于未提供的功能，则需用户编写自定义逻辑实现。但为了对比方便，下文依旧从功能的角度讲解PyTorch的用法。",{"type":18,"tag":127,"props":313,"children":315},{"id":314},"transforms",[316],{"type":18,"tag":31,"props":317,"children":318},{},[319],{"type":18,"tag":31,"props":320,"children":321},{},[322],{"type":18,"tag":31,"props":323,"children":324},{},[325],{"type":24,"value":326},"Transforms",{"type":18,"tag":42,"props":328,"children":329},{},[330],{"type":24,"value":331},"PyTorch并未将数据变换视为一个单独的步骤，而是将其融合到了数据加载中，即在用户自定义数据加载类时，应先完成对数据的变换，再通过 __getitem__ 或 __next__ 返回。",{"type":18,"tag":42,"props":333,"children":334},{},[335],{"type":24,"value":336},"下列代码在编写自定义数据加载类时，除了完成对图片文件的读取外，还进行了随机裁切和归一化变换。",{"type":18,"tag":42,"props":338,"children":339},{},[340],{"type":24,"value":341},"其中随机裁切使用了torchvision官方提供的接口，这些接口底层基于torch的Tensor运算实现，具有更好的性能，更为推荐使用。",{"type":18,"tag":42,"props":343,"children":344},{},[345],{"type":24,"value":346},"如果官方提供的API无法满足需要，用户也可参考其中归一化函数的写法，自行编写处理逻辑。",{"type":18,"tag":42,"props":348,"children":349},{},[350],{"type":24,"value":351},"import torch",{"type":18,"tag":42,"props":353,"children":354},{},[355],{"type":24,"value":356},"from torchvision.io import decode_image",{"type":18,"tag":42,"props":358,"children":359},{},[360],{"type":24,"value":361},"from torchvision.transforms import v2",{"type":18,"tag":42,"props":363,"children":364},{},[365],{"type":24,"value":221},{"type":18,"tag":42,"props":367,"children":368},{},[369],{"type":24,"value":226},{"type":18,"tag":42,"props":371,"children":372},{},[373],{"type":24,"value":374},"class ImageDataset(torch.utils.data.Dataset):",{"type":18,"tag":42,"props":376,"children":377},{},[378],{"type":24,"value":379},"def __init__(self, image_dir):",{"type":18,"tag":42,"props":381,"children":382},{},[383],{"type":24,"value":384},"self.files = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]",{"type":18,"tag":42,"props":386,"children":387},{},[388],{"type":24,"value":389},"def __getitem__(self, index):",{"type":18,"tag":42,"props":391,"children":392},{},[393],{"type":24,"value":394},"image = decode_image(self.files[index], mode=\"RGB\")",{"type":18,"tag":42,"props":396,"children":397},{},[398],{"type":24,"value":399},"croped_image = v2.RandomCrop(5)(image)",{"type":18,"tag":42,"props":401,"children":402},{},[403],{"type":24,"value":404},"normalized_image = normalize(croped_image)",{"type":18,"tag":42,"props":406,"children":407},{},[408],{"type":24,"value":409},"return normalized_image",{"type":18,"tag":42,"props":411,"children":412},{},[413],{"type":24,"value":414},"def __len__(self):",{"type":18,"tag":42,"props":416,"children":417},{},[418],{"type":24,"value":419},"return len(self.files)",{"type":18,"tag":42,"props":421,"children":422},{},[423],{"type":24,"value":424},"需要特殊说明的是，在使用torch官方预封装的数据加载API时，由于加载逻辑无法修改，需要通过 transform 和 target_transform 参数传递需要执行的数据变换。",{"type":18,"tag":42,"props":426,"children":427},{},[428],{"type":24,"value":429},"下列代码使用torchvision官方提供的API加载MNIST数据集，并通过 transform 参数指定对图像执行随机裁切变换，通过 target_transform 参数指定对标签执行类型转换变换。",{"type":18,"tag":42,"props":431,"children":432},{},[433],{"type":24,"value":434},"import torchvision",{"type":18,"tag":42,"props":436,"children":437},{},[438],{"type":24,"value":361},{"type":18,"tag":42,"props":440,"children":441},{},[442],{"type":24,"value":176},{"type":18,"tag":42,"props":444,"children":445},{},[446],{"type":24,"value":447},"image_transform = v2.RandomCrop(5)",{"type":18,"tag":42,"props":449,"children":450},{},[451],{"type":24,"value":452},"label_transform = v2.ToDtype(torch.uint8)",{"type":18,"tag":42,"props":454,"children":455},{},[456],{"type":24,"value":457},"dataset = torchvision.datasets.MNIST(root=mnist_dataset_dir,",{"type":18,"tag":42,"props":459,"children":460},{},[461],{"type":24,"value":462},"transform=image_transform,",{"type":18,"tag":42,"props":464,"children":465},{},[466],{"type":24,"value":467},"target_transform=label_transform)",{"type":18,"tag":127,"props":469,"children":471},{"id":470},"batch",[472],{"type":18,"tag":31,"props":473,"children":474},{},[475],{"type":18,"tag":31,"props":476,"children":477},{},[478],{"type":18,"tag":31,"props":479,"children":480},{},[481],{"type":24,"value":482},"Batch",{"type":18,"tag":42,"props":484,"children":485},{},[486],{"type":24,"value":487},"PyTorch的Batch操作通过 DataLoader 实现，当前有两种执行Batch的方式。",{"type":18,"tag":42,"props":489,"children":490},{},[491],{"type":24,"value":492},"下列代码通过 DataLoader 的 batch_size 参数指定了将连续的32个样本组合为1个批数据，这也是Batch操作最为常见的用法。",{"type":18,"tag":42,"props":494,"children":495},{},[496],{"type":24,"value":497},"from torch.utils.data import DataLoader",{"type":18,"tag":42,"props":499,"children":500},{},[501],{"type":24,"value":502},"from torchvision.datasets import MNIST",{"type":18,"tag":42,"props":504,"children":505},{},[506],{"type":24,"value":507},"dataset = MNIST(root=\"/path/to/mnist_dataset_directory\")",{"type":18,"tag":42,"props":509,"children":510},{},[511],{"type":24,"value":512},"dataloader = DataLoader(dataset, batch_size=32)",{"type":18,"tag":42,"props":514,"children":515},{},[516],{"type":24,"value":517},"除此之外，用户也可以通过自行编写采样器逻辑，来指定如何组合批数据。",{"type":18,"tag":42,"props":519,"children":520},{},[521],{"type":24,"value":522},"下列代码编写了一个Batch Sampler，通过 __iter__ 函数一次返回一批索引值，DataLoader 将依次读取对应样本，然后将其组合为1个批数据。",{"type":18,"tag":42,"props":524,"children":525},{},[526],{"type":24,"value":527},"from torch.utils.data import DataLoader, Sampler",{"type":18,"tag":42,"props":529,"children":530},{},[531],{"type":24,"value":502},{"type":18,"tag":42,"props":533,"children":534},{},[535],{"type":24,"value":536},"class BatchSampler(Sampler):",{"type":18,"tag":42,"props":538,"children":539},{},[540],{"type":24,"value":541},"def __init__(self, data_size, batch_size):",{"type":18,"tag":42,"props":543,"children":544},{},[545],{"type":24,"value":546},"self.data_size = data_size",{"type":18,"tag":42,"props":548,"children":549},{},[550],{"type":24,"value":551},"self.batch_size = batch_size",{"type":18,"tag":42,"props":553,"children":554},{},[555],{"type":24,"value":414},{"type":18,"tag":42,"props":557,"children":558},{},[559],{"type":24,"value":560},"return self.data_size // self.batch_size",{"type":18,"tag":42,"props":562,"children":563},{},[564],{"type":24,"value":565},"def __iter__(self):",{"type":18,"tag":42,"props":567,"children":568},{},[569],{"type":24,"value":570},"batch_indices = []",{"type":18,"tag":42,"props":572,"children":573},{},[574],{"type":24,"value":575},"for index in range(self.data_size):",{"type":18,"tag":42,"props":577,"children":578},{},[579],{"type":24,"value":580},"batch_indices.append(index)",{"type":18,"tag":42,"props":582,"children":583},{},[584],{"type":24,"value":585},"if len(batch_indices) == self.batch_size:",{"type":18,"tag":42,"props":587,"children":588},{},[589],{"type":24,"value":590},"yield batch_indices",{"type":18,"tag":42,"props":592,"children":593},{},[594],{"type":24,"value":570},{"type":18,"tag":42,"props":596,"children":597},{},[598],{"type":24,"value":507},{"type":18,"tag":42,"props":600,"children":601},{},[602],{"type":24,"value":603},"batch_sampler = BatchSampler(data_size=60000, batch_size=32)",{"type":18,"tag":42,"props":605,"children":606},{},[607],{"type":24,"value":608},"dataloader = DataLoader(dataset, batch_sampler=batch_sampler)",{"type":18,"tag":42,"props":610,"children":611},{},[612],{"type":24,"value":613},"同样，若输入数据间的shape不统一，是无法直接组合的，此时需要用户自行编写组合批数据的逻辑。",{"type":18,"tag":42,"props":615,"children":616},{},[617],{"type":24,"value":618},"下列代码加载了一批具有动态shape的数据，然后通过自定义 collate_fn ，将特征列沿着最长维进行填充，并组合为1个批数据返回。",{"type":18,"tag":42,"props":620,"children":621},{},[622],{"type":24,"value":497},{"type":18,"tag":42,"props":624,"children":625},{},[626],{"type":24,"value":627},"from torch.nn.utils.rnn import pad_sequence",{"type":18,"tag":42,"props":629,"children":630},{},[631],{"type":24,"value":632},"def collate_fn(data):",{"type":18,"tag":42,"props":634,"children":635},{},[636],{"type":24,"value":637},"features, targets = zip(*data)",{"type":18,"tag":42,"props":639,"children":640},{},[641],{"type":24,"value":642},"features = pad_sequence(features, batch_first=True)",{"type":18,"tag":42,"props":644,"children":645},{},[646],{"type":24,"value":647},"targets = torch.stack(targets)",{"type":18,"tag":42,"props":649,"children":650},{},[651],{"type":24,"value":652},"return features, targets",{"type":18,"tag":42,"props":654,"children":655},{},[656],{"type":24,"value":657},"dataset = DynamicShapeDataset()",{"type":18,"tag":42,"props":659,"children":660},{},[661],{"type":24,"value":662},"dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)",{"type":18,"tag":26,"props":664,"children":666},{"id":665},"tensorflow",[667],{"type":18,"tag":31,"props":668,"children":669},{},[670],{"type":18,"tag":31,"props":671,"children":672},{},[673],{"type":18,"tag":31,"props":674,"children":675},{},[676],{"type":24,"value":677},"TensorFlow",{"type":18,"tag":42,"props":679,"children":680},{},[681],{"type":24,"value":682},"TensorFlow的设计与MindSpore较为相似，当用户定义好数据集对象后，即可通过对象的方法执行各种数据处理操作。",{"type":18,"tag":127,"props":684,"children":686},{"id":685},"map",[687],{"type":18,"tag":31,"props":688,"children":689},{},[690],{"type":18,"tag":31,"props":691,"children":692},{},[693],{"type":18,"tag":31,"props":694,"children":695},{},[696],{"type":24,"value":697},"Map",{"type":18,"tag":42,"props":699,"children":700},{},[701],{"type":24,"value":702},"TensorFlow通过Map将指定的数据变换作用于每个样本。但其并未提供预置的数据变换，需要用户自行编写自定义逻辑。",{"type":18,"tag":42,"props":704,"children":705},{},[706],{"type":24,"value":707},"下列代码定义了一个包含两个字符串的数据集，然后通过自定义函数将字符串中的字符转为大写。",{"type":18,"tag":42,"props":709,"children":710},{},[711],{"type":24,"value":712},"import tensorflow as tf",{"type":18,"tag":42,"props":714,"children":715},{},[716],{"type":24,"value":717},"def upper_case_fn(data):",{"type":18,"tag":42,"props":719,"children":720},{},[721],{"type":24,"value":722},"return data.numpy().decode('utf-8').upper()",{"type":18,"tag":42,"props":724,"children":725},{},[726],{"type":24,"value":727},"dataset = tf.data.Dataset.from_tensor_slices(['hello', 'world'])",{"type":18,"tag":42,"props":729,"children":730},{},[731],{"type":24,"value":732},"dataset = dataset.map(lambda data: tf.py_function(func=upper_case_fn,",{"type":18,"tag":42,"props":734,"children":735},{},[736],{"type":24,"value":737},"inp=[data], Tout=tf.string))",{"type":18,"tag":42,"props":739,"children":740},{},[741],{"type":24,"value":742},"需要注意的是，由于TensorFlow会将自定义函数作为图来执行，当前并非所有的Python语法都支持自动转换为图，所以往往需要通过 tf.numpy_function 和 tf.py_function 来包装Python代码。",{"type":18,"tag":127,"props":744,"children":746},{"id":745},"batch-1",[747],{"type":18,"tag":31,"props":748,"children":749},{},[750],{"type":18,"tag":31,"props":751,"children":752},{},[753],{"type":18,"tag":31,"props":754,"children":755},{},[756],{"type":24,"value":482},{"type":18,"tag":42,"props":758,"children":759},{},[760],{"type":24,"value":761},"TensorFlow提供了Batch操作来执行基础的批处理。",{"type":18,"tag":42,"props":763,"children":764},{},[765],{"type":24,"value":766},"下列代码实现了一个简单的数据集，然后通过Batch操作将连续的3个样本组合为1个批数据。",{"type":18,"tag":42,"props":768,"children":769},{},[770],{"type":24,"value":712},{"type":18,"tag":42,"props":772,"children":773},{},[774],{"type":24,"value":775},"dataset = tf.data.Dataset.range(8)",{"type":18,"tag":42,"props":777,"children":778},{},[779],{"type":24,"value":780},"dataset = dataset.batch(3)",{"type":18,"tag":42,"props":782,"children":783},{},[784],{"type":24,"value":785},"但当用户输入数据的shape不固定时，则需要改用 padded_batch 或 ragged_batch 操作。",{"type":18,"tag":42,"props":787,"children":788},{},[789],{"type":24,"value":790},"下列代码定义了一个变长数据集，然后通过 padded_batch 操作，首先使用指定的 padding_values 将每条数据的长度填充至 padded_shapes，然后再组合为批数据。",{"type":18,"tag":42,"props":792,"children":793},{},[794],{"type":24,"value":712},{"type":18,"tag":42,"props":796,"children":797},{},[798],{"type":24,"value":799},"dataset = tf.data.Dataset.range(1, 5, output_type=tf.int32) # [[1], [2], [3], [4]]",{"type":18,"tag":42,"props":801,"children":802},{},[803],{"type":24,"value":804},"dataset = dataset.map(lambda x: tf.fill([x], x))) # [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]",{"type":18,"tag":42,"props":806,"children":807},{},[808],{"type":24,"value":809},"dataset = dataset.padded_batch(2, padded_shapes=5, # [[1, 0, 0, 0, 0], [2, 2, 0, 0, 0]]",{"type":18,"tag":42,"props":811,"children":812},{},[813],{"type":24,"value":814},"padding_values=0) # [[3, 3, 3, 0, 0], [4, 4, 4, 4, 0]]",{"type":18,"tag":42,"props":816,"children":817},{},[818],{"type":24,"value":819},"此外，TensorFlow还提供了一种特殊的Tensor结构，来表示变长数据，即 RaggedTensor。",{"type":18,"tag":42,"props":821,"children":822},{},[823],{"type":24,"value":824},"下列代码定义了一个变长数据集，然后通过 ragged_batch 操作将连续的 batch_size 条样本构造为1个RaggedTensor。",{"type":18,"tag":42,"props":826,"children":827},{},[828],{"type":24,"value":712},{"type":18,"tag":42,"props":830,"children":831},{},[832],{"type":24,"value":833},"dataset = tf.data.Dataset.range(1, 5) # [[1], [2], [3], [4]]",{"type":18,"tag":42,"props":835,"children":836},{},[837],{"type":24,"value":838},"dataset = dataset.map(lambda x: tf.range(x)) # [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3]]",{"type":18,"tag":42,"props":840,"children":841},{},[842],{"type":24,"value":843},"dataset = dataset.ragged_batch(2) #",{"type":18,"tag":19,"props":845,"children":846},{"id":7},[],{"type":18,"tag":42,"props":848,"children":849},{},[850],{"type":24,"value":851},"除了上述列举的操作外，TensorFlow还提供了以下几种Batch，本文将不再详细说明。",{"type":18,"tag":42,"props":853,"children":854},{},[855],{"type":24,"value":856},"sparse_batch：将变长数据组合为另一种稀疏的Tensor结构，即 SparseTensor。",{"type":18,"tag":42,"props":858,"children":859},{},[860],{"type":24,"value":861},"unbatch：将已经组装的批数据重新拆分为单个样本。",{"type":18,"tag":42,"props":863,"children":864},{},[865],{"type":24,"value":866},"rebatch： 将已经组装的批数据先拆分为单个样本，再根据新指定的参数重新组合。",{"title":7,"searchDepth":868,"depth":868,"links":869},4,[870,872,876,880],{"id":28,"depth":871,"text":28},2,{"id":83,"depth":871,"text":83,"children":873},[874],{"id":129,"depth":875,"text":141},3,{"id":294,"depth":871,"text":306,"children":877},[878,879],{"id":314,"depth":875,"text":326},{"id":470,"depth":875,"text":482},{"id":665,"depth":871,"text":677,"children":881},[882,883],{"id":685,"depth":875,"text":697},{"id":745,"depth":875,"text":482},"markdown","content:technology-blogs:zh:3459.md","content","technology-blogs/zh/3459.md","technology-blogs/zh/3459","md",1776506129968]