mindspore.ops.PrimitiveWithInfer
- class mindspore.ops.PrimitiveWithInfer(name)[源代码]
- PrimitiveWithInfer是Python中的原语基类,在python中定义了跟踪推理的函数。 - 可以重写四个方法来定义Primitive的推断逻辑:__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__(),则__infer__()的优先级最高。 - 如果未定义__infer__(),则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。 - 关于如何自定义算子,请查看 自定义算子 。 - 参数:
- name (str) - 当前Primitive的名称。 
 
- 支持平台:
- Ascend- GPU- CPU
 - 样例: - >>> from mindspore.ops import prim_attr_register, PrimitiveWithInfer >>> # init a Primitive class with infer >>> class Add(PrimitiveWithInfer): ... @prim_attr_register ... def __init__(self): ... pass ... ... def infer_shape(self, x, y): ... return x # output shape same as first input 'x' ... ... def infer_dtype(self, x, y): ... return x # output type same as first input 'x' ... >>> # init a Primitive obj >>> add = Add()