mindspore.ops.PrimitiveWithCheck

查看源文件
class mindspore.ops.PrimitiveWithCheck(name)[源代码]

PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。

可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。

如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。

了解更多如何自定义算子,请查看 自定义算子

参数:
  • name (str) - 当前Primitive的名称。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import dtype as mstype
>>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck
>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
...     @prim_attr_register
...     def __init__(self):
...         pass
...     def check_shape(self, input_x):
...         Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name)
...
...     def check_dtype(self, input_x):
...         Validator.check_subclass("input_x", input_x, mstype.tensor_type, self.name)
...
>>> # init a Primitive obj
>>> add = Flatten()
check_dtype(*args)[源代码]

检查输入参数的数据类型。

参数:
返回:

None。

check_shape(*args)[源代码]

检查输入参数的shape。

说明

Scalar的shape是一个空元组。

参数:
  • args (tuple(int)) - 输入tensor的shape。

返回:

None。