Converting Dataset to MindRecord

Download NotebookView Source On Gitee

In MindSpore, the dataset used to train the network model can be converted into MindSpore-specific format data (MindSpore Record format), making it easier to save and load data. The goal is to normalize the user’s dataset and further enable the reading of the data through the MindDataset interface and use it during the training process.

conversion

In addition, the performance of MindSpore in some scenarios is optimized, and using the MindSpore Record data format can reduce disk IO and network IO overhead, which results in a better user experience.

The MindSpore data format has the following features:

  1. Unified storage and access of user data are implemented, simplifying training data loading.

  2. Data is aggregated for storage, which can be efficiently read, managed and moved.

  3. Data encoding and decoding are efficient and transparent to users.

  4. The partition size is flexibly controlled to implement distributed training.

Record File Structure

As shown in the following figure, a MindSpore Record file consists of a data file and an index file.

mindrecord

The data file contains file headers, scalar data pages, and block data pages, which are used to store the training data after user normalization, and a single MindSpore Record file is recommended to be less than 20G, and the user can store the large dataset as multiple MindSpore Record files.

The index file contains index information generated based on scalar data (such as image Label, image file name) for convenient retrieval and statistical dataset information.

The specific purposes of file headers, scalar data pages, and block data pages in data files are as follows:

  • File header: the meta information of MindSpore Record file, which is mainly used to store file header size, scalar data page size, block data page size, Schema information, index field, statistics, file segment information, the correspondence between scalar data and block data, etc.

  • Scalar data page: mainly used to store integer, string, floating-point data, such as the Label of the image, the file name of the image, the length and width of the image, that is, the information suitable for storing with scalars will be saved here.

  • Block data page: mainly used to store binary strings, NumPy arrays and other data, such as binary image files themselves, dictionaries converted into text, etc.

It should be noted that neither the data files nor the index files can support renaming operations at this time.

Converting Dataset to Record Format

The following mainly describes how to convert CV class data and NLP class data to MindSpore Record file format, and read MindSpore Record file through the MindDataset interface.

Converting CV class dataset

This example mainly uses a CV dataset containing 100 records and converts it to MindSpore Record format as an example, and describes how to convert a CV class dataset to the MindSpore Record file format and read it through the MindDataset interface.

First, you need to create a dataset of 100 pictures and save it, whose sample contains three fields: file_name (string), label (integer), and data (binary), and then use the MindDataset interface to read the MindSpore Record file.

  1. Generate 100 images and convert them to MindSpore Record file format.

[1]:
import os
from PIL import Image
from io import BytesIO

import mindspore.mindrecord as record


# The full path to the output MindSpore Record file
MINDRECORD_FILE = "test.mindrecord"

if os.path.exists(MINDRECORD_FILE):
    os.remove(MINDRECORD_FILE)
    os.remove(MINDRECORD_FILE + ".db")

# Define the contained fields
cv_schema = {"file_name": {"type": "string"},
             "label": {"type": "int32"},
             "data": {"type": "bytes"}}

# Declare the MindSpore Record file format
writer = record.FileWriter(file_name=MINDRECORD_FILE, shard_num=1)
writer.add_schema(cv_schema, "it is a cv dataset")
writer.add_index(["file_name", "label"])

# Build a dataset
data = []
for i in range(100):
    i += 1
    sample = {}
    white_io = BytesIO()
    Image.new('RGB', (i*10, i*10), (255, 255, 255)).save(white_io, 'JPEG')
    image_bytes = white_io.getvalue()
    sample['file_name'] = str(i) + ".jpg"
    sample['label'] = i
    sample['data'] = white_io.getvalue()

    data.append(sample)
    if i % 10 == 0:
        writer.write_raw_data(data)
        data = []

if data:
    writer.write_raw_data(data)

writer.commit()

As can be seen from the printed result MSRStatus.SUCCESS above, the dataset conversion was successful. In the examples that follow in this article, you can see this print result if the dataset is successfully converted.

  1. Read the MindSpore Record file format via the MindDataset interface.

[2]:
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# Read the MindSpore Record file format
data_set = ds.MindDataset(dataset_files=MINDRECORD_FILE)
decode_op = vision.Decode()
data_set = data_set.map(operations=decode_op, input_columns=["data"], num_parallel_workers=2)

# Count the number of samples
print("Got {} samples".format(data_set.get_dataset_size()))

Converting NLP class dataset

This example first creates a MindSpore Record file format with 100 records. Its sample contains eight fields, all of which are integer arrays, and then uses the MindDataset interface to read the MindSpore Record file.

For ease of presentation, the preprocessing process of converting text to lexicographic order is omitted here.

  1. Generate 100 images and convert them to MindSpore Record file format.

[ ]:
import os
import numpy as np
import mindspore.mindrecord as record

