mindspore.ops.PrimitiveWithCheck

class mindspore.ops.PrimitiveWithCheck(name)[source]

PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments but used the infer method registered in c++ source codes.

There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.

Parameters

name (str) – Name of the current Primitive.

Examples

>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
>>>     @prim_attr_register
>>>     def __init__(self):
>>>         pass
>>>     def check_shape(self, input_x):
>>>         validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
>>>
>>>     def check_dtype(self, input_x):
>>>         validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
>>>
>>> # init a Primitive obj
>>> add = Flatten()
check_dtype(*args)[source]

Check data types of input args.

Parameters

args (mindspore.dtype) – data type of inputs.

Returns

None.

check_shape(*args)[source]

Check shapes of input args.

Note

The shape of scalar is an empty tuple.

Parameters

args (tuple(int)) – shapes of input tensors.

Returns

None.