mindspore.ops.Primitive

class mindspore.ops.Primitive(name)[source]

Primitive is the base class of operator primitives in python.

Parameters

name (str) – Name for the current Primitive.

Examples

>>> from mindspore.ops.primitive import prim_attr_register, Primitive
>>> add = Primitive('add')
>>>
>>> # or work with prim_attr_register:
>>> # init a Primitive class with attr1 and attr2
>>> class Add(Primitive):
...     @prim_attr_register
...     def __init__(self, attr1, attr2):
...         '''init for add'''
...     # check attr1 and attr2 or do some initializations
...     # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
add_prim_attr(name, value)[source]

Add primitive attribute.

Parameters
  • name (str) – Attribute Name.

  • value (Any) – Attribute value.

Examples

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.add_prim_attr("attr",1)
>>> out = a.attrs["attr"]
>>> print(out)
1
check_elim(*args)[source]

Check if the primitive can be eliminated. Subclass in need should override this method.

Parameters

args (Primitive args) – Same as arguments of current Primitive.

Returns

A tuple consisting of two elements. The first element means if the primitive can be calculated in compiling stage, the second element is calculated result.

Examples

>>> from mindspore.ops.primitive import prim_attr_register, Primitive
>>> from mindspore import Tensor
>>> import numpy as np
>>> class AddN(Primitive):
...     @prim_attr_register
...     def __init__(self):
...         self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
...     def check_elim(self, inputs):
...         if len(inputs) != 1:
...             return (False, None)
...         if isinstance(inputs[0], Tensor):
...             return (True, inputs[0])
...
>>> addn = AddN()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> output = addn.check_elim((input_x,))
>>> print(output)
(True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00]))
del_prim_attr(name)[source]

Delete primitive attribute.

Parameters

name (str) – Attribute Name.

Examples

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.add_prim_attr("attr",1)
>>> a = a.del_prim_attr("attr")
>>> print(a.attrs)
{'input_names': ['x', 'y'], 'output_names' : ['output']}
init_prim_io_names(inputs, outputs)[source]

Initialize the name of inputs and outputs of Tensor or attributes.

Parameters
  • inputs (list[str]) – list of inputs names.

  • outputs (list[str]) – list of outputs names.

Examples

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a.init_prim_io_names(["x","y"],["sum"])
>>> print(a.input_names)
['x','y']
>>> print(a.output_names)
['sum']
recompute(mode=True)[source]

Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.

Note

  • If the computation involves something like randomization or global variable, the equivalence is not guaranteed currently.

  • Not supported in pynative mode

Parameters

mode (bool) – Specifies whether the primitive is recomputed. Default: True.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, ops, nn
>>> class NetRecompute(nn.Cell):
...     def __init__(self):
...         super(NetRecompute,self).__init__()
...         self.relu = ops.ReLU().recompute()
...         self.sqrt = ops.Sqrt()
...     def construct(self, x):
...         out = self.relu(x)
...         return self.sqrt(out)
...
>>> class GradNet(nn.Cell):
...     def __init__(self, network):
...         super(GradNet,self).__init__()
...         self.network = network
...         self.grad = ops.GradOperation()
...     def construct(self, x):
...         g_out = self.grad(self.network)(x)
...         return g_out
...
>>> x = Tensor(np.array([-1,1]).astype(np.float32))
>>> net = NetRecompute()
>>> grad = GradNet(net)
>>> a = grad(x)
>>> print(a)
[0. 0.5]
set_prim_instance_name(instance_name)[source]

Set instance name to primitive operator.

Note

It will be called by default when user defines primitive operator.

Parameters

instance_name (str) – Instance name of primitive operator set by user.

Examples

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.set_prim_instance_name("add")
>>> print(a.instance_name)
add
set_stage(stage)[source]

Add stage id to primitive attribute.

Note

It is valid only in semi auto parallel. In other parallel modes, please set it to be 0.

Parameters

stage (int) – The stage id for the current operation.

Examples

>>> from mindspore import ops
>>> add = ops.Add()
>>> print(add.set_stage(0))
Prim[Add]<stage=0>
shard(in_strategy=None, out_strategy=None)[source]

Add strategies to primitive attribute.

Note

It is valid only in semi auto parallel or auto parallel mode. In other parallel modes, strategies set here will be ignored.

Parameters
  • in_strategy (tuple) – Describe the split strategy of operator input. Default: None.

  • out_strategy (tuple) – Describe the split strategy of operator output, it is only for certain operators, such as MatMul. Default: None.

Examples

>>> from mindspore import ops
>>> add = ops.Add()
>>> print(add.shard(((1, 1), (1, 1))))
Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None>
property update_parameter

Return whether the primitive will update the value of parameter.