mindspore.ops.Morph
- class mindspore.ops.Morph(fn, infer_shape, infer_dtype, bprop_fn=None)[source]
The Morph Primitive is used to encapsulate a user-defined function fn, allowing it to be used as a custom Primitive.
The Morph Primitive is primarily designed for custom graph optimization in GRAPH mode. For example, it supports encapsulation of irregular collective communications (such as
mindspore.ops.AlltoAllV()
) in distributed auto-parallel training scenarios.When the Morph Primitive is applied to inputs, it is actually the encapsulated user-defined function fn that is applied to the inputs.
The main difference between the Morph Primitive and
mindspore.ops.Custom()
is that the former is expanded and replaced by the user-defined fn before automatic differentiation, so there is no need to implement a backward function.Note
This primitive is only supported in GRAPH_MODE.
A user-defined bprop (by argument: bprop_fn) is allowed for Morph.
fn and bprop_fn must satisfy the syntax constraints of the graph mode.
vararg, kwarg, kwonlyargs and free variables are not supported in user-defined function.
- Parameters
fn (Function) – MindSpore's function, user-defined function.
infer_shape (Function) – MindSpore's function, user-defined infer_shape function.
infer_dtype (Function) – MindSpore's function, user-defined infer_dtype function.
bprop_fn (Function, optional) – MindSpore's function, user-defined bprop function, default:
None
.
- Inputs:
The inputs of user-defined fn.
- Outputs:
The outputs of user-defined fn.
- Raises
RuntimeError – if not used in GRAPH_MODE.
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, nn, ops, Tensor, Parameter >>> >>> np_weight0 = np.array([1.0, 2.0, 3.0]) >>> np_weight1 = np.array([4.0, 5.0, 6.0]) >>> np_input_x = np.array([7.0, 8.0, 9.0]) >>> >>> def infer_dtype(args): ... return args >>> >>> def infer_shape(args): ... return args >>> >>> def mul_by(*args): ... def inner(x): ... return args[0] * x ... return inner >>> >>> NUMBER_100 = 100 >>> class MorphNet(nn.Cell): ... def __init__(self): ... super(MorphNet, self).__init__() ... self.weight0 = Parameter(Tensor(np_weight0, ms.float32), name="weight0") ... self.weight1 = Parameter(Tensor(np_weight1, ms.float32), name="weight1") ... self.mul_by_100 = ops.Morph(mul_by(NUMBER_100), infer_shape, infer_dtype) ... def construct(self, x): ... a = x * self.weight0 ... b = self.mul_by_100(a) ... out = b * self.weight1 ... return out >>> >>> context.set_context(mode=context.GRAPH_MODE) >>> input_x = Tensor(np_input_x, ms.float32) >>> net = MorphNet() >>> grad_op = ops.GradOperation(get_all=True, get_by_list=True) >>> grad_net = grad_op(net, net.trainable_params()) >>> bwd_out = grad_net(input_x) >>> x_grad = bwd_out[0][0].asnumpy() >>> weight0_grad = bwd_out[1][0].asnumpy() >>> weight1_grad = bwd_out[1][1].asnumpy() >>> print("x_grad", x_grad) >>> print("weight0_grad", weight0_grad) >>> print("weight1_grad", weight1_grad) x_grad [ 400. 1000. 1800.] weight0_grad [2800. 4000. 5400.] weight1_grad [ 700. 1600. 2700.]