基于自定义算子接口调用第三方算子库

查看源文件

概述

当开发网络遇到内置算子不足以满足需求时,你可以利用MindSpore的Python API中的Custom原语方便快捷地进行不同类型自定义算子的定义和使用。

网络开发者可以根据需要选用不同的自定义算子开发方式。详情请参考Custom算子的使用指南

其中,自定义算子有一种开发方式aot方式有其特殊的使用方式。aot方式可以通过加载预编译的so来调用相应的cpp/cuda函数。因此,当第三方库提供了cpp/cuda函数API时,可以尝试将其函数接口在so中调用,以下以PyTorch的Aten库为例进行介绍。

PyTorch Aten算子对接

当迁移一张使用PyTorch Aten算子的网络遇到内置算子不足的情况时,我们可以利用Custom算子的aot开发方式调用PyTorch Aten的算子进行快速验证。

PyTorch提供了一种方式可以支持引入PyTorch的头文件,从而使用其相关的数据结构编写cpp/cuda代码,并编译成so。参考:https://pytorch.org/docs/stable/_modules/torch/utils/cpp_extension.html#CppExtension

将两种方式结合使用,自定义算子可以调用PyTorch Aten算子,使用方式如下:

1. 下载工程文件

工程文件可以通过这里下载。

使用以下命令解压压缩包,得到文件夹test_custom_pytorch

tar xvf test_custom_pytorch.tar

文件夹中包含以下几个文件:

test_custom_pytorch
├── env.sh                           # set PyTorch/lib into LD_LIBRARY_PATH
├── leaky_relu.cpp                   # an example of use Aten CPU operator
├── leaky_relu.cu                    # an example of use Aten GPU operator
├── ms_ext.cpp                       # convert Tensors between MindSpore and PyTorch
├── ms_ext.h                         # convert API
├── README.md
├── run_cpu.sh                       # a script to run cpu case
├── run_gpu.sh                       # a script to run gpu case
├── setup.py                         # a script to compile cpp/cu into so
├── test_cpu_op_in_gpu_device.py     # a test file to run Aten CPU operator on GPU device
├── test_cpu_op.py                   # a test file to run Aten CPU operator on CPU device
└── test_gpu_op.py                   # a test file to run Aten GPU operator on GPU device

使用PyTorch Aten算子主要关注env.sh、setup.py、leaky_relu.cpp/cu、test_*.py即可。

其中,env.sh用于设置环境变量,setup.py用于编译so,leaky_relu.cpp/cu用于参考编写调用PyTorch Aten算子的源码,test_*.py用于参考调用Custom算子。

2. 编写调用PyTorch Aten算子的源码文件

参考leaky_relu.cpp/cu,编写调用PyTorch Aten算子的源码文件。

由于aot类型的自定义算子采用AOT编译方式,要求网络开发者基于特定接口,手写算子实现函数对应的源码文件,并提前将源码文件编译为动态链接库,然后在网络运行时框架会自动调用执行动态链接库中的函数。在算子实现的开发语言方面,GPU平台支持CUDACPU平台支持CC++。源码文件中的算子实现函数的接口规范如下:

extern "C" int func_name(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra);

如果是调用cpu算子,以leaky_relu.cpp为例,该文件提供AOT需要的函数LeakyRelu,里面调用了PyTorch Aten的函数torch::leaky_relu_out

#include <string.h>
#include <torch/extension.h> // 头文件引用部分
#include "ms_ext.h"

extern "C" int LeakyRelu(
    int nparam,
    void** params,
    int* ndims,
    int64_t** shapes,
    const char** dtypes,
    void* stream,
    void* extra) {
    auto tensors = get_torch_tensors(nparam, params, ndims, shapes, dtypes, c10::kCPU);
    auto at_input = tensors[0];
    auto at_output = tensors[1];
    torch::leaky_relu_out(at_output, at_input);
    // 如果使用不带输出的版本,代码如下:
    // torch::Tensor output = torch::leaky_relu(at_input);
    // at_output.copy_(output);
  return 0;
}

如果是调用gpu算子,以leaky_relu.cu为例:

#include <string.h>
#include <torch/extension.h> // 头文件引用部分
#include "ms_ext.h"

