mindspore::kernel
接口汇总
类名 |
描述 |
---|---|
算子基类。 |
|
算子扩展能力基类。 |
|
Mindspore Kernel Mindspore算子基类。 |
|
IKernel 算子模板类。 |
Kernel
#include <kernel.h>
Kernel是算子实现的基类,定义了几个必须实现的接口。继承自IKernel。
构造函数
Kernel
Kernel()
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
Kernel的默认与带参构造函数,构造Kernel实例。
析构函数
~Kernel
virtual ~Kernel()
Kernel的析构函数。
公有成员函数
InferShape
virtual int InferShape()
在用户调用Model::Build
接口时,或是模型推理中需要推理算子形状时,会调用到该接口。
在自定义算子场景中,用户可以覆写该接口,实现自定义算子的形状推理逻辑。详见自定义算子章节。
在InferShape
函数中,一般需要实现算子的形状、数据类型和数据排布的推理逻辑。
返回值
是否成功。
type
virtual schema::PrimitiveType type()
返回算子的类型。
quant_type
virtual schema::QuantType quant_type()
返回算子的量化类型。
KernelInterface
#include <kernel_interface.h>
算子扩展能力基类。
~KernelInterface
virtual ~KernelInterface() = default;
析构函数。
KernelInterfaceCreator
using KernelInterfaceCreator = std::function<std::shared_ptr<KernelInterface>()>
创建KernelInterface的函数原型声明。
公有成员函数
Infer
算子用于推测输出shape的方法。
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive)
Infer
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive, const Kernel *kernel)
参数
推测自定义算子的输出shape。
MSKernel
#include <kernel.h>
Mindspore Kernel 算子类。是IKernel的父类。
构造函数
MSKernel() = default;
MSKernel(std::vector<mindspore::MSTensor> inputs, std::vector<mindspore::MSTensor> outputs,
const mindspore::Context *ctx)
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)) {}
参数
inputs
: 输入。output
: 输出。ctx
: 算子对应Context。
析构函数
virtual ~MSKernel() = default;
公有成员函数
InferShape
virtual int InferShape()
在用户调用Model::Build
接口时,或是模型推理中需要推理算子形状时,会调用到该接口。
在自定义算子场景中,用户可以覆写该接口,实现自定义算子的形状推理逻辑。详见自定义算子章节。
在InferShape
函数中,一般需要实现算子的形状、数据类型和数据排布的推理逻辑。
返回值
状态码。
Prepare
virtual int Prepare() = 0;
进行算子运行前相关的准备工作,MindSpore Lite 框架运行时会对所有算子执行一遍Prepare后再执行Execute。
返回值
状态码。
Execute
virtual int Execute() = 0;
运行算子。
返回值
状态码。
ReSize
virtual int ReSize() = 0;
在用户调用Model::Resize
接口时,或是模型推理中需要重新推理算子形状时,会调用到该接口。
在ReSize
函数中,若有必要,根据输入的形状态重新推理输出形状,并分配算子运算中需要的内存。
返回值
状态码。
set_inputs
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; }
设置算子的输入列表。
参数
in_tensors
: 算子的所有输入MSTensor列表。
set_input
virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; }
设置算子指定位置的输入。
参数
in_tensor
: 算子的输入MSTensor。index
: 算子输入在所有输入中的下标,从0开始计数。
set_outputs
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; }
设置算子的输出列表。
参数
out_tensor
: 算子的所有输出MSTensor列表。
set_output
virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; }
设置算子指定位置的输出。
参数
out_tensor
: 算子的输出MSTensor。index
: 算子输出在所有输出中的下标,从0开始计数。
inputs
virtual const std::vector<mindspore::MSTensor *> &inputs()
返回算子的所有输入MSTensor列表。
返回值
算子的所有输入。
outputs
virtual const std::vector<mindspore::MSTensor *> &outputs()
返回算子的所有输出MSTensor列表。
返回值
算子的所有输出。
name
std::string name() const
返回算子的名称。
返回值
算子的名称。
set_name
void set_name(const std::string &name)
设置算子的名称。
参数
name
: 算子名称。
context
const lite::Context *context() const
返回算子对应的Context。
返回值
算子的Context。
GetAttr
std::string GetAttr(const std::string &key) const
获取指定配置名对应的配置。
参数
key
: 配置名。
SetConfig
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config)
保存配置内容的常量指针到kernel里,该接口当前是由框架在加载配置文件时自动触发调用的,不建议用户使用。
参数
config
: 配置的常量指针。
GetConfig
std::map<std::string, std::string> GetConfig(const std::string §ion) const
获取指定章节名对应的配置。
参数
section
: 配置的章节名称。
保护成员变量
std::string name_;
const mindspore::Context *context_ = nullptr;
std::vector<mindspore::MSTensor> inputs_;
std::vector<mindspore::MSTensor> outputs_;
std::map<std::string, std::string> attrs_;
const std::map<std::string, std::map<std::string, std::string>> *config_ = nullptr;
name_: 名字
context_: 训练context。
inputs_: 输入。
outputs_: 输出。
attrs_: 属性。
config_: 配置。
保护成员函数
void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; }
设置算子的属性。
参数
key
: 属性键。value
: 属性值。
IKernel
继承自MSKernel
构造函数
IKernel() = default;
IKernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const Primitive *primitive, const mindspore::Context *ctx)
: MSKernel(inputs, outputs, ctx), primitive_(primitive) {}
参数
inputs
: 输入Tensor。outputs
: 输出Tensor。primitive
: 算子经过flatbuffers反序化后的结果,存储算子属性。ctx
: 计算context。
析构函数
~IKernel() override = default;
公有成员函数
Primitive
const Primitive *primitive() const { return this->primitive_; }
获取IKernel算子的primitive。
返回值
IKernel算子经过flatbuffers反序化后的结果。
保护成员函数
const Primitive *primitive_ = nullptr;
IKernel算子经过flatbuffers反序化后的结果。