mindspore.graph.register_op_wrapper

mindspore.graph.register_op_wrapper(op_name)[源代码]

Register a custom wrapper function to replace the node of an operator in computation graph.

This decorator allows users to register custom wrapper functions that replace the computation logic of operators in computation graph. When the JIT compiler compiles the specified operator, it will compile the registered wrapper function instead of the original operator.

The wrapper function receives the same arguments as the original operator and must return a callable function.

说明

  • Only one wrapper function can be registered per operator. Attempting to register multiple wrappers for the same operator will raise an exception.

  • The replacement affects all instances of the specified operator within JIT-compiled Cells in the current process.

  • The wrapper function arguments should match the original operator's signature.

Parameters

op_name (str) – The name of the operator to register the wrapper function for. Must be a non-empty string.

Returns

callable. A decorated function that is registered as op wrapper function

Raises
  • ValueError – If op_name is not a non-empty string.

  • RuntimeError – If attempting to register a wrapper for the same op_name multiple times.

Supported Platforms:

Ascend GPU CPU

样例

>>> from mindspore import Tensor, ops, jit
>>> from mindspore.graph import register_op_wrapper
>>> # Replace Add operation with custom logic (x + 2*y):
>>> @register_op_wrapper("Add")
>>> def double_add_wrapper(x, y):
...     def func(x, y):
...         return x + 2 * y
...     return func
>>>
>>> class MyNet(nn.Cell):
...     @jit
...     def construct(self, x, y):
...         return ops.Add()(x, y)  # Will execute x + 2*y