{ "cells": [ { "cell_type": "markdown", "source": [ "# MindSpore数据格式转换\n", "\n", "`Ascend` `GPU` `CPU` `数据准备`\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX2RhdGFzZXRfY29udmVyc2lvbi5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_dataset_conversion.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_dataset_conversion.py) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.5/docs/mindspore/programming_guide/source_zh_cn/dataset_conversion.ipynb)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 概述\n", "\n", "用户可以将非标准的数据集和常用的数据集转换为MindSpore数据格式,即MindRecord,从而方便地加载到MindSpore中进行训练。同时,MindSpore在部分场景做了性能优化,使用MindRecord可以获得更好的性能。\n", "\n", "## 非标准数据集转换MindRecord\n", "\n", "下面主要介绍如何将CV类数据和NLP类数据转换为MindRecord,并通过`MindDataset`实现MindRecord文件的读取。\n", "\n", "### 转换CV类数据集\n", "\n", "本示例主要介绍用户如何将自己的CV类数据集转换成MindRecord,并使用`MindDataset`读取。\n", "\n", "本示例首先创建一个包含100条记录的MindRecord文件,其样本包含`file_name`(字符串)、\n", "`label`(整型)、 `data`(二进制)三个字段,然后使用`MindDataset`读取该MindRecord文件。\n", "\n", "1. 导入相关模块。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "from io import BytesIO\n", "import os\n", "import mindspore.dataset as ds\n", "from mindspore.mindrecord import FileWriter\n", "import mindspore.dataset.vision.c_transforms as vision\n", "from PIL import Image" ], "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:33:34.444561Z", "start_time": "2021-02-22T10:33:34.441434Z" } } }, { "cell_type": "markdown", "source": [ "2. 生成100张图像,并转换成MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "MINDRECORD_FILE = \"test.mindrecord\"\n", "\n", "if os.path.exists(MINDRECORD_FILE):\n", " os.remove(MINDRECORD_FILE)\n", " os.remove(MINDRECORD_FILE + \".db\")\n", "\n", "writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)\n", "\n", "cv_schema = {\"file_name\": {\"type\": \"string\"}, \"label\": {\"type\": \"int32\"}, \"data\": {\"type\": \"bytes\"}}\n", "writer.add_schema(cv_schema, \"it is a cv dataset\")\n", "\n", "writer.add_index([\"file_name\", \"label\"])\n", "\n", "data = []\n", "for i in range(100):\n", " i += 1\n", "\n", " sample = {}\n", " white_io = BytesIO()\n", " Image.new('RGB', (i*10, i*10), (255, 255, 255)).save(white_io, 'JPEG')\n", " image_bytes = white_io.getvalue()\n", " sample['file_name'] = str(i) + \".jpg\"\n", " sample['label'] = i\n", " sample['data'] = white_io.getvalue()\n", "\n", " data.append(sample)\n", " if i % 10 == 0:\n", " writer.write_raw_data(data)\n", " data = []\n", "\n", "if data:\n", " writer.write_raw_data(data)\n", "\n", "writer.commit()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "MSRStatus.SUCCESS" ] }, "metadata": {}, "execution_count": 2 } ], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:34:03.889515Z", "start_time": "2021-02-22T10:34:02.950207Z" } } }, { "cell_type": "markdown", "source": [ "**参数说明:**\n", "\n", "- `MINDRECORD_FILE`:输出的MindRecord文件路径。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "3. 通过`MindDataset`读取MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "decode_op = vision.Decode()\n", "data_set = data_set.map(operations=decode_op, input_columns=[\"data\"], num_parallel_workers=2)\n", "count = 0\n", "for item in data_set.create_dict_iterator(output_numpy=True):\n", " count += 1\n", "print(\"Got {} samples\".format(count))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Got 100 samples\n" ] } ], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:34:07.729322Z", "start_time": "2021-02-22T10:34:07.575711Z" } } }, { "cell_type": "markdown", "source": [ "\n", "### 转换NLP类数据集\n", "\n", "本示例主要介绍用户如何将自己的NLP类数据集转换成MindRecord,并使用`MindDataset`读取。为了方便展示,此处略去了将文本转换成字典序的预处理过程。\n", "\n", "本示例首先创建一个包含100条记录的MindRecord文件,其样本包含八个字段,均为整型数组,然后使用`MindDataset`读取该MindRecord文件。\n", "\n", "1. 导入相关模块。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "import os\n", "import numpy as np\n", "import mindspore.dataset as ds\n", "from mindspore.mindrecord import FileWriter" ], "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:34:21.606147Z", "start_time": "2021-02-22T10:34:21.603094Z" } } }, { "cell_type": "markdown", "source": [ "2. 生成100条文本数据,并转换成MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "MINDRECORD_FILE = \"test.mindrecord\"\n", "\n", "if os.path.exists(MINDRECORD_FILE):\n", " os.remove(MINDRECORD_FILE)\n", " os.remove(MINDRECORD_FILE + \".db\")\n", "\n", "writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)\n", "\n", "nlp_schema = {\"source_sos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"source_sos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"source_eos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"source_eos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"target_sos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"target_sos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"target_eos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n", " \"target_eos_mask\": {\"type\": \"int64\", \"shape\": [-1]}}\n", "writer.add_schema(nlp_schema, \"it is a preprocessed nlp dataset\")\n", "\n", "data = []\n", "for i in range(100):\n", " i += 1\n", "\n", " sample = {\"source_sos_ids\": np.array([i, i + 1, i + 2, i + 3, i + 4], dtype=np.int64),\n", " \"source_sos_mask\": np.array([i * 1, i * 2, i * 3, i * 4, i * 5, i * 6, i * 7], dtype=np.int64),\n", " \"source_eos_ids\": np.array([i + 5, i + 6, i + 7, i + 8, i + 9, i + 10], dtype=np.int64),\n", " \"source_eos_mask\": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),\n", " \"target_sos_ids\": np.array([28, 29, 30, 31, 32], dtype=np.int64),\n", " \"target_sos_mask\": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),\n", " \"target_eos_ids\": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),\n", " \"target_eos_mask\": np.array([48, 49, 50, 51], dtype=np.int64)}\n", "\n", " data.append(sample)\n", " if i % 10 == 0:\n", " writer.write_raw_data(data)\n", " data = []\n", "\n", "if data:\n", " writer.write_raw_data(data)\n", "\n", "writer.commit()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "MSRStatus.SUCCESS" ] }, "metadata": {}, "execution_count": 5 } ], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:34:23.883130Z", "start_time": "2021-02-22T10:34:23.660213Z" } } }, { "cell_type": "markdown", "source": [ "**参数说明:**\n", "\n", "- `MINDRECORD_FILE`:输出的MindRecord文件路径。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "3. 通过`MindDataset`读取MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "count = 0\n", "for item in data_set.create_dict_iterator():\n", " count += 1\n", "print(\"Got {} samples\".format(count))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Got 100 samples\n" ] } ], "metadata": { "ExecuteTime": { "end_time": "2021-02-22T10:34:27.133717Z", "start_time": "2021-02-22T10:34:27.083785Z" } } }, { "cell_type": "markdown", "source": [ "## 常用数据集转换MindRecord\n", "\n", "MindSpore提供转换常用数据集的工具类,能够将常用的数据集转换为MindRecord。部分常用数据集及其对应的工具类列表如下。\n", "\n", "| 数据集 | 格式转换工具类 |\n", "| :------- | :----------- |\n", "| CIFAR-10 | Cifar10ToMR |\n", "| ImageNet | ImageNetToMR |\n", "| TFRecord | TFRecordToMR |\n", "| CSV File | CsvToMR |\n", "\n", "更多数据集转换的详细说明可参见[API文档](https://www.mindspore.cn/docs/api/zh-CN/r1.5/api_python/mindspore.mindrecord.html)。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### 转换CIFAR-10数据集\n", "\n", "用户可以通过`Cifar10ToMR`类,将CIFAR-10原始数据转换为MindRecord,并使用`MindDataset`读取。\n", "\n", "1. 下载[CIFAR-10数据集](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)并解压到指定目录,在Jupyter Notebook中执行如下命令:" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-python.tar.gz --no-check-certificate\n", "!mkdir -p datasets\n", "!tar -xzf cifar-10-python.tar.gz -C datasets" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "解压后数据集文件的目录结构如下:\n", "\n", "```text\n", "./datasets/cifar-10-batches-py\n", "├── batches.meta\n", "├── data_batch_1\n", "├── data_batch_2\n", "├── data_batch_3\n", "├── data_batch_4\n", "├── data_batch_5\n", "├── readme.html\n", "└── test_batch\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "2. 导入相关模块" ], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "import os\n", "import mindspore.dataset as ds\n", "import mindspore.dataset.vision.c_transforms as vision\n", "from mindspore.mindrecord import Cifar10ToMR" ], "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-02-18T02:27:04.856761Z", "start_time": "2021-02-18T02:26:46.536793Z" } } }, { "cell_type": "markdown", "source": [ "3. 创建`Cifar10ToMR`对象,调用`transform`接口,将CIFAR-10数据集转换为MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "ds_target_path = \"./datasets/mindspore_dataset_conversion/\"\n", "# clean old run files \n", "os.system(\"rm -f {}*\".format(ds_target_path))\n", "os.system(\"mkdir -p {}\".format(ds_target_path))\n", "\n", "CIFAR10_DIR = \"./datasets/cifar-10-batches-py\"\n", "MINDRECORD_FILE = \"./datasets/mindspore_dataset_conversion/cifar10.mindrecord\"\n", "cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)\n", "cifar10_transformer.transform(['label'])" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "MSRStatus.SUCCESS" ] }, "metadata": {}, "execution_count": 9 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "**参数说明:**\n", "\n", "- `CIFAR10_DIR`:CIFAR-10数据集路径。\n", "\n", "- `MINDRECORD_FILE`:输出的MindRecord文件路径。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "4. 通过`MindDataset`读取MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "decode_op = vision.Decode()\n", "data_set = data_set.map(operations=decode_op, input_columns=[\"data\"], num_parallel_workers=2)\n", "count = 0\n", "for item in data_set.create_dict_iterator(output_numpy=True):\n", " count += 1\n", "print(\"Got {} samples\".format(count))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Got 50000 samples\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### 转换ImageNet数据集\n", "\n", "用户可以通过`ImageNetToMR`类,将ImageNet原始数据(图片、标注)转换为MindRecord,并使用`MindDataset`读取。\n", "\n", "1. 下载[ImageNet数据集](http://image-net.org/download),将所有图片存放在`images/`文件夹,用一个映射文件`labels_map.txt`记录图片和标签的对应关系。映射文件包含2列,分别为各类别图片目录和标签ID,用空格隔开,映射文件示例如下:\n", "\n", "```text\n", "n01440760 0\n", "n01443537 1\n", "n01484850 2\n", "n01491361 3\n", "n01494475 4\n", "n01496331 5\n", "\n", "```\n", "\n", "文件目录结构如下所示:\n", "\n", "```text\n", "├─ labels_map.txt\n", "└─ images\n", " └─ ......\n", "```\n", "\n", "2. 导入相关模块。\n", "\n", "```python\n", "import mindspore.dataset as ds\n", "import mindspore.dataset.vision.c_transforms as vision\n", "from mindspore.mindrecord import ImageNetToMR\n", "```\n", "\n", "3. 创建ImageNetToMR对象,调用transform接口,将数据集转换为MindRecord。\n", "\n", "```python\n", "IMAGENET_MAP_FILE = \"./labels_map.txt\"\n", "IMAGENET_IMAGE_DIR = \"./images\"\n", "MINDRECORD_FILE = \"./imagenet.mindrecord\"\n", "imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, MINDRECORD_FILE, partition_number=1)\n", "imagenet_transformer.transform()\n", "```\n", "\n", "**参数说明:**\n", "\n", "- IMAGENET_MAP_FILE:ImageNet数据集标签映射文件的路径。\n", "\n", "- IMAGENET_IMAGE_DIR:包含ImageNet所有图片的文件夹路径。\n", "\n", "- MINDRECORD_FILE:输出的MindRecord文件路径。\n", "\n", "\n", "4. 通过MindDataset读取MindRecord。\n", "\n", "```python\n", "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "decode_op = vision.Decode()\n", "data_set = data_set.map(operations=decode_op, input_columns=[\"image\"], num_parallel_workers=2)\n", "count = 0\n", "for item in data_set.create_dict_iterator(output_numpy=True):\n", " count += 1\n", "print(\"Got {} samples\".format(count))\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### 转换CSV数据集\n", "\n", "本示例首先创建一个包含5条记录的CSV文件,然后通过`CsvToMR`工具类将CSV文件转换为MindRecord,并最终通过`MindDataset`将其读取出来。\n", "\n", "1. 导入相关模块。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "import csv\n", "import os\n", "import mindspore.dataset as ds\n", "from mindspore.mindrecord import CsvToMR" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "2. 生成CSV文件,并转换成MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "CSV_FILE = \"test.csv\"\n", "MINDRECORD_FILE = \"test.mindrecord\"\n", "\n", "def generate_csv():\n", " headers = [\"id\", \"name\", \"math\", \"english\"]\n", " rows = [(1, \"Lily\", 78.5, 90),\n", " (2, \"Lucy\", 99, 85.2),\n", " (3, \"Mike\", 65, 71),\n", " (4, \"Tom\", 95, 99),\n", " (5, \"Jeff\", 85, 78.5)]\n", " with open(CSV_FILE, 'w', encoding='utf-8') as f:\n", " writer = csv.writer(f)\n", " writer.writerow(headers)\n", " writer.writerows(rows)\n", "\n", "generate_csv()\n", "\n", "if os.path.exists(MINDRECORD_FILE):\n", " os.remove(MINDRECORD_FILE)\n", " os.remove(MINDRECORD_FILE + \".db\")\n", "\n", "csv_transformer = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=1)\n", "\n", "csv_transformer.transform()\n", "\n", "assert os.path.exists(MINDRECORD_FILE)\n", "assert os.path.exists(MINDRECORD_FILE + \".db\")" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "**参数说明:**\n", "\n", "- `CSV_FILE`:CSV文件的路径。\n", "\n", "- `MINDRECORD_FILE`:输出的MindRecord文件路径。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "3. 通过`MindDataset`读取MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "count = 0\n", "for item in data_set.create_dict_iterator(output_numpy=True):\n", " count += 1\n", "print(\"Got {} samples\".format(count))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Got 5 samples\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### 转换TFRecord数据集\n", "\n", "> 目前支持TensorFlow 1.13.0-rc1及以上版本。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "此部分示例需提前安装TensorFlow,如果未安装,执行下面的命令进行安装。如本文档以Notebook运行时,完成安装后,需要重启kernel后,执行后续代码。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 14, "source": [ "os.system('pip install tensorflow') if os.system('python -c \"import tensorflow\"') else print(\"TensorFlow installed\")" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "本示例首先通过TensorFlow创建一个TFRecord文件,然后通过`TFRecordToMR`工具类将TFRecord文件转换为MindRecord,最后通过`MindDataset`将其读取出来,并使用`Decode`算子对`image_bytes`字段进行解码。\n", "\n", "1. 导入相关模块。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 15, "source": [ "import collections\n", "from io import BytesIO\n", "import os\n", "import mindspore.dataset as ds\n", "from mindspore.mindrecord import TFRecordToMR\n", "import mindspore.dataset.vision.c_transforms as vision\n", "from PIL import Image\n", "import tensorflow as tf" ], "outputs": [], "metadata": { "scrolled": true } }, { "cell_type": "markdown", "source": [ "2. 生成TFRecord文件。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 16, "source": [ "TFRECORD_FILE = \"test.tfrecord\"\n", "MINDRECORD_FILE = \"test.mindrecord\"\n", "\n", "def generate_tfrecord():\n", " def create_int_feature(values):\n", " if isinstance(values, list):\n", " feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))\n", " else:\n", " feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))\n", " return feature\n", "\n", " def create_float_feature(values):\n", " if isinstance(values, list):\n", " feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))\n", " else:\n", " feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values]))\n", " return feature\n", "\n", " def create_bytes_feature(values):\n", " if isinstance(values, bytes):\n", " white_io = BytesIO()\n", " Image.new('RGB', (10, 10), (255, 255, 255)).save(white_io, 'JPEG')\n", " image_bytes = white_io.getvalue()\n", " feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))\n", " else:\n", " feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))\n", " return feature\n", "\n", " writer = tf.io.TFRecordWriter(TFRECORD_FILE)\n", "\n", " example_count = 0\n", " for i in range(10):\n", " file_name = \"000\" + str(i) + \".jpg\"\n", " image_bytes = bytes(str(\"aaaabbbbcccc\" + str(i)), encoding=\"utf-8\")\n", " int64_scalar = i\n", " float_scalar = float(i)\n", " int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890]\n", " float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2),\n", " float(i+4.4), float(i+123456.9), float(i+98765432.1)]\n", "\n", " features = collections.OrderedDict()\n", " features[\"file_name\"] = create_bytes_feature(file_name)\n", " features[\"image_bytes\"] = create_bytes_feature(image_bytes)\n", " features[\"int64_scalar\"] = create_int_feature(int64_scalar)\n", " features[\"float_scalar\"] = create_float_feature(float_scalar)\n", " features[\"int64_list\"] = create_int_feature(int64_list)\n", " features[\"float_list\"] = create_float_feature(float_list)\n", "\n", " tf_example = tf.train.Example(features=tf.train.Features(feature=features))\n", " writer.write(tf_example.SerializeToString())\n", " example_count += 1\n", " writer.close()\n", " print(\"Write {} rows in tfrecord.\".format(example_count))\n", "\n", "generate_tfrecord()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Write 10 rows in tfrecord.\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "**参数说明:**\n", "\n", "- `TFRECORD_FILE`:TFRecord文件的路径。\n", "\n", "- `MINDRECORD_FILE`:输出的MindRecord文件路径。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "3. 将TFRecord转换成MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 17, "source": [ "feature_dict = {\"file_name\": tf.io.FixedLenFeature([], tf.string),\n", " \"image_bytes\": tf.io.FixedLenFeature([], tf.string),\n", " \"int64_scalar\": tf.io.FixedLenFeature([], tf.int64),\n", " \"float_scalar\": tf.io.FixedLenFeature([], tf.float32),\n", " \"int64_list\": tf.io.FixedLenFeature([6], tf.int64),\n", " \"float_list\": tf.io.FixedLenFeature([7], tf.float32),\n", " }\n", "\n", "if os.path.exists(MINDRECORD_FILE):\n", " os.remove(MINDRECORD_FILE)\n", " os.remove(MINDRECORD_FILE + \".db\")\n", "\n", "tfrecord_transformer = TFRecordToMR(TFRECORD_FILE, MINDRECORD_FILE, feature_dict, [\"image_bytes\"])\n", "tfrecord_transformer.transform()\n", "\n", "assert os.path.exists(MINDRECORD_FILE)\n", "assert os.path.exists(MINDRECORD_FILE + \".db\")" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "4. 通过`MindDataset`读取MindRecord。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 18, "source": [ "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n", "decode_op = vision.Decode()\n", "data_set = data_set.map(operations=decode_op, input_columns=[\"image_bytes\"], num_parallel_workers=2)\n", "count = 0\n", "for item in data_set.create_dict_iterator(output_numpy=True):\n", " count += 1\n", "print(\"Got {} samples\".format(count))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Got 10 samples\n" ] } ], "metadata": {} } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }