离线构建自定义算子

查看源文件

概述

MindSpore Lite的转换工具除了基本的模型转换功能之外,还支持用户对模型进行自定义的优化与构建,生成用户自定义算子的模型。

我们提供了一套注册机制,允许用户基于转换工具进行能力扩展:包括节点解析扩展、模型解析扩展以及图优化扩展,用户可以根据自身的需要对模型实现自定义的解析与融合优化。

节点解析扩展:用户自定义模型中某一节点的解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考NodeParserNodeParserRegistry。 模型解析扩展:用户自定义模型的整个解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考ModelParserModelParserRegistry。 图优化扩展:模型解析之后,将获得MindSpore定义的图结构,用户可基于此结构自定义图的优化过程。接口可参考PassBasePassPositionPassRegistry

节点解析扩展需要依赖flatbuffers和protobuf及三方框架的序列化文件,并且flatbuffers和protobuf需要与发布件采用的版本一致,序列化文件需保证兼容发布件采用的序列化文件。发布件中不提供flatbuffers、protobuf及序列化文件,用户需自行编译,并生成序列化文件。用户可以从MindSpore仓中获取flatbuffersprobobufONNX原型文件CAFFE原型文件TF原型文件TFLITE原型文件

MindSpore Lite还提供了一系列的注册宏,以便于用户侧的扩展接入转换工具。注册宏包括节点解析注册REG_NODE_PARSER、模型解析注册REG_MODEL_PARSER、图优化注册REG_PASS、图优化调度注册REG_SCHEDULED_PASS

MindSpore Lite转换工具的扩展能力,目前仅支持Linux系统。

本章节将通过MindSpore Lite转换工具扩展功能的示例程序,涵盖节点扩展案例、优化扩展案例以及编译链接全流程,来使用户能够快速了解转换工具的扩展功能的使用。

模型解析扩展,鉴于是模块化的扩展能力,本章不做详细介绍,但会提供一个简化的单元案例,以供用户参考。

本章节以add.tflite模型为例。该模型仅包含一个简单的Add算子,通过自定义的节点解析、图优化,将Add算子转化为Custom算子,最终输出Custom单算子模型。

相关代码放置在mindspore/lite/examples/converter_extend目录。

节点扩展

  1. 自定义节点解析:用户需继承NodeParser,继而根据不同的框架,选择不同的重载接口。

  2. 节点解析注册:用户调用注册接口REG_NODE_PARSER,完成自定义的节点解析接入转换工具。

class AddParserTutorial : public NodeParser {  // 继承基类
 public:
  AddParserTutorial() = default;
  ~AddParserTutorial() = default;
  ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,            // 重载接口
                         const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
                         const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared<AddParserTutorial>());     // 调用注册接口

示例代码请参考node_parser

模型扩展

示例代码请参考MindSpore仓的模型扩展的单元案例ModelParserRegistryTest

优化扩展

  1. 自定义优化:用户需继承PassBase,重载Execute接口函数Execute

  2. 优化注册:调用优化的注册接口REG_PASS,完成自定义把用户自己实现的Pass类注册进MindSpore Lite里。

class PassTutorial : public registry::PassBase {  // 继承基类
 public:
  PassTutorial() : PassBase("PassTutorial") {}

  ~PassTutorial() = default;

  bool Execute(const api::FuncGraphPtr &func_graph) override;     // 重载接口

 private:
  AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
};

using mindspore::registry::POSITION_BEGIN;            // 选择调度位置
REG_PASS(PassTutorial, opt::PassTutorial)             // 注册扩展类
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})  // 注册调度逻辑

示例代码可参考pass

在离线转换阶段,我们会对模型的每一个节点的输出张量进行推断,包括输出张量的Format、DataType以及Shape,因此,离线转换阶段,用户需提供自己实现的算子的推断过程,这里用户可以参考算子Infershape扩展说明,示例代码可参考infer

示例演示

编译

  • 环境要求

    • 系统环境:Linux x86_64,推荐使用Ubuntu 18.04.02LTS

    • 编译依赖:

  • 编译准备

    MindSpore Lite的发布件不会提供其他框架下的序列化文件,因此,用户需自行编译获得,请参考概述

    本示例采用的是tflite模型,用户需编译flatbuffers,从MindSpore仓中获取TFLITE原型文件,最终生成tflite的序列化文件。

    mindspore/lite/examples/converter_extend目录下创建schema文件目录,继而将生成的序列化文件置于schema目录下。

  • 编译构建

    mindspore/lite/examples/converter_extend目录下执行build.sh,将自动下载MindSpore Lite发布件并编译Demo。

    bash build.sh
    

    若使用该build脚本下载MindSpore Lite发布件失败,请手动下载硬件平台为CPU、操作系统为Ubuntu-x64的MindSpore Lite发布件mindspore-lite-{version}-linux-x64.tar.gz,将解压后tools/converter/lib目录、tools/converter/include目录拷贝到mindspore/lite/examples/converter_extend目录下。

    通过手动下载并且将文件放到指定位置后,需要再次执行build.sh脚本才能完成编译构建。

  • 编译输出

    mindspore/lite/examples/converter_extend/build目录下生成了libconverter_extend_tutorial.so的动态库。

执行程序

  1. 拷贝动态库

    将生成的libconverter_extend_tutorial.so动态库文件拷贝到发布件的tools/converter/lib下。

  2. 进入发布件的转换目录

    cd ${PACKAGE_ROOT_PATH}/tools/converter/converter
    
  3. 创建converter的配置文件(converter.cfg),文件内容如下:

    [registry]
    plugin_path=libconverter_extend_tutorial.so      # 用户请配置动态库的正确路径
    
  4. 将转换工具需要的动态链接库加入环境变量LD_LIBRARY_PATH

    export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/tools/converter/lib
    
  5. 执行converter

    ./converter_lite --fmk=TFLITE --modelFile=add.tflite --configFile=converter.cfg --outputFile=add_extend
    

执行完后,将生成名为add_extend.ms的模型文件,文件路径由参数outputFile决定。