数据集加载总览

Ascend GPU CPU 数据准备

在线运行下载Notebook下载样例代码查看源文件

概述

MindSpore支持加载图像领域常用的数据集,用户可以直接使用mindspore.dataset中对应的类实现数据集的加载。支持的常用数据集及对应的数据集类举例如下,更多数据集支持情况请参考API文档。

图像数据集

数据集类

数据集简介

MNIST

MnistDataset

MNIST是一个大型手写数字图像数据集,拥有60,000张训练图像和10,000张测试图像,常用于训练各种图像处理系统。

CIFAR-10

Cifar10Dataset

CIFAR-10是一个微小图像数据集,包含10种类别下的60,000张32x32大小彩色图像,平均每种类别6,000张,其中5,000张为训练集,1,000张为测试集。

CIFAR-100

Cifar100Dataset

CIFAR-100与CIFAR-10类似,但拥有100种类别,平均每种类别600张,其中500张为训练集,100张为测试集。

CelebA

CelebADataset

CelebA是一个大型人脸图像数据集,包含超过200,000张名人人脸图像,每张图像拥有40个特征标记。

PASCAL-VOC

VOCDataset

PASCAL-VOC是一个常用图像数据集,被广泛用于目标检测、图像分割等计算机视觉领域。

COCO

CocoDataset

COCO是一个大型目标检测、图像分割、姿态估计数据集。

CLUE

CLUEDataset

CLUE是一个大型中文语义理解数据集。

Manifest

ManifestDataset

Manifest是华为ModelArts支持的一种数据格式,描述了原始文件和标注信息,可用于标注、训练、推理场景。

MindSpore支持加载文本领域常用的数据集,用户可以直接使用mindspore.dataset中对应的类实现数据集的加载。支持的常用数据集及对应的数据集类举例如下,更多数据集支持情况请参考API文档。

文本数据集

数据集类

数据集简介

IMDB

IMDBDataset

IMDB 数据集包含来自互联网电影数据库(IMDB)的 50,000 条严重两极分化的评论,25,000 条用于训练,25000条用于测试。

Wiki Text

WikiTextDataset

WikiText 英语词库数据是一个包含1亿个词汇的英文词库数据,这些词汇是从Wikipedia的优质文章和标杆文章中提取得到。

Yahoo Answers

YahooAnswersDataset

数据集的 10 个主要分类数据。每个类 别分别包含 140,000 个训练样本和 5,000 个测试样本。

Text File

TextFileDataset

文本文件数据集,其中每行文本是一个样本。

MindSpore支持加载音频领域常用的数据集,用户可以直接使用mindspore.dataset中对应的类实现数据集的加载。支持的常用数据集及对应的数据集类举例如下,更多数据集支持情况请参考API文档。

音频数据集

数据集类

数据集简介

LJSpeech

LJSpeechDataset

这是一个公共领域数据集语音数据集,由 13,100 个短音频剪辑组成,单个发言者阅读 7 本非小说类书籍段落。

Speech Commands

SpeechCommandsDataset

是一个有声单词的音频数据集,旨在帮助训练和评估关键字识别系统。

Ted-Lium

TedliumDataset

TED-LIUM语料库是英语TED演讲,带有转录,采样频率为 16kHZ,它包含大约 118 个小时的演讲时间。

MindSpore还支持加载多种数据存储格式下的数据集,用户可以直接使用mindspore.dataset中对应的类加载磁盘中的数据文件。目前支持的数据格式及对应加载方式如下表所示。

数据格式

数据集类

数据格式简介

MindRecord

MindDataset

MindRecord是MindSpore的自研数据格式,具有读写高效、易于分布式处理等优势。

TFRecord

TFRecordDataset

TFRecord是TensorFlow定义的一种二进制数据文件格式。

CSV File

CSVDataset

CSV指逗号分隔值,其文件以纯文本形式存储表格数据。

MindSpore也同样支持使用GeneratorDataset自定义数据集的加载方式,用户可以根据需要实现自己的数据集类。

自定义数据集类

数据格式简介

GeneratorDataset

用户自定义的数据集读取、处理的方式。

NumpySlicesDataset

