Source code for mindspore.dataset.engine.obs.obs_mindrecord_dataset

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
The dataset module provide the internal Dataset API which load mindrecord files from OBS.
"""


import math
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.managers import SyncManager
import os
import queue
import random
import sys
import time

from mindspore import log as logger
from ..datasets import Shuffle
from ...core.config import set_seed


class _Manager(SyncManager):
    pass


def _get_manager():
    """ PriorityQueue that cross threads."""

    _Manager.register("PriorityQueue", queue.PriorityQueue)
    m = _Manager()
    m.start()
    return m


def _init_cache_and_working_queue(cache, q, shard_files, local_path):
    """
    Initialize the downloading queue and local cache which store the status of local dataset file.
    """

    from .util import init_cache_and_queue

    idx = 0
    for shard_file, _, _, is_full_dataset in shard_files:
        dataset_file = os.path.basename(shard_file)
        path = os.path.join(local_path, dataset_file)
        init_cache_and_queue(cache, q, path, shard_file,
                             idx, is_full_dataset, lock_file=dataset_file)
        idx += 1
    return cache, q


def _remove_unused_dataset(local_path, num_shards, shard_id, epoch_num):
    """ Rank(rank_id mod 8 equal to 0) remove all dataset files. """

    from .config_loader import config

    if not num_shards:
        return
    # if num_shards less than or equal to 8, assume that there is only one node(server) and
    # the dataset does not need to be removed.
    if num_shards <= 8 or shard_id % 8 != 0:
        return

    sync_dir = '/cache/sync_data/' + str(epoch_num)
    while True:
        if os.path.exists(sync_dir) and len(os.listdir(sync_dir)) >= min(num_shards - 1, 7):
            break
        time.sleep(config.WARMINGUP_TIME)
        logger.info("[{} FUNCTION] Shard: {} wait for other rank ready in epoch: {}.".format(
            sys._getframe().f_code.co_name, shard_id, epoch_num))  # pylint: disable=W0212

    files = os.listdir(local_path)
    for dataset_file in files:
        if dataset_file.endswith('.db'):
            continue
        dataset_path = os.path.join(local_path, dataset_file)
        os.remove(dataset_path)

    for ready_file in os.listdir(sync_dir):
        os.remove(os.path.join(sync_dir, ready_file))


def _wait_remove_datset(num_shards, shard_id, epoch_num):
    """ Rank(rank_id mod 8 not equal to 0) wait for removing dataset files. """

    from .config_loader import config

    if not num_shards:
        return
    if num_shards <= 8 or shard_id % 8 == 0:
        return

    sync_dir = '/cache/sync_data/' + str(epoch_num)

    if not os.path.exists(sync_dir):
        try:
            os.makedirs(sync_dir)
        except FileExistsError:
            pass

    sync_file = os.path.join(sync_dir, 'ready_' + str(shard_id))
    with open(sync_file, 'w') as f:
        f.write('ok')

    while True:
        if os.path.exists(sync_dir) and not os.listdir(sync_dir):
            break
        time.sleep(config.WARMINGUP_TIME)
        logger.info("[{} FUNCTION] Shard: {} wait for removing dataset files in epoch: {}.".format(
            sys._getframe().f_code.co_name, shard_id, epoch_num))  # pylint: disable=W0212


def _init_shard_files(dataset_files, shuffle, seed, num_shards, shard_id, shard_equal_rows,
                      size_per_shard, local_path, current_epoch):
    """ Calculate the dataset files required by each sharding and the corresponding index. """

    from .config_loader import config
    from .util import detect_all_meta_files, fetch_meta_files, make_dataset_tuple, make_shard_files, make_shard_samples

    shard_files = None
    if shuffle is False or shuffle == Shuffle.INFILE:
        pass
    else:
        set_seed(seed)
        random.shuffle(dataset_files)
    if num_shards:  # distributed training
        # As each sharding has the same number of samples, need to fetch all meta files.
        if shard_equal_rows:
            if size_per_shard is None:
                if shard_id % 8 == 0:
                    fetch_meta_files(dataset_files, local_path)
                else:
                    while detect_all_meta_files(dataset_files, local_path) is False:
                        time.sleep(config.WAIT_META_TIME)
            full_dataset_size, dataset_file_size_list = make_dataset_tuple(
                dataset_files, local_path)
            size_per_shard = math.ceil(full_dataset_size / num_shards)
            shard_files = make_shard_samples(
                dataset_file_size_list, size_per_shard, shard_id)
        else:
            shard_files = make_shard_files(dataset_files, num_shards, shard_id)
    else:
        shard_files = [(dataset_file, -1, -1, True)
                       for dataset_file in dataset_files]
    logger.info("[{} FUNCTION] Shard: {} expect dataset: {} in epoch: {}.".format(
        sys._getframe().f_code.co_name, shard_id, shard_files, current_epoch))  # pylint: disable=W0212
    return shard_files, size_per_shard


def _download_work(shard_id, current_idx, local_path, cache, q):
    """ daemon process in backend. """
    from .config_loader import config
    from .util import try_load_from_obs, get_used_disk_per

    while True:
        idx, dataset_file = q.get()
        used_disk = get_used_disk_per()
        while used_disk > float(config.DISK_THRESHOLD):
            logger.info("[{} FUNCTION] Used disk space is {}%, and the disk threshold is {}%.".format(
                sys._getframe().f_code.co_name, used_disk*100,  # pylint: disable=W0212
                float(config.DISK_THRESHOLD)*100))
            retry_cnt = 0
            has_deleted = _delete_candidate_datasets(
                current_idx.value, idx, cache, q, local_path)
            while not has_deleted:
                if retry_cnt > config.MAX_RETRY:
                    logger.warning("Delete operation retries times {} has exceeded threshold {}, "
                                   "please clear enough disk space.".format(retry_cnt, config.MAX_RETRY))
                has_deleted = _delete_candidate_datasets(
                    current_idx.value, idx, cache, q, local_path)
                retry_cnt += 1
                time.sleep(config.RETRY_DELTA_TIME)
            used_disk = get_used_disk_per()

        logger.info("[{} FUNCTION] Shard: {} try to download: {}.".format(
            sys._getframe().f_code.co_name, shard_id, dataset_file))  # pylint: disable=W0212
        # update cache
        remote_path = os.path.dirname(dataset_file)
        dataset_file = os.path.basename(dataset_file)
        _, is_shared = cache[dataset_file]
        try_load_from_obs(remote_path, dataset_file, local_path)
        cache[dataset_file] = (idx, is_shared)
        logger.info("[{} FUNCTION] Shard: {} finish to download: {}.".format(
            sys._getframe().f_code.co_name, shard_id, dataset_file))  # pylint: disable=W0212


def _delete_candidate_datasets(current_idx, queue_top_idx, cache, q, local_path):
    """
    1. Try to delete all the datasets which have been loaded during the epoch.
    2. Otherwise, try to delete a low priority dataset in the epoch.
    3. As soon as the low priority data is deleted, it is placed in the download queue.
    """

    used_datasets = []
    low_priority_dataset = ''
    max_idx = -1
    delete = False
    for k, v in cache.items():
        idx, is_shared = v
        if is_shared is False and idx >= 0:
            if idx > max_idx:
                max_idx = idx
                low_priority_dataset = k
            if idx < current_idx:
                used_datasets.append(k)
    for used_dataset in used_datasets:
        dataset_path = os.path.join(local_path, used_dataset)
        if not os.path.exists(dataset_path):
            continue
        # update cache
        idx, is_shared = cache[used_dataset]
        cache[used_dataset] = (-1, is_shared)
        os.remove(dataset_path)
        delete = True
        logger.info("[{} FUNCTION] Delete used dataset file: {} and update the cache.".format(
            sys._getframe().f_code.co_name, used_dataset))  # pylint: disable=W0212

    if delete:
        return True
    if max_idx <= current_idx or max_idx <= queue_top_idx:
        return False
    dataset_path = os.path.join(local_path, low_priority_dataset)
    if not os.path.exists(dataset_path):
        return False
    # update cache
    idx, is_shared = cache[low_priority_dataset]
    cache[low_priority_dataset] = (-1, is_shared)
    os.remove(dataset_path)
    q.put((idx, low_priority_dataset))
    logger.info("[{} FUNCTION] Delete low priority dataset file: {} and update the cache.".format(
        sys._getframe().f_code.co_name, low_priority_dataset))  # pylint: disable=W0212
    return True


def _sync_up_for_obs_mindrecord_dataset(rank_id, current_epoch):
    """ Upload the synchronization file to OBS. """

    from .config_loader import config
    from .util import file_upload_to_obs

    sync_info = "download_dataset"
    job_id = os.environ.get('BATCH_JOB_ID', 'unknown')
    ready_file_name = sync_info + '_ready_' + str(rank_id) + '.txt'
    ready_dir = os.path.join(job_id, str(current_epoch) + "/")

    file_upload_to_obs(config.SYNC_OBS_PATH, ready_dir, ready_file_name)
    logger.info("[{} FUNCTION] Current rank:{}'s sync file:{} is ready for epoch:{}.".format(
        sys._getframe().f_code.co_name, rank_id, os.path.join(ready_dir, ready_file_name),  # pylint: disable=W0212
        current_epoch))


[文档]def sync_wait_for_dataset(rank_id, rank_size, current_epoch): """ Wait util the dataset files required by all devices are downloaded. Note: It should be used together with `mindspore.dataset.OBSMindDataset` and be called before each epoch. Args: rank_id(int): Rank ID of the device. rank_size(int): Rank size. current_epoch(int): Number of current epochs. Examples: >>> # Create a synchronization callback >>> import mindspore as ms >>> from mindspore.dataset import sync_wait_for_dataset >>> >>> class SyncForDataset(ms.Callback): ... def __init__(self): ... super(SyncForDataset, self).__init__() ... def epoch_begin(self, run_context): ... cb_params = run_context.original_args() ... epoch_num = cb_params.cur_epoch_num ... sync_wait_for_dataset(rank_id, rank_size, epoch_num) """ from .config_loader import config from .util import obsClient, get_bucket_and_key bucket_name, object_key = get_bucket_and_key(config.SYNC_OBS_PATH) job_id = os.environ.get('BATCH_JOB_ID', 'unknown') ready_dir = os.path.join(object_key, job_id, str(current_epoch) + "/") success = False while True: if success: break try: # no guarantee that the dir is included. resp = obsClient.listObjects(bucket_name, prefix=ready_dir) if resp.status < 300: ready_num = 0 for content in resp.body.contents: if content.key.endswith(".txt"): ready_num += 1 if ready_num >= rank_size: success = True else: logger.warning("[{} FUNCTION] OBS SDK errorCode:{}, errMsg: {}.".format( sys._getframe(), resp.errorCode, resp.errorMessage)) # pylint: disable=W0212 except Exception: # pylint: disable=W0703 import traceback logger.error(traceback.format_exc()) time.sleep(config.RETRY_DELTA_TIME) logger.info("[{} FUNCTION] Waiting for sync dir:{} and current_rank:{}, total_rank:{}, " "ready_rank:{} in epoch:{}.".format(sys._getframe().f_code.co_name, # pylint: disable=W0212 ready_dir, rank_id, rank_size, ready_num, current_epoch)) logger.info("[{} FUNCTION] Succeed to sync dir:{} and begin epoch:{}.".format( sys._getframe().f_code.co_name, ready_dir, current_epoch)) # pylint: disable=W0212
def _sync_for_obs_mindrecord_dataset(worker, shard_files, cache, num_shards, shard_id, current_epoch): """ Synchronize all shardings. """ from .config_loader import config while True: if worker.ready(): worker.get() dataset, _, _, _ = shard_files[-1] current_dataset = os.path.basename(dataset) hit_cache = cache[current_dataset][0] if hit_cache >= 0: # hit cache logger.info("[{} FUNCTION] Current_rank:{} has download:{} for epoch:{}.".format( sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212 _sync_up_for_obs_mindrecord_dataset(shard_id, current_epoch) break time.sleep(config.WARMINGUP_TIME) logger.info("[{} FUNCTION] Current_rank:{} wait for downloading:{} in epoch:{}.".format( sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212 sync_wait_for_dataset(shard_id, num_shards, current_epoch) class MindRecordFromOBS: """ Internal class which load remote dataset files from OBS. """ def __init__(self, dataset_files, columns_list, shuffle, num_shards, shard_id, shard_equal_rows, local_path): self._dataset_files = dataset_files self._columns_list = columns_list self._num_shards = num_shards self._shard_id = shard_id self._shard_equal_rows = shard_equal_rows self._local_path = os.path.realpath(local_path) self._shuffle = Shuffle.GLOBAL if shuffle is True else shuffle from .config_loader import config self._epoch_seed = config.SEED self._file_seed = config.SEED self._size_per_shard = None self._curr_epoch = 1 self._curr_step = 1 self._shard_files, self._size_per_shard = _init_shard_files(self._dataset_files, self._shuffle, self._epoch_seed, self._num_shards, self._shard_id, self._shard_equal_rows, self._size_per_shard, self._local_path, self._curr_epoch) m = _get_manager() self._queue = m.PriorityQueue() self._cache = m.dict() self._index = 0 self._current_idx = m.Value('i', self._index) self._cache, self._queue = _init_cache_and_working_queue( self._cache, self._queue, self._shard_files, self._local_path) self._index = 0 self._first_epoch = True self._iteration = None self._cache_miss_times = 0 self._pool = ThreadPool(processes=1) self._worker = self._pool.apply_async( _download_work, (self._shard_id, self._current_idx, self._local_path, self._cache, self._queue)) _sync_for_obs_mindrecord_dataset( self._worker, self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch) def __next__(self): from .config_loader import config from ..datasets_standard_format import MindDataset from .util import make_sampler if self._iteration: try: self._curr_step += 1 return next(self._iteration) except StopIteration: self._index += 1 self._current_idx.value = self._index self._iteration = None if self._index >= len(self._shard_files): self._first_epoch = False self._curr_epoch += 1 self._curr_step = 0 raise StopIteration return next(self) else: f, start, end, is_full_dataset = self._shard_files[self._index] current_dataset = os.path.basename(f) hit_cache = self._cache[current_dataset][0] if hit_cache >= 0: # hit cache self._cache_miss_times = 0 # launch pipeline set_seed(self._file_seed) sampler = make_sampler( self._shuffle, is_full_dataset, start, end) self._file_seed += 1 path = os.path.join(self._local_path, current_dataset) logger.info("[{} FUNCTION] Shard:{} start to load dataset:{} in epoch:{}.".format( sys._getframe().f_code.co_name, self._shard_id, path, self._curr_epoch)) # pylint: disable=W0212 self._iteration = MindDataset(dataset_files=[path], columns_list=self._columns_list, sampler=sampler, shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True) else: # cache miss self._cache_miss_times += 1 logger.info("[{} FUNCTION] Cache miss in shard {} for times {}, expect dataset {}.".format( sys._getframe().f_code.co_name, self._shard_id, self._cache_miss_times, # pylint: disable=W0212 current_dataset)) time.sleep(self._cache_miss_times * config.WAIT_STEP_TIME) return next(self) def __iter__(self): if self._first_epoch: self._index = 0 self._current_idx.value = self._index self._iteration = None return self self._index = 0 self._current_idx.value = self._index self._epoch_seed += 1 self._iteration = None self._shard_files, self._size_per_shard = _init_shard_files(self._dataset_files, self._shuffle, self._epoch_seed, self._num_shards, self._shard_id, self._shard_equal_rows, self._size_per_shard, self._local_path, self._curr_epoch) self._cache.clear() # reset queue try: while True: self._queue.get_nowait() except queue.Empty: pass _remove_unused_dataset( self._local_path, self._num_shards, self._shard_id, self._curr_epoch) _wait_remove_datset(self._num_shards, self._shard_id, self._curr_epoch) self._cache, self._queue = _init_cache_and_working_queue( self._cache, self._queue, self._shard_files, self._local_path) _sync_for_obs_mindrecord_dataset(self._worker, self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch) return self def __len__(self): from .util import fetch_meta_files, make_dataset_tuple if self._size_per_shard is not None: return self._size_per_shard dataset_files = [] for dataset_file, _, _, _ in self._shard_files: dataset_files.append(dataset_file) fetch_meta_files(dataset_files, self._local_path) self._size_per_shard, _ = make_dataset_tuple( dataset_files, self._local_path) return len(self) def get_col_names(self): """ Get column names of Mindrecord format dataset.""" from ..datasets_standard_format import MindDataset target_dataset = None while target_dataset is None: for f, _, _, _ in self._shard_files: current_dataset = os.path.basename(f) if self._cache[current_dataset][0] >= 0: target_dataset = current_dataset path = os.path.join(self._local_path, target_dataset) _iteration = MindDataset(dataset_files=[path], shuffle=False) return _iteration.get_col_names()