数据集切分

查看源文件

概述

在进行分布式训练时,以图片数据为例,当单张图片的大小过大时,如遥感卫星等大幅面图片,即使一张图片也过大,需要对图片进行切分,每张卡读取一部分图片, 进行分布式训练。处理数据集切分的场景,需要配合模型并行一起才能达到预期的降低显存的效果,因此,基于自动并行提供了该项功能。本教程使用的样例为ResNet50,不是大幅面的网络,仅作示例。真实应用到大幅面的网络时,往往需要详细设计并行策略。

操作实践

样例代码说明

目录结构如下:

└─sample_code
    ├─distributed_training
    │      rank_table_16pcs.json
    │      rank_table_8pcs.json
    │      rank_table_2pcs.json
    │      resnet.py
    │      resnet50_distributed_training_dataset_slice.py
    │      run_dataset_slice.sh

创建数据集

数据集切分仅支持全/半自动模式,在数据并行模式下不涉及。

使用数据集切分时,需要同时调用数据集的SlicePatches接口去构造数据集,并且,为了保证各卡读入数据一致,需要对数据集固定随机数种子。

数据集定义部分如下。

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore.communication import init, get_rank, get_group_size

ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
init()
ds.config.set_seed(1000) # set dataset seed to make sure that all cards read the same data
def create_dataset(data_path, repeat_num=1, batch_size=32, slice_h_num=1, slice_w_num=1):
    resize_height = 224
    resize_width = 224
    rescale = 1.0 / 255.0
    shift = 0.0

    rank_id = get_rank()

    # create a full dataset before slicing
    data_set = ds.Cifar10Dataset(data_path, shuffle=True)

    # define map operations
    random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))
    random_horizontal_op = vision.RandomHorizontalFlip()
    resize_op = vision.Resize((resize_height, resize_width))
    rescale_op = vision.Rescale(rescale, shift)
    normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
    changeswap_op = vision.HWC2CHW()
    type_cast_op = transforms.TypeCast(ms.int32)

    c_trans = [random_crop_op, random_horizontal_op]
    c_trans += [resize_op, rescale_op, normalize_op]

    # apply map operations on images
    data_set = data_set.map(operations=type_cast_op, input_columns="label")
    # in random map function, using num_parallel_workers=1 to avoid the dataset random seed not working.
    data_set = data_set.map(operations=c_trans, input_columns="image", num_parallel_workers=1)
    # slice image
    slice_patchs_img_op = vision.SlicePatchs(slice_h_num, slice_w_num)
    img_cols = ['img' + str(x) for x in range(slice_h_num * slice_w_num)]
    data_set = data_set.map(operations=slice_patchs_img_op, input_columns="image", output_columns=img_cols,
                            column_order=[img_cols[rank_id % (slice_h_num * slice_w_num)], "label"])
    # change hwc to chw
    data_set = data_set.map(operations=changeswap_op, input_columns=img_cols[rank_id % (slice_h_num * slice_w_num)])
    # apply batch operations
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)

    # apply repeat operations
    data_set = data_set.repeat(repeat_num)

    return data_set

配置数据集切分策略

数据集切分仅支持全/半自动模式,在数据并行模式下不涉及。

mindspore.auto_parallel_context中提供了dataset_strategy选项,为数据集配置切分策略。

dataset_strategy接口还有以下几点限制:

  1. 每个输入至多允许在一维进行切分。如支持set_auto_parallel_context(dataset_strategy=((1, 1, 1, 8), (8,))))或者dataset_strategy=((1, 1, 1, 8), (1,))),每个输入都仅仅切分了一维;但是不支持dataset_strategy=((1, 1, 4, 2), (1,)),其第一个输入切分了两维。

  2. 维度最高的一个输入,切分的数目,一定要比其他维度的多。如支持dataset_strategy=((1, 1, 1, 8), (8,)))或者dataset_strategy=((1, 1, 1, 1), (1,))),其维度最多的输入为第一个输入,切分份数为8,其余输入切分均不超过8;但是不支持dataset_strategy=((1, 1, 1, 1), (8,)),其维度最多的输入为第一维,切分份数为1,但是其第二个输入切分份数却为8,超过了第一个输入的切分份数。

import os
import mindspore as ms
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, gradients_mean=True)
slice_h_num = 1
slice_w_num = 8
batch_size = 256
ms.set_auto_parallel_context(dataset_strategy=(((1, 1, slice_h_num, slice_w_num), (1,))))
data_path = os.getenv('DATA_PATH')
dataset = create_dataset(data_path, batch_size=batch_size, slice_h_num=slice_h_num, slice_w_num=slice_w_num)

运行代码

上述流程的数据,代码和执行过程,可以参考:https://www.mindspore.cn/tutorials/experts/zh-CN/r1.8/parallel/train_ascend.html#单机多卡训练。差异点在于,将执行脚本更改为run_dataset_slice.sh。