# The full path of the output MindSpore Record file
MINDRECORD_FILE = "test.mindrecord"

if os.path.exists(MINDRECORD_FILE):
    os.remove(MINDRECORD_FILE)
    os.remove(MINDRECORD_FILE + ".db")

# Defines the fields that the sample data contains
nlp_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
              "source_sos_mask": {"type": "int64", "shape": [-1]},
              "source_eos_ids": {"type": "int64", "shape": [-1]},
              "source_eos_mask": {"type": "int64", "shape": [-1]},
              "target_sos_ids": {"type": "int64", "shape": [-1]},
              "target_sos_mask": {"type": "int64", "shape": [-1]},
              "target_eos_ids": {"type": "int64", "shape": [-1]},
              "target_eos_mask": {"type": "int64", "shape": [-1]}}

# Declare the MindSpore Record file format
writer = record.FileWriter(file_name=MINDRECORD_FILE, shard_num=1)
writer.add_schema(nlp_schema, "Preprocessed nlp dataset.")

# Build a virtual dataset
data = []
for i in range(100):
    i += 1
    sample = {"source_sos_ids": np.array([i, i + 1, i + 2, i + 3, i + 4], dtype=np.int64),
              "source_sos_mask": np.array([i * 1, i * 2, i * 3, i * 4, i * 5, i * 6, i * 7], dtype=np.int64),
              "source_eos_ids": np.array([i + 5, i + 6, i + 7, i + 8, i + 9, i + 10], dtype=np.int64),
              "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
              "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
              "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
              "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
              "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}
    data.append(sample)

    if i % 10 == 0:
        writer.write_raw_data(data)
        data = []

if data:
    writer.write_raw_data(data)

writer.commit()
  1. Read the MindSpore Record format file through the MindDataset interface.

[4]:
import mindspore.dataset as ds

# Read MindSpore Record file format
data_set = ds.MindDataset(dataset_files=MINDRECORD_FILE, shuffle=False)

# Count the number of samples
print("Got {} samples".format(data_set.get_dataset_size()))

# Print the part of data
count = 0
for item in data_set.create_dict_iterator():
    print("source_sos_ids:", item["source_sos_ids"])
    count += 1
    if count == 10:
        break

Other datasets conversion

MindSpore provides a tool class for converting commonly used datasets, capable of converting commonly used datasets to the MindSpore Record file format.

For more detailed descriptions of dataset transformations, refer to API Documentation

Converting the CIFAR-10 dataset

Users can convert the CIFAR-10 raw data into a MindSpore Record through the Cifar10ToMR class and read it by using the MindDataset interface.

  1. Download the CIFAR-10 Dataset and extract it to the specified directory. The following example code downloads and extracts the dataset to the specified location.

[ ]:
from mindvision import dataset

# Declare the dataset download address and dataset storage path
dl_path = "./datasets"
dl_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-python.tar.gz"

# Download and unzip the dataset
dl = dataset.DownLoad()
dl.download_and_extract_archive(url=dl_url, download_path=dl_path)

The directory structure of the extracted dataset files is as follows:

./datasets/cifar-10-batches-py
├── batches.meta
├── data_batch_1
├── data_batch_2
├── data_batch_3
├── data_batch_4
├── data_batch_5
├── readme.html
└── test_batch
  1. Create a Cifar10ToMR object, call the transform interface, and convert the CIFAR-10 dataset to the MindSpore Record file format.

[ ]:
import os
from mindspore.mindrecord import Cifar10ToMR

ds_target_path = "./datasets/mindspore_dataset_conversion/"

os.system("rm -f {}*".format(ds_target_path))
os.system("mkdir -p {}".format(ds_target_path))

# CIFAR-10 dataset path
CIFAR10_DIR = "./datasets/cifar-10-batches-py"
# Output MindSpore Record file path
MINDRECORD_FILE = "./datasets/mindspore_dataset_conversion/cifar10.mindrecord"

cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform(['label'])
  1. Read the MindSpore Record file format through the MindDataset interface.

[ ]:
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# Read MindSpore Record file format
data_set = ds.MindDataset(dataset_files=MINDRECORD_FILE)
decode_op = vision.Decode()
data_set = data_set.map(operations=decode_op, input_columns=["data"], num_parallel_workers=2)

# Count the number of samples
print("Got {} samples".format(data_set.get_dataset_size()))

Converting the CSV dataset

This example first creates a CSV file containing 5 records, then converts the CSV file to the MindSpore Record file format through the CsvToMR tool class, and finally reads it through the MindDataset interface.

This example relies on pandas, a third-party support package, and can be installed using the command pip install pandas. As this document runs as a Notebook, you need to restart the kernel after completing the installation to execute subsequent code.

  1. Generate the CSV file, and convert to MindSpore Record.

[ ]:
import csv
import os
from mindspore import mindrecord as record

# The path to the CSV file
CSV_FILE = "test.csv"
# The path to the Output MindSpore Record file
MINDRECORD_FILE = "test.mindrecord"

if os.path.exists(MINDRECORD_FILE):
    os.remove(MINDRECORD_FILE)
    os.remove(MINDRECORD_FILE + ".db")

def generate_csv():
    """Generate csv format file data"""
    headers = ["id", "name", "math", "english"]
    rows = [(1, "Lily", 78.5, 90),
            (2, "Lucy", 99, 85.2),
            (3, "Mike", 65, 71),
            (4, "Tom", 95, 99),
            (5, "Jeff", 85, 78.5)]
    with open(CSV_FILE, 'w', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(rows)

# Generate csv format file data
generate_csv()

# Convert csv format file
csv_transformer = record.CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=1)
csv_transformer.transform()

assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
  1. Read MindSpore Record through MindDataset interface.

[ ]:
import mindspore.dataset as ds

data_set = ds.MindDataset(dataset_files=MINDRECORD_FILE)

# Count the number of samples
print("Got {} samples".format(data_set.get_dataset_size()))

Converting TFRecord Dataset

This example requires TensorFlow to be installed in advance, and currently only tensorFlow 1.13.0-rc1 and above are supported. As this document runs as a Notebook, you need to restart the kernel after completing the installation to execute subsequent code.

This example first creates a TFRecord file through TensorFlow, then converts the TFRecord file into a MindSpore Record format file through the TFRecordToMR tool class, and finally reads it through the MindDataset interface and uses the Decode function to decode the image_bytes field.

  1. Import the related module.

[10]:
import collections
from io import BytesIO
import os
import mindspore.dataset as ds
import mindspore.mindrecord as record
import mindspore.dataset.vision as vision
from PIL import Image
import tensorflow as tf

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  1. Generate the TFRecord file.

[ ]:
# The Path to the TFRecord file
TFRECORD_FILE = "test.tfrecord"
# The path to the Output MindSpore Record file
MINDRECORD_FILE = "test.mindrecord"

def generate_tfrecord():
    def create_int_feature(values):
        if isinstance(values, list):
            feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
        else:
            feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))
        return feature

    def create_float_feature(values):
        if isinstance(values, list):
            feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
        else:
            feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values]))
        return feature

    def create_bytes_feature(values):
        if isinstance(values, bytes):
            white_io = BytesIO()
            Image.new('RGB', (10, 10), (255, 255, 255)).save(white_io, 'JPEG')
            image_bytes = white_io.getvalue()
            feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
        else:
            feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
        return feature

    writer = tf.io.TFRecordWriter(TFRECORD_FILE)

    example_count = 0
    for i in range(10):
        # Create randomly Tensorflow sample data
        file_name = "000" + str(i) + ".jpg"
        image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
        int64_scalar = i
        float_scalar = float(i)
        int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890]
        float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2),
                      float(i+4.4), float(i+123456.9), float(i+98765432.1)]

        # Save the data in the TFRecord file format
        features = collections.OrderedDict()
        features["file_name"] = create_bytes_feature(file_name)
        features["image_bytes"] = create_bytes_feature(image_bytes)
        features["int64_scalar"] = create_int_feature(int64_scalar)
        features["float_scalar"] = create_float_feature(float_scalar)
        features["int64_list"] = create_int_feature(int64_list)
        features["float_list"] = create_float_feature(float_list)

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
        example_count += 1

    writer.close()
    print("Write {} rows in tfrecord.".format(example_count))

generate_tfrecord()
  1. Convert TFRecord to MindSpore Record.

[ ]:
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
                "image_bytes": tf.io.FixedLenFeature([], tf.string),
                "int64_scalar": tf.io.FixedLenFeature([], tf.int64),
                "float_scalar": tf.io.FixedLenFeature([], tf.float32),
                "int64_list": tf.io.FixedLenFeature([6], tf.int64),
                "float_list": tf.io.FixedLenFeature([7], tf.float32),
                }

if os.path.exists(MINDRECORD_FILE):
    os.remove(MINDRECORD_FILE)
    os.remove(MINDRECORD_FILE + ".db")

tfrecord_transformer = record.TFRecordToMR(TFRECORD_FILE, MINDRECORD_FILE, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()

assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
  1. Read MindSpore Record through MindDataset interface.

[ ]:
data_set = ds.MindDataset(dataset_files=MINDRECORD_FILE)
decode_op = vision.Decode()
data_set = data_set.map(operations=decode_op, input_columns=["image_bytes"], num_parallel_workers=2)

# Count the number of samples
print("Got {} samples".format(data_set.get_dataset_size()))