Source code for mindspore.dataset.engine.serializer_deserializer

# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Functions to support dataset serialize and deserialize.
"""
import json
import os

from mindspore import log as logger
from . import datasets as de


[文档]def serialize(dataset, json_filepath=""): """ Serialize dataset pipeline into a JSON file. Note: Currently some Python objects are not supported to be serialized. Some examples of unsupported objects are callable user-defined Python functions (Python UDFs) and GeneratorDataset. For such unsupported objects, partially serialized JSON output is produced for which later deserialization, pipeline execution of the deserialized JSON file and/or re-serialization of the deserialized pipeline may result in an error. For example, serialization of callable user-defined Python functions (Python UDFs) is not supported, and a warning results on serialization. Any produced serialized JSON file output for this dataset pipeline is not valid to be deserialized. Args: dataset (Dataset): The starting node. json_filepath (str): The filepath where a serialized JSON file will be generated (default=""). Returns: Dict, The dictionary contains the serialized dataset graph. Raises: OSError: Cannot open a file Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, num_samples=100) >>> one_hot_encode = transforms.OneHot(10) # num_classes is input argument >>> dataset = dataset.map(operations=one_hot_encode, input_columns="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> # serialize it to JSON file >>> serialized_data = ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") """ return dataset.to_json(json_filepath)
[文档]def deserialize(input_dict=None, json_filepath=None): """ Construct dataset pipeline from a JSON file produced by dataset serialize function. Args: input_dict (dict): A Python dictionary containing a serialized dataset graph (default=None). json_filepath (str): A path to the JSON file (default=None). Returns: de.Dataset or None if error occurs. Raises: OSError: Can not open the JSON file. Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, num_samples=100) >>> one_hot_encode = transforms.OneHot(10) # num_classes is input argument >>> dataset = dataset.map(operations=one_hot_encode, input_columns="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> # Case 1: to/from JSON file >>> serialized_data = ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") >>> deserialized_dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json") >>> # Case 2: to/from Python dictionary >>> serialized_data = ds.serialize(dataset) >>> deserialized_dataset = ds.deserialize(input_dict=serialized_data) """ data = None if input_dict: data = de.DeserializedDataset(input_dict) if json_filepath: data = de.DeserializedDataset(json_filepath) return data
def expand_path(node_repr, key, val): """Convert relative to absolute path.""" if isinstance(val, list): node_repr[key] = [os.path.abspath(file) for file in val] else: node_repr[key] = os.path.abspath(val)
[文档]def show(dataset, indentation=2): """ Write the dataset pipeline graph to logger.info file. Args: dataset (Dataset): The starting node. indentation (int, optional): The indentation used by the JSON print. Do not indent if indentation is None (default=2). Examples: >>> dataset = ds.MnistDataset(mnist_dataset_dir, num_samples=100) >>> one_hot_encode = transforms.OneHot(10) >>> dataset = dataset.map(operations=one_hot_encode, input_columns="label") >>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> ds.show(dataset) """ pipeline = dataset.to_json() logger.info(json.dumps(pipeline, indent=indentation))
[文档]def compare(pipeline1, pipeline2): """ Compare if two dataset pipelines are the same. Args: pipeline1 (Dataset): a dataset pipeline. pipeline2 (Dataset): a dataset pipeline. Returns: Whether pipeline1 is equal to pipeline2. Examples: >>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, num_samples=100) >>> pipeline2 = ds.Cifar10Dataset(cifar10_dataset_dir, num_samples=100) >>> res = ds.compare(pipeline1, pipeline2) """ return pipeline1.to_json() == pipeline2.to_json()