mindspore.ops.Morph

View Source On Gitee
class mindspore.ops.Morph(fn, infer_shape, infer_dtype, bprop_fn=None)[source]

The Morph Primitive is used to encapsulate a user-defined function fn, allowing it to be used as a custom Primitive.

The Morph Primitive is primarily designed for custom graph optimization in GRAPH mode. For example, it supports encapsulation of irregular collective communications (such as mindspore.ops.AlltoAllV()) in distributed auto-parallel training scenarios.

When the Morph Primitive is applied to inputs, it is actually the encapsulated user-defined function fn that is applied to the inputs.

The main difference between the Morph Primitive and mindspore.ops.Custom() is that the former is expanded and replaced by the user-defined fn before automatic differentiation, so there is no need to implement a backward function.

Note

  • This primitive is only supported in GRAPH_MODE.

  • A user-defined bprop (by argument: bprop_fn) is allowed for Morph.

  • fn and bprop_fn must satisfy the syntax constraints of the graph mode.

  • vararg, kwarg, kwonlyargs and free variables are not supported in user-defined function.

Parameters
  • fn (Function) – MindSpore's function, user-defined function.

  • infer_shape (Function) – MindSpore's function, user-defined infer_shape function.

  • infer_dtype (Function) – MindSpore's function, user-defined infer_dtype function.

  • bprop_fn (Function, optional) – MindSpore's function, user-defined bprop function, default: None.

Inputs:

The inputs of user-defined fn.

Outputs:

The outputs of user-defined fn.

Raises

RuntimeError – if not used in GRAPH_MODE.

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context, nn, ops, Tensor, Parameter
>>>
>>> np_weight0 = np.array([1.0, 2.0, 3.0])
>>> np_weight1 = np.array([4.0, 5.0, 6.0])
>>> np_input_x = np.array([7.0, 8.0, 9.0])
>>>
>>> def infer_dtype(args):
...     return args
>>>
>>> def infer_shape(args):
...     return args
>>>
>>> def mul_by(*args):
...     def inner(x):
...         return args[0] * x
...     return inner
>>>
>>> NUMBER_100 = 100
>>> class MorphNet(nn.Cell):
...     def __init__(self):
...         super(MorphNet, self).__init__()
...         self.weight0 = Parameter(Tensor(np_weight0, ms.float32), name="weight0")
...         self.weight1 = Parameter(Tensor(np_weight1, ms.float32), name="weight1")
...         self.mul_by_100 = ops.Morph(mul_by(NUMBER_100), infer_shape, infer_dtype)
...     def construct(self, x):
...         a = x * self.weight0
...         b = self.mul_by_100(a)
...         out = b * self.weight1
...         return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> input_x = Tensor(np_input_x, ms.float32)
>>> net = MorphNet()
>>> grad_op = ops.GradOperation(get_all=True, get_by_list=True)
>>> grad_net = grad_op(net, net.trainable_params())
>>> bwd_out = grad_net(input_x)
>>> x_grad = bwd_out[0][0].asnumpy()
>>> weight0_grad = bwd_out[1][0].asnumpy()
>>> weight1_grad = bwd_out[1][1].asnumpy()
>>> print("x_grad", x_grad)
>>> print("weight0_grad", weight0_grad)
>>> print("weight1_grad", weight1_grad)
x_grad [ 400. 1000. 1800.]
weight0_grad [2800. 4000. 5400.]
weight1_grad [ 700. 1600. 2700.]