mindspore.ops.PrimitiveWithInfer
- class mindspore.ops.PrimitiveWithInfer(name)[source]
- PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python. - There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation. - Parameters
- name (str) – Name of the current Primitive. 
 - Supported Platforms:
- Ascend- GPU- CPU
 - Examples - >>> from mindspore.ops.primitive import prim_attr_register, PrimitiveWithCheck >>> # init a Primitive class with infer >>> class Add(PrimitiveWithInfer): ... @prim_attr_register ... def __init__(self): ... pass ... ... def infer_shape(self, x, y): ... return x # output shape same as first input 'x' ... ... def infer_dtype(self, x, y): ... return x # output type same as first input 'x' ... >>> # init a Primitive obj >>> add = Add()