用户自定义的由NumPy构建数据集的方式。

更多详细的数据集加载接口说明,参见API文档

常用数据集加载

下面将介绍几种常用数据集的加载方式。

CIFAR-10/100数据集

下载CIFAR-10数据集并解压到指定位置,以下示例代码将数据集下载并解压到指定位置。

[ ]:
import os
import requests
import tarfile
import zipfile
import shutil

requests.packages.urllib3.disable_warnings()

def download_dataset(url, target_path):
    """download and decompress dataset"""
    if not os.path.exists(target_path):
        os.makedirs(target_path)
    download_file = url.split("/")[-1]
    if not os.path.exists(download_file):
        res = requests.get(url, stream=True, verify=False)
        if download_file.split(".")[-1] not in ["tgz", "zip", "tar", "gz"]:
            download_file = os.path.join(target_path, download_file)
        with open(download_file, "wb") as f:
            for chunk in res.iter_content(chunk_size=512):
                if chunk:
                    f.write(chunk)
    if download_file.endswith("zip"):
        z = zipfile.ZipFile(download_file, "r")
        z.extractall(path=target_path)
        z.close()
    if download_file.endswith(".tar.gz") or download_file.endswith(".tar") or download_file.endswith(".tgz"):
        t = tarfile.open(download_file)
        names = t.getnames()
        for name in names:
            t.extract(name, target_path)
        t.close()
    print("The {} file is downloaded and saved in the path {} after processing".format(os.path.basename(url), target_path))

download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz", "./datasets")
test_path = "./datasets/cifar-10-batches-bin/test"
train_path = "./datasets/cifar-10-batches-bin/train"
os.makedirs(test_path, exist_ok=True)
os.makedirs(train_path, exist_ok=True)
if not os.path.exists(os.path.join(test_path, "test_batch.bin")):
    shutil.move("./datasets/cifar-10-batches-bin/test_batch.bin", test_path)
[shutil.move("./datasets/cifar-10-batches-bin/"+i, train_path) for i in os.listdir("./datasets/cifar-10-batches-bin/") if os.path.isfile("./datasets/cifar-10-batches-bin/"+i) and not i.endswith(".html") and not os.path.exists(os.path.join(train_path, i))]

解压后数据集文件的目录结构如下:

./datasets/cifar-10-batches-bin
├── readme.html
├── test
│   └── test_batch.bin
└── train
    ├── batches.meta.txt
    ├── data_batch_1.bin
    ├── data_batch_2.bin
    ├── data_batch_3.bin
    ├── data_batch_4.bin
    └── data_batch_5.bin

下面的样例通过Cifar10Dataset接口加载CIFAR-10数据集,使用顺序采样器获取其中5个样本,然后展示了对应图片的形状和标签。

CIFAR-100数据集和MNIST数据集的加载方式也与之类似。

[2]:
import mindspore.dataset as ds

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 1

VOC数据集

VOC数据集有多个版本,此处以VOC2012为例。下载VOC2012数据集并解压,如果点击下载不成功,请尝试复制链接地址后下载。目录结构如下。

└─ VOCtrainval_11-May-2012
    └── VOCdevkit
        └── VOC2012
            ├── Annotations
            ├── ImageSets
            ├── JPEGImages
            ├── SegmentationClass
            └── SegmentationObject

下面的样例通过VOCDataset接口加载VOC2012数据集,分别演示了将任务指定为分割(Segmentation)和检测(Detection)时的原始图像形状和目标形状。

import mindspore.dataset as ds

DATA_DIR = "VOCtrainval_11-May-2012/VOCdevkit/VOC2012/"

dataset = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", num_samples=2, decode=True, shuffle=False)

print("[Segmentation]:")
for data in dataset.create_dict_iterator():
    print("image shape:", data["image"].shape)
    print("target shape:", data["target"].shape)

dataset = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", num_samples=1, decode=True, shuffle=False)

print("[Detection]:")
for data in dataset.create_dict_iterator():
    print("image shape:", data["image"].shape)
    print("bbox shape:", data["bbox"].shape)

输出结果:

