mindspore.graph.register_op_wrapper

查看源文件
mindspore.graph.register_op_wrapper(op_name)[源代码]

注册自定义包装函数以替换算子在计算图中的节点。

此装饰器允许用户注册自定义包装函数来替换计算图中算子的计算逻辑。当 JIT 编译器编译指定算子时,它将编译注册的包装函数而不是原始算子。

包装函数接收与原始算子相同的参数,并且必须返回一个可调用的函数。

说明

  • 每个算子只能注册一个包装函数。尝试为同一算子注册多个包装函数将抛出异常。

  • 替换影响当前进程中 JIT 编译 Cell 内指定算子的所有实例。

  • 包装函数参数应与原始算子的签名匹配。

参数:
  • op_name (str) - 要注册包装函数的算子名称。必须为非空字符串。

返回:

callable,被装饰器的,被注册为自定义包装函数的函数。

异常:
  • ValueError - 当 op_name 不是非空字符串时抛出。

  • RuntimeError - 当尝试为同一 op_name 注册包装函数多次时抛出。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import Tensor, ops, jit
>>> from mindspore.graph import register_op_wrapper
>>> # Replace Add operation with custom logic (x + 2*y):
>>> @register_op_wrapper("Add")
>>> def double_add_wrapper(x, y):
...     def func(x, y):
...         return x + 2 * y
...     return func
>>> 
>>> class MyNet(nn.Cell):
...     @jit
...     def construct(self, x, y):
...         return ops.Add()(x, y)  # Will execute x + 2*y