extern "C" int LeakyRelu(
    int nparam,
    void** params,
    int* ndims,
    int64_t** shapes,
    const char** dtypes,
    void* stream,
    void* extra) {
    cudaStream_t custream = static_cast<cudaStream_t>(stream);
    cudaStreamSynchronize(custream);
    auto tensors = get_torch_tensors(nparam, params, ndims, shapes, dtypes, c10::kCUDA);
    auto at_input = tensors[0];
    auto at_output = tensors[1];
    torch::leaky_relu_out(at_output, at_input);
  return 0;
}

其中,PyTorch Aten提供了带输出的算子函数版本和不带输出的算子函数版本,带输出的算子函数有_out后缀,PyTorch Aten提供了300+常用算子的api

当调用torch::*_out时,不需要output拷贝。当调用不带_out后缀的版本,需要调用APItorch.Tensor.copy_进行结果拷贝。

想查看支持调用PyTorch Aten的哪些函数,CPU版本参考PyTorch安装路径下的:python*/site-packages/torch/include/ATen/CPUFunctions_inl.h ,相应的GPU版本参考python*/site-packages/torch/include/ATen/CUDAFunctions_inl.h

以上用例中使用了ms_ext.h提供的api,这里稍作介绍:

// 将 MindSpore kernel 的 inputs/outputs 转换为 PyTorch Aten 的 Tensor
std::vector<at::Tensor> get_torch_tensors(int nparam, void** params, int* ndims, int64_t** shapes, const char** dtypes, c10::Device device) ;

3. 使用编译脚本setup.py生成so

setup.py使用PyTorch Aten提供的cppextension将上述c++/cuda源码编译成so文件。

执行前需要确保已经安装PyTorch。

pip install torch

并将PyTorch的lib加入LD_LIBRARY_PATH

export LD_LIBRARY_PATH=$(python3 -c 'import torch, os; print(os.path.dirname(torch.__file__))')/lib:$LD_LIBRARY_PATH

执行:

cpu: python setup.py leaky_relu.cpp leaky_relu_cpu.so
gpu: python setup.py leaky_relu.cu leaky_relu_gpu.so

将得到我们需要的 so 文件。

4. 使用自定义算子

以CPU为例,使用Custom算子调用上述PyTorch Aten算子,代码见test_cpu_op.py:

import numpy as np
import mindspore as ms
from mindspore.nn import Cell
import mindspore.ops as ops

ms.set_context(device_target="CPU")

def LeakyRelu():
    return ops.Custom("./leaky_relu_cpu.so:LeakyRelu", out_shape=lambda x : x, out_dtype=lambda x : x, func_type="aot")

class Net(Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.leaky_relu = LeakyRelu()

    def construct(self, x):
        return self.leaky_relu(x)

if __name__ == "__main__":
    x0 = np.array([[0.0, -0.1], [-0.2, 1.0]]).astype(np.float32)
    net = Net()
    output = net(ms.Tensor(x0))
    print(output)

执行:

python test_cpu_op.py

结果:

[[ 0.    -0.001]
 [-0.002  1.   ]]

注意:

若使用的是PyTorch Aten GPU算子,device_target需设置为"GPU".

set_context(device_target="GPU")
op = ops.Custom("./leaky_relu_gpu.so:LeakyRelu", out_shape=lambda x : x, out_dtype=lambda x : x, func_type="aot")

若使用的是PyTorch Aten CPU算子,而device_target"GPU",需要增加设置如下:

set_context(device_target="GPU")
op = ops.Custom("./leaky_relu_cpu.so:LeakyRelu", out_shape=lambda x : x, out_dtype=lambda x : x, func_type="aot")
op.add_prim_attr("primitive_target", "CPU")
  1. 使用cppextension编译so需满足该工具需要的编译器版本,检查gcc/clang/nvcc是否存在。

  2. 使用cppextension编译so会在脚本路径生成一个build的文件夹,里面存放了so,脚本会将so拷贝到build外,但是cppextension如果发现build里已经有so会跳过编译,因此如果是新编译的so要记得清空build下的so。

  3. 以上测试基于PyTorch 1.9.1版本,cuda使用11.1,python3.7,下载链接:https://download.pytorch.org/whl/cu111/torch-1.9.1%2Bcu111-cp37-cp37m-linux_x86_64.whl,PyTorch Aten支持的cuda版本需和本地的cuda版本一致,其他版本是否支持需用户自行探索。