[Segmentation]:
image shape: (281, 500, 3)
target shape: (281, 500, 3)
image shape: (375, 500, 3)
target shape: (375, 500, 3)
[Detection]:
image shape: (442, 500, 3)
bbox shape: (2, 4)

COCO数据集

COCO数据集有多个版本,此处以COCO2017的验证数据集为例。下载COCO2017的验证集检测任务标注全景分割任务标注并解压,如果点击下载不成功,请尝试复制链接地址后下载。只取其中的验证集部分,按以下目录结构存放。

└─ COCO
    ├── val2017
    └── annotations
        ├── instances_val2017.json
        ├── panoptic_val2017.json
        └── person_keypoints_val2017.json

下面的样例通过CocoDataset接口加载COCO2017数据集,分别演示了将任务指定为目标检测(Detection)、背景分割(Stuff)、关键点检测(Keypoint)和全景分割(Panoptic)时获取到的不同数据。

import mindspore.dataset as ds

DATA_DIR = "COCO/val2017/"
ANNOTATION_FILE = "COCO/annotations/instances_val2017.json"
KEYPOINT_FILE = "COCO/annotations/person_keypoints_val2017.json"
PANOPTIC_FILE = "COCO/annotations/panoptic_val2017.json"

dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", num_samples=1)
for data in dataset.create_dict_iterator():
    print("Detection:", data.keys())

dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff", num_samples=1)
for data in dataset.create_dict_iterator():
    print("Stuff:", data.keys())

dataset = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint", num_samples=1)
for data in dataset.create_dict_iterator():
    print("Keypoint:", data.keys())

dataset = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", num_samples=1)
for data in dataset.create_dict_iterator():
    print("Panoptic:", data.keys())

输出结果:

Detection: dict_keys(['image', 'bbox', 'category_id', 'iscrowd'])
Stuff: dict_keys(['image', 'segmentation', 'iscrowd'])
Keypoint: dict_keys(['image', 'keypoints', 'num_keypoints'])
Panoptic: dict_keys(['image', 'bbox', 'category_id', 'iscrowd', 'area'])

Manifest数据格式

Manifest是华为ModelArts支持的数据格式文件,详细说明请参见Manifest文档

本次示例需下载测试数据test_manifest.zip并将其解压到指定位置,执行如下命令:

[ ]:
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_manifest.zip", "./datasets/mindspore_dataset_loading/test_manifest/")

解压后数据集文件的目录结构如下:

./datasets/mindspore_dataset_loading/test_manifest/
├── eval
│   ├── 1.JPEG
│   └── 2.JPEG
├── test_manifest.json
└── train
    ├── 1.JPEG
    └── 2.JPEG

下面的样例通过ManifestDataset接口加载Manifest文件test_manifest.json,并展示已加载数据的标签。

[6]:
import mindspore.dataset as ds

DATA_FILE = "./datasets/mindspore_dataset_loading/test_manifest/test_manifest.json"
manifest_dataset = ds.ManifestDataset(DATA_FILE)

for data in manifest_dataset.create_dict_iterator():
    print(data["label"])
0
1

特定格式数据集加载

下面将介绍几种特定格式数据集文件的加载方式。

MindRecord数据格式

MindRecord是MindSpore定义的一种数据格式,使用MindRecord能够获得更好的性能提升。

阅读数据格式转换章节,了解如何将数据集转化为MindSpore数据格式。

执行本例之前需下载对应的测试数据test_mindrecord.zip并解压到指定位置,执行如下命令:

[ ]:
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_mindrecord.zip", "./datasets/mindspore_dataset_loading/")

下载的数据集文件的目录结构如下:

./datasets/mindspore_dataset_loading/
├── test.mindrecord
└── test.mindrecord.db

下面的样例通过MindDataset接口加载MindRecord文件,并展示已加载数据的标签。

[4]:
import mindspore.dataset as ds

DATA_FILE = ["./datasets/mindspore_dataset_loading/test.mindrecord"]
mindrecord_dataset = ds.MindDataset(DATA_FILE)

for data in mindrecord_dataset.create_dict_iterator(output_numpy=True):
    print(data.keys())
