代码
基于MindSpore案例的香橙派开发板离线推理实践--使用Shufflenet模型对输入图片进行分类

基于MindSpore案例的香橙派开发板离线推理实践--使用Shufflenet模型对输入图片进行分类

基于MindSpore案例的香橙派开发板离线推理实践--使用Shufflenet模型对输入图片进行分类

样例介绍

功能:使用Shufflenet模型对输入图片进行分类。 样例输入:原始Cifar20数据集。 样例输出:图片分类的结果。

前期准备

基础镜像的样例目录中已包含转换后的om模型以及测试数据集,如果直接运行,可跳过此步骤。如果需要重新转换模型,可参考如下步骤:

利用atc工具将原始模型转换为om模型,转换命令如下:

atc --output_type=FP32 --input_format=NCHW--output="./shufflenet" --soc_version=Ascend310B4 --framework=1 --save_original_model=false --model="./shufflenet.air" --precision_mode=allow_fp32_to_fp16

其中各个参数具体含义如下:

  • --output_type:指定网络输出数据类型。
  • --input_format:输入Tensor的内存排列方式。
  • --output:输出的模型文件路径。
  • --soc_version:昇腾AI处理器型号。
  • --framework:原始框架类型, 0: Caffe, 1: MindSpore, 3: TensorFlow, 5: ONNX。
  • --save_original_model:转换后是否保留原始模型文件。
  • --model:原始模型文件路径。
  • --precision_mode:选择算子精度模式。

模型推理实现

1. 导入三方库

import numpy as np
import cv2
import matplotlib.pyplot as plt
import acl

import acllite_utils as utils
import constants as const
from acllite_model import AclLiteModel
from acllite_resource import resource_list

import mindspore as ms
import mindspore.dataset as ds
from mindspore.dataset import vision

2. 定义acllite资源初始化与去初始化类

class AclLiteResource:
    """
    AclLiteResource
    """
    def __init__(self, device_id=0):
        self.device_id = device_id
        self.context = None
        self.stream = None
        self.run_mode = None
        
    def init(self):
        """
        init resource
        """
        print("init resource stage:")
        ret = acl.init()

        ret = acl.rt.set_device(self.device_id)
        utils.check_ret("acl.rt.set_device", ret)

        self.context, ret = acl.rt.create_context(self.device_id)
        utils.check_ret("acl.rt.create_context", ret)

        self.stream, ret = acl.rt.create_stream()
        utils.check_ret("acl.rt.create_stream", ret)

        self.run_mode, ret = acl.rt.get_run_mode()
        utils.check_ret("acl.rt.get_run_mode", ret)

        print("Init resource success")

    def __del__(self):
        print("acl resource release all resource")
        resource_list.destroy()
        if self.stream:
            print("acl resource release stream")
            acl.rt.destroy_stream(self.stream)

        if self.context:
            print("acl resource release context")
            acl.rt.destroy_context(self.context)

        print("Reset acl device ", self.device_id)
        acl.rt.reset_device(self.device_id)
        print("Release acl resource success")

3. 定义分类类,包含前处理、推理、结果展示等操作

class Classification(object):
    """
    class for Classification
    """
    def __init__(self, model_path):
        self._model_path = model_path
        self.device_id = 0
        self._model = None

    def init(self):
        """
        Initialize
        """
        # Load model
        self._model = AclLiteModel(self._model_path)

        return const.SUCCESS

    def pre_process(self, input_path):
        # 读取原始Cifar10数据集
        dataset_predict = ds.Cifar10Dataset(dataset_dir=input_path, shuffle=False, num_samples=16)
        # 进行一系列的数据前处理操作
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
                ]
        dataset_predict = dataset_predict.map(image_trans, 'image')
        # 设置batch_size
        dataset_predict = dataset_predict.batch(1)
        data_loader = dataset_predict.create_dict_iterator()

        return data_loader
    

    def inference(self, data_loader):
        """
        model inference
        """  
        results = []
        for i, resized_image in enumerate(data_loader):
            # 读物数据集图片
            result1 = ms.Tensor(resized_image['image'])
            result = self._model.execute([result1.asnumpy(), ])
            
            # 选取概率最高的那个结果
            pred = np.argmax(result[0], axis=1)
            results.append(pred[0])
        
        return results, data_loader

    def show_process(self, infer_output, data_origin, input_path):
        """
        show process
        """
        dataset_predict = ds.Cifar10Dataset(dataset_dir=input_path, shuffle=False, num_samples=16)
        class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

        # 推理效果展示(上方为预测的结果,下方为推理效果图片)
        plt.figure(figsize=(16, 5))

        for i, resized_image in enumerate(dataset_predict.create_dict_iterator()):
            plt.subplot(2, 8, i+1)
            plt.title('{}'.format(class_dict[infer_output[0][i]]))
            plt.imshow(resized_image["image"].asnumpy())
            plt.axis("off")
        plt.show()

4. 构造主函数,串联整个代码逻辑

from download import download

# 获取数据集
dataset_url = "https://mindspore-courses.obs.cn-north-4.myhuaweicloud.com/orange-pi-mindspore/03-ResNet50/cifar-10-batches-bin.zip"
download(dataset_url, "./dataset", kind="zip", replace=True)

# 获取模型om文件
model_url = "https://mindspore-courses.obs.cn-north-4.myhuaweicloud.com/orange-pi-mindspore/06-ShuffleNet/shufflenet.zip"
download(model_url, "./", kind="zip", replace=True)


def main():
    # 模型
    MODEL_PATH = "shufflenet.om"
    # 数据集
    input_path = "./dataset/cifar-10-batches-bin"

    acl_resource = AclLiteResource()  # 初始化acl资源
    acl_resource.init()
    
    # Instantiation Classification object
    classification = Classification(MODEL_PATH)  # 构造模型对象
    
    # init
    ret = classification.init()  # 初始化模型类变量
    utils.check_ret("classification.init ", ret)  
    
    # preprocess
    crop_and_paste_image = classification.pre_process(input_path)  # 前处理
    
    # inference
    result = classification.inference(crop_and_paste_image)  # 推理

    # show
    classification.show_process(result, crop_and_paste_image, input_path)  # 结果展示

5. 运行

运行完成后,会显示推理后的图片,运行代码如下:

main()

6. 样例总结

我们来回顾一下以上代码,可以包括以下几个步骤:

  1. 初始化acl资源:在调用acl相关资源时,必须先初始化AscendCL,否则可能会导致后续系统内部资源初始化出错。
  2. 对图片进行前处理:在此样例中,我们首先根据图片路径,使用mindspore.dataset.Cifar10Dataset接口下加载CIFAR-10数据集,再利用mindspore.dataset.vision接口对数据集进行剪切、通道转换等操作,使得模型正常推理。然后基于数据集对象创建数据迭代器。
  3. 推理:利用AclLiteModel.execute接口对数据进行推理。
  4. 可视化推理结果:利用plt将结果画出。