mindspore.ops.Custom ===================== .. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg :target: https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/ops/mindspore.ops.Custom.rst :alt: 查看源文件 .. py:class:: mindspore.ops.Custom(func, bprop=None, out_dtype=None, func_type="hybrid", out_shape=None, reg_info=None) `Custom` 算子是MindSpore自定义算子的统一接口。用户可以利用该接口自行定义MindSpore内置算子库尚未包含的算子。 根据输入函数的不同,你可以创建多个自定义算子,并且把它们用在神经网络中。 关于自定义算子的详细说明和介绍,包括参数的正确书写,见 `自定义算子教程 `_ 。 .. warning:: - 这是一个实验性API,后续可能修改或删除。 .. note:: 不同自定义算子的函数类型(func_type)支持的平台类型不同。每种类型支持的平台如下: - "hybrid": ["GPU", "CPU"]. - "akg": ["GPU", "CPU"]. - "aot": ["GPU", "CPU"]. - "pyfunc": ["CPU"]. - "julia": ["CPU"]. 参数: - **func** (Union[function, str]) - 自定义算子的函数表达。 - function:如果 `func` 是函数类型,那么 `func` 应该是一个Python函数,它描述了用户定义的操作符的计算逻辑。该函数可以是以下之一: 1. AKG操作符实现函数,可以使用ir builder/tvm compute/hybrid语法。 2. 纯Python函数。 3. 使用Hybrid DSL编写的带有装饰器的内核函数。 - 字符串:如果 `func` 是字符串类型,那么 `str` 应该是包含函数名的文件路径。当 `func_type` 是"aot"或"julia"时,可以使用这种方式。 1. 对于"aot": 目前"aot"支持GPU/CPU(仅Linux)平台。"aot"意味着提前编译,在这种情况下,Custom直接启动用户定义的"xxx.so"文件作为操作符。用户需要提前将手写的"xxx.cu"/"xxx.cc"文件编译成"xxx.so",并提供文件路径和函数名。 - "xxx.so"文件生成: 1) GPU平台:给定用户定义的"xxx.cu"文件(例如"{path}/add.cu"),使用nvcc命令进行编译(例如"nvcc --shared -Xcompiler -fPIC -o add.so add.cu")。 2) CPU平台:给定用户定义的"xxx.cc"文件(例如"{path}/add.cc"),使用g++/gcc命令进行编译(例如"g++ --shared -fPIC -o add.so add.cc")。 - 定义"xxx.cc"/"xxx.cu"文件: "aot"是一个跨平台的标识符。"xxx.cc"或"xxx.cu"中定义的函数具有相同的参数。通常,该函数应该像这样: .. code-block:: int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) 参数: - `nparam(int)` : 输入和输出的总数;假设操作符有2个输入和3个输出,那么 `nparam=5` 。 - `params(void **)` : 输入和输出指针的数组指针;输入和输出的指针类型为 `void *` ;假设操作符有2个输入和3个输出,那么第一个输入的指针是 `params[0]` ,第二个输出的指针是 `params[3]` 。 - `ndims(int *)` : 输入和输出维度数的数组指针;假设 `params[i]` 是一个1024x1024的张量, `params[j]` 是一个77x83x4的张量,那么 `ndims[i]=2` , `ndims[j]=3` 。 - `shapes(int64_t **)` : 输入和输出形状( `int64_t *` )的数组指针;第 `i` 个输入的第 `j` 个维度的大小是 `shapes[i][j]` (其中 `0<=j