dict_keys(['chinese', 'english'])
dict_keys(['chinese', 'english'])
dict_keys(['chinese', 'english'])

TFRecord数据格式

TFRecord是TensorFlow定义的一种二进制数据文件格式。

下面的样例通过TFRecordDataset接口加载TFRecord文件,并介绍了两种不同的数据集格式设定方案。

下载tfrecord测试数据test_tftext.zip并解压到指定位置,执行如下命令:

[ ]:
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_tftext.zip", "./datasets/mindspore_dataset_loading/test_tfrecord/")

解压后数据集文件的目录结构如下:

./datasets/mindspore_dataset_loading/test_tfrecord/
└── test_tftext.tfrecord
  1. 传入数据集路径或TFRecord文件列表,本例使用test_tftext.tfrecord,创建TFRecordDataset对象。

[8]:
import mindspore.dataset as ds

DATA_FILE = "./datasets/mindspore_dataset_loading/test_tfrecord/test_tftext.tfrecord"
tfrecord_dataset = ds.TFRecordDataset(DATA_FILE)

for tf_data in tfrecord_dataset.create_dict_iterator():
    print(tf_data.keys())
dict_keys(['chinese', 'line', 'words'])
dict_keys(['chinese', 'line', 'words'])
dict_keys(['chinese', 'line', 'words'])
  1. 用户可以通过编写Schema文件或创建Schema对象,设定数据集格式及特征。

    • 编写Schema文件

      将数据集格式和特征按JSON格式写入Schema文件。

      • columns:列信息字段,需要根据数据集的实际列名定义。上面的示例中,数据集有三组数据,其列均为chineselinewords

        然后在创建TFRecordDataset时将Schema文件路径传入。

[9]:
import os
import json

data_json = {
    "columns": {
        "chinese": {
            "type": "uint8",
            "rank": 1
            },
        "line": {
            "type": "int8",
            "rank": 1
            },
        "words": {
            "type": "uint8",
            "rank": 0
            }
        }
    }

if not os.path.exists("dataset_schema_path"):
    os.mkdir("dataset_schema_path")
SCHEMA_DIR = "dataset_schema_path/schema.json"
with open(SCHEMA_DIR, "w") as f:
    json.dump(data_json, f, indent=4)

tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=SCHEMA_DIR)

for tf_data in tfrecord_dataset.create_dict_iterator():
    print(tf_data.values())
dict_values([Tensor(shape=[57], dtype=UInt8, value= [230, 177, 159, 229, 183, 158, 229, 184, 130, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 229, 143, 130,
 229, 138, 160, 228, 186, 134, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 231, 154, 132, 233, 128, 154,
 232, 189, 166, 228, 187, 170, 229, 188, 143]), Tensor(shape=[22], dtype=Int8, value= [ 71, 111, 111, 100,  32, 108, 117,  99, 107,  32, 116, 111,  32, 101, 118, 101, 114, 121, 111, 110, 101,  46]), Tensor(shape=[32], dtype=UInt8, value= [229, 165, 179,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32, 101, 118, 101, 114, 121, 111, 110, 101,
  99,  32,  32,  32,  32,  32,  32,  32])])
dict_values([Tensor(shape=[12], dtype=UInt8, value= [231, 148, 183, 233, 187, 152, 229, 165, 179, 230, 179, 170]), Tensor(shape=[19], dtype=Int8, value= [ 66, 101,  32, 104,  97, 112, 112, 121,  32, 101, 118, 101, 114, 121,  32, 100,  97, 121,  46]), Tensor(shape=[20], dtype=UInt8, value= [ 66, 101,  32,  32,  32, 104,  97, 112, 112, 121, 100,  97, 121,  32,  32,  98,  32,  32,  32,  32])])
dict_values([Tensor(shape=[48], dtype=UInt8, value= [228, 187, 138, 229, 164, 169, 229, 164, 169, 230, 176, 148, 229, 164, 170, 229, 165, 189, 228, 186, 134, 230, 136, 145,
 228, 187, 172, 228, 184, 128, 232, 181, 183, 229, 142, 187, 229, 164, 150, 233, 157, 162, 231, 142, 169, 229, 144, 167
 ]), Tensor(shape=[20], dtype=Int8, value= [ 84, 104, 105, 115,  32, 105, 115,  32,  97,  32, 116, 101, 120, 116,  32, 102, 105, 108, 101,  46]), Tensor(shape=[16], dtype=UInt8, value= [ 84, 104, 105, 115, 116, 101, 120, 116, 102, 105, 108, 101,  97,  32,  32,  32])])
  • 创建Schema对象

    创建Schema对象,为其添加自定义字段,然后在创建数据集对象时传入。

