mindspore.nn.ParameterUpdate

View Source On Gitee
class mindspore.nn.ParameterUpdate(param)[source]

Cell that updates parameter.

With this Cell, one can manually update param with the input Tensor.

Parameters

param (Parameter) – The parameter to be updated manually.

Inputs:
  • x (Tensor) - A tensor whose shape and type are the same with param.

Outputs:

Tensor, the updated value.

Raises

KeyError – If parameter with the specified name does not exist.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor
>>> network = nn.Dense(3, 4)
>>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param"
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> output = update(weight)
>>> print(output)
[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]
 [ 9. 10. 11.]]