mindflow.data.mind_dataset 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#pylint: disable=W0223
#pylint: disable=W0221
#pylint: disable=W0212
"""
Combine pde/ic/bc datasets together
"""
from __future__ import absolute_import

import mindspore.dataset as ds
from mindspore import log as logger

from .data_base import Data, CONSTRAINT_TYPES
from ..utils.check_func import check_param_type, check_param_value, check_dict_type, check_dict_type_value


[文档]class MindDataset(Data): """ Create dataset from MindRecord-type data. Args: dataset_files (Union[str, list[str]]): If `dataset_file` is a str, it represents for a file name of one component of a mindrecord source, other files with identical source in the same path will be found and loaded automatically. If `dataset_file` is a list, it represents for a list of dataset files to be read directly. dataset_name (str, optional): name of dataset. Default: ``"dataset_name"``. constraint_type (str, optional): constraint type of the specified dataset to get it's corresponding loss function. Default: ``"Label"``. Other supported types can be found in `mindflow.data.Dataset`. shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch If `shuffle` is ``False``, no shuffling will be performed. If `shuffle` is ``True``, performs global shuffle. Default: ``True``. Otherwise, there are two levels of shuffling: - ``Shuffle.GLOBAL``: Shuffle both the files and sample. - ``Shuffle.FILES``: Shuffle files only. num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None``. When this argument is specified, 'num_samples' reflects the maximum sample number of per shard. shard_id (int, optional): The shard ID within `num_shards`. Default: ``None``. This argument can only be specified when `num_shards` is also specified. sampler (Sampler, optional): Object used to choose samples from the dataset. Default: ``None``, sampler is exclusive with shuffle and block_reader. Support list: ``SubsetRandomSampler``, ``PkSampler``, ``RandomSampler``, ``SequentialSampler``, ``DistributedSampler``. num_samples (int, optional): The number of samples to be included in the dataset. Default: ``None``, all samples. num_parallel_workers (int, optional): The number of readers. Default: ``None``. Raises: ValueError: If `dataset_files` are not valid or do not exist. TypeError: If `dataset_name` is not string. ValueError: If `constraint_type.lower()` not in [``"equation"``, ``"bc"``, ``"ic"``, ``"label"``, ``"function"``, ``"custom"``]. RuntimeError: If `num_shards` is specified but `shard_id` is ``None``. RuntimeError: If `shard_id` is specified but `num_shards` is ``None``. ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`). Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindflow.data import MindDataset >>> dataset_files = ["./data_dir"] # contains 1 or multiple MindRecord files >>> dataset = MindDataset(dataset_files=dataset_files) """ def __init__(self, dataset_files, dataset_name="dataset", constraint_type="Label", shuffle=True, num_shards=None, shard_id=None, sampler=None, num_samples=None, num_parallel_workers=None): super(MindDataset, self).__init__() check_param_type(dataset_name, "dataset_name", data_type=str) check_param_type(constraint_type, "constraint_type", data_type=str) check_param_value(constraint_type.lower(), "constraint_type", CONSTRAINT_TYPES) self.mind_dataset = ds.MindDataset(dataset_files, num_parallel_workers=num_parallel_workers, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, sampler=sampler, num_samples=num_samples) self.columns_list = self.mind_dataset.get_col_names() logger.info("Get MindRecord dataset with columns: {}".format(self.columns_list)) self.dataset_columns_map = {dataset_name: self.columns_list} for i in range(len(self.columns_list)): self.column_index_map[self.columns_list[i]] = i self.dataset_constraint_map = {dataset_name: constraint_type} self.dataset_name = [dataset_name]
[文档] def split_dataset(self, dataset_dict, constraint_dict=None): """split the original dataset in order to set difference loss functions. Args: dataset_dict (dict): dictionary of each sub-dataset, the key is the labeled name while the value refers to the specified columns contained in the sub-dataset. constraint_dict (Union[None, str, dict]): The constraint type of specified dataset. If ``None``, "Label" will be set for all. If is string, all will be set to the same one. If is dict, the subdataset and it's constraint type is specified by the pair (key, value). Default: ``None``. Examples: >>> dataset.split_dataset({"Equation" : "inner_points", "BC" : "bc_points"}) """ check_dict_type_value(dataset_dict, "sub-dataset dict", key_type=str, value_type=str, value_value=self.columns_list) if constraint_dict: check_dict_type(constraint_dict, "sub-constraint dict", key_type=str, value_type=str) for key, value in constraint_dict.items(): if value.lower() not in CONSTRAINT_TYPES: raise ValueError("constraint type for dataset {} should be in {}, but got {}".format( key, CONSTRAINT_TYPES, value)) if dataset_dict.keys() != constraint_dict.keys(): raise ValueError("The sub-dataset name should be consistent, but got dataset_dict: {}," "while constraint_dict: {}".format(dataset_dict.keys(), constraint_dict.keys())) self.dataset_columns_map = dataset_dict if constraint_dict: self.dataset_constraint_map = constraint_dict else: self.dataset_columns_map.clear() for sub_data in self.dataset_columns_map.keys(): self.dataset_constraint_map[sub_data] = "Label" self.dataset_name = list(dataset_dict.keys())
[文档] def set_constraint_type(self, constraint_type="Equation"): """set constraint type of dataset Args: constraint_type (Union[str, dict]): The constraint type of specified dataset. If is string, the constraint type of all subdataset will be set to the same one. If is dict, the subdataset and it's constraint type is specified by the pair (key, value). Examples: >>> dataset.set_constraint_type("Equation") """ check_param_type(constraint_type, "constraint_type", data_type=(str, dict)) if isinstance(constraint_type, str): check_param_value(constraint_type.lower(), "constraint_type", CONSTRAINT_TYPES) for dataset in self.dataset_name: self.dataset_constraint_map[dataset] = constraint_type else: for dataset in self.dataset_name: if dataset not in constraint_type: raise ValueError("constraint type of dataset {} was not defined in constraint_type {}".format( dataset, constraint_type)) self.dataset_columns_map[dataset] = constraint_type[dataset]
[文档] def create_dataset(self, batch_size=1, preprocess_fn=None, updated_columns_list=None, drop_remainder=True, prebatched_data=False, num_parallel_workers=1, python_multiprocessing=False): """ create the final mindspore type dataset. Args: batch_size (int, optional): An int number of rows each batch is created with. Default: ``1``. preprocess_fn (Union[list[TensorOp], list[functions]], optional): List of operations to be applied on the dataset. Operations are applied in the order they appear in this list. Default: ``None``. updated_columns_list (list, optional): List of columns to be applied on the dataset. Default: ``None``. drop_remainder (bool, optional): Determines whether or not to drop the last block whose data row number is less than batch size. If ``True``, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node. Default: ``True``. prebatched_data (bool, optional): Generate pre-batched data before data preprocessing. Default: ``False``. num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel. Default: 1. python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multi-processing. This option could be beneficial if the function is computational heavy. Default: ``False``. Returns: BatchDataset, dataset batched. Examples: >>> data = dataset.create_dataset() """ check_param_type(prebatched_data, "prebatched_data", data_type=bool) check_param_type(drop_remainder, "drop_remainder", data_type=bool) check_param_type(batch_size, "batch_size", data_type=int, exclude_type=bool) dataset = self.mind_dataset if prebatched_data: dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_parallel_workers) if preprocess_fn: dataset = dataset.map(operations=preprocess_fn, input_columns=self.columns_list, output_columns=updated_columns_list, column_order=updated_columns_list, num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing) if updated_columns_list: self.columns_list = updated_columns_list dataset_name = "data" self.dataset_columns_map = {dataset_name: self.columns_list} for i in range(len(self.columns_list)): self.column_index_map[self.columns_list[i]] = i if not prebatched_data: dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_parallel_workers) logger.info("Final dataset size: {}".format(dataset.get_dataset_size())) logger.info("Dataset columns map after preprocess: {}".format(self.dataset_columns_map)) logger.info("Dataset column index after preprocess: {}".format(self.column_index_map)) logger.info("Dataset constraints after preprocess: {}".format(self.dataset_constraint_map)) return dataset
[文档] def get_columns_list(self): """get columns list Returns: list[str]. column names list of the final unified dataset. Examples: >>> columns_list = dataset.get_columns_list() """ return self.columns_list
def __getitem__(self, index): return list(self.mind_dataset.create_tuple_iterator())[index] def __len__(self): return self.mind_dataset.get_dataset_size()