[10]:
from mindspore import dtype as mstype
schema = ds.Schema()
schema.add_column('chinese', de_type=mstype.uint8)
schema.add_column('line', de_type=mstype.uint8)
tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=schema)

for tf_data in tfrecord_dataset.create_dict_iterator():
    print(tf_data)
{'chinese': Tensor(shape=[12], dtype=UInt8, value= [231, 148, 183, 233, 187, 152, 229, 165, 179, 230, 179, 170]), 'line': Tensor(shape=[19], dtype=UInt8, value= [ 66, 101,  32, 104,  97, 112, 112, 121,  32, 101, 118, 101, 114, 121,  32, 100,  97, 121,  46])}
{'chinese': Tensor(shape=[48], dtype=UInt8, value= [228, 187, 138, 229, 164, 169, 229, 164, 169, 230, 176, 148, 229, 164, 170, 229, 165, 189, 228, 186, 134, 230, 136, 145,
 228, 187, 172, 228, 184, 128, 232, 181, 183, 229, 142, 187, 229, 164, 150, 233, 157, 162, 231, 142, 169, 229, 144, 167
 ]), 'line': Tensor(shape=[20], dtype=UInt8, value= [ 84, 104, 105, 115,  32, 105, 115,  32,  97,  32, 116, 101, 120, 116,  32, 102, 105, 108, 101,  46])}
{'chinese': Tensor(shape=[57], dtype=UInt8, value= [230, 177, 159, 229, 183, 158, 229, 184, 130, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 229, 143, 130,
 229, 138, 160, 228, 186, 134, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 231, 154, 132, 233, 128, 154,
 232, 189, 166, 228, 187, 170, 229, 188, 143]), 'line': Tensor(shape=[22], dtype=UInt8, value= [ 71, 111, 111, 100,  32, 108, 117,  99, 107,  32, 116, 111,  32, 101, 118, 101, 114, 121, 111, 110, 101,  46])}

对比上述中的编写和创建步骤,可以看出:

步骤

chinese

line

words

编写

UInt8

Int8

UInt8

创建

UInt8

UInt8

示例编写步骤中的columns中数据由chinese(UInt8)、line(Int8)和words(UInt8)变为了示例创建步骤中的chinese(UInt8)、line(UInt8),通过Schema对象,设定数据集的数据类型和特征,使得列中的数据类型和特征相应改变了。

CSV数据格式

下面的样例通过CSVDataset加载CSV格式数据集文件,并展示了已加载数据的keys

下载测试数据test_csv.zip并解压到指定位置,执行如下命令:

[ ]:
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_csv.zip", "./datasets/mindspore_dataset_loading/test_csv/")

解压后数据集文件的目录结构如下:

./datasets/mindspore_dataset_loading/test_csv/
├── test1.csv
└── test2.csv

传入数据集路径或CSV文件列表,Text格式数据集文件的加载方式与CSV文件类似。

[15]:
import mindspore.dataset as ds

DATA_FILE = ["./datasets/mindspore_dataset_loading/test_csv/test1.csv", "./datasets/mindspore_dataset_loading/test_csv/test2.csv"]
csv_dataset = ds.CSVDataset(DATA_FILE)

for csv_data in csv_dataset.create_dict_iterator(output_numpy=True):
    print(csv_data.keys())
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])

自定义数据集加载

