on-device执行

Ascend GPU CPU 模型运行

image0image1image2

概述

MindSpore支持的后端包括Ascend、GPU、CPU,所谓On Device中的Device通常指Ascend(昇腾)AI处理器。

昇腾芯片上集成了AICORE、AICPU和CPU。其中,AICORE负责大型Tensor Vector运算,AICPU负责标量运算,CPU负责逻辑控制和任务分发。

Host侧CPU负责将图或算子下发到昇腾芯片。昇腾芯片由于具备了运算、逻辑控制和任务分发的功能,所以不需要与Host侧的CPU进行频繁的交互,只需要将计算完的最终结果返回给Host侧,实现整图下沉到Device执行,避免Host-Device频繁交互,减小了开销。

计算图下沉

计算图整图下沉到Device上执行,减少Host-Device交互开销。可以结合循环下沉实现多个Step下沉,进一步减少Host和Device的交互次数。

循环下沉是在On Device执行的基础上的优化,目的是进一步减少Host侧和Device侧之间的交互次数。通常情况下,每个step都返回一个结果,循环下沉是控制每隔多少个step返回一次结果。

默认配置下是每一个epoch返回一次结果,这样每个epoch里,Host侧和Device侧只需要进行一次数据交互。

也可以结合train接口的dataset_sink_modesink_size控制每个epoch的下沉数据量。

数据下沉

Modeltrain接口参数dataset_sink_mode可以控制数据是否下沉。dataset_sink_mode为True表示数据下沉,否则为非下沉。所谓下沉即数据通过通道直接传送到Device上。

dataset_sink_mode参数可以配合sink_size控制每个epoch下沉的数据量大小。当dataset_sink_mode设置为True,即数据下沉模式时:

如果sink_size为默认值-1,则每一个epoch下沉的数据量为原始的整个数据集大小;

如果sink_size>0,此时原始数据集可以被无限次遍历,每个epoch下沉sink_size大小的数据量,下一个epoch继续从上次遍历的结束位置继续遍历。

下沉的总数据量由epochsink_size两个变量共同控制,即总数据量=epoch*sink_size

当使用LossMonitorTimeMonitor或其它Callback接口时,如果dataset_sink_mode设置为False,Host侧和Device侧之间每个step交互一次,所以会每个step返回一个结果,如果dataset_sink_mode为True,因为数据在Device上通过通道传输, Host侧和Device侧之间每个epoch进行一次数据交互,所以每个epoch只返回一次结果。

当前CPU和PyNative模式不支持数据下沉。

代码样例如下:

[4]:
import os

import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as CT
import mindspore.dataset.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context, Model
from mindspore import dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.dataset.vision import Inter
from mindspore.nn import Accuracy
import mindspore.ops as ops
from mindspore.train.callback import LossMonitor


def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    """
    create dataset for train or test
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = CT.TypeCast(mstype.int32)

    # apply map operations on images
    mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    """weight initial for conv layer"""
    weight = weight_variable()
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=kernel_size, stride=stride, padding=padding,
                     weight_init=weight, has_bias=False, pad_mode="valid")


def fc_with_initialize(input_channels, out_channels):
    """weight initial for fc layer"""
    weight = weight_variable()
    bias = weight_variable()
    return nn.Dense(input_channels, out_channels, weight, bias)


def weight_variable():
    """weight initial"""
    return TruncatedNormal(0.02)


class LeNet5(nn.Cell):
    """
    Lenet network
    Args:
        num_class (int): Num classes. Default: 10.

    Returns:
        Tensor, output tensor

    Examples:
        >>> LeNet(num_class=10)
    """

    def __init__(self, num_class=10):
        super(LeNet5, self).__init__()
        self.num_class = num_class
        self.batch_size = 32
        self.conv1 = conv(1, 6, 5)
        self.conv2 = conv(6, 16, 5)
        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
        self.fc2 = fc_with_initialize(120, 84)
        self.fc3 = fc_with_initialize(84, self.num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.reshape = ops.Reshape()

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.reshape(x, (self.batch_size, -1))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


if __name__ == "__main__":
    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    ds_train_path = "./datasets/MNIST_Data/train/"
    ds_train = create_dataset(ds_train_path, 32)

    network = LeNet5(10)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
    model = Model(network, net_loss, net_opt)

    print("============== Starting Training ==============")
    model.train(epoch=10, train_dataset=ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True, sink_size=1000)
============== Starting Training ==============
epoch: 1 step: 1000, loss is 0.110185064
epoch: 2 step: 1000, loss is 0.12088283
epoch: 3 step: 1000, loss is 0.15903473
epoch: 4 step: 1000, loss is 0.030054657
epoch: 5 step: 1000, loss is 0.013846226
epoch: 6 step: 1000, loss is 0.052161213
epoch: 7 step: 1000, loss is 0.0050197737
epoch: 8 step: 1000, loss is 0.17207858
epoch: 9 step: 1000, loss is 0.010310417
epoch: 10 step: 1000, loss is 0.000672762

batch_size为32的情况下,数据集的大小为1875,当sink_size设置为1000时,表示每个epoch下沉1000个batch的数据,下沉次数为epoch=10,下沉的总数据量为:epoch*sink_size=10000。

dataset_sink_mode为True,所以每个epoch返回一次结果。

dataset_sink_mode为False时,sink_size参数设置无效。