mindspore.graph.register_op_infer
- mindspore.graph.register_op_infer(op_name, infer_name)[源代码]
Register a custom inference function for an operator.
This decorator allows users to register custom inference functions for operators in MindSpore. The registered inference function will be called during graph compilation to compute custom attributes for the operator's output. The result of the inference function will be stored in the output of the operator, making it accessible to subsequent operations in its function registered in register_op_wrapper or register_op_infer.
说明
Multiple inference functions can be registered for the same operator using different infer_name values.
The inference function will receive the same arguments as the operator it is registered for.
The inference result will be propagated through the computation graph and can be accessed by subsequent operations.
This registration affects all instances of the specified operator in the current process.
- Parameters
- Returns
callable. A decorated function that is registered as custom inference function
- Raises
ValueError – If op_name or infer_name is not a non-empty string.
RuntimeError – If an inference function is registered for the same infer_name of the same operator multiple times.
- Supported Platforms:
AscendGPUCPU
样例
>>> from mindspore import Tensor, ops, jit >>> from mindspore.graph import register_op_infer >>> # Register a custom layout inference for the Add operator, the custom layout can >>> # be obtained during inferencing the subsequent operator Mul: >>> class MyNet(nn.Cell): ... @jit ... def construct(self, x, y): ... z = ops.Add()(x, y) ... result = ops.Mul()(z, x) ... return result >>> >>> @register_op_infer("Add", 'layout') >>> def infer_add_layout(x, y): ... return x.shape + y.shape >>> >>> @register_op_infer("Mul", 'layout') >>> def infer_mul_layout(x, y): ... print(x.layout) >>> >>> net = MyNet() >>> input_x = Tensor([[1., 2., 3.], [4., 5., 6.]]) >>> input_y = Tensor([[1., 2., 3.], [4., 5., 6.]]) >>> net(input_x, input_y) (2, 3, 2, 3)