对于目前MindSpore不支持直接加载的数据集,可以通过GeneratorDataset接口实现自定义方式的加载,或者将其转换成MindRecord数据格式。GeneratorDataset接口接收一个可随机访问对象或可迭代对象,由该对象自定义数据读取的方式。

  1. __getitem__函数的随机访问对象相比可迭代对象,不需进行index递增等操作,逻辑更精简,易于使用。

  2. 分布式训练场景需对数据集进行切片操作,GeneratorDataset初始化时可以接收sampler参数, 也可接收`num_shards、shard_id来指定切片份数和取第几份,后面这种方式更易于使用。

下面分别展示这两种不同的自定义数据集加载方法,为了便于对比,生成的随机数据保持相同。

构造可随机访问对象

可随机访问的对象具有__getitem__函数,能够随机访问指定索引位置的数据。定义数据集类的时候重写__getitem__函数,即可使得该类的对象支持随机访问。

[16]:
import numpy as np
import mindspore.dataset as ds

class GetDatasetGenerator:
    def __init__(self):
        np.random.seed(58)
        self.__data = np.random.sample((5, 2))
        self.__label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return (self.__data[index], self.__label[index])

    def __len__(self):
        return len(self.__data)

dataset_generator = GetDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

for data in dataset.create_dict_iterator():
    print(data["data"], data["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]

构造可迭代对象

可迭代的对象具有__iter__函数和__next__函数,能够在每次调用时返回一条数据。定义数据集类的时候重写__iter__函数和__next__函数,通过__iter__函数返回迭代器,通过__next__函数定义数据集加载方式,即可使得该类的对象可迭代。

[17]:
import numpy as np
import mindspore.dataset as ds

class IterDatasetGenerator:
    def __init__(self):
        np.random.seed(58)
        self.__index = 0
        self.__data = np.random.sample((5, 2))
        self.__label = np.random.sample((5, 1))

    def __next__(self):
        if self.__index >= len(self.__data):
            raise StopIteration
        else:
            item = (self.__data[self.__index], self.__label[self.__index])
            self.__index += 1
            return item

    def __iter__(self):
        self.__index = 0
        return self

    def __len__(self):
        return len(self.__data)

dataset_generator = IterDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

for data in dataset.create_dict_iterator():
    print(data["data"], data["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]

需要注意的是,如果数据集本身并不复杂,直接定义一个可迭代的函数即可快速实现自定义加载功能。

[18]:
import numpy as np
import mindspore.dataset as ds

np.random.seed(58)
data = np.random.sample((5, 2))
label = np.random.sample((5, 1))

def GeneratorFunc():
    for i in range(5):
        yield (data[i], label[i])

dataset = ds.GeneratorDataset(GeneratorFunc, ["data", "label"])

for item in dataset.create_dict_iterator():
    print(item["data"], item["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]

NumPy数据格式

如果所有数据已经读入内存,可以直接使用NumpySlicesDataset类将其加载。

下面的样例分别介绍了通过NumpySlicesDataset加载arrays数据、 list数据和dict数据的方式。

  • 加载NumPy arrays数据

[11]:
import numpy as np
import mindspore.dataset as ds

np.random.seed(6)
features, labels = np.random.sample((4, 2)), np.random.sample((4, 1))

data = (features, labels)
dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)

for np_arr_data in dataset:
    print(np_arr_data[0], np_arr_data[1])
[0.89286015 0.33197981] [0.33540785]
[0.82122912 0.04169663] [0.62251943]
[0.10765668 0.59505206] [0.43814143]
[0.52981736 0.41880743] [0.73588211]
  • 加载Python list数据

[12]:
import mindspore.dataset as ds

data1 = [[1, 2], [3, 4]]

dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False)

for np_list_data in dataset:
    print(np_list_data[0])
[1 2]
[3 4]
  • 加载Python dict数据

[13]:
import mindspore.dataset as ds

data1 = {"a": [1, 2], "b": [3, 4]}

dataset = ds.NumpySlicesDataset(data1, column_names=["col1", "col2"], shuffle=False)

for np_dic_data in dataset.create_dict_iterator():
    print(np_dic_data)
{'col1': Tensor(shape=[], dtype=Int64, value= 1), 'col2': Tensor(shape=[], dtype=Int64, value= 3)}
{'col1': Tensor(shape=[], dtype=Int64, value= 2), 'col2': Tensor(shape=[], dtype=Int64, value= 4)}