mindspore.ops

Operators can be used in the construct function of Cell.

Examples

>>> from mindspore.ops import operations as P
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import functional as F

Note

  • The Primitive operators in operations need to be used after instantiation.

  • The composite operators are the pre-defined combination of operators.

  • The functional operators are the pre-instantiated Primitive operators, which can be used directly as a function.

  • For functional operators usage, please refer to https://gitee.com/mindspore/mindspore/blob/master/mindspore/ops/functional.py

class mindspore.ops.ACos(*args, **kwargs)[source]

Computes arccosine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> acos = P.ACos()
>>> input_x = Tensor(np.array([0.74, 0.04, 0.30, 0.56]), mindspore.float32)
>>> output = acos(input_x)
class mindspore.ops.Abs(*args, **kwargs)[source]

Returns absolute value of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([-1.0, 1.0, 0.0]), mindspore.float32)
>>> abs = P.Abs()
>>> abs(input_x)
[1.0, 1.0, 0.0]
class mindspore.ops.AccumulateNV2(*args, **kwargs)[source]

Computes accumulation of all input tensors element-wise.

AccumulateNV2 is similar to AddN, but there is a significant difference among them: AccumulateNV2 will not wait for all of its inputs to be ready before summing. That is to say, AccumulateNV2 is able to save memory when inputs are ready at different time since the minimum temporary storage is proportional to the output size rather than the input size.

Inputs:
  • input_x (Union(tuple[Tensor], list[Tensor])) - The input tuple or list is made up of multiple tensors whose dtype is number to be added together.

Outputs:

Tensor, has the same shape and dtype as each entry of the input_x.

Examples

>>> class NetAccumulateNV2(nn.Cell):
>>>     def __init__(self):
>>>         super(NetAccumulateNV2, self).__init__()
>>>         self.accumulateNV2 = P.AccumulateNV2()
>>>
>>>     def construct(self, *z):
>>>         return self.accumulateNV2(z)
>>>
>>> net = NetAccumulateNV2()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32)
>>> net(input_x, input_y, input_x, input_y)
Tensor([10., 14., 18.], shape=(3,), dtype=mindspore.float32)
class mindspore.ops.Acosh(*args, **kwargs)[source]

Compute inverse hyperbolic cosine of the input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> acosh = P.Acosh()
>>> input_x = Tensor(np.array([1.0, 1.5, 3.0, 100.0]), mindspore.float32)
>>> output = acosh(input_x)
class mindspore.ops.Adam(*args, **kwargs)[source]

Updates gradients by Adaptive Moment Estimation (Adam) algorithm.

The Adam algorithm is proposed in Adam: A Method for Stochastic Optimization.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ w = w - l * \frac{m}{\sqrt{v} + \epsilon} \end{array}\end{split}\]

\(m\) represents the 1st moment vector, \(v\) represents the 2nd moment vector, \(g\) represents gradient, \(l\) represents scaling factor lr, \(\beta_1, \beta_2\) represent beta1 and beta2, \(t\) represents updating step while \(beta_1^t\) and \(beta_2^t\) represent beta1_power and beta2_power, \(\alpha\) represents learning_rate, \(w\) represents var, \(\epsilon\) represents epsilon.

Parameters
  • use_locking (bool) – Whether to enable a lock to protect variable tensors from being updated. If true, updates of the var, m, and v tensors will be protected by a lock. If false, the result is unpredictable. Default: False.

  • use_nesterov (bool) – Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. If true, update the gradients using NAG. If true, update the gradients without using NAG. Default: False.

Inputs:
  • var (Tensor) - Weights to be updated.

  • m (Tensor) - The 1st moment vector in the updating formula, has the same type as var.

  • v (Tensor) - the 2nd moment vector in the updating formula. Mean square gradients with the same type as var.

  • beta1_power (float) - \(beta_1^t\) in the updating formula.

  • beta2_power (float) - \(beta_2^t\) in the updating formula.

  • lr (float) - \(l\) in the updating formula.

  • beta1 (float) - The exponential decay rate for the 1st moment estimations.

  • beta2 (float) - The exponential decay rate for the 2nd moment estimations.

  • epsilon (float) - Term added to the denominator to improve numerical stability.

  • gradient (Tensor) - Gradient, has the same type as var.

Outputs:

Tuple of 3 Tensor, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • m (Tensor) - The same shape and data type as m.

  • v (Tensor) - The same shape and data type as v.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_adam = P.Adam()
>>>         self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
>>>         self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
>>>     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
>>>         out = self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
>>>                               epsilon, grad)
>>>         return out
>>> net = Net()
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient)
class mindspore.ops.AddN(*args, **kwargs)[source]

Computes addition of all input tensors element-wise.

All input tensors must have the same shape.

Inputs:
  • input_x (Union(tuple[Tensor], list[Tensor])) - The input tuple or list is made up of multiple tensors whose dtype is number or bool to be added together.

Outputs:

Tensor, has the same shape and dtype as each entry of the input_x.

Examples

>>> class NetAddN(nn.Cell):
>>>     def __init__(self):
>>>         super(NetAddN, self).__init__()
>>>         self.addN = P.AddN()
>>>
>>>     def construct(self, *z):
>>>         return self.addN(z)
>>>
>>> net = NetAddN()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32)
>>> net(input_x, input_y, input_x, input_y)
[10.0, 14.0, 18.0]
class mindspore.ops.AiCPURegOp(op_name)[source]

Class for AiCPU op info register

attr(name=None, value_type=None, value=None, **kwargs)[source]

Register AiCPU op attribute information.

Parameters
  • name (str) – Name of the attribute. Default: None.

  • value_type (str) – Value type of the attribute. Default: None.

  • value (str) – Value of the attribute. Default: None.

  • kwargs (dict) – Other information of the attribute.

input(index=None, name=None, param_type=None, **kwargs)[source]

Register AiCPU op input information.

Parameters
  • index (int) – Order of the input. Default: None.

  • name (str) – Name of the input. Default: None.

  • param_type (str) – Param type of the input. Default: None.

  • kwargs (dict) – Other information of the input.

output(index=None, name=None, param_type=None, **kwargs)[source]

Register AiCPU op output information.

Parameters
  • index (int) – Order of the output. Default: None.

  • name (str) – Name of the output. Default: None.

  • param_type (str) – Param type of the output. Default: None.

  • kwargs (dict) – Other information of the output.

class mindspore.ops.AllGather(*args, **kwargs)[source]

Gathers tensors from the specified communication group.

Note

The tensors must have the same shape and format in all processes of the collection.

Parameters

group (str) – The communication group to work on. Default: “hccl_world_group”.

Raises
  • TypeError – If group is not a string.

  • ValueError – If the local rank id of the calling process in the group is larger than the group’s rank size.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor. If the number of devices in the group is N, then the shape of output is \((N, x_1, x_2, ..., x_R)\).

Examples

>>> import mindspore.ops.operations as P
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>>
>>> init()
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.allgather = P.AllGather(group="nccl_world_group")
>>>
>>>     def construct(self, x):
>>>         return self.allgather(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
class mindspore.ops.AllReduce(*args, **kwargs)[source]

Reduces the tensor data across all devices in such a way that all devices will get the same final result.

Note

The operation of AllReduce does not support “prod” currently. The tensors must have the same shape and format in all processes of the collection.

Parameters
  • op (str) – Specifies an operation used for element-wise reductions, like sum, max, and min. Default: ReduceOp.SUM.

  • group (str) – The communication group to work on. Default: “hccl_world_group”.

Raises
  • TypeError – If any of operation and group is not a string, or fusion is not an integer, or the input’s dtype is bool.

  • ValueError – If the operation is “prod”.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape of the input, i.e., \((x_1, x_2, ..., x_R)\). The contents depend on the specified operation.

Examples

>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> from mindspore.ops.operations.comm_ops import ReduceOp
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init()
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.allreduce_sum = P.AllReduce(ReduceOp.SUM, group="nccl_world_group")
>>>
>>>     def construct(self, x):
>>>         return self.allreduce_sum(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
class mindspore.ops.ApplyAdaMax(*args, **kwargs)[source]

Updates relevant entries according to the adamax scheme.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\ v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\ var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon} \end{array}\end{split}\]

\(t\) represents updating step while \(m\) represents the 1st moment vector, \(m_{t-1}\) is the last momentent of \(m_{t}\), \(v\) represents the 2nd moment vector, \(v_{t-1}\) is the last momentent of \(v_{t}\), \(l\) represents scaling factor lr, \(g\) represents grad, \(\beta_1, \beta_2\) represent beta1 and beta2, \(beta_1^t\) represents beta1_power, \(var\) represents the variable to be updated, \(\epsilon\) represents epsilon.

Inputs of var, m, v and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Variable to be updated. With float32 or float16 data type.

  • m (Parameter) - The 1st moment vector in the updating formula, has the same shape and type as var. With float32 or float16 data type.

  • v (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients with the same shape and type as var. With float32 or float16 data type.

  • beta1_power (Union[Number, Tensor]) - \(beta_1^t\) in the updating formula, must be scalar. With float32 or float16 data type.

  • lr (Union[Number, Tensor]) - Learning rate, \(l\) in the updating formula, must be scalar. With float32 or float16 data type.

  • beta1 (Union[Number, Tensor]) - The exponential decay rate for the 1st moment estimations, must be scalar. With float32 or float16 data type.

  • beta2 (Union[Number, Tensor]) - The exponential decay rate for the 2nd moment estimations, must be scalar. With float32 or float16 data type.

  • epsilon (Union[Number, Tensor]) - A small value added for numerical stability, must be scalar. With float32 or float16 data type.

  • grad (Tensor) - A tensor for gradient, has the same shape and type as var. With float32 or float16 data type.

Outputs:

Tuple of 3 Tensor, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • m (Tensor) - The same shape and data type as m.

  • v (Tensor) - The same shape and data type as v.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_ada_max = P.ApplyAdaMax()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
>>>         self.v = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="v")
>>>     def construct(self, beta1_power, lr, beta1, beta2, epsilon, grad):
>>>         out = self.apply_ada_max(self.var, self.m, self.v, beta1_power, lr, beta1, beta2, epsilon, grad)
>>>         return out
>>> net = Net()
>>> beta1_power =Tensor(0.9, mstype.float32)
>>> lr = Tensor(0.001, mstype.float32)
>>> beta1 = Tensor(0.9, mstype.float32)
>>> beta2 = Tensor(0.99, mstype.float32)
>>> epsilon = Tensor(1e-10, mstype.float32)
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> result = net(beta1_power, lr, beta1, beta2, epsilon, grad)
class mindspore.ops.ApplyAdadelta(*args, **kwargs)[source]

Updates relevant entries according to the adadelta scheme.

\[accum = \rho * accum + (1 - \rho) * grad^2\]
\[\text{update} = \sqrt{\text{accum_update} + \epsilon} * \frac{grad}{\sqrt{accum + \epsilon}}\]
\[\text{accum_update} = \rho * \text{accum_update} + (1 - \rho) * update^2\]
\[var -= lr * update\]

Inputs of var, accum, accum_update and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Weights to be updated. With float32 or float16 data type.

  • accum (Parameter) - Accumulation to be updated, has the same shape and type as var. With float32 or float16 data type.

  • accum_update (Parameter) - Accum_update to be updated, has the same shape and type as var. With float32 or float16 data type.

  • lr (Union[Number, Tensor]) - Learning rate, must be scalar. With float32 or float16 data type.

  • rho (Union[Number, Tensor]) - Decay rate, must be scalar. With float32 or float16 data type.

  • epsilon (Union[Number, Tensor]) - A small value added for numerical stability, must be scalar. With float32 or float16 data type.

  • grad (Tensor) - Gradients, has the same shape and type as var. With float32 or float16 data type.

Outputs:

Tuple of 3 Tensor, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

  • accum_update (Tensor) - The same shape and data type as accum_update.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_adadelta = P.ApplyAdadelta()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.accum_update = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum_update")
>>>     def construct(self, lr, rho, epsilon, grad):
>>>         out = self.apply_adadelta(self.var, self.accum, self.accum_update, lr, rho, epsilon, grad)
>>>         return out
>>> net = Net()
>>> lr = Tensor(0.001, mstype.float32)
>>> rho = Tensor(0.0, mstype.float32)
>>> epsilon = Tensor(1e-6, mstype.float32)
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> result = net(lr, rho, epsilon, grad)
class mindspore.ops.ApplyAdagrad(*args, **kwargs)[source]

Updates relevant entries according to the adagrad scheme.

\[accum += grad * grad\]
\[var -= lr * grad * \frac{1}{\sqrt{accum}}\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent.. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

update_slots (bool) – If True, accum will be updated. Default: True.

Inputs:
  • var (Parameter) - Variable to be updated. With float32 or float16 data type.

  • accum (Parameter) - Accumulation to be updated. The shape and dtype must be the same as var. With float32 or float16 data type.

  • lr (Union[Number, Tensor]) - The learning rate value, must be scalar. With float32 or float16 data type.

  • grad (Tensor) - A tensor for gradient. The shape and dtype must be the same as var. With float32 or float16 data type.

Outputs:

Tuple of 2 Tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_adagrad = P.ApplyAdagrad()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>     def construct(self, lr, grad):
>>>         out = self.apply_adagrad(self.var, self.accum, lr, grad)
>>>         return out
>>> net = Net()
>>> lr = Tensor(0.001, mstype.float32)
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> result = net(lr, grad)
class mindspore.ops.ApplyAdagradV2(*args, **kwargs)[source]

Updates relevant entries according to the adagradv2 scheme.

\[accum += grad * grad\]
\[var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • epsilon (float) – A small value added for numerical stability.

  • update_slots (bool) – If True, accum will be updated. Default: True.

Inputs:
  • var (Parameter) - Variable to be updated. With float16 or float32 data type.

  • accum (Parameter) - Accumulation to be updated. The shape and dtype must be the same as var. With float16 or float32 data type.

  • lr (Union[Number, Tensor]) - The learning rate value, must be a float number or a scalar tensor with float16 or float32 data type.

  • grad (Tensor) - A tensor for gradient. The shape and dtype must be the same as var. With float16 or float32 data type.

Outputs:

Tuple of 2 Tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as m.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=1e-6)
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>     def construct(self, lr, grad):
>>>         out = self.apply_adagrad_v2(self.var, self.accum, lr, grad)
>>>         return out
>>> net = Net()
>>> lr = Tensor(0.001, mstype.float32)
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> result = net(lr, grad)
class mindspore.ops.ApplyAddSign(*args, **kwargs)[source]

Updates relevant entries according to the AddSign algorithm.

\[\begin{split}\begin{array}{ll} \\ m_{t} = \beta * m_{t-1} + (1 - \beta) * g \\ \text{update} = (\alpha + \text{sign_decay} * sign(g) * sign(m)) * g \\ var = var - lr_{t} * \text{update} \end{array}\end{split}\]

\(t\) represents updating step while \(m\) represents the 1st moment vector, \(m_{t-1}\) is the last momentent of \(m_{t}\), \(lr\) represents scaling factor lr, \(g\) represents grad.

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Variable tensor to be updated. With float32 or float16 data type.

  • m (Parameter) - Variable tensor to be updated, has the same dtype as var.

  • lr (Union[Number, Tensor]) - The learning rate value, must be a scalar. With float32 or float16 data type.

  • alpha (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.

  • sign_decay (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.

  • beta (Union[Number, Tensor]) - The exponential decay rate, must be a scalar. With float32 or float16 data type.

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

Outputs:

Tuple of 2 Tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • m (Tensor) - The same shape and data type as m.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_add_sign = P.ApplyAddSign()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
>>>         self.lr = 0.001
>>>         self.alpha = 1.0
>>>         self.sign_decay = 0.99
>>>         self.beta = 0.9
>>>     def construct(self, grad):
>>>         out = self.apply_add_sign(self.var, self.m, self.lr, self.alpha, self.sign_decay, self.beta, grad)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
class mindspore.ops.ApplyCenteredRMSProp(*args, **kwargs)[source]

Optimizer that implements the centered RMSProp algorithm. Please refer to the usage in source code of nn.RMSProp.

Note

Update var according to the centered RMSProp algorithm.

\[g_{t} = \rho g_{t-1} + (1 - \rho)\nabla Q_{i}(w)\]
\[s_{t} = \rho s_{t-1} + (1 - \rho)(\nabla Q_{i}(w))^2\]
\[m_{t} = \beta m_{t-1} + \frac{\eta} {\sqrt{s_{t} - g_{t}^2 + \epsilon}} \nabla Q_{i}(w)\]
\[w = w - m_{t}\]

where \(w\) represents var, which will be updated. \(g_{t}\) represents mean_gradient, \(g_{t-1}\) is the last momentent of \(g_{t}\). \(s_{t}\) represents mean_square, \(s_{t-1}\) is the last momentent of \(s_{t}\), \(m_{t}\) represents moment, \(m_{t-1}\) is the last momentent of \(m_{t}\). \(\rho\) represents decay. \(\beta\) is the momentum term, represents momentum. \(\epsilon\) is a smoothing term to avoid division by zero, represents epsilon. \(\eta\) represents learning_rate. \(\nabla Q_{i}(w)\) represents grad.

Parameters

use_locking (bool) – Whether to enable a lock to protect the variable and accumlation tensors from being updated. Default: False.

Inputs:
  • var (Tensor) - Weights to be update.

  • mean_gradient (Tensor) - Mean gradients, must have the same type as var.

  • mean_square (Tensor) - Mean square gradients, must have the same type as var.

  • moment (Tensor) - Delta of var, must have the same type as var.

  • grad (Tensor) - Gradient, must have the same type as var.

  • learning_rate (Union[Number, Tensor]) - Learning rate. Must be a float number or a scalar tensor with float16 or float32 data type.

  • decay (float) - Decay rate.

  • momentum (float) - Momentum.

  • epsilon (float) - Ridge term.

Outputs:

Tensor, parameters to be update.

Examples

>>> centered_rms_prop = P.ApplyCenteredRMSProp()
>>> input_x = Tensor(np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> mean_grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> mean_square = Tensor(np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> moment = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> learning_rate = Tensor(0.9, mindspore.float32)
>>> decay = 0.0
>>> momentum = 1e-10
>>> epsilon = 0.05
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
>>>                            learning_rate, decay, momentum, epsilon)
[[[ -6.        -9.024922]
  [-12.049845 -15.074766]
  [-18.09969  -21.124613]]
 [[-24.149532 -27.174456]
  [-30.199379 -33.2243  ]
  [-36.249226 -39.274143]]]
class mindspore.ops.ApplyFtrl(*args, **kwargs)[source]

Updates relevant entries according to the FTRL scheme.

Parameters

use_locking (bool) – Use locks for updating operation if true . Default: False.

Inputs:
  • var (Parameter) - The variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - The accumulation to be updated, must be same type and shape as var.

  • linear (Parameter) - the linear coefficient to be updated, must be same type and shape as var.

  • grad (Tensor) - Gradient. The data type must be float16 or float32.

  • lr (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001. It must be a float number or a scalar tensor with float16 or float32 data type.

  • l1 (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero. Default: 0.0. It must be a float number or a scalar tensor with float16 or float32 data type.

  • l2 (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero. Default: 0.0. It must be a float number or a scalar tensor with float16 or float32 data type.

  • lr_power (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.

Outputs:

Tensor, represents the updated var.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class ApplyFtrlNet(nn.Cell):
>>>     def __init__(self):
>>>         super(ApplyFtrlNet, self).__init__()
>>>         self.apply_ftrl = P.ApplyFtrl()
>>>         self.lr = 0.001
>>>         self.l1 = 0.0
>>>         self.l2 = 0.0
>>>         self.lr_power = -0.5
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
>>>
>>>     def construct(self, grad):
>>>         out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2,
>>>                               self.lr_power)
>>>         return out
>>>
>>> net = ApplyFtrlNet()
>>> input_x = Tensor(np.random.randint(-4, 4, (3, 3)), mindspore.float32)
>>> result = net(input_x)
[[0.67455846   0.14630564   0.160499  ]
 [0.16329421   0.00415689   0.05202988]
 [0.18672481   0.17418946   0.36420345]]
class mindspore.ops.ApplyGradientDescent(*args, **kwargs)[source]

Updates relevant entries according to the following formula.

\[var = var - \alpha * \delta\]

Inputs of var and delta comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Variable tensor to be updated. With float32 or float16 data type.

  • alpha (Union[Number, Tensor]) - Scaling factor, must be a scalar. With float32 or float16 data type.

  • delta (Tensor) - A tensor for the change, has the same type as var.

Outputs:

Tensor, represents the updated var.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_gradient_descent = P.ApplyGradientDescent()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.alpha = 0.001
>>>     def construct(self, delta):
>>>         out = self.apply_gradient_descent(self.var, self.alpha, delta)
>>>         return out
>>> net = Net()
>>> delta = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(delta)
class mindspore.ops.ApplyMomentum(*args, **kwargs)[source]

Optimizer that implements the Momentum algorithm.

Refer to the paper On the importance of initialization and momentum in deep learning for more details.

Inputs of variable, accumulation and gradient comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. Data type conversion of Parameter is not supported. RuntimeError exception will be thrown.

Parameters
  • use_locking (bool) – Whether to enable a lock to protect the variable and accumlation tensors from being updated. Default: False.

  • use_nesterov (bool) – Enable Nesterov momentum. Default: False.

  • gradient_scale (float) – The scale of the gradient. Default: 1.0.

Inputs:
  • variable (Parameter) - Weights to be updated. data type must be float.

  • accumulation (Parameter) - Accumulated gradient value by moment weight. Has the same data type with variable.

  • learning_rate (Union[Number, Tensor]) - The learning rate value, must be a float number or a scalar tensor with float data type.

  • gradient (Tensor) - Gradient, has the same data type as variable.

  • momentum (Union[Number, Tensor]) - Momentum, must be a float number or a scalar tensor with float data type.

Outputs:

Tensor, parameters to be updated.

Examples

Please refer to the usage in nn.ApplyMomentum.

class mindspore.ops.ApplyPowerSign(*args, **kwargs)[source]

Updates relevant entries according to the AddSign algorithm.

\[\begin{split}\begin{array}{ll} \\ m_{t} = \beta * m_{t-1} + (1 - \beta) * g \\ \text{update} = \exp(\text{logbase} * \text{sign_decay} * sign(g) * sign(m)) * g \\ var = var - lr_{t} * \text{update} \end{array}\end{split}\]

\(t\) represents updating step while \(m\) represents the 1st moment vector, \(m_{t-1}\) is the last momentent of \(m_{t}\), \(lr\) represents scaling factor lr, \(g\) represents grad.

All of inputs comply with the implicit type conversion rules to make the data types consistent. If lr, logbase, sign_decay or beta is a number, the number is automatically converted to Tensor, and the data type is consistent with the Tensor data type involved in the operation. If inputs are tensors and have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Variable tensor to be updated. With float32 or float16 data type. If data type of var is float16, all inputs must have the same data type as var.

  • m (Parameter) - Variable tensor to be updated, has the same dtype as var.

  • lr (Union[Number, Tensor]) - The learning rate value, must be a scalar. With float32 or float16 data type.

  • logbase (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.

  • sign_decay (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.

  • beta (Union[Number, Tensor]) - The exponential decay rate, must be a scalar. With float32 or float16 data type.

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

Outputs:

Tuple of 2 Tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • m (Tensor) - The same shape and data type as m.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_power_sign = P.ApplyPowerSign()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
>>>         self.lr = 0.001
>>>         self.logbase = np.e
>>>         self.sign_decay = 0.99
>>>         self.beta = 0.9
>>>     def construct(self, grad):
>>>         out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase,
                                        self.sign_decay, self.beta, grad)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
class mindspore.ops.ApplyProximalAdagrad(*args, **kwargs)[source]

Updates relevant entries according to the proximal adagrad algorithm.

\[accum += grad * grad\]
\[\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}\]
\[var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – If true, the var and accumulation tensors will be protected from being updated. Default: False.

Inputs:
  • var (Parameter) - Variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - Accumulation to be updated. Must has the same shape and dtype as var.

  • lr (Union[Number, Tensor]) - The learning rate value, must be scalar. The data type must be float16 or float32.

  • l1 (Union[Number, Tensor]) - l1 regularization strength, must be scalar. The data type must be float16 or float32.

  • l2 (Union[Number, Tensor]) - l2 regularization strength, must be scalar. The data type must be float16 or float32.

  • grad (Tensor) - Gradient with the same shape and dtype as var.

Outputs:

Tuple of 2 Tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_proximal_adagrad = P.ApplyProximalAdagrad()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.lr = 0.01
>>>         self.l1 = 0.0
>>>         self.l2 = 0.0
>>>     def construct(self, grad):
>>>         out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
class mindspore.ops.ApplyProximalGradientDescent(*args, **kwargs)[source]

Updates relevant entries according to the FOBOS(Forward Backward Splitting) algorithm.

\[\text{prox_v} = var - \alpha * \delta\]
\[var = \frac{sign(\text{prox_v})}{1 + \alpha * l2} * \max(\left| \text{prox_v} \right| - alpha * l1, 0)\]

Inputs of var and delta comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • var (Parameter) - Variable tensor to be updated. With float32 or float16 data type.

  • alpha (Union[Number, Tensor]) - Saling factor, must be a scalar. With float32 or float16 data type.

  • l1 (Union[Number, Tensor]) - l1 regularization strength, must be scalar. With float32 or float16 data type.

  • l2 (Union[Number, Tensor]) - l2 regularization strength, must be scalar. With float32 or float16 data type.

  • delta (Tensor) - A tensor for the change, has the same type as var.

Outputs:

Tensor, represents the updated var.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.apply_proximal_gradient_descent = P.ApplyProximalGradientDescent()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.alpha = 0.001
>>>         self.l1 = 0.0
>>>         self.l2 = 0.0
>>>     def construct(self, delta):
>>>         out = self.apply_proximal_gradient_descent(self.var, self.alpha, self.l1, self.l2, delta)
>>>         return out
>>> net = Net()
>>> delta = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(delta)
class mindspore.ops.ApplyRMSProp(*args, **kwargs)[source]

Optimizer that implements the Root Mean Square prop(RMSProp) algorithm. Please refer to the usage in source code of nn.RMSProp.

Note

Update var according to the RMSProp algorithm.

\[s_{t} = \rho s_{t-1} + (1 - \rho)(\nabla Q_{i}(w))^2\]
\[m_{t} = \beta m_{t-1} + \frac{\eta} {\sqrt{s_{t} + \epsilon}} \nabla Q_{i}(w)\]
\[w = w - m_{t}\]

where \(w\) represents var, which will be updated. \(s_{t}\) represents mean_square, \(s_{t-1}\) is the last momentent of \(s_{t}\), \(m_{t}\) represents moment, \(m_{t-1}\) is the last momentent of \(m_{t}\). \(\rho\) represents decay. \(\beta\) is the momentum term, represents momentum. \(\epsilon\) is a smoothing term to avoid division by zero, represents epsilon. \(\eta\) represents learning_rate. \(\nabla Q_{i}(w)\) represents grad.

Parameters

use_locking (bool) – Whether to enable a lock to protect the variable and accumlation tensors from being updated. Default: False.

Inputs:
  • var (Tensor) - Weights to be update.

  • mean_square (Tensor) - Mean square gradients, must have the same type as var.

  • moment (Tensor) - Delta of var, must have the same type as var.

  • learning_rate (Union[Number, Tensor]) - Learning rate. Must be a float number or a scalar tensor with float16 or float32 data type.

  • grad (Tensor) - Gradient, must have the same type as var.

  • decay (float) - Decay rate. Only constant value is allowed.

  • momentum (float) - Momentum. Only constant value is allowed.

  • epsilon (float) - Ridge term. Only constant value is allowed.

Outputs:

Tensor, parameters to be update.

Examples

>>> apply_rms = P.ApplyRMSProp()
>>> input_x = Tensor(1., mindspore.float32)
>>> mean_square = Tensor(2., mindspore.float32)
>>> moment = Tensor(1., mindspore.float32)
>>> grad = Tensor(2., mindspore.float32 )
>>> learning_rate = Tensor(0.9, mindspore.float32)
>>> decay = 0.0
>>> momentum = 1e-10
>>> epsilon = 0.001
>>> result = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon)
(-2.9977674, 0.80999994, 1.9987665)
class mindspore.ops.ApproximateEqual(*args, **kwargs)[source]

Returns true if abs(x1-x2) is smaller than tolerance element-wise, otherwise false.

Inputs of x1 and x2 comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

tolerance (float) – The maximum deviation that two elements can be considered equal. Default: 1e-05.

Inputs:
  • x1 (Tensor) - A tensor. Must be one of the following types: float32, float16.

  • x2 (Tensor) - A tensor of the same type and shape as ‘x1’.

Outputs:

Tensor, the shape is the same as the shape of ‘x1’, and the data type is bool.

Examples

>>> x1 = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> x2 = Tensor(np.array([2, 4, 6]), mindspore.float32)
>>> approximate_equal = P.ApproximateEqual(2.)
>>> result = approximate_equal(x1, x2)
[True  True  False]
class mindspore.ops.ArgMaxWithValue(*args, **kwargs)[source]

Calculates the maximum value with the corresponding index.

Calculates the maximum value along with the given axis for the input tensor. It returns the maximum values and indices.

Note

In auto_parallel and semi_auto_parallel mode, the first output index can not be used.

Parameters
  • axis (int) – The dimension to reduce. Default: 0.

  • keep_dims (bool) – Whether to reduce dimension, if true, the output will keep same dimension with the input, the output will reduce dimension if false. Default: False.

Inputs:
  • input_x (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as \((x_1, x_2, ..., x_N)\).

Outputs:

tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the maximum value of the input tensor. - index (Tensor) - The index for the maximum value of the input tensor. If keep_dims is true, the shape of output tensors is \((x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)\). Otherwise, the shape is \((x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)\). - output_x (Tensor) - The maximum value of input tensor, with the same shape as index.

Examples

>>> input_x = Tensor(np.random.rand(5), mindspore.float32)
>>> index, output = P.ArgMaxWithValue()(input_x)
class mindspore.ops.ArgMinWithValue(*args, **kwargs)[source]

Calculates the minimum value with corresponding index, return indices and values.

Calculates the minimum value along with the given axis for the input tensor. It returns the minimum values and indices.

Note

In auto_parallel and semi_auto_parallel mode, the first output index can not be used.

Parameters
  • axis (int) – The dimension to reduce. Default: 0.

  • keep_dims (bool) – Whether to reduce dimension, if true the output will keep the same dimension as the input, the output will reduce dimension if false. Default: False.

Inputs:
  • input_x (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as \((x_1, x_2, ..., x_N)\).

Outputs:

tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the minimum value of the input tensor. - index (Tensor) - The index for the maximum value of the input tensor. If keep_dims is true, the shape of output tensors is \((x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)\). Otherwise, the shape is \((x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)\). - output_x (Tensor) - The minimum value of input tensor, with the same shape as index.

Examples

>>> input_x = Tensor(np.random.rand(5))
>>> index, output = P.ArgMinWithValue()(input_x)
class mindspore.ops.Argmax(*args, **kwargs)[source]

Returns the indices of the max value of a tensor across the axis.

If the shape of input tensor is \((x_1, ..., x_N)\), the shape of the output tensor will be \((x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)\).

Parameters
  • axis (int) – Axis where the Argmax operation applies to. Default: -1.

  • output_type (mindspore.dtype) – An optional data type of mindspore.dtype.int32. Default: mindspore.dtype.int32.

Inputs:
  • input_x (Tensor) - Input tensor.

Outputs:

Tensor, indices of the max value of input tensor across the axis.

Examples

>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
>>> index = P.Argmax(output_type=mindspore.int32)(input_x)
1
class mindspore.ops.Argmin(*args, **kwargs)[source]

Returns the indices of the min value of a tensor across the axis.

If the shape of input tensor is \((x_1, ..., x_N)\), the shape of the output tensor is \((x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)\).

Parameters
  • axis (int) – Axis where the Argmin operation applies to. Default: -1.

  • output_type (mindspore.dtype) – An optional data type of mindspore.dtype.int32. Default: mindspore.dtype.int32.

Inputs:
  • input_x (Tensor) - Input tensor.

Outputs:

Tensor, indices of the min value of input tensor across the axis.

Examples

>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
>>> index = P.Argmin()(input_x)
>>> assert index == Tensor(2, mindspore.int64)
class mindspore.ops.Asin(*args, **kwargs)[source]

Computes arcsine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> asin = P.Asin()
>>> input_x = Tensor(np.array([0.74, 0.04, 0.30, 0.56]), mindspore.float32)
>>> output = asin(input_x)
[0.8331, 0.0400, 0.3047, 0.5944]
class mindspore.ops.Asinh(*args, **kwargs)[source]

Compute inverse hyperbolic sine of the input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> asinh = P.Asinh()
>>> input_x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), mindspore.float32)
>>> output = asinh(input_x)
[-2.3212, 1.1976, 1.8184, 5.2983]
class mindspore.ops.Assert(*args, **kwargs)[source]

Asserts that the given condition is true. If input condition evaluates to false, print the list of tensor in data.

Parameters

summarize (int) – Print this many entries of each tensor.

Inputs:
  • condition [Union[Tensor[bool], bool]] - The condition to evaluate.

  • input_data (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.

Examples

>>> class AssertDemo(nn.Cell):
>>>     def __init__(self):
>>>         super(AssertDemo, self).__init__()
>>>         self.assert = P.Assert(summarize=10)
>>>         self.add = P.TensorAdd()
>>>
>>>     def construct(self, x, y):
>>>         data = self.add(x, y)
>>>         self.assert(True, [data])
>>>         return data
class mindspore.ops.Assign(*args, **kwargs)[source]

Assigns Parameter with a value.

Inputs of variable and value comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • variable (Parameter) - The Parameter.

  • value (Tensor) - The value to be assigned.

Outputs:

Tensor, has the same type as original variable.

Examples

>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.y = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="y")
>>>
>>>     def construct(self, x):
>>>         P.Assign()(self.y, x)
>>>         return x
>>> x = Tensor([2.0], mindspore.float32)
>>> net = Net()
>>> net(x)
class mindspore.ops.AssignAdd(*args, **kwargs)[source]

Updates a Parameter by adding a value to it.

Inputs of variable and value comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. If value is a number, the number is automatically converted to Tensor, and the data type is consistent with the Tensor data type involved in the operation. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • variable (Parameter) - The Parameter.

  • value (Union[numbers.Number, Tensor]) - The value to be added to the variable. It must have the same shape as variable if it is a Tensor.

Examples

>>> class Net(Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.AssignAdd = P.AssignAdd()
>>>         self.variable = mindspore.Parameter(initializer(1, [1], mindspore.int64), name="global_step")
>>>
>>>     def construct(self, x):
>>>         self.AssignAdd(self.variable, x)
>>>         return self.variable
>>>
>>> net = Net()
>>> value = Tensor(np.ones([1]).astype(np.int64)*100)
>>> net(value)
class mindspore.ops.AssignSub(*args, **kwargs)[source]

Updates a Parameter by subtracting a value from it.

Inputs of variable and value comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. If value is a number, the number is automatically converted to Tensor, and the data type is consistent with the Tensor data type involved in the operation. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • variable (Parameter) - The Parameter.

  • value (Union[numbers.Number, Tensor]) - The value to be subtracted from the variable. It must have the same shape as variable if it is a Tensor.

Examples

>>> class Net(Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.AssignSub = P.AssignSub()
>>>         self.variable = mindspore.Parameter(initializer(1, [1], mindspore.int32), name="global_step")
>>>
>>>     def construct(self, x):
>>>         self.AssignSub(self.variable, x)
>>>         return self.variable
>>>
>>> net = Net()
>>> value = Tensor(np.ones([1]).astype(np.int32)*100)
>>> net(value)
class mindspore.ops.Atan(*args, **kwargs)[source]

Computes the trigonometric inverse tangent of the input element-wise.

Inputs:
  • input_x (Tensor): The input tensor.

Outputs:

A Tensor, has the same type as the input.

Examples

>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
>>> tan = P.Tan()
>>> output_y = tan(input_x)
>>> atan = P.Atan()
>>> atan(output_y)
[[1.047, 07850001]]
class mindspore.ops.Atan2(*args, **kwargs)[source]

Returns arctangent of input_x/input_y element-wise.

It returns \(\theta\ \in\ [-\pi, \pi]\) such that \(x = r*\sin(\theta), y = r*\cos(\theta)\), where \(r = \sqrt{x^2 + y^2}\).

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • input_x (Tensor) - The input tensor.

  • input_y (Tensor) - The input tensor.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is same as input_x.

Examples

>>> input_x = Tensor(np.array([[0, 1]]), mindspore.float32)
>>> input_y = Tensor(np.array([[1, 1]]), mindspore.float32)
>>> atan2 = P.Atan2()
>>> atan2(input_x, input_y)
[[0. 0.7853982]]
class mindspore.ops.Atanh(*args, **kwargs)[source]

Computes inverse hyperbolic tangent of the input element-wise.

Inputs:
  • input_x (Tensor): The input tensor.

Outputs:

A Tensor, has the same type as the input.

Examples

>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
>>> atanh = P.Atanh()
>>> atanh(input_x)
[[1.8869909 1.058268]]
class mindspore.ops.AvgPool(*args, **kwargs)[source]

Average pooling operation.

Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes. Typically the input is of shape \((N_{in}, C_{in}, H_{in}, W_{in})\), AvgPool2d outputs regional average in the \((H_{in}, W_{in})\)-dimension. Given kernel size \(ks = (h_{ker}, w_{ker})\) and stride \(s = (s_0, s_1)\), the operation is as follows.

\[\text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1} \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)\]
Parameters
  • ksize (Union[int, tuple[int]]) – The size of kernel used to take the average value, is an int number that represents height and width are both ksize, or a tuple of two int numbers that represent height and width respectively. Default: 1.

  • strides (Union[int, tuple[int]]) – The distance of kernel moving, an int number that represents the height and width of movement are both strides, or a tuple of two int numbers that represent height and width of movement respectively. Default: 1.

  • padding (str) –

    The optional value for pad mode, is “same” or “valid”, not case sensitive. Default: “valid”.

    • same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side.

    • valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded.

Inputs:
  • input (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\).

Outputs:

Tensor, with shape \((N, C_{out}, H_{out}, W_{out})\).

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.avgpool_op = P.AvgPool(padding="VALID", ksize=2, strides=1)
>>>
>>>     def construct(self, x):
>>>         result = self.avgpool_op(x)
>>>         return result
>>>
>>> input_x = Tensor(np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4), mindspore.float32)
>>> net = Net()
>>> result = net(input_x)
[[[[ 2.5   3.5   4.5]
   [ 6.5   7.5   8.5]]
  [[ 14.5  15.5  16.5]
   [ 18.5  19.5  20.5]]
  [[ 26.5  27.5  28.5]
   [ 30.5  31.5  32.5]]]]
class mindspore.ops.BNTrainingReduce(*args, **kwargs)[source]

For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with BNTrainingUpdate.

Inputs:
  • x (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape \((N, C, A, B)\).

Outputs:
  • sum (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape \((C,)\).

  • square_sum (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape \((C,)\).

Examples

>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> bn_training_reduce = P.BNTrainingReduce(input_x)
>>> output = bn_training_reduce(input_x)
class mindspore.ops.BNTrainingUpdate(*args, **kwargs)[source]

For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with BNTrainingReduce.

Parameters
  • isRef (bool) – If a ref. Default: True.

  • epsilon (float) – A small value added to variance avoid dividing by zero. Default: 1e-5.

  • factor (float) – A weight for updating the mean and variance. Default: 0.1.

Inputs:
  • x (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape \((N, C, A, B)\).

  • sum (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce. Tensor of shape \((C,)\).

  • square_sum (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce. Tensor of shape \((C,)\).

  • scale (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor. Tensor of shape \((C,)\).

  • offset (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset. Tensor of shape \((C,)\).

  • mean (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape \((C,)\).

  • variance (Tensor) - A 1-D Tensor with float16 or float32, for the update variance. Tensor of shape \((C,)\).

Outputs:
  • y (Tensor) - Tensor, has the same shape data type as x.

  • mean (Tensor) - Tensor for the updated mean, with float32 data type. Has the same shape as variance.

  • variance (Tensor) - Tensor for the updated variance, with float32 data type. Has the same shape as variance.

  • batch_mean (Tensor) - Tensor for the mean of x, with float32 data type. Has the same shape as variance.

  • batch_variance (Tensor) - Tensor for the mean of variance, with float32 data type. Has the same shape as variance.

Examples

>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> sum = Tensor(np.ones([64]), mindspore.float32)
>>> square_sum = Tensor(np.ones([64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> offset = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> bn_training_update = P.BNTrainingUpdate()
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
class mindspore.ops.BasicLSTMCell(*args, **kwargs)[source]

Applies the long short-term memory (LSTM) to the input.

\[\begin{split}\begin{array}{ll} \\ i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ h_t = o_t * \tanh(c_t) \\ \end{array}\end{split}\]

Here \(\sigma\) is the sigmoid function, and \(*\) is the Hadamard product. \(W, b\) are learnable weights between the output and the input in the formula. For instance, \(W_{ix}, b_{ix}\) are the weight and bias used to transform from input \(x\) to \(i\). Details can be found in paper LONG SHORT-TERM MEMORY and Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling.

Parameters
  • keep_prob (float) – If not 1.0, append Dropout layer on the outputs of each LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0].

  • forget_bias (float) – Add forget bias to forget gate biases in order to decrease former scale. Default: 1.0.

  • state_is_tuple (bool) – If true, the state is a tuple of 2 tensors, containing h and c; If false, the state is

  • Default (a tensor and it needs to be split first.) – True.

  • activation (str) – Activation. Default: “tanh”. Only “tanh” is currently supported.

Inputs:
  • x (Tensor) - Current words. Tensor of shape (batch_size, input_size). The data type must be float16 or float32.

  • h (Tensor) - Hidden state last moment. Tensor of shape (batch_size, hidden_size). The data type must be float16 or float32.

  • c (Tensor) - Cell state last moment. Tensor of shape (batch_size, hidden_size). The data type must be float16 or float32.

  • w (Tensor) - Weight. Tensor of shape (input_size + hidden_size, 4 x hidden_size). The data type must be float16 or float32.

  • b (Tensor) - Bias. Tensor of shape (4 x hidden_size). The data type must be the same as c.

Outputs:
  • ct (Tensor) - Forward \(c_t\) cache at moment t. Tensor of shape (batch_size, hidden_size). Has the same type with input c.

  • ht (Tensor) - Cell output. Tensor of shape (batch_size, hidden_size). With data type of float16.

  • it (Tensor) - Forward \(i_t\) cache at moment t. Tensor of shape (batch_size, hidden_size). Has the same type with input c.

  • jt (Tensor) - Forward \(j_t\) cache at moment t. Tensor of shape (batch_size, hidden_size). Has the same type with input c.

  • ft (Tensor) - Forward \(f_t\) cache at moment t. Tensor of shape (batch_size, hidden_size). Has the same type with input c.

  • ot (Tensor) - Forward \(o_t\) cache at moment t. Tensor of shape (batch_size, hidden_size). Has the same type with input c.

  • tanhct (Tensor) - Forward \(tanh c_t\) cache at moment t. Tensor of shape (batch_size, hidden_size), has the same type with input c.

Examples

>>> x = Tensor(np.random.rand(1, 32).astype(np.float16))
>>> h = Tensor(np.random.rand(1, 64).astype(np.float16))
>>> c = Tensor(np.random.rand(1, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 256).astype(np.float16))
>>> b = Tensor(np.random.rand(256, ).astype(np.float16))
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
>>> lstm(x, h, c, w, b)
class mindspore.ops.BatchMatMul(*args, **kwargs)[source]

Computes matrix multiplication between two tensors by batch

result[…, :, :] = tensor(a[…, :, :]) * tensor(b[…, :, :]).

The two input tensors must have the same rank and the rank must be not less than 3.

Parameters
  • transpose_a (bool) – If true, the last two dimensions of a is transposed before multiplication. Default: False.

  • transpose_b (bool) – If true, the last two dimensions of b is transposed before multiplication. Default: False.

Inputs:
  • input_x (Tensor) - The first tensor to be multiplied. The shape of the tensor is \((*B, N, C)\), where \(*B\) represents the batch size which can be multidimensional, \(N\) and \(C\) are the size of the last two dimensions. If transpose_a is True, its shape must be \((*B, C, N)\).

  • input_y (Tensor) - The second tensor to be multiplied. The shape of the tensor is \((*B, C, M)\). If transpose_b is True, its shape must be \((*B, M, C)\).

Outputs:

Tensor, the shape of the output tensor is \((*B, N, M)\).

Examples

>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul()
>>> output = batmatmul(input_x, input_y)
>>>
>>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul(transpose_a=True)
>>> output = batmatmul(input_x, input_y)
class mindspore.ops.BatchNorm(*args, **kwargs)[source]

Batch Normalization for input data and updated parameters.

Batch Normalization is widely used in convolutional neural networks. This operation applies Batch Normalization over input to avoid internal covariate shift as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the features using a mini-batch of data and the learned parameters which can be described in the following formula,

\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is scale, \(\beta\) is bias, \(\epsilon\) is epsilon.

Parameters
  • is_training (bool) – If is_training is True, mean and variance are computed during training. If is_training is False, they’re loaded from checkpoint during inference. Default: False.

  • epsilon (float) – A small value added for numerical stability. Default: 1e-5.

Inputs:
  • input_x (Tensor) - Tensor of shape \((N, C)\), with float16 or float32 data type.

  • scale (Tensor) - Tensor of shape \((C,)\), with float16 or float32 data type.

  • bias (Tensor) - Tensor of shape \((C,)\), has the same data type with scale.

  • mean (Tensor) - Tensor of shape \((C,)\), with float16 or float32 data type.

  • variance (Tensor) - Tensor of shape \((C,)\), has the same data type with mean.

Outputs:

Tuple of 5 Tensor, the normalized inputs and the updated parameters.

  • output_x (Tensor) - The same type and shape as the input_x. The shape is \((N, C)\).

  • updated_scale (Tensor) - Tensor of shape \((C,)\).

  • updated_bias (Tensor) - Tensor of shape \((C,)\).

  • reserve_space_1 (Tensor) - Tensor of shape \((C,)\).

  • reserve_space_2 (Tensor) - Tensor of shape \((C,)\).

Examples

>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> bias = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> batch_norm = P.BatchNorm()
>>> output = batch_norm(input_x, scale, bias, mean, variance)
class mindspore.ops.BatchToSpace(*args, **kwargs)[source]

Divides batch dimension with blocks and interleaves these blocks back into spatial dimensions.

This operation will divide batch dimension N into blocks with block_size, the output tensor’s N dimension is the corresponding number of blocks after division. The output tensor’s H, W dimension is product of original H, W dimension and block_size with given amount to crop from dimension, respectively.

Parameters
  • block_size (int) – The block size of division, has the value not less than 2.

  • crops (Union[list(int), tuple(int)]) – The crop value for H and W dimension, containing 2 subtraction lists. Each list contains 2 integers. All values must be not less than 0. crops[i] specifies the crop values for the spatial dimension i, which corresponds to the input dimension i+2. It is required that input_shape[i+2]*block_size >= crops[i][0]+crops[i][1].

Inputs:
  • input_x (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by product of block_shape.

Outputs:

Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size and crops. The output shape will be (n’, c’, h’, w’), where

\(n' = n//(block\_size*block\_size)\)

\(c' = c\)

\(h' = h*block\_size-crops[0][0]-crops[0][1]\)

\(w' = w*block\_size-crops[1][0]-crops[1][1]\)

Examples

>>> block_size = 2
>>> crops = [[0, 0], [0, 0]]
>>> op = P.BatchToSpace(block_size, crops)
>>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
>>> output = op(input_x)
[[[[1., 2.], [3., 4.]]]]
class mindspore.ops.BatchToSpaceND(*args, **kwargs)[source]

Divides batch dimension with blocks and interleave these blocks back into spatial dimensions.

This operation will divide batch dimension N into blocks with block_shape, the output tensor’s N dimension is the corresponding number of blocks after division. The output tensor’s H, W dimension is product of original H, W dimension and block_shape with given amount to crop from dimension, respectively.B

Parameters
  • block_shape (Union[list(int), tuple(int)]) – The block shape of dividing block with all value >= 1. The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2.

  • crops (Union[list(int), tuple(int)]) – The crop value for H and W dimension, containing 2 subtraction list, each containing 2 int value. All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].

Inputs:
  • input_x (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by product of block_shape.

Outputs:

Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape and crops. The output shape will be (n’, c’, h’, w’), where

\(n' = n//(block\_shape[0]*block\_shape[1])\)

\(c' = c\)

\(h' = h*block\_shape[0]-crops[0][0]-crops[0][1]\)

\(w' = w*block\_shape[1]-crops[1][0]-crops[1][1]\)

Examples

>>> block_shape = [2, 2]
>>> crops = [[0, 0], [0, 0]]
>>> batch_to_space_nd = P.BatchToSpaceND(block_shape, crops)
>>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
>>> output = batch_to_space_nd(input_x)
[[[[1., 2.], [3., 4.]]]]
class mindspore.ops.BesselI0e(*args, **kwargs)[source]

Computes BesselI0e of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x. Data type must be float16 or float32.

Examples

>>> bessel_i0e = P.BesselI0e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i0e(input_x)
[0.7979961, 0.5144438, 0.75117415, 0.9157829]
class mindspore.ops.BesselI1e(*args, **kwargs)[source]

Computes BesselI1e of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x. Data type must be float16 or float32.

Examples

>>> bessel_i1e = P.BesselI1e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i1e(input_x)
[0.09507662, 0.19699717, 0.11505538, 0.04116856]
class mindspore.ops.BiasAdd(*args, **kwargs)[source]

Returns sum of input and bias tensor.

Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis except for the channel axis.

Inputs:
  • input_x (Tensor) - The input tensor. The shape can be 2-4 dimensions.

  • bias (Tensor) - The bias tensor, with shape \((C)\). The shape of bias must be the same as input_x in the second dimension.

Outputs:

Tensor, with the same shape and type as input_x.

Examples

>>> input_x = Tensor(np.arange(6).reshape((2, 3)), mindspore.float32)
>>> bias = Tensor(np.random.random(3).reshape((3,)), mindspore.float32)
>>> bias_add = P.BiasAdd()
>>> bias_add(input_x, bias)
class mindspore.ops.BinaryCrossEntropy(*args, **kwargs)[source]

Computes the Binary Cross Entropy between the target and the output.

Note

Sets input as \(x\), input label as \(y\), output as \(\ell(x, y)\). Let,

\[L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]\]

Then,

\[\begin{split}\ell(x, y) = \begin{cases} L, & \text{if reduction} = \text{`none';}\\ \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}\end{split}\]
Parameters

reduction (str) – Specifies the reduction to be applied to the output. Its value must be one of ‘none’, ‘mean’, ‘sum’. Default: ‘mean’.

Inputs:
  • input_x (Tensor) - The input Tensor. The data type must be float16 or float32.

  • input_y (Tensor) - The label Tensor which has same shape and data type as input_x.

  • weight (Tensor, optional) - A rescaling weight applied to the loss of each batch element. And it must have same shape and data type as input_x. Default: None.

Outputs:

Tensor or Scalar, if reduction is ‘none’, then output is a tensor and has the same shape as input_x. Otherwise, the output is a scalar.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.binary_cross_entropy = P.BinaryCrossEntropy()
>>>     def construct(self, x, y, weight):
>>>         result = self.binary_cross_entropy(x, y, weight)
>>>         return result
>>>
>>> net = Net()
>>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
>>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32)
>>> weight = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> result = net(input_x, input_y, weight)
0.38240486
class mindspore.ops.BitwiseAnd(*args, **kwargs)[source]

Returns bitwise and of two tensors element-wise.

Inputs of input_x1 and input_x2 comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • input_x1 (Tensor) - The input tensor with int16, int32 or uint16 data type.

  • input_x2 (Tensor) - The input tensor with same type as the input_x1.

Outputs:
  • y (Tensor) - The same type as the input_x1.

Examples

>>> input_x1 = Tensor(np.array([0, 0, 1, -1, 1, 1, 1]), mstype.int16)
>>> input_x2 = Tensor(np.array([0, 1, 1, -1, -1, 2, 3]), mstype.int16)
>>> bitwise_and = P.BitwiseAnd()
>>> bitwise_and(input_x1, input_x2)
[0, 0, 1, -1, 1, 0, 1]
class mindspore.ops.BitwiseOr(*args, **kwargs)[source]

Returns bitwise or of two tensors element-wise.

Inputs of input_x1 and input_x2 comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • input_x1 (Tensor) - The input tensor with int16, int32 or uint16 data type.

  • input_x2 (Tensor) - The input tensor with same type as the input_x1.

Outputs:
  • y (Tensor) - The same type as the input_x1.

Examples

>>> input_x1 = Tensor(np.array([0, 0, 1, -1, 1, 1, 1]), mstype.int16)
>>> input_x2 = Tensor(np.array([0, 1, 1, -1, -1, 2, 3]), mstype.int16)
>>> bitwise_or = P.BitwiseOr()
>>> bitwise_or(input_x1, input_x2)
[0, 1, 1, -1, -1, 3, 3]
class mindspore.ops.BitwiseXor(*args, **kwargs)[source]

Returns bitwise xor of two tensors element-wise.

Inputs of input_x1 and input_x2 comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • input_x1 (Tensor) - The input tensor with int16, int32 or uint16 data type.

  • input_x2 (Tensor) - The input tensor with same type as the input_x1.

Outputs:
  • y (Tensor) - The same type as the input_x1.

Examples

>>> input_x1 = Tensor(np.array([0, 0, 1, -1, 1, 1, 1]), mstype.int16)
>>> input_x2 = Tensor(np.array([0, 1, 1, -1, -1, 2, 3]), mstype.int16)
>>> bitwise_xor = P.BitwiseXor()
>>> bitwise_xor(input_x1, input_x2)
[0, 1, 0, 0, -2, 3, 2]
class mindspore.ops.BoundingBoxDecode(*args, **kwargs)[source]

Decodes bounding boxes locations.

Parameters
  • means (tuple) – The means of deltas calculation. Default: (0.0, 0.0, 0.0, 0.0).

  • stds (tuple) – The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0).

  • max_shape (tuple) – The max size limit for decoding box calculation.

  • wh_ratio_clip (float) – The limit of width and height ratio for decoding box calculation. Default: 0.016.

Inputs:
  • anchor_box (Tensor) - Anchor boxes. The shape of anchor_box must be (n, 4).

  • deltas (Tensor) - Delta of boxes. Which has the same shape with anchor_box.

Outputs:

Tensor, decoded boxes.

Examples

>>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32)
>>> deltas = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32)
>>> boundingbox_decode = P.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0),
>>>                                          max_shape=(768, 1280), wh_ratio_clip=0.016)
>>> boundingbox_decode(anchor_box, deltas)
[[4.1953125  0.  0.  5.1953125]
 [2.140625  0.  3.859375  60.59375]]
class mindspore.ops.BoundingBoxEncode(*args, **kwargs)[source]

Encodes bounding boxes locations.

Parameters
  • means (tuple) – Means for encoding bounding boxes calculation. Default: (0.0, 0.0, 0.0, 0.0).

  • stds (tuple) – The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0).

Inputs:
  • anchor_box (Tensor) - Anchor boxes. The shape of anchor_box must be (n, 4).

  • groundtruth_box (Tensor) - Ground truth boxes. Which has the same shape with anchor_box.

Outputs:

Tensor, encoded bounding boxes.

Examples

>>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32)
>>> groundtruth_box = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32)
>>> boundingbox_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
>>> boundingbox_encode(anchor_box, groundtruth_box)
[[5.0000000e-01  5.0000000e-01  -6.5504000e+04  6.9335938e-01]
 [-1.0000000e+00  2.5000000e-01  0.0000000e+00  4.0551758e-01]]
class mindspore.ops.Broadcast(*args, **kwargs)[source]

Broadcasts the tensor to the whole group.

Note

The tensors must have the same shape and format in all processes of the collection.

Parameters
  • root_rank (int) – Source rank. Required in all processes except the one that is sending the data.

  • group (str) – The communication group to work on. Default: “hccl_world_group”.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape of the input, i.e., \((x_1, x_2, ..., x_R)\). The contents depend on the data of the root_rank device.

Raises

TypeError – If root_rank is not a integer or group is not a string.

Examples

>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init()
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.broadcast = P.Broadcast(1)
>>>
>>>     def construct(self, x):
>>>         return self.broadcast((x,))
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
class mindspore.ops.BroadcastTo(*args, **kwargs)[source]

Broadcasts input tensor to a given shape. Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one. When input shape is broadcast to target shape, it starts with the trailing dimensions.

Parameters

shape (tuple) – The target shape to broadcast.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, with the given shape and the same data type as input_x.

Examples

>>> shape = (2, 3)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> broadcast_to = P.BroadcastTo(shape)
>>> broadcast_to(input_x)
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
class mindspore.ops.CTCGreedyDecoder(*args, **kwargs)[source]

Performs greedy decoding on the logits given in inputs.

Parameters

merge_repeated (bool) – If true, merge repeated classes in output. Default: True.

Inputs:
  • inputs (Tensor) - The input Tensor must be a 3-D tensor whose shape is (max_time, batch_size, num_classes). num_classes must be num_labels + 1 classes, num_labels indicates the number of actual labels. Blank labels are reserved. Default blank label is num_classes - 1. Data type must be float32 or float64.

  • sequence_length (Tensor) - A tensor containing sequence lengths with the shape of (batch_size). The type must be int32. Each value in the tensor must not greater than max_time.

Outputs:
  • decoded_indices (Tensor) - A tensor with shape of (total_decoded_outputs, 2). Data type is int64.

  • decoded_values (Tensor) - A tensor with shape of (total_decoded_outputs), it stores the decoded classes. Data type is int64.

  • decoded_shape (Tensor) - The value of tensor is [batch_size, max_decoded_legth]. Data type is int64.

  • log_probability (Tensor) - A tensor with shape of (batch_size, 1), containing sequence log-probability, has the same type as inputs.

Examples

>>>    class CTCGreedyDecoderNet(nn.Cell):
>>>        def __init__(self):
>>>            super(CTCGreedyDecoderNet, self).__init__()
>>>            self.ctc_greedy_decoder = P.CTCGreedyDecoder()
>>>            self.assert_op = P.Assert(300)
>>>
>>>        def construct(self, inputs, sequence_length):
>>>            out = self.ctc_greedy_decoder(inputs,sequence_length)
>>>            self.assert_op(True, (out[0], out[1], out[2], out[3]))
>>>            return out[2]
>>>
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> net = CTCGreedyDecoderNet()
>>> output = net(inputs, sequence_length)
class mindspore.ops.CTCLoss(*args, **kwargs)[source]

Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.

Parameters
  • preprocess_collapse_repeated (bool) – If true, repeated labels will be collapsed prior to the CTC calculation. Default: False.

  • ctc_merge_repeated (bool) – If false, during CTC calculation, repeated non-blank labels will not be merged and these labels will be interpreted as individual ones. This is a simplfied version of CTC. Default: True.

  • ignore_longer_outputs_than_inputs (bool) – If true, sequences with longer outputs than inputs will be ignored. Default: False.

Inputs:
  • inputs (Tensor) - The input Tensor must be a 3-D tensor whose shape is (max_time, batch_size, num_classes). num_classes must be num_labels + 1 classes, num_labels indicates the number of actual labels. Blank labels are reserved. Default blank label is num_classes - 1. Data type must be float16, float32 or float64.

  • labels_indices (Tensor) - The indices of labels. labels_indices[i, :] == [b, t] means labels_values[i] stores the id for (batch b, time t). The type must be int64 and rank must be 2.

  • labels_values (Tensor) - A 1-D input tensor. The values are associated with the given batch and time. The type must be int32. labels_values[i] must in the range of [0, num_classes).

  • sequence_length (Tensor) - A tensor containing sequence lengths with the shape of (batch_size). The type must be int32. Each value in the tensor must not be greater than max_time.

Outputs:
  • loss (Tensor) - A tensor containing log-probabilities, the shape is (batch_size). The tensor has the same type with inputs.

  • gradient (Tensor) - The gradient of loss, has the same type and shape with inputs.

Examples

>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64)
>>> labels_values = Tensor(np.array([2, 2]), mindspore.int32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> ctc_loss = P.CTCLoss()
>>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
class mindspore.ops.Cast(*args, **kwargs)[source]

Returns a tensor with the new specified data type.

Inputs:
  • input_x (Union[Tensor, Number]) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The tensor to be cast.

  • type (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.

Outputs:

Tensor, the shape of tensor is the same as input_x, \((x_1, x_2, ..., x_R)\).

Examples

>>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
>>> input_x = Tensor(input_np)
>>> type_dst = mindspore.float16
>>> cast = P.Cast()
>>> result = cast(input_x, type_dst)
class mindspore.ops.Ceil(*args, **kwargs)[source]

Round a tensor up to the closest integer element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. It’s element data type must be float16 or float32.

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> input_x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32)
>>> ceil_op = P.Ceil()
>>> ceil_op(input_x)
[2.0, 3.0, -1.0]
class mindspore.ops.CheckBprop(*args, **kwargs)[source]

Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.

Raises

TypeError – If tuples x and y are not the same.

Inputs:
  • input_x (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked.

  • input_y (tuple[Tensor]) - The input_y contains the inputs of bprop to check against.

Outputs:

(tuple[Tensor]), the input_x, if data type and shape of corresponding elements from input_x and input_y are the same.

Examples

>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> out = P.CheckBprop()(input_x, input_y)
class mindspore.ops.CheckValid(*args, **kwargs)[source]

Checks bounding box.

Checks whether the bounding box cross data and data border are valid.

Inputs:
  • bboxes (Tensor) - Bounding boxes tensor with shape (N, 4). Data type must be float16 or float32.

  • img_metas (Tensor) - Raw image size information with the format of (height, width, ratio). Data type must be float16 or float32.

Outputs:

Tensor, the valided tensor.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.check_valid = P.CheckValid()
>>>     def construct(self, x, y):
>>>         valid_result = self.check_valid(x, y)
>>>         return valid_result
>>>
>>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32)
>>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32)
>>> net = Net()
>>> result = net(bboxes, img_metas)
[True   False   False]
class mindspore.ops.Concat(*args, **kwargs)[source]

Concats tensor in specified axis.

Concats input tensors along with the given axis.

Note

The input data is a tuple of tensors. These tensors have the same rank R. Set the given axis as m, and \(0 \le m < R\). Set the number of input tensors as N. For the \(i\)-th tensor \(t_i\), it has the shape of \((x_1, x_2, ..., x_{mi}, ..., x_R)\). \(x_{mi}\) is the \(m\)-th dimension of the \(i\)-th tensor. Then, the shape of the output tensor is

\[(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)\]
Parameters

axis (int) – The specified axis. Default: 0.

Inputs:
  • input_x (tuple, list) - A tuple or a list of input tensors.

Outputs:

Tensor, the shape is \((x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)\).

Examples

>>> data1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
>>> data2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
>>> op = P.Concat()
>>> output = op((data1, data2))
[[0, 1],
 [2, 1],
 [0, 1],
 [2, 1]]
class mindspore.ops.ControlDepend(*args, **kwargs)[source]

Adds control dependency relation between source and destination operation.

In many cases, we need to control the execution order of operations. ControlDepend is designed for this. ControlDepend will instruct the execution engine to run the operations in a specific order. ControlDepend tells the engine that the destination operations must depend on the source operation which means the source operations must be executed before the destination.

Note

This operation does not work in PYNATIVE_MODE.

Parameters
  • depend_mode (int) – Use 0 for a normal dependency relation. Use 1 to depends on operations which using Parameter

  • Default (as its input.) –

Inputs:
  • src (Any) - The source input. It can be a tuple of operations output or a single operation output. We do not concern about the input data, but concern about the operation that generates the input data. If depend_mode is 1 and the source input is Parameter, we will try to find the operations that used the parameter as input.

  • dst (Any) - The destination input. It can be a tuple of operations output or a single operation output. We do not concern about the input data, but concern about the operation that generates the input data. If depend_mode is 1 and the source input is Parameter, we will try to find the operations that used the parameter as input.

Outputs:

Bool. This operation has no actual data output, it will be used to setup the order of relative operations.

Examples

>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.control_depend = P.ControlDepend()
>>>         self.softmax = P.Softmax()
>>>
>>>     def construct(self, x, y):
>>>         mul = x * y
>>>         softmax = self.softmax(x)
>>>         ret = self.control_depend(mul, softmax)
>>>         return ret
>>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
>>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
>>> net = Net()
>>> output = net(x, y)
class mindspore.ops.Conv2D(*args, **kwargs)[source]

2D convolution layer.

Applies a 2D convolution over an input tensor which is typically of shape \((N, C_{in}, H_{in}, W_{in})\), where \(N\) is batch size and \(C_{in}\) is channel number. For each batch of shape \((C_{in}, H_{in}, W_{in})\), the formula is defined as:

\[out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,\]

where \(ccor\) is the cross correlation operator, \(C_{in}\) is the input channel number, \(j\) ranges from \(0\) to \(C_{out} - 1\), \(W_{ij}\) corresponds to the \(i\)-th channel of the \(j\)-th filter and \(out_{j}\) corresponds to the \(j\)-th channel of the output. \(W_{ij}\) is a slice of kernel and it has shape \((\text{ks_h}, \text{ks_w})\), where \(\text{ks_h}\) and \(\text{ks_w}\) are the height and width of the convolution kernel. The full kernel has shape \((C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})\), where group is the group number to split the input in the channel dimension.

If the ‘pad_mode’ is set to be “valid”, the output height and width will be \(\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor\) and \(\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor\) respectively.

The first introduction can be found in paper Gradient Based Learning Applied to Document Recognition. More detailed introduction can be found here: http://cs231n.github.io/convolutional-networks/.

Parameters
  • out_channel (int) – The dimension of the output.

  • kernel_size (Union[int, tuple[int]]) – The kernel size of the 2D convolution.

  • mode (int) – Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 1.

  • pad_mode (str) – Modes to fill padding. It could be “valid”, “same”, or “pad”. Default: “valid”.

  • pad (Union(int, tuple[int])) – The pad value to be filled. Default: 0. If pad is an integer, the paddings of top, bottom, left and right are the same, equal to pad. If pad is a tuple of four integers, the padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.

  • stride (Union(int, tuple[int])) – The stride to be applied to the convolution filter. Default: 1.

  • dilation (Union(int, tuple[int])) – Specifies the space to use between kernel elements. Default: 1.

  • group (int) – Splits input into groups. Default: 1.

Returns

Tensor, the value that applied 2D convolution.

Inputs:
  • input (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\).

  • weight (Tensor) - Set size of kernel is \((K_1, K_2)\), then the shape is \((C_{out}, C_{in}, K_1, K_2)\).

Outputs:

Tensor of shape \((N, C_{out}, H_{out}, W_{out})\).

Examples

>>> input = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
>>> conv2d = P.Conv2D(out_channel=32, kernel_size=3)
>>> conv2d(input, weight)
class mindspore.ops.Conv2DBackpropInput(*args, **kwargs)[source]

Computes the gradients of convolution with respect to the input.

Parameters
  • out_channel (int) – The dimensionality of the output space.

  • kernel_size (Union[int, tuple[int]]) – The size of the convolution window.

  • pad_mode (str) – Modes to fill padding. It could be “valid”, “same”, or “pad”. Default: “valid”.

  • pad (Union[int, tuple[int]]) – The pad value to be filled. Default: 0. If pad is an integer, the paddings of top, bottom, left and right are the same, equal to pad. If pad is a tuple of four integers, the padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.

  • mode (int) – Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 1.

  • stride (Union[int. tuple[int]]) – The stride to be applied to the convolution filter. Default: 1.

  • dilation (Union[int. tuple[int]]) – Specifies the dilation rate to be used for the dilated convolution. Default: 1.

  • group (int) – Splits input into groups. Default: 1.

Returns

Tensor, the gradients of convolution.

Examples

>>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32)
>>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
>>> x = Tensor(np.ones([10, 32, 32, 32]))
>>> conv2d_backprop_input = P.Conv2DBackpropInput(out_channel=32, kernel_size=3)
>>> conv2d_backprop_input(dout, weight, F.shape(x))
class mindspore.ops.Cos(*args, **kwargs)[source]

Computes cosine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> cos = P.Cos()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = cos(input_x)
class mindspore.ops.Cosh(*args, **kwargs)[source]

Computes hyperbolic cosine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> cosh = P.Cosh()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = cosh(input_x)
[1.0289385 1.364684 1.048436 1.4228927]
class mindspore.ops.CropAndResize(*args, **kwargs)[source]

Extracts crops from the input image tensor and resizes them.

Note

In case that the output shape depends on crop_size, the crop_size must be constant.

Parameters
  • method (str) – An optional string that specifies the sampling method for resizing. It can be “bilinear”, “nearest” or “bilinear_v2”. The option “bilinear” stands for standard bilinear interpolation algorithm, while “bilinear_v2” may result in better result in some cases. Default: “bilinear”

  • extrapolation_value (float) – An optional float value used extrapolation, if applicable. Default: 0.

Inputs:
  • x (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth]. Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.

  • boxes (Tensor) - A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled crop is an up-down flipped version of the original image. The width dimension is treated similarly. Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to extrapolate the input image values. Types allowd: float32.

  • box_index (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). The value of box_ind[i] specifies the image that the i-th box refers to. Types allowd: int32.

  • crop_size (Tuple[int]) - A tuple of two int32 elements: (crop_height, crop_width). Only constant value is allowed. All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive.

Outputs:

A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32.

Examples

>>> class CropAndResizeNet(nn.Cell):
>>>     def __init__(self, crop_size):
>>>         super(CropAndResizeNet, self).__init__()
>>>         self.crop_and_resize = P.CropAndResize()
>>>         self.crop_size = crop_size
>>>     @ms_function
>>>     def construct(self, x, boxes, box_index):
>>>         return self.crop_and_resize(x, boxes, box_index, self.crop_size)
>>>
>>> BATCH_SIZE = 1
>>> NUM_BOXES = 5
>>> IMAGE_HEIGHT = 256
>>> IMAGE_WIDTH = 256
>>> CHANNELS = 3
>>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32)
>>> boxes = np.random.uniform(size=[NUM_BOXES, 4]).astype(np.float32)
>>> box_index = np.random.uniform(size=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32)
>>> crop_size = (24, 24)
>>> crop_and_resize = CropAndResizeNet(crop_size=crop_size)
>>> output = crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_index))
>>> print(output.asnumpy())
class mindspore.ops.CumProd(*args, **kwargs)[source]

Compute the cumulative product of the tensor x along axis.

Parameters
  • exclusive (bool) – If true, perform exclusive cumulative product. Default: False.

  • reverse (bool) – If true, reverse the result along axis. Default: False

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (int) - The dimensions to compute the cumulative product. Only constant value is allowed.

Outputs:

Tensor, has the same shape and dtype as the input_x.

Examples

>>> input_x = Tensor(np.array([a, b, c]).astype(np.float32))
>>> op0 = P.CumProd()
>>> output = op0(input_x, 0) # output=[a, a * b, a * b * c]
>>> op1 = P.CumProd(exclusive=True)
>>> output = op1(input_x, 0) # output=[1, a, a * b]
>>> op2 = P.CumProd(reverse=True)
>>> output = op2(input_x, 0) # output=[a * b * c, b * c, c]
>>> op3 = P.CumProd(exclusive=True, reverse=True)
>>> output = op3(input_x, 0) # output=[b * c, c, 1]
class mindspore.ops.CumSum(*args, **kwargs)[source]

Computes the cumulative sum of input tensor along axis.

Parameters
  • exclusive (bool) – If true, perform exclusive mode. Default: False.

  • reverse (bool) – If true, perform inverse cumulative sum. Default: False.

Inputs:
  • input (Tensor) - The input tensor to accumulate.

  • axis (int) - The axis to accumulate the tensor’s value. Only constant value is allowed. Must be in the range [-rank(input), rank(input)).

Outputs:

Tensor, the shape of the output tensor is consistent with the input tensor’s.

Examples

>>> input = Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float32))
>>> cumsum = P.CumSum()
>>> output = cumsum(input, 1)
[[ 3.  7. 13. 23.]
 [ 1.  7. 14. 23.]
 [ 4.  7. 15. 22.]
 [ 1.  4. 11. 20.]]
class mindspore.ops.DType(*args, **kwargs)[source]

Returns the data type of input tensor as mindspore.dtype.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

mindspore.dtype, the data type of a tensor.

Examples

>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> type = P.DType()(input_tensor)
class mindspore.ops.DataFormatDimMap(*args, **kwargs)[source]

Returns the dimension index in the destination data format given in the source data format.

Parameters
  • src_format (string) – An optional value for source data format. Default: ‘NHWC’.

  • dst_format (string) – An optional value for destination data format. Default: ‘NCHW’.

Inputs:
  • input_x (Tensor) - A Tensor with each element as a dimension index in source data format. The suggested values is in the range [-4, 4). It’s type is int32.

Outputs:

Tensor, has the same type as the input_x.

Examples

>>> x = Tensor([0, 1, 2, 3], mindspore.int32)
>>> dfdm = P.DataFormatDimMap()
>>> dfdm(x)
[0 3 1 2]
class mindspore.ops.DataType[source]

Various combinations of dtype and format.

The current list below may be incomplete. Please add it if necessary.

class mindspore.ops.Depend(*args, **kwargs)[source]

Depend is used for processing side-effect operations.

Inputs:
  • value (Tensor) - the real value to return for depend operator.

  • expr (Expression) - the expression to execute with no outputs.

Outputs:

Tensor, the value passed by last operator.

class mindspore.ops.DepthToSpace(*args, **kwargs)[source]

Rearranges blocks of depth data into spatial dimensions.

This is the reverse operation of SpaceToDepth.

The depth of output tensor is \(input\_depth / (block\_size * block\_size)\).

The output tensor’s height dimension is \(height * block\_size\).

The output tensor’s weight dimension is \(weight * block\_size\).

The input tensor’s depth must be divisible by block_size * block_size. The data format is “NCHW”.

Parameters

block_size (int) – The block size used to divide depth data. It must be >= 2.

Inputs:
  • x (Tensor) - The target tensor. It must be a 4-D tensor.

Outputs:

Tensor, has the same shape and dtype as the ‘x’.

Examples

>>> x = Tensor(np.random.rand(1,12,1,1), mindspore.float32)
>>> block_size = 2
>>> op = P.DepthToSpace(block_size)
>>> output = op(x)
>>> output.asnumpy().shape == (1,3,2,2)
class mindspore.ops.DepthwiseConv2dNative(*args, **kwargs)[source]

Returns the depth-wise convolution value for the input.

Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier. Given an input tensor of shape \((N, C_{in}, H_{in}, W_{in})\) where \(N\) is the batch size and a filter tensor with kernel size \((ks_{h}, ks_{w})\), containing \(C_{in} * \text{channel_multiplier}\) convolutional filters of depth 1; it applies different filters to each input channel (channel_multiplier channels for each input channel has the default value 1), then concatenates the results together. The output has \(\text{in_channels} * \text{channel_multiplier}\) channels.

Parameters
  • channel_multiplier (int) – The multipiler for the original output convolution. Its value must be greater than 0.

  • kernel_size (Union[int, tuple[int]]) – The size of the convolution kernel.

  • mode (int) – Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 3.

  • pad_mode (str) – Modes to fill padding. It could be “valid”, “same”, or “pad”. Default: “valid”.

  • pad (Union[int, tuple[int]]) – The pad value to be filled. If pad is an integer, the paddings of top, bottom, left and right are the same, equal to pad. If pad is a tuple of four integers, the padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly. Default: 0.

  • stride (Union[int, tuple[int]]) – The stride to be applied to the convolution filter. Default: 1.

  • dilation (Union[int, tuple[int]]) – Specifies the dilation rate to be used for the dilated convolution. Default: 1.

  • group (int) – Splits input into groups. Default: 1.

Inputs:
  • input (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\).

  • weight (Tensor) - Set the size of kernel as \((K_1, K_2)\), then the shape is \((K, C_{in}, K_1, K_2)\), K must be 1.

Outputs:

Tensor of shape \((N, C_{in} * \text{channel_multiplier}, H_{out}, W_{out})\).

Examples

>>> input = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32)
>>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3))
>>> output = depthwise_conv2d(input, weight)
>>> output.shape == (10, 96, 30, 30)
class mindspore.ops.Diag(*args, **kwargs)[source]

Constructs a diagonal tensor with a given diagonal values.

Assume input_x has dimensions \([D_1,... D_k]\), the output is a tensor of rank 2k with dimensions \([D_1,..., D_k, D_1,..., D_k]\) where: \(output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]\) and 0 everywhere else.

Inputs:
  • input_x (Tensor) - The input tensor. The input shape must be less than 5d.

Outputs:

Tensor, has the same dtype as the input_x.

Examples

>>> input_x = Tensor([1, 2, 3, 4])
>>> diag = P.Diag()
>>> diag(input_x)
[[1, 0, 0, 0],
 [0, 2, 0, 0],
 [0, 0, 3, 0],
 [0, 0, 0, 4]]
class mindspore.ops.DiagPart(*args, **kwargs)[source]

Extracts the diagonal part from given tensor.

Assume input has dimensions \([D_1,..., D_k, D_1,..., D_k]\), the output is a tensor of rank k with dimensions \([D_1,..., D_k]\) where: \(output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]\).

Inputs:
  • input_x (Tensor) - The input Tensor.

Outputs:

Tensor.

Examples
>>> input_x = Tensor([[1, 0, 0, 0],
>>>                   [0, 2, 0, 0],
>>>                   [0, 0, 3, 0],
>>>                   [0, 0, 0, 4]])
>>> diag_part = P.DiagPart()
>>> diag_part(input_x)
[1, 2, 3, 4]
class mindspore.ops.Div(*args, **kwargs)[source]

Computes the quotient of dividing the first input tensor by the second input tensor element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - When the first input is a tensor, The second input could be a number, a bool, or a tensor whose data type is number or bool. When the first input is a number or a bool, the second input must be a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([-4.0, 5.0, 6.0]), mindspore.float32)
>>> input_y = Tensor(np.array([3.0, 2.0, 3.0]), mindspore.float32)
>>> div = P.Div()
>>> div(input_x, input_y)
[-1.3, 2.5, 2.0]
class mindspore.ops.DivNoNan(*args, **kwargs)[source]

Computes a safe divide which returns 0 if the y is zero.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([-1.0, 0., 1.0, 5.0, 6.0]), mindspore.float32)
>>> input_y = Tensor(np.array([0., 0., 0., 2.0, 3.0]), mindspore.float32)
>>> div_no_nan = P.DivNoNan()
>>> div_no_nan(input_x, input_y)
[0., 0., 0., 2.5, 2.0]
class mindspore.ops.Dropout(*args, **kwargs)[source]

During training, randomly zeroes some of the elements of the input tensor with probability.

Parameters

keep_prob (float) – The keep rate, between 0 and 1, e.g. keep_prob = 0.9, means dropping out 10% of input units.

Inputs:
  • shape (tuple[int]) - The shape of target mask.

Outputs:

Tensor, the value of generated mask for input shape.

Examples

>>> dropout = P.Dropout(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in)
class mindspore.ops.DropoutDoMask(*args, **kwargs)[source]

Applies dropout mask on the input tensor.

Take the mask output of DropoutGenMask as input, and apply dropout on the input.

Inputs:
  • input_x (Tensor) - The input tensor.

  • mask (Tensor) - The mask to be applied on input_x, which is the output of DropoutGenMask. And the shape of input_x must be the same as the value of DropoutGenMask’s input shape. If input wrong mask, the output of DropoutDoMask are unpredictable.

  • keep_prob (Union[Tensor, float]) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9, means dropping out 10% of input units. The value of keep_prob is the same as the input keep_prob of DropoutGenMask.

Outputs:

Tensor, the value that applied dropout on.

Examples

>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> shape = (2, 2, 3)
>>> keep_prob = Tensor(0.5, mindspore.float32)
>>> dropout_gen_mask = P.DropoutGenMask()
>>> dropout_do_mask = P.DropoutDoMask()
>>> mask = dropout_gen_mask(shape, keep_prob)
>>> output = dropout_do_mask(x, mask, keep_prob)
>>> assert output.shape == (2, 2, 3)
[[[2.0, 0.0, 0.0],
  [0.0, 0.0, 0.0]],
 [[0.0, 0.0, 0.0],
  [2.0, 2.0, 2.0]]]
class mindspore.ops.DropoutGenMask(*args, **kwargs)[source]

Generates the mask value for the input shape.

Parameters
  • Seed0 (int) – Seed0 value for random generating. Default: 0.

  • Seed1 (int) – Seed1 value for random generating. Default: 0.

Inputs:
  • shape (tuple[int]) - The shape of target mask.

  • keep_prob (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9, means dropping out 10% of input units.

Outputs:

Tensor, the value of generated mask for input shape.

Examples

>>> dropout_gen_mask = P.DropoutGenMask()
>>> shape = (2, 4, 5)
>>> keep_prob = Tensor(0.5, mindspore.float32)
>>> mask = dropout_gen_mask(shape, keep_prob)
[249, 11, 134, 133, 143, 246, 89, 52, 169, 15, 94, 63, 146, 103, 7, 101]
class mindspore.ops.DynamicRNN(*args, **kwargs)[source]

DynamicRNN Operator.

Parameters
  • cell_type (str) – An string identifying the cell type in the op. Default: ‘LSTM’. Only ‘LSTM’ is currently supported.

  • direction (str) – An string identifying the direction in the op. Default: ‘UNIDIRECTIONAL’. Only ‘UNIDIRECTIONAL’ is currently supported.

  • cell_depth (int) – An integer identifying the cell depth in the op. Default: 1.

  • use_peephole (bool) – An bool identifying if use peephole in the op. Default: False.

  • keep_prob (float) – An float identifying the keep prob in the op. Default: 1.0.

  • cell_clip (float) – An float identifying the cell clip in the op. Default: -1.0.

  • num_proj (int) – An integer identifying the num proj in the op. Default: 0.

  • time_major (bool) – An bool identifying the time major in the op. Default: True. Only True is currently supported.

  • activation (str) – An string identifying the type of activation function in the op. Default: ‘tanh’. Only ‘tanh’ is currently supported.

  • forget_bias (float) – An float identifying the forget bias in the op. Default: 0.0.

  • is_training (bool) – An bool identifying is training in the op. Default: True.

Inputs:
  • x (Tensor) - Current words. Tensor of shape (num_step, batch_size, input_size). The data type must be float16 or float32.

  • w (Tensor) - Weight. Tensor of shape (input_size + hidden_size, 4 x hidden_size). The data type must be float16 or float32.

  • b (Tensor) - Bias. Tensor of shape (4 x hidden_size). The data type must be float16 or float32.

  • seq_length (Tensor) - The length of each batch. Tensor of shape (batch_size). Only None is currently supported.

  • init_h (Tensor) - Hidden state of initial time. Tensor of shape (1, batch_size, hidden_size).

  • init_c (Tensor) - Cell state of initial time. Tensor of shape (1, batch_size, hidden_size).

Outputs:
  • y (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • output_h (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). With data type of float16.

  • output_c (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • i (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • j (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • f (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • o (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

  • tanhct (Tensor) - A Tensor of shape (num_step, batch_size, hidden_size). Has the same type with input b.

Examples

>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
>>> b = Tensor(np.random.rand(128).astype(np.float16))
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> dynamic_rnn = P.DynamicRNN()
>>> output = lstm(x, w, b, None, init_h, init_c)
class mindspore.ops.DynamicShape(*args, **kwargs)[source]

Returns the shape of input tensor.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor[int], 1-dim Tensor of type int32

Examples

>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> shape = P.DynamicShape()
>>> output = shape(input_tensor)
class mindspore.ops.EditDistance(*args, **kwargs)[source]

Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences.

Parameters

normalize (bool) – If true, edit distances are normalized by length of truth. Default: True.

Inputs:
  • hypothesis_indices (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type. The shape of tensor is \((N, R)\).

  • hypothesis_values (Tensor) - The values of the hypothesis list SparseTensor. Must be 1-D vector with length of N.

  • hypothesis_shape (Tensor) - The shape of the hypothesis list SparseTensor. Must be R-length vector with int64 data type. Only constant value is allowed.

  • truth_indices (Tensor) - The indices of the truth list SparseTensor. With int64 data type. The shape of tensor is \((M, R)\).

  • truth_values (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.

  • truth_shape (Tensor) - The shape of the truth list SparseTensor. Must be R-length vector with int64 data type. Only constant value is allowed.

Outputs:

Tensor, a dense tensor with rank R-1 and float32 data type.

Examples

>>> import numpy as np
>>> from mindspore import context
>>> from mindspore import Tensor
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class EditDistance(nn.Cell):
>>>     def __init__(self, hypothesis_shape, truth_shape, normalize=True):
>>>         super(EditDistance, self).__init__()
>>>         self.edit_distance = P.EditDistance(normalize)
>>>         self.hypothesis_shape = hypothesis_shape
>>>         self.truth_shape = truth_shape
>>>
>>>     def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
>>>         return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
>>>                                   truth_indices, truth_values, self.truth_shape)
>>>
>>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64))
>>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64))
>>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64))
>>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32))
>>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64))
>>> edit_distance = EditDistance(hypothesis_shape, truth_shape)
>>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
>>> [[1.0, 1.0], [1.0, 1.0]]
class mindspore.ops.Elu(*args, **kwargs)[source]

Computes exponential linear: alpha * (exp(x) - 1) if x < 0, x otherwise. The data type of input tensor must be float.

Parameters

alpha (float) – The coefficient of negative factor whose type is float, only support ‘1.0’ currently. Default: 1.0.

Inputs:
  • input_x (Tensor) - The input tensor whose data type must be float.

Outputs:

Tensor, has the same shape and data type as input_x.

Examples

>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> elu = P.Elu()
>>> result = elu(input_x)
Tensor([[-0.632  4.0   -0.999]
        [2.0    -0.993  9.0  ]], shape=(2, 3), dtype=mindspore.float32)
class mindspore.ops.EmbeddingLookup(*args, **kwargs)[source]

Returns a slice of input tensor based on the specified indices.

This Primitive has the similar functionality as GatherV2 operating on axis = 0, but has one more inputs: offset.

Inputs:
  • input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). This represents a Tensor slice, instead of the entire Tensor. Currently, the dimension is restricted to be 2.

  • input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. Values can be out of range of input_params, and the exceeding part will be filled with 0 in the output.

  • offset (int) - Specifies the offset value of this input_params slice. Thus the real indices are equal to input_indices minus offset.

Outputs:

Tensor, the shape of tensor is \((z_1, z_2, ..., z_N)\).

Examples

>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
class mindspore.ops.Eps(*args, **kwargs)[source]

Creates a tensor filled with input_x dtype minimum val.

Inputs:
  • input_x (Tensor) - Input tensor. The data type must be float16 or float32.

Outputs:

Tensor, has the same type and shape as input_x, but filled with input_x dtype minimum val.

Examples

>>> input_x = Tensor([4, 1, 2, 3], mindspore.float32)
>>> out = P.Eps()(input_x)
[1.52587891e-05, 1.52587891e-05, 1.52587891e-05, 1.52587891e-05]
class mindspore.ops.Equal(*args, **kwargs)[source]

Computes the equivalence between two tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number]) - The first input is a number or a tensor whose data type is number.

  • input_y (Union[Tensor, Number]) - The second input is a number when the first input is a tensor or a tensor whose data type is number.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> equal = P.Equal()
>>> equal(input_x, 2.0)
[False, True, False]
>>>
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 2, 4]), mindspore.int32)
>>> equal = P.Equal()
>>> equal(input_x, input_y)
[True, True, False]
class mindspore.ops.EqualCount(*args, **kwargs)[source]

Computes the number of the same elements of two tensors.

The two input tensors must have the same data type and shape.

Inputs:
  • input_x (Tensor) - The first input tensor.

  • input_y (Tensor) - The second input tensor.

Outputs:

Tensor, with the type same as input tensor and size as (1,).

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 2, 4]), mindspore.int32)
>>> equal_count = P.EqualCount()
>>> equal_count(input_x, input_y)
[2]
class mindspore.ops.Erf(*args, **kwargs)[source]

Computes the Gauss error function of input_x element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. The data type must be float16 or float32.

Outputs:

Tensor, has the same shape and dtype as the input_x.

Examples

>>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32)
>>> erf = P.Erf()
>>> erf(input_x)
[-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765]
class mindspore.ops.Erfc(*args, **kwargs)[source]

Computes the complementary error function of input_x element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. The data type must be float16 or float32.

Outputs:

Tensor, has the same shape and dtype as the input_x.

Examples

>>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32)
>>> erfc = P.Erfc()
>>> erfc(input_x)
[1.8427168, 0., 0.1572832, 0.00469124, 0.00002235]
class mindspore.ops.Exp(*args, **kwargs)[source]

Returns exponential of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. The data type mast be float16 or float32.

Outputs:

Tensor, has the same shape and dtype as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> exp = P.Exp()
>>> exp(input_x)
[ 2.71828183,  7.3890561 , 54.59815003]
class mindspore.ops.ExpandDims(*args, **kwargs)[source]

Adds an additional dimension at the given axis.

Note

If the specified axis is a negative number, the index is counted backward from the end and starts at 1.

Raises

ValueError – If axis is not an integer or not in the valid range.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

  • axis (int) - Specifies the dimension index at which to expand the shape of input_x. The value of axis must be in the range [-input_x.dim()-1, input_x.dim()]. Only constant value is allowed.

Outputs:

Tensor, the shape of tensor is \((1, x_1, x_2, ..., x_R)\) if the value of axis is 0.

Examples

>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> expand_dims = P.ExpandDims()
>>> output = expand_dims(input_tensor, 0)
[[[2.0, 2.0],
  [2.0, 2.0]]]
class mindspore.ops.Expm1(*args, **kwargs)[source]

Returns exponential then minus 1 of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. With float16 or float32 data type.

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([0.0, 1.0, 2.0, 4.0]), mindspore.float32)
>>> expm1 = P.Expm1()
>>> expm1(input_x)
[ 0.,  1.71828183,  6.3890561 , 53.59815003]
class mindspore.ops.Eye(*args, **kwargs)[source]

Creates a tensor with ones on the diagonal and zeros the rest.

Inputs:
  • n (int) - The number of rows of returned tensor

  • m (int) - The number of columns of returned tensor

  • t (mindspore.dtype) - MindSpore’s dtype, The data type of the returned tensor.

Outputs:

Tensor, a tensor with ones on the diagonal and the rest of elements are zero.

Examples

>>> eye = P.Eye()
>>> out_tensor = eye(2, 2, mindspore.int32)
[[1, 0],
 [0, 1]]
class mindspore.ops.Fill(*args, **kwargs)[source]

Creates a tensor filled with a scalar value.

Creates a tensor with shape described by the first argument and fills it with values in the second argument.

Inputs:
  • type (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.

  • shape (tuple) - The specified shape of output tensor. Only constant value is allowed.

  • value (scalar) - Value to fill the returned tensor. Only constant value is allowed.

Outputs:

Tensor, has the same type and shape as input value.

Examples

>>> fill = P.Fill()
>>> fill(mindspore.float32, (2, 2), 1)
[[1.0, 1.0],
 [1.0, 1.0]]
class mindspore.ops.Flatten(*args, **kwargs)[source]

Flattens a tensor without changing its batch size on the 0-th axis.

Inputs:
  • input_x (Tensor) - Tensor of shape \((N, \ldots)\) to be flattened.

Outputs:

Tensor, the shape of the output tensor is \((N, X)\), where \(X\) is the product of the remaining dimension.

Examples

>>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
>>> flatten = P.Flatten()
>>> output = flatten(input_tensor)
>>> assert output.shape == (1, 24)
class mindspore.ops.FloatStatus(*args, **kwargs)[source]

Determine if the elements contain Not a Number(NaN), infinite or negative infinite. 0 for normal, 1 for overflow.

Inputs:
  • input_x (Tensor) - The input tensor. The data type must be float16 or float32.

Outputs:

Tensor, has the shape of (1,), and has the same dtype of input mindspore.dtype.float32 or mindspore.dtype.float16.

Examples

>>> float_status = P.FloatStatus()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = float_status(input_x)
class mindspore.ops.Floor(*args, **kwargs)[source]

Round a tensor down to the closest integer element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. Its element data type must be float.

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> input_x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32)
>>> floor = P.Floor()
>>> floor(input_x)
[1.0, 2.0, -2.0]
class mindspore.ops.FloorDiv(*args, **kwargs)[source]

Divide the first input tensor by the second input tensor element-wise and round down to the closest integer.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> floor_div = P.FloorDiv()
>>> floor_div(input_x, input_y)
[0, 1, -1]
class mindspore.ops.FloorMod(*args, **kwargs)[source]

Compute the remainder of division element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool , and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> floor_mod = P.FloorMod()
>>> floor_mod(input_x, input_y)
[2, 1, 2]
class mindspore.ops.FusedBatchNorm(*args, **kwargs)[source]

FusedBatchNorm is a BatchNorm that moving mean and moving variance will be computed instead of being loaded.

Batch Normalization is widely used in convolutional networks. This operation applies Batch Normalization over input to avoid internal covariate shift as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula.

\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is scale, \(\beta\) is bias, \(\epsilon\) is epsilon.

Parameters
  • mode (int) – Mode of batch normalization, value is 0 or 1. Default: 0.

  • epsilon (float) – A small value added for numerical stability. Default: 1e-5.

  • momentum (float) – The hyper parameter to compute moving average for running_mean and running_var (e.g. \(new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean\)). Momentum value must be [0, 1]. Default: 0.9.

Inputs:
  • input_x (Tensor) - Tensor of shape \((N, C)\).

  • scale (Tensor) - Tensor of shape \((C,)\).

  • bias (Tensor) - Tensor of shape \((C,)\).

  • mean (Tensor) - Tensor of shape \((C,)\).

  • variance (Tensor) - Tensor of shape \((C,)\).

Outputs:

Tuple of 5 Tensor, the normalized input and the updated parameters.

  • output_x (Tensor) - The same type and shape as the input_x.

  • updated_scale (Tensor) - Tensor of shape \((C,)\).

  • updated_bias (Tensor) - Tensor of shape \((C,)\).

  • updated_moving_mean (Tensor) - Tensor of shape \((C,)\).

  • updated_moving_variance (Tensor) - Tensor of shape \((C,)\).

Examples

>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> bias = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> op = P.FusedBatchNorm()
>>> output = op(input_x, scale, bias, mean, variance)
class mindspore.ops.FusedBatchNormEx(*args, **kwargs)[source]

FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that moving mean and moving variance will be computed instead of being loaded.

Batch Normalization is widely used in convolutional networks. This operation applies Batch Normalization over input to avoid internal covariate shift as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula.

\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is scale, \(\beta\) is bias, \(\epsilon\) is epsilon.

Parameters
  • mode (int) – Mode of batch normalization, value is 0 or 1. Default: 0.

  • epsilon (float) – A small value added for numerical stability. Default: 1e-5.

  • momentum (float) – The hyper parameter to compute moving average for running_mean and running_var (e.g. \(new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean\)). Momentum value must be [0, 1]. Default: 0.9.

Inputs:
  • input_x (Tensor) - The input of FusedBatchNormEx, Tensor of shape \((N, C)\),

    data type: float16 or float32.

  • scale (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape \((C,)\),

    data type: float32.

  • bias (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape \((C,)\),

    data type: float32.

  • mean (Tensor) - mean value, Tensor of shape \((C,)\), data type: float32.

  • variance (Tensor) - variance value, Tensor of shape \((C,)\), data type: float32.

Outputs:

Tuple of 6 Tensors, the normalized input, the updated parameters and reserve.

  • output_x (Tensor) - The input of FusedBatchNormEx, same type and shape as the input_x.

  • updated_scale (Tensor) - Updated parameter scale, Tensor of shape \((C,)\), data type: float32.

  • updated_bias (Tensor) - Updated parameter bias, Tensor of shape \((C,)\), data type: float32.

  • updated_moving_mean (Tensor) - Updated mean value, Tensor of shape \((C,)\), data type: float32.

  • updated_moving_variance (Tensor) - Updated variance value, Tensor of shape \((C,)\),

    data type: float32.

  • reserve (Tensor) - reserve space, Tensor of shape \((C,)\), data type: float32.

Examples

>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> bias = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> op = P.FusedBatchNormEx()
>>> output = op(input_x, scale, bias, mean, variance)
class mindspore.ops.FusedSparseAdam(*args, **kwargs)[source]

Merges the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam) algorithm. This operator is used when the gradient is sparse.

The Adam algorithm is proposed in Adam: A Method for Stochastic Optimization.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ w = w - l * \frac{m}{\sqrt{v} + \epsilon} \end{array}\end{split}\]

\(m\) represents the 1st moment vector, \(v\) represents the 2nd moment vector, \(g\) represents gradient, \(l\) represents scaling factor lr, \(\beta_1, \beta_2\) represent beta1 and beta2, \(t\) represents updating step while \(beta_1^t\) and \(beta_2^t\) represent beta1_power and beta2_power, \(\alpha\) represents learning_rate, \(w\) represents var, \(\epsilon\) represents epsilon.

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • use_locking (bool) – Whether to enable a lock to protect variable tensors from being updated. If true, updates of the var, m, and v tensors will be protected by a lock. If false, the result is unpredictable. Default: False.

  • use_nesterov (bool) – Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. If true, update the gradients using NAG. If true, update the gradients without using NAG. Default: False.

Inputs:
  • var (Parameter) - Parameters to be updated with float32 data type.

  • m (Parameter) - The 1st moment vector in the updating formula, has the same type as var with float32 data type.

  • v (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients, has the same type as var with float32 data type.

  • beta1_power (Tensor) - \(beta_1^t\) in the updating formula with float32 data type.

  • beta2_power (Tensor) - \(beta_2^t\) in the updating formula with float32 data type.

  • lr (Tensor) - \(l\) in the updating formula. With float32 data type.

  • beta1 (Tensor) - The exponential decay rate for the 1st moment estimations with float32 data type.

  • beta2 (Tensor) - The exponential decay rate for the 2nd moment estimations with float32 data type.

  • epsilon (Tensor) - Term added to the denominator to improve numerical stability with float32 data type.

  • gradient (Tensor) - Gradient value with float32 data type.

  • indices (Tensor) - Gradient indices with int32 data type.

Outputs:

Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.

  • var (Tensor) - A Tensor with shape (1,).

  • m (Tensor) - A Tensor with shape (1,).

  • v (Tensor) - A Tensor with shape (1,).

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_adam = P.FusedSparseAdam()
>>>         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
>>>         self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
>>>     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
>>>         out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
>>>                                      epsilon, grad, indices)
>>>         return out
>>> net = Net()
>>> beta1_power = Tensor(0.9, mstype.float32)
>>> beta2_power = Tensor(0.999, mstype.float32)
>>> lr = Tensor(0.001, mstype.float32)
>>> beta1 = Tensor(0.9, mstype.float32)
>>> beta2 = Tensor(0.999, mstype.float32)
>>> epsilon = Tensor(1e-8, mstype.float32)
>>> gradient = Tensor(np.random.rand(2, 1, 2), mstype.float32)
>>> indices = Tensor([0, 1], mstype.int32)
>>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
class mindspore.ops.FusedSparseFtrl(*args, **kwargs)[source]

Merges the duplicate value of the gradient and then updates relevant entries according to the FTRL-proximal scheme.

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • lr (float) – The learning rate value, must be positive.

  • l1 (float) – l1 regularization strength, must be greater than or equal to zero.

  • l2 (float) – l2 regularization strength, must be greater than or equal to zero.

  • lr_power (float) – Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.

  • use_locking (bool) – Use locks for updating operation if true . Default: False.

Inputs:
  • var (Parameter) - The variable to be updated. The data type must be float32.

  • accum (Parameter) - The accumulation to be updated, must be same type and shape as var.

  • linear (Parameter) - the linear coefficient to be updated, must be same type and shape as var.

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

  • indices (Tensor) - A vector of indices into the first dimension of var and accum. The shape of indices must be the same as grad in first dimension. The type must be int32.

Outputs:

Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.

  • var (Tensor) - A Tensor with shape (1,).

  • accum (Tensor) - A Tensor with shape (1,).

  • linear (Tensor) - A Tensor with shape (1,).

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>>     def __init__(self):
>>>         super(SparseApplyFtrlNet, self).__init__()
>>>         self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>>         self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>>         self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
>>>
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>>         return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
class mindspore.ops.FusedSparseLazyAdam(*args, **kwargs)[source]

Merges the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam) algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the original Adam algorithm, as only the current indices parameters will be updated.

The Adam algorithm is proposed in Adam: A Method for Stochastic Optimization.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ m = \beta_1 * m + (1 - \beta_1) * g \\ v = \beta_2 * v + (1 - \beta_2) * g * g \\ l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ w = w - l * \frac{m}{\sqrt{v} + \epsilon} \end{array}\end{split}\]

\(m\) represents the 1st moment vector, \(v\) represents the 2nd moment vector, \(g\) represents gradient, \(l\) represents scaling factor lr, \(\beta_1, \beta_2\) represent beta1 and beta2, \(t\) represents updating step while \(beta_1^t\) and \(beta_2^t\) represent beta1_power and beta2_power, \(\alpha\) represents learning_rate, \(w\) represents var, \(\epsilon\) represents epsilon.

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • use_locking (bool) – Whether to enable a lock to protect variable tensors from being updated. If true, updates of the var, m, and v tensors will be protected by a lock. If false, the result is unpredictable. Default: False.

  • use_nesterov (bool) – Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. If true, update the gradients using NAG. If true, update the gradients without using NAG. Default: False.

Inputs:
  • var (Parameter) - Parameters to be updated with float32 data type.

  • m (Parameter) - The 1st moment vector in the updating formula, has the same type as var with float32 data type.

  • v (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients, has the same type as var with float32 data type.

  • beta1_power (Tensor) - \(beta_1^t\) in the updating formula with float32 data type.

  • beta2_power (Tensor) - \(beta_2^t\) in the updating formula with float32 data type.

  • lr (Tensor) - \(l\) in the updating formula with float32 data type.

  • beta1 (Tensor) - The exponential decay rate for the 1st moment estimations with float32 data type.

  • beta2 (Tensor) - The exponential decay rate for the 2nd moment estimations with float32 data type.

  • epsilon (Tensor) - Term added to the denominator to improve numerical stability with float32 data type.

  • gradient (Tensor) - Gradient value with float32 data type.

  • indices (Tensor) - Gradient indices with int32 data type.

Outputs:

Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.

  • var (Tensor) - A Tensor with shape (1,).

  • m (Tensor) - A Tensor with shape (1,).

  • v (Tensor) - A Tensor with shape (1,).

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_lazyadam = P.FusedSparseLazyAdam()
>>>         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
>>>         self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
>>>         self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
>>>     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
>>>         out = self.sparse_apply_lazyadam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1,
>>>                                          beta2, epsilon, grad, indices)
>>>         return out
>>> net = Net()
>>> beta1_power = Tensor(0.9, mstype.float32)
>>> beta2_power = Tensor(0.999, mstype.float32)
>>> lr = Tensor(0.001, mstype.float32)
>>> beta1 = Tensor(0.9, mstype.float32)
>>> beta2 = Tensor(0.999, mstype.float32)
>>> epsilon = Tensor(1e-8, mstype.float32)
>>> gradient = Tensor(np.random.rand(2, 1, 2), mstype.float32)
>>> indices = Tensor([0, 1], mstype.int32)
>>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
class mindspore.ops.FusedSparseProximalAdagrad(*args, **kwargs)[source]

Merges the duplicate value of the gradient and then updates relevant entries according to the proximal adagrad algorithm.

\[accum += grad * grad\]
\[\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}\]
\[var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)\]

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – If true, the variable and accumulation tensors will be protected from being updated. Default: False.

Inputs:
  • var (Parameter) - Variable tensor to be updated. The data type must be float32.

  • accum (Parameter) - Variable tensor to be updated, has the same dtype as var.

  • lr (Tensor) - The learning rate value. The data type must be float32.

  • l1 (Tensor) - l1 regularization strength. The data type must be float32.

  • l2 (Tensor) - l2 regularization strength. The data type must be float32.

  • grad (Tensor) - A tensor of the same type as var, for the gradient. The data type must be float32.

  • indices (Tensor) - A vector of indices into the first dimension of var and accum. The data type must be int32.

Outputs:

Tuple of 2 Tensors, this operator will update the input parameters directly, the outputs are useless.

  • var (Tensor) - A Tensor with shape (1,).

  • accum (Tensor) - A Tensor with shape (1,).

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>>         self.lr = Tensor(0.01, mstype.float32)
>>>         self.l1 = Tensor(0.0, mstype.float32)
>>>         self.l2 = Tensor(0.0, mstype.float32)
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
>>>                                                  self.l2, grad, indices)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
class mindspore.ops.Gamma(*args, **kwargs)[source]

Produces random positive floating-point values x, distributed according to probability density function:

\[\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}},\]
Parameters
  • seed (int) – Random seed, must be non-negative. Default: 0.

  • seed2 (int) – Random seed2, must be non-negative. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

  • alpha (Tensor) - The α distribution parameter. It must be greater than 0. It is also known as the shape parameter with float32 data type.

  • beta (Tensor) - The β distribution parameter. It must be greater than 0. It is also known as the scale parameter with float32 data type.

Outputs:

Tensor. The shape must be the broadcasted shape of Input “shape” and shapes of alpha and beta. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> alpha = Tensor(1.0, mstype.float32)
>>> beta = Tensor(1.0, mstype.float32)
>>> gamma = P.Gamma(seed=3)
>>> output = Gamma(shape, alpha, beta)
class mindspore.ops.GatherNd(*args, **kwargs)[source]

Gathers slices from a tensor by indices.

Using given indices to gather slices from a tensor with a specified shape.

Inputs:
  • input_x (Tensor) - The target tensor to gather values.

  • indices (Tensor) - The index tensor, with int data type.

Outputs:

Tensor, has the same type as input_x and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].

Examples

>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> op = P.GatherNd()
>>> output = op(input_x, indices)
[-0.1, 0.5]
class mindspore.ops.GatherV2(*args, **kwargs)[source]

Returns a slice of input tensor based on the specified indices and axis.

Inputs:
  • input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The original Tensor.

  • input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. Must be in the range [0, input_param.shape[axis]).

  • axis (int) - Specifies the dimension index to gather indices.

Outputs:

Tensor, the shape of tensor is \((z_1, z_2, ..., z_N)\).

Examples

>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.GatherV2()(input_params, input_indices, axis)
[[2.0, 7.0],
 [4.0, 54.0],
 [2.0, 55.0]]
class mindspore.ops.GeSwitch(*args, **kwargs)[source]

Adds control switch to data.

Switch data flows into false or true branch depending on the condition. If the condition is true, the true branch will be activated, or vise verse.

Inputs:
  • data (Union[Tensor, Number]) - The data to be used for switch control.

  • pred (Tensor) - It must be a scalar whose type is bool and shape is (), It is used as condition for switch control.

Outputs:

tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data. The false_output connects with the false_branch and the true_output connects with the true_branch.

Examples

>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.square = P.Square()
>>>         self.add = P.TensorAdd()
>>>         self.value = Tensor(np.full((1), 3), mindspore.float32)
>>>         self.switch = P.GeSwitch()
>>>         self.merge = P.Merge()
>>>         self.less = P.Less()
>>>
>>>     def construct(self, x, y):
>>>         cond = self.less(x, y)
>>>         st1, sf1 = self.switch(x, cond)
>>>         st2, sf2 = self.switch(y, cond)
>>>         add_ret = self.add(st1, st2)
>>>         st3, sf3 = self.switch(self.value, cond)
>>>         sq_ret = self.square(sf3)
>>>         ret = self.merge((add_ret, sq_ret))
>>>         return ret[0]
>>>
>>> x = Tensor(10.0, dtype=mindspore.float32)
>>> y = Tensor(5.0, dtype=mindspore.float32)
>>> net = Net()
>>> output = net(x, y)
class mindspore.ops.Gelu(*args, **kwargs)[source]

Gaussian Error Linear Units activation function.

GeLU is described in the paper Gaussian Error Linear Units (GELUs). And also please refer to BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.

Gelu is defined as follows:

\[\text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),\]

where \(erf\) is the “Gauss error function” .

Inputs:
  • input_x (Tensor) - Input to compute the Gelu with data type of float16 or float32.

Outputs:

Tensor, with the same type and shape as input.

Examples

>>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> gelu = P.Gelu()
>>> result = gelu(tensor)
class mindspore.ops.GetNext(*args, **kwargs)[source]

Returns the next element in the dataset queue.

Note

The GetNext operation needs to be associated with network and it also depends on the init_dataset interface, it can’t be used directly as a single operation. For details, please refer to connect_network_with_dataset source code.

Parameters
  • types (list[mindspore.dtype]) – The type of the outputs.

  • shapes (list[tuple[int]]) – The dimensionality of the outputs.

  • output_num (int) – The output number, length of types and shapes.

  • shared_name (str) – The queue name of init_dataset interface.

Inputs:

No inputs.

Outputs:

tuple[Tensor], the output of Dataset. The shape is described in shapes and the type is described is types.

Examples

>>> get_next = P.GetNext([mindspore.float32, mindspore.int32], [[32, 1, 28, 28], [10]], 2, 'shared_name')
>>> feature, label = get_next()
class mindspore.ops.GradOperation(get_all=False, get_by_list=False, sens_param=False)[source]

A higher-order function which is used to generate the gradient function for the input function.

The gradient function generated by GradOperation higher-order function can be customized by construction arguments.

Given an input function net = Net() that takes x and y as inputs, and has a parameter z, see Net in Examples.

To generate a gradient function that returns gradients with respect to the first input (see GradNetWrtX in Examples).

  1. Construct a GradOperation higher-order function with default arguments: grad_op = GradOperation().

  2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  3. Call the gradient function with input function’s inputs to get the gradients with respect to the first input: grad_op(net)(x, y).

To generate a gradient function that returns gradients with respect to all inputs (see GradNetWrtXY in Examples).

  1. Construct a GradOperation higher-order function with get_all=True which indicates getting gradients with respect to all inputs, they are x and y in example function Net(): grad_op = GradOperation(get_all=True).

  2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  3. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs: gradient_function(x, y).

To generate a gradient function that returns gradients with respect to given parameters (see GradNetWithWrtParams in Examples).

  1. Construct a GradOperation higher-order function with get_by_list=True: grad_op = GradOperation(get_by_list=True).

  2. Construct a ParameterTuple that will be passed to the input function when constructing GradOperation higher-order function, it will be used as a parameter filter that determine which gradient to return: params = ParameterTuple(net.trainable_params()).

  3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

4. Call the gradient function with input function’s inputs to get the gradients with respect to given parameters: gradient_function(x, y).

To generate a gradient function that returns gradients with respect to all inputs and given parameters in the format of ((dx, dy), (dz))(see GradNetWrtInputsAndParams in Examples).

  1. Construct a GradOperation higher-order function with get_all=True and get_by_list=True: grad_op = GradOperation(get_all=True, get_by_list=True).

  2. Construct a ParameterTuple that will be passed along input function when constructing GradOperation higher-order function: params = ParameterTuple(net.trainable_params()).

  3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

  4. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs and given parameters: gradient_function(x, y).

We can configure the sensitivity(gradient with respect to output) by setting sens_param as True and passing an extra sensitivity input to the gradient function, the sensitivity input should has the same shape and type with input function’s output(see GradNetWrtXYWithSensParam in Examples).

  1. Construct a GradOperation higher-order function with get_all=True and sens_param=True: grad_op = GradOperation(get_all=True, sens_param=True).

  2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32)).

  3. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  4. Call the gradient function with input function’s inputs and sens_param to get the gradients with respect to all inputs: gradient_function(x, y, grad_wrt_output).

Parameters
  • get_all (bool) – If True, get all the gradients with respect to inputs. Default: False.

  • get_by_list (bool) – If True, get all the gradients with respect to Parameter variables. If get_all and get_by_list are both False, get the gradient with respect to first input. If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables at the same time in the form of ((gradients with respect to inputs), (gradients with respect to parameters)). Default: False.

  • sens_param (bool) – Whether to append sensitivity (gradient with respect to output) as input. If sens_param is False, a ‘ones_like(outputs)’ sensitivity will be attached automatically. Default: False.

Returns

The higher-order function which takes a function as argument and returns gradient function for it.

Examples

>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.matmul = P.MatMul()
>>>         self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
>>>     def construct(self, x, y):
>>>         x = x * self.z
>>>         out = self.matmul(x, y)
>>>         return out
>>>
>>> class GradNetWrtX(nn.Cell):
>>>     def __init__(self, net):
>>>         super(GradNetWrtX, self).__init__()
>>>         self.net = net
>>>         self.grad_op = GradOperation()
>>>     def construct(self, x, y):
>>>         gradient_function = self.grad_op(self.net)
>>>         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> GradNetWrtX(Net())(x, y)
Tensor(shape=[2, 3], dtype=Float32,
[[1.4100001 1.5999999 6.6      ]
 [1.4100001 1.5999999 6.6      ]])
>>>
>>> class GradNetWrtXY(nn.Cell):
>>>     def __init__(self, net):
>>>         super(GradNetWrtXY, self).__init__()
>>>         self.net = net
>>>         self.grad_op = GradOperation(get_all=True)
>>>     def construct(self, x, y):
>>>         gradient_function = self.grad_op(self.net)
>>>         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXY(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[4.5099998 2.7       3.6000001]
 [4.5099998 2.7       3.6000001]]), Tensor(shape=[3, 3], dtype=Float32,
[[2.6       2.6       2.6      ]
 [1.9       1.9       1.9      ]
 [1.3000001 1.3000001 1.3000001]]))
>>>
>>> class GradNetWrtXYWithSensParam(nn.Cell):
>>>     def __init__(self, net):
>>>         super(GradNetWrtXYWithSensParam, self).__init__()
>>>         self.net = net
>>>         self.grad_op = GradOperation(get_all=True, sens_param=True)
>>>         self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
>>>     def construct(self, x, y):
>>>         gradient_function = self.grad_op(self.net)
>>>         return gradient_function(x, y, self.grad_wrt_output)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXYWithSensParam(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[2.211     0.51      1.4900001]
 [5.588     2.68      4.07     ]]), Tensor(shape=[3, 3], dtype=Float32,
[[1.52       2.82       2.14      ]
 [1.1        2.05       1.55      ]
 [0.90000004 1.55       1.25      ]]))
>>>
>>> class GradNetWithWrtParams(nn.Cell):
>>>     def __init__(self, net):
>>>         super(GradNetWithWrtParams, self).__init__()
>>>         self.net = net
>>>         self.params = ParameterTuple(net.trainable_params())
>>>         self.grad_op = GradOperation(get_by_list=True)
>>>     def construct(self, x, y):
>>>         gradient_function = self.grad_op(self.net, self.params)
>>>         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWithWrtParams(Net())(x, y)
(Tensor(shape=[1], dtype=Float32, [21.536]),)
>>>
>>> class GradNetWrtInputsAndParams(nn.Cell):
>>>     def __init__(self, net):
>>>         super(GradNetWrtInputsAndParams, self).__init__()
>>>         self.net = net
>>>         self.params = ParameterTuple(net.trainable_params())
>>>         self.grad_op = GradOperation(get_all=True, get_by_list=True)
>>>     def construct(self, x, y):
>>>         gradient_function = self.grad_op(self.net, self.params)
>>>         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtInputsAndParams(Net())(x, y)
((Tensor(shape=[2, 3], dtype=Float32,
[[3.52 3.9  2.6 ]
 [3.52 3.9  2.6 ]]), Tensor(shape=[3, 3], dtype=Float32,
[[0.6       0.6       0.6      ]
 [1.9       1.9       1.9      ]
 [1.3000001 1.3000001 1.3000001]])), (Tensor(shape=[1], dtype=Float32, [12.902]),))
class mindspore.ops.Greater(*args, **kwargs)[source]

Computes the boolean value of \(x > y\) element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> greater = P.Greater()
>>> greater(input_x, input_y)
[False, True, False]
class mindspore.ops.GreaterEqual(*args, **kwargs)[source]

Computes the boolean value of \(x >= y\) element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> greater_equal = P.GreaterEqual()
>>> greater_equal(input_x, input_y)
[True, True, False]
class mindspore.ops.HSigmoid(*args, **kwargs)[source]

Hard sigmoid activation function.

Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.

Hard sigmoid is defined as:

\[\text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})),\]

where \(x_{i}\) is the \(i\)-th slice in the given dimension of the input Tensor.

Inputs:
  • input_data (Tensor) - The input of HSigmoid, data type must be float16 or float32.

Outputs:

Tensor, with the same type and shape as the input_data.

Examples

>>> hsigmoid = P.HSigmoid()
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> result = hsigmoid(input_x)
class mindspore.ops.HSwish(*args, **kwargs)[source]

Hard swish activation function.

Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.

Hard swish is defined as:

\[\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},\]

where \(x_{i}\) is the \(i\)-th slice in the given dimension of the input Tensor.

Inputs:
  • input_data (Tensor) - The input of HSwish, data type must be float16 or float32.

Outputs:

Tensor, with the same type and shape as the input_data.

Examples

>>> hswish = P.HSwish()
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> result = hswish(input_x)
class mindspore.ops.HistogramFixedWidth(*args, **kwargs)[source]

Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal width and determined by the arguments range and nbins.

Parameters
  • dtype (str) – An optional attribute. Must be one of the following types: “int32”, “int64”. Default: “int32”.

  • nbins (int) – The number of histogram bins, the type is a positive integer.

Inputs:
  • x (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16.

  • range (Tensor) - Must has the same data type as x, and the shape is [2]. x <= range[0] will be mapped to hist[0], x >= range[1] will be mapped to hist[-1].

Outputs:

Tensor, the type is int32.

Examples

>>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16)
>>> range = Tensor([0.0, 5.0], mindspore.float16)
>>> hist = P.HistogramFixedWidth(5)
>>> hist(x, range)
[2 1 1 0 2]
class mindspore.ops.HistogramSummary(*args, **kwargs)[source]

Outputs tensor to protocol buffer through histogram summary operator.

Inputs:
  • name (str) - The name of the input variable.

  • value (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.

Examples

>>> class SummaryDemo(nn.Cell):
>>>     def __init__(self,):
>>>         super(SummaryDemo, self).__init__()
>>>         self.summary = P.HistogramSummary()
>>>         self.add = P.TensorAdd()
>>>
>>>     def construct(self, x, y):
>>>         x = self.add(x, y)
>>>         name = "x"
>>>         self.summary(name, x)
>>>         return x
class mindspore.ops.HookBackward(hook_fn, cell_id='')[source]

This operation is used as a tag to hook gradient in intermediate variables. Note that this function is only supported in Pynative Mode.

Note

The hook function must be defined like hook_fn(grad) -> Tensor or None, where grad is the gradient passed to the primitive and gradient may be modified and passed to next primitive. The difference between a hook function and callback of InsertGradientOf is that a hook function is executed in the python environment while callback will be parsed and added to the graph.

Parameters

hook_fn (Function) – Python function. hook function.

Inputs:
  • inputs (Tensor) - The variable to hook.

Examples

>>> def hook_fn(grad_out):
>>>     print(grad_out)
>>>
>>> grad_all = GradOperation(get_all=True)
>>> hook = P.HookBackward(hook_fn)
>>>
>>> def hook_test(x, y):
>>>     z = x * y
>>>     z = hook(z)
>>>     z = z * y
>>>     return z
>>>
>>> def backward(x, y):
>>>     return grad_all(hook_test)(x, y)
>>>
>>> backward(1, 2)
class mindspore.ops.HyperMap(ops=None)[source]

Hypermap will apply the set operation to input sequences.

Apply the operations to every elements of the sequence or nested sequence. Different from Map, the HyperMap supports to apply on nested structure.

Parameters

ops (Union[MultitypeFuncGraph, None]) – ops is the operation to apply. If ops is None, the operations should be put in the first input of the instance.

Inputs:
  • args (Tuple[sequence]) - If ops is None, all the inputs should be sequences with the same length. And each row of the sequences will be the inputs of the operation.

    If ops is not None, the first input is the operation, and the others are inputs.

Outputs:

Sequence or nested sequence, the sequence of output after applying the function. e.g. operation(args[0][i], args[1][i]).

Examples

>>> from mindspore import dtype as mstype
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
>>> # square all the tensor in the nested list
>>>
>>> square = MultitypeFuncGraph('square')
>>> @square.register("Tensor")
... def square_tensor(x):
...     return F.square(x)
>>>
>>> common_map = HyperMap()
>>> common_map(square, nest_tensor_list)
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
>>> square_map = HyperMap(square)
>>> square_map(nest_tensor_list)
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
class mindspore.ops.IFMR(*args, **kwargs)[source]

The TFMR(Input Feature Map Reconstruction).

Parameters
  • min_percentile (float) – Min init percentile.

  • max_percentile (float) – Max init percentile.

  • Union[list (search_range) – Range of searching.

  • search_step (float) – Step size of searching.

  • with_offset (bool) – Whether using offset.

Inputs:
  • data (Tensor) - A Tensor of feature map. With float16 or float32 data type.

  • data_min (Tensor) - A Tensor of min value of feature map, the shape is \((1)\). With float16 or float32 data type.

  • data_max (Tensor) - A Tensor of max value of feature map, the shape is \((1)\). With float16 or float32 data type.

  • cumsum (Tensor) - A 1-D Tensor of cumsum bin of data. With int32 data type.

Outputs:
  • scale (Tensor) - A tensor of optimal scale, the shape is \((1)\). Data dtype is float32.

  • offset (Tensor) - A tensor of optimal offset, the shape is \((1)\). Data dtype is float32.

Examples

>>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
>>> data_min = Tensor([0.1], mstype.float32)
>>> data_max = Tensor([0.5], mstype.float32)
>>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
>>> ifmr = P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
                  search_step=1.0, with_offset=False)
>>> output = ifmr(data, data_min, data_max, cumsum)
class mindspore.ops.IOU(*args, **kwargs)[source]

Calculates intersection over union for boxes.

Computes the intersection over union (IOU) or the intersection over foreground (IOF) based on the ground-truth and predicted regions.

\[ \begin{align}\begin{aligned}\text{IOU} = \frac{\text{Area of Overlap}}{\text{Area of Union}}\\\text{IOF} = \frac{\text{Area of Overlap}}{\text{Area of Ground Truth}}\end{aligned}\end{align} \]
Parameters

mode (string) – The mode is used to specify the calculation method, now supporting ‘iou’ (intersection over union) or ‘iof’ (intersection over foreground) mode. Default: ‘iou’.

Inputs:
  • anchor_boxes (Tensor) - Anchor boxes, tensor of shape (N, 4). “N” indicates the number of anchor boxes, and the value “4” refers to “x0”, “y0”, “x1”, and “y1”. Data type must be float16 or float32.

  • gt_boxes (Tensor) - Ground truth boxes, tensor of shape (M, 4). “M” indicates the number of ground truth boxes, and the value “4” refers to “x0”, “y0”, “x1”, and “y1”. Data type must be float16 or float32.

Outputs:

Tensor, the ‘iou’ values, tensor of shape (M, N), with the same data type as anchor_boxes.

Raises

KeyError – When mode is not ‘iou’ or ‘iof’.

Examples

>>> iou = P.IOU()
>>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> iou(anchor_boxes, gt_boxes)
[[0.0, 65504, 65504],
 [0.0, 0.0, 0.0],
 [0.22253, 0.0, 0.0]]
class mindspore.ops.ImageSummary(*args, **kwargs)[source]

Outputs image tensor to protocol buffer through image summary operator.

Inputs:
  • name (str) - The name of the input variable, it must not be an empty string.

  • value (Tensor) - The value of image, the rank of tensor must be 4.

Examples

>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.summary = P.ImageSummary()
>>>
>>>     def construct(self, x):
>>>         name = "image"
>>>         out = self.summary(name, x)
>>>         return out
class mindspore.ops.InTopK(*args, **kwargs)[source]

Whether the targets are in the top k predictions.

Parameters

k (int) – Specifies the number of top elements to be used for computing precision.

Inputs:
  • x1 (Tensor) - A 2D Tensor defines the predictions of a batch of samples with float16 or float32 data type.

  • x2 (Tensor) - A 1D Tensor defines the labels of a batch of samples with int32 data type. The size of x2 must be equal to x1’s first dimension. The values of x2 can not be negative and must be equal to or less than index of x1’s second dimension.

Outputs:

Tensor has 1 dimension of type bool and the same shape with x2. For labeling sample i in x2, if the label in the first k predictions for sample i is in x1, then the value is True, otherwise False.

Examples

>>> x1 = Tensor(np.array([[1, 8, 5, 2, 7], [4, 9, 1, 3, 5]]), mindspore.float32)
>>> x2 = Tensor(np.array([1, 3]), mindspore.int32)
>>> in_top_k = P.InTopK(3)
>>> result = in_top_k(x1, x2)
[True  False]
class mindspore.ops.InplaceAdd(*args, **kwargs)[source]

Adds v into specified rows of x. Computes y = x; y[i,] += v.

Parameters

indices (Union[int, tuple]) – Indices into the left-most dimension of x, and determines which rows of x to add with v. It is an integer or a tuple, whose value is in [0, the first dimension size of x).

Inputs:
  • input_x (Tensor) - The first input is a tensor whose data type is float16, float32 or int32.

  • input_v (Tensor) - The second input is a tensor that has the same dimension sizes as x except the first dimension, which must be the same as indices’s size. It has the same data type with input_x.

Outputs:

Tensor, has the same shape and dtype as input_x.

Examples

>>> indices = (0, 1)
>>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplaceAdd = P.InplaceAdd(indices)
>>> inplaceAdd(input_x, input_v)
[[1.5 3.]
 [4. 5.5]
 [5. 6.]]
class mindspore.ops.InplaceSub(*args, **kwargs)[source]

Subtracts v into specified rows of x. Computes y = x; y[i, :] -= v.

Parameters

indices (Union[int, tuple]) – Indices into the left-most dimension of x, and determines which rows of x to subtract with v. It is a int or tuple, whose value is in [0, the first dimension size of x).

Inputs:
  • input_x (Tensor) - The first input is a tensor whose data type is float16, float32 or int32.

  • input_v (Tensor) - The second input is a tensor who has the same dimension sizes as x except the first dimension, which must be the same as indices’s size. It has the same data type with input_x.

Outputs:

Tensor, has the same shape and dtype as input_x.

Examples

>>> indices = (0, 1)
>>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplaceSub = P.InplaceSub(indices)
>>> inplaceSub(input_x, input_v)
[[0.5 1.]
 [2. 2.5]
 [5. 6.]]
class mindspore.ops.InplaceUpdate(*args, **kwargs)[source]

Updates specified rows with values in v.

Parameters

indices (Union[int, tuple]) – Indices into the left-most dimension of x, and determines which rows of x to update with v. It is a int or tuple, whose value is in [0, the first dimension size of x).

Inputs:
  • x (Tensor) - A tensor which to be inplace updated. It can be one of the following data types: float32, float16 and int32.

  • v (Tensor) - A tensor with the same type as x and the same dimension size as x except the first dimension, which must be the same as the size of indices.

Outputs:

Tensor, with the same type and shape as the input x.

Examples

>>> indices = (0, 1)
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplace_update = P.InplaceUpdate(indices)
>>> result = inplace_update(x, v)
[[0.5, 1.0],
 [1.0, 1.5],
 [5.0, 6.0]]
class mindspore.ops.InsertGradientOf(*args, **kwargs)[source]

Attaches callback to graph node that will be invoked on the node’s gradient.

Parameters

f (Function) – MindSpore’s Function. Callback function.

Inputs:
  • input_x (Any) - The graph node to attach to.

Outputs:

Tensor, returns input_x directly. InsertGradientOf does not affect the forward result.

Examples

>>> def clip_gradient(dx):
>>>     ret = dx
>>>     if ret > 1.0:
>>>         ret = 1.0
>>>
>>>     if ret < 0.2:
>>>         ret = 0.2
>>>
>>>     return ret
>>>
>>> clip = P.InsertGradientOf(clip_gradient)
>>> grad_all = C.GradOperation(get_all=True)
>>> def InsertGradientOfClipDemo():
>>>     def clip_test(x, y):
>>>         x = clip(x)
>>>         y = clip(y)
>>>         c = x * y
>>>         return c
>>>
>>>     @ms_function
>>>     def f(x, y):
>>>         return clip_test(x, y)
>>>
>>>     def fd(x, y):
>>>         return grad_all(clip_test)(x, y)
>>>
>>>     print("forward: ", f(1.1, 0.1))
>>>     print("clip_gradient:", fd(1.1, 0.1))
class mindspore.ops.Inv(*args, **kwargs)[source]

Computes Inv(Reciprocal) of input tensor element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). Must be one of the following types: float16, float32, int32.

Outputs:

Tensor, has the same shape and data type as input_x.

Examples

>>> inv = P.Inv()
>>> input_x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
>>> output = inv(input_x)
[4., 2.5, 3.2258065, 1.923077]
class mindspore.ops.Invert(*args, **kwargs)[source]

Flips all bits of input tensor element-wise.

Inputs:
  • input_x (Tensor[int16], Tensor[uint16]) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> invert = P.Invert()
>>> input_x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
>>> output = invert(input_x)
[-26, -5, -14, -10]
class mindspore.ops.InvertPermutation(*args, **kwargs)[source]

Computes the inverse of an index permutation.

Given a tuple input, this operation inserts a dimension of 1 at the dimension This operation calculates the inverse of the index replacement. It requires a 1-dimensional tuple x, which represents the array starting at zero, and swaps each value with its index position. In other words, for the output tuple y and the input tuple x, this operation calculates the following: \(y[x[i]] = i, \quad i \in [0, 1, \ldots, \text{len}(x)-1]\).

Note

These values must include 0. There must be no duplicate values and the values can not be negative.

Inputs:
  • input_x (Union(tuple[int], list[int]) - The input is constructed by multiple integers, i.e., \((y_1, y_2, ..., y_S)\) representing the indices. The values must include 0. There can be no duplicate values or negative values. Only constant value is allowed. The maximum value msut be equal to length of input_x.

Outputs:

tuple[int]. It has the same length as the input.

Examples

>>> invert = P.InvertPermutation()
>>> input_data = (3, 4, 0, 2, 1)
>>> output = invert(input_data)
>>> output == (2, 4, 3, 0, 1)
class mindspore.ops.IsFinite(*args, **kwargs)[source]

Judge which elements are finite for each position.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape of input, and the dtype is bool.

Examples

>>> is_finite = P.IsFinite()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_finite(input_x)
[False   True   False]
class mindspore.ops.IsInf(*args, **kwargs)[source]

Judging which elements are inf or -inf for each position

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape of input, and the dtype is bool.

Examples

>>> is_inf = P.IsInf()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_inf(input_x)
class mindspore.ops.IsInstance(*args, **kwargs)[source]

Checks whether an object is an instance of a target type.

Inputs:
  • inst (Any Object) - The instance to be checked. Only constant value is allowed.

  • type_ (mindspore.dtype) - The target type. Only constant value is allowed.

Outputs:

bool, the check result.

Examples

>>> a = 1
>>> result = P.IsInstance()(a, mindspore.int32)
True
class mindspore.ops.IsNan(*args, **kwargs)[source]

Judge which elements are nan for each position.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape of input, and the dtype is bool.

Examples

>>> is_nan = P.IsNan()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_nan(input_x)
class mindspore.ops.IsSubClass(*args, **kwargs)[source]

Checks whether one type is subtraction class of another type.

Inputs:
  • sub_type (mindspore.dtype) - The type to be checked. Only constant value is allowed.

  • type_ (mindspore.dtype) - The target type. Only constant value is allowed.

Outputs:

bool, the check result.

Examples

>>> result = P.IsSubClass()(mindspore.int32,  mindspore.intc)
True
class mindspore.ops.KLDivLoss(*args, **kwargs)[source]

Computes the Kullback-Leibler divergence between the target and the output.

Note

Sets input as \(x\), input label as \(y\), output as \(\ell(x, y)\). Let,

\[L = \{l_1,\dots,l_N\}^\top, \quad l_n = y_n \cdot (\log y_n - x_n)\]

Then,

\[\begin{split}\ell(x, y) = \begin{cases} L, & \text{if reduction} = \text{`none';}\\ \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}\end{split}\]
Parameters

reduction (str) – Specifies the reduction to be applied to the output. Its value must be one of ‘none’, ‘mean’, ‘sum’. Default: ‘mean’.

Inputs:
  • input_x (Tensor) - The input Tensor. The data type must be float32.

  • input_y (Tensor) - The label Tensor which has the same shape as input_x. The data type must be float32.

Outputs:

Tensor or Scalar, if reduction is ‘none’, then output is a tensor and has the same shape as input_x. Otherwise it is a scalar.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.kldiv_loss = P.KLDivLoss()
>>>     def construct(self, x, y):
>>>         result = self.kldiv_loss(x, y)
>>>         return result
>>>
>>> net = Net()
>>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
>>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32)
>>> result = net(input_x, input_y)
class mindspore.ops.L2Loss(*args, **kwargs)[source]

Calculates half of the L2 norm of a tensor without using the sqrt.

Set input_x as x and output as loss.

\[loss = sum(x ** 2) / nelement(x)\]

\(nelement(x)\) represents the number of input_x.

Inputs:
  • input_x (Tensor) - A input Tensor. Data type must be float16 or float32.

Outputs:

Tensor, has the same dtype as input_x. The output tensor is the value of loss which is a scalar tensor.

Examples
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)
>>> l2_loss = P.L2Loss()
>>> l2_loss(input_x)
7.0
class mindspore.ops.L2Normalize(*args, **kwargs)[source]

L2 normalization Operator.

This operator will normalize the input using the given axis. The function is shown as follows:

\[\text{output} = \frac{x}{\sqrt{\text{max}(\text{sum} (\text{input_x}^2), \epsilon)}},\]

where \(\epsilon\) is epsilon.

Parameters
  • axis (int) – The starting axis for the input to apply the L2 normalization. Default: 0.

  • epsilon (float) – A small value added for numerical stability. Default: 1e-4.

Inputs:
  • input_x (Tensor) - Input to compute the normalization. Data type must be float16 or float32.

Outputs:

Tensor, with the same type and shape as the input.

Examples

>>> l2_normalize = P.L2Normalize()
>>> input_x = Tensor(np.random.randint(-256, 256, (2, 3, 4)), mindspore.float32)
>>> result = l2_normalize(input_x)
[[[-0.47247353   -0.30934513   -0.4991462   0.8185567 ]
  [-0.08070751   -0.9961299    -0.5741758   0.09262337]
  [-0.9916556    -0.3049123     0.5730487  -0.40579924]
 [[-0.88134485    0.9509498    -0.86651784  0.57442576]
  [ 0.99673784    0.08789381   -0.8187321   0.9957012 ]
  [ 0.12891524   -0.9523804    -0.81952125  0.91396334]]]
class mindspore.ops.LARSUpdate(*args, **kwargs)[source]

Conducts lars (layer-wise adaptive rate scaling) update on the sum of squares of gradient.

Parameters
  • epsilon (float) – Term added to the denominator to improve numerical stability. Default: 1e-05.

  • hyperpara (float) – Trust coefficient for calculating the local learning rate. Default: 0.001.

  • use_clip (bool) – Whether to use clip operation for calculating the local learning rate. Default: False.

Inputs:
  • weight (Tensor) - The weight to be updated.

  • gradient (Tensor) - The gradient of weight, which has the same shape and dtype with weight.

  • norm_weight (Tensor) - A scalar tensor, representing the sum of squares of weight.

  • norm_gradient (Tensor) - A scalar tensor, representing the sum of squares of gradient.

  • weight_decay (Union[Number, Tensor]) - Weight decay. It must be a scalar tensor or number.

  • learning_rate (Union[Number, Tensor]) - Learning rate. It must be a scalar tensor or number.

Outputs:

Tensor, represents the new gradient.

Examples

>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.lars = P.LARSUpdate()
>>>         self.reduce = P.ReduceSum()
>>>     def construct(self, weight, gradient):
>>>         w_square_sum = self.reduce(F.square(weight))
>>>         grad_square_sum = self.reduce(F.square(gradient))
>>>         grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
>>>         return grad_t
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
>>> net = Net()
>>> ms_output = net(Tensor(weight), Tensor(gradient))
class mindspore.ops.LRN(*args, **kwargs)[source]

Local Response Normalization.

Parameters
  • depth_radius (int) – Half-width of the 1-D normalization window with the shape of 0-D.

  • bias (float) – An offset (usually positive to avoid dividing by 0).

  • alpha (float) – A scale factor, usually positive.

  • beta (float) – An exponent.

  • norm_region (str) – Specifies normalization region. Options: “ACROSS_CHANNELS”. Default: “ACROSS_CHANNELS”.

Inputs:
  • x (Tensor) - A 4D Tensor with float16 or float32 data type.

Outputs:

Tensor, with the same shape and data type as the input tensor.

Examples

>>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32)
>>> lrn = P.LRN()
>>> lrn(x)
class mindspore.ops.LSTM(*args, **kwargs)[source]

Performs the long short term memory(LSTM) on the input.

For detailed information, please refer to nn.LSTM.

class mindspore.ops.LayerNorm(*args, **kwargs)[source]

Applies the Layer Normalization to the input tensor.

This operator will normalize the input tensor on given axis. LayerNorm is described in the paper Layer Normalization.

\[y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta\]

where \(\gamma\) is scale, \(\beta\) is bias, \(\epsilon\) is epsilon.

Parameters
  • begin_norm_axis (int) – The begin axis of the input_x to apply LayerNorm, the value must be in [-1, rank(input)). Default: 1.

  • begin_params_axis (int) – The begin axis of the parameter input (gamma, beta) to apply LayerNorm, the value must be in [-1, rank(input)). Default: 1.

  • epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-7.

Inputs:
  • input_x (Tensor) - Tensor of shape \((N, \ldots)\). The input of LayerNorm.

  • gamma (Tensor) - Tensor of shape \((P_0, \ldots, P_\text{begin_params_axis})\). The learnable parameter gamma as the scale on norm.

  • beta (Tensor) - Tensor of shape \((P_0, \ldots, P_\text{begin_params_axis})\). The learnable parameter beta as the scale on norm.

Outputs:

tuple[Tensor], tuple of 3 tensors, the normalized input and the updated parameters.

  • output_x (Tensor) - The normalized input, has the same type and shape as the input_x. The shape is \((N, C)\).

  • mean (Tensor) - Tensor of shape \((C,)\).

  • variance (Tensor) - Tensor of shape \((C,)\).

Examples

>>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
>>> gamma = Tensor(np.ones([3]), mindspore.float32)
>>> beta = Tensor(np.ones([3]), mindspore.float32)
>>> layer_norm = P.LayerNorm()
>>> output = layer_norm(input_x, gamma, beta)
([[-0.22474492, 1., 2.2247488], [-0.22474492, 1., 2.2247488]],
 [[2.], [2.]], [[0.6666667], [0.6666667]])
class mindspore.ops.Less(*args, **kwargs)[source]

Computes the boolean value of \(x < y\) element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> less = P.Less()
>>> less(input_x, input_y)
[False, False, True]
class mindspore.ops.LessEqual(*args, **kwargs)[source]

Computes the boolean value of \(x <= y\) element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool , and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> less_equal = P.LessEqual()
>>> less_equal(input_x, input_y)
[True, False, True]
class mindspore.ops.Log(*args, **kwargs)[source]

Returns the natural logarithm of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. With float16 or float32 data type. The value must be greater than 0.

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> log = P.Log()
>>> log(input_x)
[0.0, 0.69314718, 1.38629436]
class mindspore.ops.Log1p(*args, **kwargs)[source]

Returns the natural logarithm of one plus the input tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor. With float16 or float32 data type. The value must be greater than -1.

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> log1p = P.Log1p()
>>> log1p(input_x)
[0.6931472, 1.0986123, 1.609438]
class mindspore.ops.LogSoftmax(*args, **kwargs)[source]

Log Softmax activation function.

Applies the Log Softmax function to the input tensor on the specified axis. Suppose a slice in the given aixs, \(x\) for each element \(x_i\), the Log Softmax function is shown as follows:

\[\text{output}(x_i) = \log \left(\frac{exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),\]

where \(N\) is the length of the Tensor.

Parameters

axis (int) – The axis to perform the Log softmax operation. Default: -1.

Inputs:
  • logits (Tensor) - The input of Log Softmax, with float16 or float32 data type.

Outputs:

Tensor, with the same type and shape as the logits.

Examples

>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> log_softmax = P.LogSoftmax()
>>> log_softmax(input_x)
[-4.4519143, -3.4519143, -2.4519143, -1.4519144, -0.4519144]
class mindspore.ops.LogicalAnd(*args, **kwargs)[source]

Computes the “logical AND” of two tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one bool. When the inputs are two tensors, the shapes of them could be broadcast, and the data types of them must be bool. When the inputs are one tensor and one bool, the bool object could only be a constant, and the data type of the tensor must be bool.

Inputs:
  • input_x (Union[Tensor, bool]) - The first input is a bool or a tensor whose data type is bool.

  • input_y (Union[Tensor, bool]) - The second input is a bool when the first input is a tensor or a tensor whose data type is bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is bool.

Examples

>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
>>> input_y = Tensor(np.array([True, True, False]), mindspore.bool_)
>>> logical_and = P.LogicalAnd()
>>> logical_and(input_x, input_y)
[True, False, False]
class mindspore.ops.LogicalNot(*args, **kwargs)[source]

Computes the “logical NOT” of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor whose dtype is bool.

Outputs:

Tensor, the shape is the same as the input_x, and the dtype is bool.

Examples

>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
>>> logical_not = P.LogicalNot()
>>> logical_not(input_x)
[False, True, False]
class mindspore.ops.LogicalOr(*args, **kwargs)[source]

Computes the “logical OR” of two tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one bool. When the inputs are two tensors, the shapes of them could be broadcast, and the data types of them must be bool. When the inputs are one tensor and one bool, the bool object could only be a constant, and the data type of the tensor must be bool.

Inputs:
  • input_x (Union[Tensor, bool]) - The first input is a bool or a tensor whose data type is bool.

  • input_y (Union[Tensor, bool]) - The second input is a bool when the first input is a tensor or a tensor whose data type is bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
>>> input_y = Tensor(np.array([True, True, False]), mindspore.bool_)
>>> logical_or = P.LogicalOr()
>>> logical_or(input_x, input_y)
[True, True, True]
class mindspore.ops.MakeRefKey(*args, **kwargs)[source]

Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions, and used for Assign target.

Parameters

tag (str) – Parameter name to make the RefKey.

Inputs:

No inputs.

Outputs:

RefKeyType, made from the Parameter name.

Examples

>>> from mindspore.ops import functional as F
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.y = mindspore.Parameter(Tensor(np.ones([6, 8, 10]), mindspore.int32), name="y")
>>>         self.make_ref_key = P.MakeRefKey("y")
>>>
>>>     def construct(self, x):
>>>         key = self.make_ref_key()
>>>         ref = F.make_ref(key, x, self.y)
>>>         return ref * x
>>>
>>> x = Tensor(np.ones([3, 4, 5]), mindspore.int32)
>>> net = Net()
>>> net(x)
class mindspore.ops.MatMul(*args, **kwargs)[source]

Multiplies matrix a by matrix b.

The rank of input tensors must be 2.

Parameters
  • transpose_a (bool) – If true, a is transposed before multiplication. Default: False.

  • transpose_b (bool) – If true, b is transposed before multiplication. Default: False.

Inputs:
  • input_x (Tensor) - The first tensor to be multiplied. The shape of the tensor is \((N, C)\). If transpose_a is True, its shape must be \((N, C)\) after transposing.

  • input_y (Tensor) - The second tensor to be multiplied. The shape of the tensor is \((C, M)\). If transpose_b is True, its shape must be \((C, M)\) after transpose.

Outputs:

Tensor, the shape of the output tensor is \((N, M)\).

Examples

>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
>>> matmul = P.MatMul()
>>> output = matmul(input_x, input_y)
class mindspore.ops.MaxPool(*args, **kwargs)[source]

Max pooling operation.

Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes.

Typically the input is of shape \((N_{in}, C_{in}, H_{in}, W_{in})\), MaxPool outputs regional maximum in the \((H_{in}, W_{in})\)-dimension. Given kernel size \(ks = (h_{ker}, w_{ker})\) and stride \(s = (s_0, s_1)\), the operation is as follows.

\[\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)\]
Parameters
  • ksize (Union[int, tuple[int]]) – The size of kernel used to take the maximum value, is an int number that represents height and width are both ksize, or a tuple of two int numbers that represent height and width respectively. Default: 1.

  • strides (Union[int, tuple[int]]) – The distance of kernel moving, an int number that represents the height and width of movement are both strides, or a tuple of two int numbers that represent height and width of movement respectively. Default: 1.

  • padding (str) –

    The optional value for pad mode, is “same” or “valid”, not case sensitive. Default: “valid”.

    • same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side.

    • valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded.

Inputs:
  • input (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\).

Outputs:

Tensor, with shape \((N, C_{out}, H_{out}, W_{out})\).

Examples

>>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> maxpool_op = P.MaxPool(padding="VALID", ksize=2, strides=1)
>>> output_tensor = maxpool_op(input_tensor)
class mindspore.ops.MaxPoolWithArgmax(ksize=1, strides=1, padding='valid')[source]

Perform max pooling on the input Tensor and return both max values and indices.

Typically the input is of shape \((N_{in}, C_{in}, H_{in}, W_{in})\), MaxPool outputs regional maximum in the \((H_{in}, W_{in})\)-dimension. Given kernel size \(ks = (h_{ker}, w_{ker})\) and stride \(s = (s_0, s_1)\), the operation is as follows.

\[\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)\]
Parameters
  • ksize (Union[int, tuple[int]]) – The size of kernel used to take the maximum value and arg value, is an int number that represents height and width are both ksize, or a tuple of two int numbers that represent height and width respectively. Default: 1.

  • strides (Union[int, tuple[int]]) – The distance of kernel moving, an int number that represents the height and width of movement are both strides, or a tuple of two int numbers that represent height and width of movement respectively. Default: 1.

  • padding (str) –

    The optional value for pad mode, is “same” or “valid”, not case sensitive. Default: “valid”.

    • same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side.

    • valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded.

Inputs:
  • input (Tensor) - Tensor of shape \((N, C_{in}, H_{in}, W_{in})\). Data type must be float16 or float32.

Outputs:

Tuple of 2 Tensors, representing the maxpool result and where the max values are generated.

  • output (Tensor) - Maxpooling result, with shape \((N, C_{out}, H_{out}, W_{out})\).

  • mask (Tensor) - Max values’ index represented by the mask.

Examples

>>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> maxpool_arg_op = P.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1)
>>> output_tensor, argmax = maxpool_arg_op(input_tensor)
class mindspore.ops.Maximum(*args, **kwargs)[source]

Computes the maximum of input tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32)
>>> input_y = Tensor(np.array([4.0, 2.0, 6.0]), mindspore.float32)
>>> maximum = P.Maximum()
>>> maximum(input_x, input_y)
[4.0, 5.0, 6.0]
class mindspore.ops.Merge(*args, **kwargs)[source]

Merges all input data to one.

One and only one of the inputs must be selected as the output

Inputs:
  • inputs (Union(Tuple, List)) - The data to be merged. All tuple elements must have the same data type.

Outputs:

tuple. Output is tuple(data, output_index). The data has the same shape of inputs element.

Examples

>>> merge = P.Merge()
>>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32)
>>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32)
>>> result = merge((input_x, input_y))
class mindspore.ops.Minimum(*args, **kwargs)[source]

Computes the minimum of input tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32)
>>> input_y = Tensor(np.array([4.0, 2.0, 6.0]), mindspore.float32)
>>> minimum = P.Minimum()
>>> minimum(input_x, input_y)
[1.0, 2.0, 3.0]
class mindspore.ops.MirrorPad(*args, **kwargs)[source]

Pads the input tensor according to the paddings and mode.

Parameters

mode (str) – Specifies the padding mode. The optional values are “REFLECT” and “SYMMETRIC”. Default: “REFLECT”.

Inputs:
  • input_x (Tensor) - The input tensor.

  • paddings (Tensor) - The paddings tensor. The value of paddings is a matrix(list), and its shape is (N, 2). N is the rank of input data. All elements of paddings are int type. For the input in the D th dimension, paddings[D, 0] indicates how many sizes to be extended ahead of the input tensor in the D th dimension, and paddings[D, 1] indicates how many sizes to be extended behind the input tensor in the D th dimension.

Outputs:

Tensor, the tensor after padding.

  • If mode is “REFLECT”, it uses a way of symmetrical copying through the axis of symmetry to fill in. If the input_x is [[1,2,3],[4,5,6],[7,8,9]] and paddings is [[1,1],[2,2]], then the Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]].

  • If mode is “SYMMETRIC”, the filling method is similar to the “REFLECT”. It is also copied according to the symmetry axis, except that it includes the symmetry axis. If the input_x is [[1,2,3],[4,5,6],[7,8,9]] and paddings is [[1,1],[2,2]], then the Outputs is [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]].

Examples

>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.pad = P.MirrorPad(mode="REFLECT")
>>>     def construct(self, x, paddings):
>>>         return self.pad(x, paddings)
>>> x = np.random.random(size=(2, 3)).astype(np.float32)
>>> paddings = Tensor([[1,1],[2,2]])
>>> pad = Net()
>>> ms_output = pad(Tensor(x), paddings)
class mindspore.ops.Mod(*args, **kwargs)[source]

Computes the remainder of dividing the first input tensor by the second input tensor element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, both dtypes cannot be bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number]) - The first input is a number or a tensor whose data type is number.

  • input_y (Union[Tensor, Number]) - When the first input is a tensor, The second input could be a number or a tensor whose data type is number. When the first input is a number, the second input must be a tensor whose data type is number.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Raises

ValueError – When input_x and input_y are not the same dtype.

Examples

>>> input_x = Tensor(np.array([-4.0, 5.0, 6.0]), mindspore.float32)
>>> input_y = Tensor(np.array([3.0, 2.0, 3.0]), mindspore.float32)
>>> mod = P.Mod()
>>> mod(input_x, input_y)
class mindspore.ops.Mul(*args, **kwargs)[source]

Multiplies two tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> input_y = Tensor(np.array([4.0, 5.0, 6.0]), mindspore.float32)
>>> mul = P.Mul()
>>> mul(input_x, input_y)
[4, 10, 18]
class mindspore.ops.Multinomial(*args, **kwargs)[source]

Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of tensor input.

Note

The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

Parameters

seed (int) – Seed data is used as entropy source for Random number engines to generate pseudo-random numbers. Must be non-negative. Default: 0.

Inputs:
  • input (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2

    dimensions.

  • num_samples (int32) - number of samples to draw.

Outputs:

Tensor with the same rows as input, each row has num_samples sampled indices.

Examples

>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = P.Multinomial(seed=10)
>>> output = multinomial(input, 2)
class mindspore.ops.MultitypeFuncGraph(name, read_value=False)[source]

Generate overloaded functions.

MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. Initialize an MultitypeFuncGraph object with name, and use register with input types as the decorator for the function to be registed. And the object can be called with different types of inputs, and work with HyperMap and Map.

Parameters
  • name (str) – Operator name.

  • read_value (bool) – If the registered function not need to set value on Parameter, and all inputs will pass by value, set read_value to True. Default: False.

Raises

ValueError – If failed to find find a matching function for the given arguments.

Examples

>>> # `add` is a metagraph object which will add two objects according to
>>> # input type using ".register" decorator.
>>> from mindspore import Tensor
>>> from mindspore.ops import Primitive, operations as P
>>> from mindspore import dtype as mstype
>>>
>>> scala_add = Primitive('scala_add')
>>> tensor_add = P.TensorAdd()
>>>
>>> add = MultitypeFuncGraph('add')
>>> @add.register("Number", "Number")
... def add_scala(x, y):
...     return scala_add(x, y)
>>> @add.register("Tensor", "Tensor")
... def add_tensor(x, y):
...     return tensor_add(x, y)
>>> add(1, 2)
3
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
Tensor(shape=[], dtype=Float32, 3)
register(*type_names)[source]

Register a function for the given type string.

Parameters

type_names (Union[str, mindspore.dtype]) – Inputs type names or types list.

Returns

decorator, a decorator to register the function to run, when called under the types described in type_names.

class mindspore.ops.NMSWithMask(*args, **kwargs)[source]

Select some bounding boxes in descending order of score.

Parameters

iou_threshold (float) – Specifies the threshold of overlap boxes with respect to IOU. Default: 0.5.

Raises

ValueError – If the iou_threshold is not a float number, or if the first dimension of input Tensor is less than or equal to 0, or if the data type of the input Tensor is not float16 or float32.

Inputs:
  • bboxes (Tensor) - The shape of tensor is \((N, 5)\). Input bounding boxes. N is the number of input bounding boxes. Every bounding box contains 5 values, the first 4 values are the coordinates of bounding box, and the last value is the score of this bounding box. The data type must be float16 or float32.

Outputs:

tuple[Tensor], tuple of three tensors, they are selected_boxes, selected_idx and selected_mask.

  • selected_boxes (Tensor) - The shape of tensor is \((N, 5)\). The list of bounding boxes after non-max suppression calculation.

  • selected_idx (Tensor) - The shape of tensor is \((N,)\). The indexes list of valid input bounding boxes.

  • selected_mask (Tensor) - The shape of tensor is \((N,)\). A mask list of valid output bounding boxes.

Examples

>>> bbox = np.random.rand(128, 5)
>>> bbox[:, 2] += bbox[:, 0]
>>> bbox[:, 3] += bbox[:, 1]
>>> inputs = Tensor(bbox, mindspore.float32)
>>> nms = P.NMSWithMask(0.5)
>>> output_boxes, indices, mask = nms(inputs)
class mindspore.ops.NPUAllocFloatStatus(*args, **kwargs)[source]

Allocates a flag to store the overflow status.

The flag is a tensor whose shape is (8,) and data type is mindspore.dtype.float32.

Note

Examples: see NPUGetFloatStatus.

Outputs:

Tensor, has the shape of (8,).

Examples

>>> alloc_status = P.NPUAllocFloatStatus()
>>> init = alloc_status()
Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=(8,), dtype=mindspore.float32)
class mindspore.ops.NPUClearFloatStatus(*args, **kwargs)[source]

Clear the flag which stores the overflow status.

Note

The flag is in the register on the Ascend device. It will be reset and can not be reused again after the NPUClearFloatStatus is called.

Examples: see NPUGetFloatStatus.

Inputs:
  • input_x (Tensor) - The output tensor of NPUAllocFloatStatus. The data type must be float16 or float32.

Outputs:

Tensor, has the same shape as input_x. All the elements in the tensor will be zero.

Examples

>>> alloc_status = P.NPUAllocFloatStatus()
>>> get_status = P.NPUGetFloatStatus()
>>> clear_status = P.NPUClearFloatStatus()
>>> init = alloc_status()
>>> flag = get_status(init)
>>> clear = clear_status(init)
Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=(8,), dtype=mindspore.float32)
class mindspore.ops.NPUGetFloatStatus(*args, **kwargs)[source]

Updates the flag which is the output tensor of NPUAllocFloatStatus with latest overflow status.

The flag is a tensor whose shape is (8,) and data type is mindspore.dtype.float32. If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there is overflow happened.

Inputs:
  • input_x (Tensor) - The output tensor of NPUAllocFloatStatus. The data type must be float16 or float32.

Outputs:

Tensor, has the same shape as input_x. All the elements in the tensor will be zero.

Examples

>>> alloc_status = P.NPUAllocFloatStatus()
>>> get_status = P.NPUGetFloatStatus()
>>> init = alloc_status()
>>> flag = get_status(init)
Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=(8,), dtype=mindspore.float32)
class mindspore.ops.Neg(*args, **kwargs)[source]

Returns a tensor with negative values of the input tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor whose dtype is number.

Outputs:

Tensor, has the same shape and dtype as input.

Examples

>>> neg = P.Neg()
>>> input_x = Tensor(np.array([1, 2, -1, 2, 0, -3.5]), mindspore.float32)
>>> result = neg(input_x)
[-1.  -2.   1.  -2.   0.   3.5]
class mindspore.ops.NotEqual(*args, **kwargs)[source]

Computes the non-equivalence of two tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> not_equal = P.NotEqual()
>>> not_equal(input_x, 2.0)
[True, False, True]
>>>
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([1, 2, 4]), mindspore.int32)
>>> not_equal = P.NotEqual()
>>> not_equal(input_x, input_y)
[False, False, True]
class mindspore.ops.OneHot(*args, **kwargs)[source]

Computes a one-hot tensor.

Makes a new tensor, whose locations represented by indices in indices take value on_value, while all other locations take value off_value.

Note

If the input indices is rank N, the output will have rank N+1. The new axis is created at dimension axis.

Parameters

axis (int) – Position to insert the value. e.g. If indices shape is [n, c], and axis is -1 the output shape will be [n, c, depth], If axis is 0 the output shape will be [depth, n, c]. Default: -1.

Inputs:
  • indices (Tensor) - A tensor of indices. Tensor of shape \((X_0, \ldots, X_n)\). Data type must be int32.

  • depth (int) - A scalar defining the depth of the one hot dimension.

  • on_value (Tensor) - A value to fill in output when indices[j] = i. With data type of float16 or float32.

  • off_value (Tensor) - A value to fill in output when indices[j] != i. Has the same data type with as on_value.

Outputs:

Tensor, one-hot tensor. Tensor of shape \((X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)\).

Examples

>>> indices = Tensor(np.array([0, 1, 2]), mindspore.int32)
>>> depth, on_value, off_value = 3, Tensor(1.0, mindspore.float32), Tensor(0.0, mindspore.float32)
>>> onehot = P.OneHot()
>>> result = onehot(indices, depth, on_value, off_value)
[[1, 0, 0], [0, 1, 0], [0, 0, 1]]
class mindspore.ops.OnesLike(*args, **kwargs)[source]

Creates a new tensor. The values of all elements are 1.

Returns a tensor of ones with the same shape and type as the input.

Inputs:
  • input_x (Tensor) - Input tensor.

Outputs:

Tensor, has the same shape and type as input_x but filled with ones.

Examples

>>> oneslike = P.OnesLike()
>>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
>>> output = oneslike(x)
[[1, 1],
 [1, 1]]
class mindspore.ops.PReLU(*args, **kwargs)[source]

Parametric Rectified Linear Unit activation function.

PReLU is described in the paper Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. Defined as follows:

\[prelu(x_i)= \max(0, x_i) + \min(0, w * x_i),\]

where \(x_i\) is an element of an channel of the input.

Note

1-dimensional input_x is not supported.

Inputs:
  • input_x (Tensor) - Float tensor, representing the output of the preview layer. With data type of float16 or float32.

  • weight (Tensor) - Float Tensor, w > 0, there are only two shapes are legitimate, 1 or the number of channels of the input. With data type of float16 or float32.

Outputs:

Tensor, with the same type as input_x.

For detailed information, please refer to nn.PReLU.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.prelu = P.PReLU()
>>>     def construct(self, input_x, weight):
>>>         result = self.prelu(input_x, weight)
>>>         return result
>>>
>>> input_x = Tensor(np.random.randint(-3, 3, (2, 3, 2)), mindspore.float32)
>>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32)
>>> net = Net()
>>> result = net(input_x, weight)
[[[-0.1, 1.0],
  [0.0, 2.0],
  [0.0, 0.0]],
 [[-0.2, -0.1],
  [2.0, -1.8],
  [0.6, 0.6]]]
class mindspore.ops.Pack(*args, **kwargs)[source]

Packs a list of tensors in specified axis.

Packs the list of input tensors with the same rank R, output is a tensor of rank (R+1).

Given input tensors of shape \((x_1, x_2, ..., x_R)\). Set the number of input tensors as N. If \(0 \le axis\), the shape of the output tensor is \((x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)\).

Parameters

axis (int) – Dimension to pack. Default: 0. Negative values wrap around. The range is [-(R+1), R+1).

Inputs:
  • input_x (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.

Outputs:

Tensor. A packed Tensor with the same type as input_x.

Raises
  • TypeError – If the data types of elements in input_x are not the same.

  • ValueError – If the length of input_x is not greater than 1; or if axis is out of the range [-(R+1), R+1); or if the shapes of elements in input_x are not the same.

Examples

>>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
>>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
>>> pack = P.Pack()
>>> output = pack([data1, data2])
[[0, 1], [2, 3]]
class mindspore.ops.Pad(*args, **kwargs)[source]

Pads input tensor according to the paddings.

Parameters

paddings (tuple) – The shape of parameter paddings is (N, 2). N is the rank of input data. All elements of paddings are int type. For the input in D th dimension, paddings[D, 0] indicates how many sizes to be extended ahead of the input tensor in the D th dimension, and paddings[D, 1] indicates how many sizes to be extended behind the input tensor in the D th dimension.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, the tensor after padding.

Examples

>>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> pad_op = P.Pad(((1, 2), (2, 1)))
>>> output_tensor = pad_op(input_tensor)
>>> assert output_tensor == Tensor(np.array([[ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
>>>                                          [ 0. ,  0. , -0.1,  0.3,  3.6,  0. ],
>>>                                          [ 0. ,  0. ,  0.4,  0.5, -3.2,  0. ],
>>>                                          [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
>>>                                          [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ]]), mindspore.float32)
class mindspore.ops.Padding(*args, **kwargs)[source]

Extends the last dimension of input tensor from 1 to pad_dim_size, by filling with 0.

Parameters

pad_dim_size (int) – The value of the last dimension of x to be extended, which must be positive.

Inputs:
  • x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The rank of x must be at least 2. The last dimension of x must be 1.

Outputs:

Tensor, the shape of tensor is \((z_1, z_2, ..., z_N)\).

Examples

>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]]
class mindspore.ops.ParallelConcat(*args, **kwargs)[source]

Concats tensor in the first dimension.

Concats input tensors along with the first dimension.

Note

The input tensors are all required to have size 1 in the first dimension.

Inputs:
  • values (tuple, list) - A tuple or a list of input tensors. The data type and shape of these tensors must be the same.

Outputs:

Tensor, data type is the same as values.

Examples

>>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
>>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
>>> op = P.ParallelConcat()
>>> output = op((data1, data2))
[[0, 1], [2, 1]]
class mindspore.ops.Partial(*args, **kwargs)[source]

Makes a partial function instance, used for pynative mode.

Inputs:
  • args (Union[FunctionType, Tensor]) - The function and bind arguments.

Outputs:

FunctionType, partial function binded with arguments.

class mindspore.ops.Poisson(*args, **kwargs)[source]

Produces random non-negative integer values i, distributed according to discrete probability function:

\[\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},\]
Parameters
  • seed (int) – Random seed, must be non-negative. Default: 0.

  • seed2 (int) – Random seed2, must be non-negative. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

  • mean (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number of occurrences of the event. It must be greater than 0. With float32 data type.

Outputs:

Tensor. Its shape must be the broadcasted shape of shape and the shape of mean. The dtype is int32.

Examples

>>> shape = (4, 16)
>>> mean = Tensor(5.0, mstype.float32)
>>> poisson = P.Poisson(seed=5)
>>> output = poisson(shape, mean)
class mindspore.ops.PopulationCount(*args, **kwargs)[source]

Calculates population count.

Inputs:
  • input (Tensor) - The data type must be int16 or uint16.

Outputs:

Tensor, with the sam shape as the input.

Examples

>>> population_count = P.PopulationCount()
>>> x_input = Tensor([0, 1, 3], mindspore.int16)
>>> population_count(x_input)
class mindspore.ops.Pow(*args, **kwargs)[source]

Computes a tensor to the power of the second input.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> input_y = 3.0
>>> pow = P.Pow()
>>> pow(input_x, input_y)
[1.0, 8.0, 64.0]
>>>
>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> input_y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32)
>>> pow = P.Pow()
>>> pow(input_x, input_y)
[1.0, 16.0, 64.0]
class mindspore.ops.Primitive(name)[source]

Primitive is the base class of primitives in python.

Parameters

name (str) – Name for the current Primitive.

Examples

>>> add = Primitive('add')
>>>
>>> # or work with prim_attr_register:
>>> # init a Primitive class with attr1 and attr2
>>> class Add(Primitive):
>>>     @prim_attr_register
>>>     def __init__(self, attr1, attr2):
>>>         # check attr1 and attr2 or do some initializations
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
add_prim_attr(name, value)[source]

Adds primitive attribute.

Parameters
  • name (str) – Attribute Name.

  • value (Any) – Attribute value.

check_elim(*args)[source]

Check if certain inputs should go to the backend. Subclass in need should override this method.

Parameters

args (Primitive args) – Same as arguments of current Primitive.

Returns

A tuple consisting of two elements. The first element indicates whether we should filter out current arguments; the seconde element is the output if we need to filter out the arguments.

init_prim_io_names(inputs, outputs)[source]

Initializes the name of inputs and outpus of Tensor or attributes.

Parameters
  • inputs (list[str]) – list of inputs names.

  • outputs (list[str]) – list of outputs names.

set_prim_instance_name(instance_name)[source]

Set instance name to primitive operator.

Note

It will be called by default when user defines primitive operator.

Parameters

instance_name (str) – Instance name of primitive operator set by user.

shard(strategy)[source]

Add strategies to primitive attribute.

Note

It is valid only in semi auto parallel or auto parallel mode. In other parallel modes, strategies set here will be ignored.

Parameters

strategy (tuple) – Strategy describes the distributed parallel mode of the current primitive.

property update_parameter

Whether the primitive will update the value of parameter.

class mindspore.ops.PrimitiveWithInfer(name)[source]

PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.

There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation.

Parameters

name (str) – Name of the current Primitive.

Examples

>>> # init a Primitive class with infer
>>> class Add(PrimitiveWithInfer):
>>>     @prim_attr_register
>>>     def __init__(self):
>>>         pass
>>>
>>>     def infer_shape(self, x, y):
>>>         return x # output shape same as first input 'x'
>>>
>>>     def infer_dtype(self, x, y):
>>>         return x # output type same as first input 'x'
>>>
>>> # init a Primitive obj
>>> add = Add()
infer_dtype(*args)[source]

Infer output dtype based on input dtype.

Parameters

args (mindspore.dtype) – data type of inputs.

Returns

mindspore.dtype, data type of outputs.

infer_shape(*args)[source]

Infer output shape based on input shape.

Note

The shape of scalar is an empty tuple.

Parameters

args (tuple(int)) – shapes of input tensors.

Returns

tuple(int), shapes of output tensors.

infer_value(*args)[source]

Infer output value based on input value at compile time.

Parameters

args (Any) – value of inputs.

Returns

Value of outputs. Return None, the value can not be inferred at compile time in this case.

class mindspore.ops.Print(*args, **kwargs)[source]

Outputs tensor or string to stdout.

Note

In pynative mode, please use python print function.

Inputs:
  • input_x (Union[Tensor, str]) - The graph node to attach to. The input supports multiple strings and tensors which are separated by ‘,’.

Examples

>>> class PrintDemo(nn.Cell):
>>>     def __init__(self):
>>>         super(PrintDemo, self).__init__()
>>>         self.print = P.Print()
>>>
>>>     def construct(self, x, y):
>>>         self.print('Print Tensor x and Tensor y:', x, y)
>>>         return x
class mindspore.ops.Pull(*args, **kwargs)[source]

Pulls weight from parameter server.

Inputs:
  • key (Tensor) - The key of the weight.

  • weight (Tensor) - The weight to be updated.

Outputs:

None.

class mindspore.ops.Push(*args, **kwargs)[source]

Pushes the inputs of the corresponding optimizer to parameter server.

Parameters
  • optim_type (string) – The optimizer type. Default: ‘ApplyMomentum’.

  • only_shape_indices (list) – The indices of input of which only shape will be pushed to parameter server. Default: None.

Inputs:
  • optim_inputs (tuple) - The inputs for this kind of optimizer.

  • optim_input_shapes (tuple) - The shapes of the inputs.

Outputs:

Tensor, the key of the weight which needs to be updated.

class mindspore.ops.RNNTLoss(*args, **kwargs)[source]

Computes the RNNTLoss and its gradient with respect to the softmax outputs.

Parameters

blank_label (int) – blank label. Default: 0.

Inputs:
  • acts (Tensor) - Tensor of shape \((B, T, U, V)\). Data type must be float16 or float32.

  • labels (Tensor[int32]) - Tensor of shape \((B, U-1)\).

  • input_lengths (Tensor[int32]) - Tensor of shape \((B,)\).

  • label_lengths (Tensor[int32]) - Tensor of shape \((B,)\).

Outputs:
  • costs (Tensor[int32]) - Tensor of shape \((B,)\).

  • grads (Tensor[int32]) - Has the same shape as acts.

Examples

>>> B, T, U, V = 1, 2, 3, 5
>>> acts = np.random.random((B, T, U, V)).astype(np.float32)
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = P.RNNTLoss(blank_label=blank)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
class mindspore.ops.ROIAlign(*args, **kwargs)[source]

Computes Region of Interest (RoI) Align operator.

The operator computes the value of each sampling point by bilinear interpolation from the nearby grid points on the feature map. No quantization is performed on any coordinates involved in the RoI, its bins, or the sampling points. The details of (RoI) Align operator are described in Mask R-CNN.

Parameters
  • pooled_height (int) – The output features’ height.

  • pooled_width (int) – The output features’ width.

  • spatial_scale (float) – A scaling factor that maps the raw image coordinates to the input feature map coordinates. Suppose the height of a RoI is ori_h in the raw image and fea_h in the input feature map, the spatial_scale must be fea_h / ori_h.

  • sample_num (int) – Number of sampling points. Default: 2.

  • roi_end_mode (int) – Number must be 0 or 1. Default: 1.

Inputs:
  • features (Tensor) - The input features, whose shape must be (N, C, H, W).

  • rois (Tensor) - The shape is (rois_n, 5). With data type of float16 or float32. rois_n represents the number of RoI. The size of the second dimension must be 5 and the 5 colunms are (image_index, top_left_x, top_left_y, bottom_right_x, bottom_right_y). image_index represents the index of image. top_left_x and top_left_y represent the x, y coordinates of the top left corner of corresponding RoI, respectively. bottom_right_x and bottom_right_y represent the x, y coordinates of the bottom right corner of corresponding RoI, respectively.

Outputs:

Tensor, the shape is (rois_n, C, pooled_height, pooled_width).

Examples

>>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32)
>>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32)
>>> roi_align = P.ROIAlign(2, 2, 0.5, 2)
>>> output_tensor = roi_align(input_tensor, rois)
>>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32)
class mindspore.ops.RandomCategorical(*args, **kwargs)[source]

Generates random samples from a given categorical distribution tensor.

Parameters

dtype (mindspore.dtype) – The type of output. Its value must be one of mindspore.int16, mindspore.int32 and mindspore.int64. Default: mindspore.int64.

Inputs:
  • logits (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].

  • num_sample (int) - Number of sample to be drawn. Only constant values is allowed.

  • seed (int) - Random seed. Default: 0. Only constant values is allowed.

Outputs:
  • output (Tensor) - The output Tensor with shape [batch_size, num_samples].

Examples

>>> class Net(nn.Cell):
>>>   def __init__(self, num_sample):
>>>     super(Net, self).__init__()
>>>     self.random_categorical = P.RandomCategorical(mindspore.int64)
>>>     self.num_sample = num_sample
>>>   def construct(self, logits, seed=0):
>>>     return self.random_categorical(logits, self.num_sample, seed)
>>>
>>> x = np.random.random((10, 5)).astype(np.float32)
>>> net = Net(8)
>>> output = net(Tensor(x))
class mindspore.ops.RandomChoiceWithMask(*args, **kwargs)[source]

Generates a random sample as index tensor with a mask tensor from a given tensor.

The input must be a tensor of rank not less than 1. If its rank is greater than or equal to 2, the first dimension specifies the number of samples. The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero sample, while the mask tensor denotes which elements in the index tensor are valid.

Parameters
  • count (int) – Number of items expected to get and the number must be greater than 0. Default: 256.

  • seed (int) – Random seed. Default: 0.

  • seed2 (int) – Random seed2. Default: 0.

Inputs:
  • input_x (Tensor[bool]) - The input tensor.

    The input tensor rank must be greater than or equal to 1 and less than or equal to 5.

Outputs:

Two tensors, the first one is the index tensor and the other one is the mask tensor.

  • index (Tensor) - The output shape is 2-D.

  • mask (Tensor) - The output shape is 1-D.

Examples

>>> rnd_choice_mask = P.RandomChoiceWithMask()
>>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
>>> output_y, output_mask = rnd_choice_mask(input_x)
class mindspore.ops.Rank(*args, **kwargs)[source]

Returns the rank of a tensor.

Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor is the number of indices required to uniquely select each element of the tensor.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor. 0-D int32 Tensor representing the rank of input, i.e., \(R\).

Examples

>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> rank = P.Rank()
>>> rank(input_tensor)
class mindspore.ops.ReLU(*args, **kwargs)[source]

Computes ReLU(Rectified Linear Unit) of input tensor element-wise.

It returns \(\max(x,\ 0)\) element-wise.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> relu = P.ReLU()
>>> result = relu(input_x)
[[0, 4.0, 0.0], [2.0, 0.0, 9.0]]
class mindspore.ops.ReLU6(*args, **kwargs)[source]

Computes ReLU(Rectified Linear Unit) upper bounded by 6 of input tensor element-wise.

It returns \(\min(\max(0,x), 6)\) element-wise.

Inputs:
  • input_x (Tensor) - The input tensor, with float16 or float32 data type.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> relu6 = P.ReLU6()
>>> result = relu6(input_x)
class mindspore.ops.ReLUV2(*args, **kwargs)[source]

Computes ReLU(Rectified Linear Unit) of input tensor element-wise.

It returns \(\max(x,\ 0)\) element-wise.

Inputs:
  • input_x (Tensor) - The input tensor must be a 4-D tensor.

Outputs:
  • output (Tensor) - Has the same type and shape as the input_x.

  • mask (Tensor) - A tensor whose data type must be uint8.

Examples

>>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32)
>>> relu_v2 = P.ReLUV2()
>>> output = relu_v2(input_x)
([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]],
 [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]])
class mindspore.ops.RealDiv(*args, **kwargs)[source]

Divide the first input tensor by the second input tensor in floating-point type element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> input_y = Tensor(np.array([4.0, 5.0, 6.0]), mindspore.float32)
>>> realdiv = P.RealDiv()
>>> realdiv(input_x, input_y)
[0.25, 0.4, 0.5]
class mindspore.ops.Reciprocal(*args, **kwargs)[source]

Returns reciprocal of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> reciprocal = P.Reciprocal()
>>> reciprocal(input_x)
[1.0, 0.5, 0.25]
class mindspore.ops.ReduceAll(*args, **kwargs)[source]

Reduce a dimension of a tensor by the “logical and” of all elements in the dimension.

The dtype of the tensor to be reduced is bool.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default : False, don’t keep these reduced dimensions.

Inputs:
  • input_x (Tensor[bool]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, the dtype is bool.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the “logical and” of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.array([[True, False], [True, True]]))
>>> op = P.ReduceAll(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.ReduceAny(*args, **kwargs)[source]

Reduce a dimension of a tensor by the “logical OR” of all elements in the dimension.

The dtype of the tensor to be reduced is bool.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default : False, don’t keep these reduced dimensions.

Inputs:
  • input_x (Tensor[bool]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, the dtype is bool.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the “logical or” of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.array([[True, False], [True, True]]))
>>> op = P.ReduceAny(keep_dims=True)
>>> output = op(input_x, 1)
[[True],
 [True]]
class mindspore.ops.ReduceMax(*args, **kwargs)[source]

Reduce a dimension of a tensor by the maximum value in this dimension.

The dtype of the tensor to be reduced is number.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default : False, don’t keep these reduced dimensions.

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, has the same dtype as the input_x.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the maximum of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = P.ReduceMax(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.ReduceMean(*args, **kwargs)[source]

Reduce a dimension of a tensor by averaging all elements in the dimension.

The dtype of the tensor to be reduced is number.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default: False.

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, has the same dtype as the input_x.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the mean of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = P.ReduceMean(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.ReduceMin(*args, **kwargs)[source]

Reduce a dimension of a tensor by the minimum value in the dimension.

The dtype of the tensor to be reduced is number.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default : False, don’t keep these reduced dimensions.

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, has the same dtype as the input_x.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the minimum of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = P.ReduceMin(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.ReduceOp[source]

Operation options for reduce tensors.

There are four kinds of operation options, “SUM”, “MAX”, “MIN”, and “PROD”.

  • SUM: Take the sum.

  • MAX: Take the maximum.

  • MIN: Take the minimum.

  • PROD: Take the product.

class mindspore.ops.ReduceProd(*args, **kwargs)[source]

Reduce a dimension of a tensor by multiplying all elements in the dimension.

The dtype of the tensor to be reduced is number.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default : False, don’t keep these reduced dimensions.

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, has the same dtype as the input_x.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the product of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = P.ReduceProd(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.ReduceScatter(*args, **kwargs)[source]

Reduces and scatters tensors from the specified communication group.

Note

The back propagation of the op is not supported yet. Stay tuned for more. The tensors must have the same shape and format in all processes of the collection.

Parameters
  • op (str) – Specifies an operation used for element-wise reductions, like SUM, MAX, AVG. Default: ReduceOp.SUM.

  • group (str) – The communication group to work on. Default: “hccl_world_group”.

Raises
  • TypeError – If any of operation and group is not a string.

  • ValueError – If the first dimension of the input cannot be divided by the rank size.

Examples

>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> from mindspore.ops.operations.comm_ops import ReduceOp
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init()
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.reducescatter = P.ReduceScatter(ReduceOp.SUM, group="nccl_world_group")
>>>
>>>     def construct(self, x):
>>>         return self.reducescatter(x)
>>>
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
class mindspore.ops.ReduceSum(*args, **kwargs)[source]

Reduce a dimension of a tensor by summing all elements in the dimension.

The dtype of the tensor to be reduced is number.

Parameters

keep_dims (bool) – If true, keep these reduced dimensions and the length is 1. If false, don’t keep these dimensions. Default: False.

Inputs:
  • input_x (Tensor[Number]) - The input tensor.

  • axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. Only constant value is allowed.

Outputs:

Tensor, has the same dtype as the input_x.

  • If axis is (), and keep_dims is False, the output is a 0-D tensor representing the sum of all elements in the input tensor.

  • If axis is int, set as 2, and keep_dims is False, the shape of output is \((x_1, x_3, ..., x_R)\).

  • If axis is tuple(int), set as (2, 3), and keep_dims is False, the shape of output is \((x_1, x_4, ..., x_R)\).

Examples

>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = P.ReduceSum(keep_dims=True)
>>> output = op(input_x, 1)
class mindspore.ops.Reshape(*args, **kwargs)[source]

Reshapes input tensor with the same values based on a given shape tuple.

Raises

ValueError – Given a shape tuple, if it has several -1; or if the product of its elements is less than or equal to 0 or cannot be divided by the product of the input tensor shape; or if it does not match the input’s array size.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

  • input_shape (tuple[int]) - The input tuple is constructed by multiple integers, i.e., \((y_1, y_2, ..., y_S)\). Only constant value is allowed.

Outputs:

Tensor, the shape of tensor is \((y_1, y_2, ..., y_S)\).

Examples

>>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> reshape = P.Reshape()
>>> output = reshape(input_tensor, (3, 2))
class mindspore.ops.ResizeBilinear(*args, **kwargs)[source]

Resizes the image to certain size using bilinear interpolation.

The resizing only affects the lower two dimensions which represent the height and width. The input images can be represented by different data types, but the data types of output images are always float32.

Parameters
  • size (tuple[int]) – A tuple of 2 int elements (new_height, new_width), the new size of the images.

  • align_corners (bool) – If true, rescale input by (new_height - 1) / (height - 1), which exactly aligns the 4 corners of images and resized images. If false, rescale by new_height / height. Default: False.

Inputs:
  • input (Tensor) - Image to be resized. Input images must be a 4-D tensor with shape \((batch, channels, height, width)\), with data type of float32 or float16.

Outputs:

Tensor, resized image. 4-D with shape [batch, channels, new_height, new_width] in float32.

Examples

>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> resize_bilinear = P.ResizeBilinear((5, 5))
>>> result = resize_bilinear(tensor)
>>> assert result.shape == (1, 1, 5, 5)
class mindspore.ops.ResizeNearestNeighbor(*args, **kwargs)[source]

Resizes the input tensor by using nearest neighbor algorithm.

Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest neighbor algorithm selects the value of the nearest point and does not consider the values of neighboring points at all, yielding a piecewise-constant interpolant.

Parameters
  • size (Union[tuple, list]) – The target size. The dimension of size must be 2.

  • align_corners (bool) – Whether the centers of the 4 corner pixels of the input and output tensors are aligned. Default: False.

Inputs:
  • input_x (Tensor) - The input tensor. The shape of the tensor is \((N, C, H, W)\).

Outputs:

Tensor, the shape of the output tensor is \((N, C, NEW\_H, NEW\_W)\).

Examples

>>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> resize = P.ResizeNearestNeighbor((2, 2))
>>> output = resize(input_tensor)
class mindspore.ops.ReverseSequence(*args, **kwargs)[source]

Reverses variable length slices.

Parameters
  • seq_dim (int) – The dimension where reversal is performed. Required.

  • batch_dim (int) – The input is sliced in this dimension. Default: 0.

Inputs:
  • x (Tensor) - The input to reverse, supporting all number types including bool.

  • seq_lengths (Tensor) - Must be a 1-D vector with int32 or int64 types.

Outputs:

Reversed tensor with the same shape and data type as input.

Examples

>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> reverse_sequence = P.ReverseSequence(seq_dim=1)
>>> output = reverse_sequence(x, seq_lengths)
[[1 2 3]
 [5 4 6]
 [9 8 7]]
class mindspore.ops.ReverseV2(*args, **kwargs)[source]

Reverses specific dimensions of a tensor.

Parameters

axis (Union[tuple(int), list(int)) – The indices of the dimensions to reverse.

Inputs:
  • input_x (Tensor) - The target tensor.

Outputs:

Tensor, has the same shape and type as input_x.

Examples

>>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
>>> op = P.ReverseV2(axis=[1])
>>> output = op(input_x)
[[4, 3, 2, 1], [8, 7, 6, 5]]
class mindspore.ops.Rint(*args, **kwargs)[source]

Returns element-wise integer closest to x.

Inputs:
  • input_x (Tensor) - The target tensor, which must be one of the following types: float16, float32.

Outputs:

Tensor, has the same shape and type as input_x.

Examples

>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
>>> op = P.Rint()
>>> output = op(input_x)
[-2., 0., 2., 2.]
class mindspore.ops.Round(*args, **kwargs)[source]

Returns half to even of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape and type as the input_x.

Examples

>>> input_x = Tensor(np.array([0.8, 1.5, 2.3, 2.5, -4.5]), mindspore.float32)
>>> round = P.Round()
>>> round(input_x)
[1.0, 2.0, 2.0, 2.0, -4.0]
class mindspore.ops.Rsqrt(*args, **kwargs)[source]

Computes reciprocal of square root of input tensor element-wise.

Inputs:
  • input_x (Tensor) - The input of Rsqrt. Each element must be a non-negative number.

Outputs:

Tensor, has the same type and shape as input_x.

Examples

>>> input_tensor = Tensor([[4, 4], [9, 9]], mindspore.float32)
>>> rsqrt = P.Rsqrt()
>>> rsqrt(input_tensor)
[[0.5, 0.5], [0.333333, 0.333333]]
class mindspore.ops.SGD(*args, **kwargs)[source]

Computes stochastic gradient descent (optionally with momentum).

Nesterov momentum is based on the formula from On the importance of initialization and momentum in deep learning.

Note

For details, please refer to nn.SGD source code.

Parameters
  • dampening (float) – The dampening for momentum. Default: 0.0.

  • weight_decay (float) – Weight decay (L2 penalty). Default: 0.0.

  • nesterov (bool) – Enable Nesterov momentum. Default: False.

Inputs:
  • parameters (Tensor) - Parameters to be updated. With float16 or float32 data type.

  • gradient (Tensor) - Gradient, with float16 or float32 data type.

  • learning_rate (Tensor) - Learning rate, a scalar tensor with float16 or float32 data type. e.g. Tensor(0.1, mindspore.float32)

  • accum (Tensor) - Accum(velocity) to be updated. With float16 or float32 data type.

  • momentum (Tensor) - Momentum, a scalar tensor with float16 or float32 data type. e.g. Tensor(0.1, mindspore.float32).

  • stat (Tensor) - States to be updated with the same shape as gradient, with float16 or float32 data type.

Outputs:

Tensor, parameters to be updated.

Examples

>>> sgd = P.SGD()
>>> parameters = Tensor(np.array([2, -0.5, 1.7, 4]), mindspore.float32)
>>> gradient = Tensor(np.array([1, -1, 0.5, 2]), mindspore.float32)
>>> learning_rate = Tensor(0.01, mindspore.float32)
>>> accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mindspore.float32)
>>> momentum = Tensor(0.1, mindspore.float32)
>>> stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mindspore.float32)
>>> result = sgd(parameters, gradient, learning_rate, accum, momentum, stat)
class mindspore.ops.SameTypeShape(*args, **kwargs)[source]

Checks whether data type and shape of two tensors are the same.

Raises
  • TypeError – If the data types of two tensors are not the same.

  • ValueError – If the shapes of two tensors are not the same.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

  • input_y (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_S)\).

Outputs:

Tensor, the shape of tensor is \((x_1, x_2, ..., x_R)\), if data type and shape of input_x and input_y are the same.

Examples

>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> out = P.SameTypeShape()(input_x, input_y)
class mindspore.ops.ScalarCast(*args, **kwargs)[source]

Cast the input scalar to another type.

Inputs:
  • input_x (scalar) - The input scalar. Only constant value is allowed.

  • input_y (mindspore.dtype) - The type to be cast. Only constant value is allowed.

Outputs:

Scalar. The type is the same as the python type corresponding to input_y.

Examples

>>> scalar_cast = P.ScalarCast()
>>> output = scalar_cast(255.0, mindspore.int32)
class mindspore.ops.ScalarSummary(*args, **kwargs)[source]

Outputs a scalar to a protocol buffer through a scalar summary operator.

Inputs:
  • name (str) - The name of the input variable, it must not be an empty string.

  • value (Tensor) - The value of scalar, and the shape of value must be [] or [1].

Examples

>>> class SummaryDemo(nn.Cell):
>>>     def __init__(self,):
>>>         super(SummaryDemo, self).__init__()
>>>         self.summary = P.ScalarSummary()
>>>         self.add = P.TensorAdd()
>>>
>>>     def construct(self, x, y):
>>>         name = "x"
>>>         self.summary(name, x)
>>>         x = self.add(x, y)
>>>         return x
class mindspore.ops.ScalarToArray(*args, **kwargs)[source]

Converts a scalar to a Tensor.

Inputs:
  • input_x (Union[int, float]) - The input is a scalar. Only constant value is allowed.

Outputs:

Tensor. 0-D Tensor and the content is the input.

Examples

>>> op = P.ScalarToArray()
>>> data = 1.0
>>> output = op(data)
class mindspore.ops.ScalarToTensor(*args, **kwargs)[source]

Converts a scalar to a Tensor, and convert data type to specified type.

Inputs:
  • input_x (Union[int, float]) - The input is a scalar. Only constant value is allowed.

  • dtype (mindspore.dtype) - The target data type. Default: mindspore.float32. Only constant value is allowed.

Outputs:

Tensor. 0-D Tensor and the content is the input.

Examples

>>> op = P.ScalarToTensor()
>>> data = 1
>>> output = op(data, mindspore.float32)
class mindspore.ops.ScatterAdd(*args, **kwargs)[source]

Updates the value of the input tensor through the add operation.

Using given values to update tensor value through the add operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do add operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the add operation with input_x, the data type is the same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> scatter_add = P.ScatterAdd()
>>> output = scatter_add(input_x, indices, updates)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
class mindspore.ops.ScatterDiv(*args, **kwargs)[source]

Updates the value of the input tensor through the div operation.

Using given values to update tensor value through the div operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do div operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the div operation with input_x, the data type is the same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_div = P.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
[[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]
class mindspore.ops.ScatterMax(*args, **kwargs)[source]

Updates the value of the input tensor through the max operation.

Using given values to update tensor value through the max operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: True.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do max operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the maximum operation with input_x, the data type is the same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
>>> scatter_max = P.ScatterMax()
>>> output = scatter_max(input_x, indices, update)
[[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
class mindspore.ops.ScatterMin(*args, **kwargs)[source]

Updates the value of the input tensor through the min operation.

Using given values to update tensor value through the min operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do min operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor doing the min operation with input_x, the data type is same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> scatter_min = P.ScatterMin()
>>> output = scatter_min(input_x, indices, update)
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
class mindspore.ops.ScatterMul(*args, **kwargs)[source]

Updates the value of the input tensor through the mul operation.

Using given values to update tensor value through the mul operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do mul operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor doing the mul operation with input_x, the data type is same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_mul = P.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]
class mindspore.ops.ScatterNd(*args, **kwargs)[source]

Scatters a tensor into a new tensor depending on the specified indices.

Creates an empty tensor, and set values by scattering the update tensor depending on indices.

Inputs:
  • indices (Tensor) - The index of scattering in the new tensor with int32 data type.

  • update (Tensor) - The source Tensor to be scattered.

  • shape (tuple[int]) - Define the shape of the output tensor, has the same type as indices.

Outputs:

Tensor, the new tensor, has the same type as update and the same shape as shape.

Examples

>>> op = P.ScatterNd()
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([3.2, 1.1]), mindspore.float32)
>>> shape = (3, 3)
>>> output = op(indices, update, shape)
class mindspore.ops.ScatterNdAdd(*args, **kwargs)[source]

Applies sparse addition to individual values or slices in a Tensor.

Using given values to update tensor value through the add operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do add operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor doing the add operation with input_x, the data type is same as input_x, the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_add = P.ScatterNdAdd()
>>> output = scatter_nd_add(input_x, indices, updates)
[1, 10, 9, 4, 12, 6, 7, 17]
class mindspore.ops.ScatterNdSub(*args, **kwargs)[source]

Applies sparse subtraction to individual values or slices in a Tensor.

Using given values to update tensor value through the subtraction operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to do add operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the subtraction operation with input_x, the data type is the same as input_x, the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_sub = P.ScatterNdSub()
>>> output = scatter_nd_sub(input_x, indices, updates)
[1, -6, -3, 4, -2, 6, 7, -1]
class mindspore.ops.ScatterNdUpdate(*args, **kwargs)[source]

Updates tensor value by using input indices and value.

Using given values to update tensor value, along with the input indices.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: True.

Inputs:
  • input_x (Parameter) - The target tensor, with data type of Parameter.

  • indices (Tensor) - The index of input tensor, with int32 data type.

  • update (Tensor) - The tensor to be updated to the input tensor, has the same type as input.

Outputs:

Tensor, has the same shape and type as input_x.

Examples

>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
class mindspore.ops.ScatterNonAliasingAdd(*args, **kwargs)[source]

Applies sparse addition to input using individual values or slices.

Using given values to update tensor value through the add operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Inputs:
  • input_x (Parameter) - The target parameter. The data type must be float16, float32 or int32.

  • indices (Tensor) - The index to perform the addition operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the addition operation with input_x, the data type is the same as input_x, the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_non_aliasing_add = P.ScatterNonAliasingAdd()
>>> output = scatter_non_aliasing_add(input_x, indices, updates)
[1, 10, 9, 4, 12, 6, 7, 17]
class mindspore.ops.ScatterSub(*args, **kwargs)[source]

Updates the value of the input tensor through the subtraction operation.

Using given values to update tensor value through the subtraction operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: False.

Inputs:
  • input_x (Parameter) - The target parameter.

  • indices (Tensor) - The index to perform the subtraction operation whose data type must be mindspore.int32.

  • updates (Tensor) - The tensor that performs the subtraction operation with input_x, the data type is the same as input_x, the shape is indices_shape + x_shape[1:].

Outputs:

Parameter, the updated input_x.

Examples

>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_sub = P.ScatterSub()
>>> output = scatter_sub(input_x, indices, updates)
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
class mindspore.ops.ScatterUpdate(*args, **kwargs)[source]

Updates tensor value by using input indices and value.

Using given values to update tensor value, along with the input indices.

Inputs of input_x and updates comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – Whether protect the assignment by a lock. Default: True.

Inputs:
  • input_x (Parameter) - The target tensor, with data type of Parameter.

  • indices (Tensor) - The index of input tensor. With int32 data type.

  • updates (Tensor) - The tensor to update the input tensor, has the same type as input, and updates.shape = indices.shape + input_x.shape[1:].

Outputs:

Tensor, has the same shape and type as input_x.

Examples

>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> updates = Tensor(np_updates, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, updates)
[[2.0, 1.2, 1.0],
 [3.0, 1.2, 1.0]]
class mindspore.ops.Select(*args, **kwargs)[source]

Returns the selected elements, either from input \(x\) or input \(y\), depending on the condition.

Given a tensor as input, this operation inserts a dimension of 1 at the dimension, if both \(x\) and \(y\) are none, the operation returns the coordinates of the true element in the condition, the coordinates are returned as a two-dimensional tensor, where the first dimension (row) represents the number of true elements and the second dimension (columns) represents the coordinates of the true elements. Keep in mind that the shape of the output tensor can vary depending on how many true values are in the input. Indexes are output in row-first order.

If neither is None, \(x\) and \(y\) must have the same shape. If \(x\) and \(y\) are scalars, the conditional tensor must be a scalar. If \(x\) and \(y\) are higher-demensional vectors, the condition must be a vector whose size matches the first dimension of \(x\), or must have the same shape as \(y\).

The conditional tensor acts as an optional compensation (mask), which determines whether the corresponding element / row in the output must be selected from \(x\) (if true) or \(y\) (if false) based on the value of each element.

If condition is a vector, then \(x\) and \(y\) are higher-demensional matrices, then it chooses to copy that row (external dimensions) from \(x\) and \(y\). If condition has the same shape as \(x\) and \(y\), you can choose to copy these elements from \(x\) and \(y\).

Inputs:
  • input_x (Tensor[bool]) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The condition tensor, decides which element is chosen.

  • input_y (Tensor) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The first input tensor.

  • input_z (Tensor) - The shape is \((x_1, x_2, ..., x_N, ..., x_R)\). The second input tensor.

Outputs:

Tensor, has the same shape as input_y. The shape is \((x_1, x_2, ..., x_N, ..., x_R)\).

Examples

>>> select = P.Select()
>>> input_x = Tensor([True, False])
>>> input_y = Tensor([2,3], mindspore.float32)
>>> input_z = Tensor([1,2], mindspore.float32)
>>> select(input_x, input_y, input_z)
class mindspore.ops.Shape(*args, **kwargs)[source]

Returns the shape of input tensor.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

tuple[int], the output tuple is constructed by multiple integers, \((x_1, x_2, ..., x_R)\).

Examples

>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> shape = P.Shape()
>>> output = shape(input_tensor)
class mindspore.ops.Sigmoid(*args, **kwargs)[source]

Sigmoid activation function.

Computes Sigmoid of input element-wise. The Sigmoid function is defined as:

\[\text{sigmoid}(x_i) = \frac{1}{1 + exp(-x_i)},\]

where \(x_i\) is the element of the input.

Inputs:
  • input_x (Tensor) - The input of Sigmoid, data type must be float16 or float32.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> sigmoid = P.Sigmoid()
>>> sigmoid(input_x)
[0.73105866, 0.880797, 0.9525742, 0.98201376, 0.9933071]
class mindspore.ops.SigmoidCrossEntropyWithLogits(*args, **kwargs)[source]

Uses the given logits to compute sigmoid cross entropy.

Note

Sets input logits as X, input label as Y, output as loss. Then,

\[p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}\]
\[loss_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})]\]
Inputs:
  • logits (Tensor) - Input logits.

  • label (Tensor) - Ground truth label.

Outputs:

Tensor, with the same shape and type as input logits.

Examples

>>> logits = Tensor(np.random.randn(2, 3).astype(np.float16))
>>> labels = Tensor(np.random.randn(2, 3).astype(np.float16))
>>> sigmoid = P.SigmoidCrossEntropyWithLogits()
>>> sigmoid(logits, labels)
class mindspore.ops.Sign(*args, **kwargs)[source]

Perform \(sign\) on tensor element-wise.

Note

\[sign(x) = \begin{cases} -1, &if\ x < 0 \cr 0, &if\ x = 0 \cr 1, &if\ x > 0\end{cases}\]
Inputs:
  • input_x (Tensor) - The input tensor.

Outputs:

Tensor, has the same shape and type as the input_x.

Examples

>>> input_x = Tensor(np.array([[2.0, 0.0, -1.0]]), mindspore.float32)
>>> sign = P.Sign()
>>> output = sign(input_x)
[[1.0, 0.0, -1.0]]
class mindspore.ops.Sin(*args, **kwargs)[source]

Computes sine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> sin = P.Sin()
>>> input_x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
>>> output = sin(input_x)
class mindspore.ops.Sinh(*args, **kwargs)[source]

Computes hyperbolic sine of input element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> sinh = P.Sinh()
>>> input_x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
>>> output = sinh(input_x)
[0.6604918 0.28367308 0.44337422 0.6604918]
class mindspore.ops.Size(*args, **kwargs)[source]

Returns the elements count size of a tensor.

Returns an int scalar representing the elements size of input, the total number of elements in the tensor.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

int, a scalar representing the elements size of input_x, tensor is the number of elements in a tensor, \(size=x_1*x_2*...x_R\).

Examples

>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> size = P.Size()
>>> output = size(input_tensor)
class mindspore.ops.Slice(*args, **kwargs)[source]

Slices a tensor in the specified shape.

Parameters
  • x (Tensor) – The target tensor.

  • begin (tuple) – The beginning of the slice. Only constant value is allowed.

  • size (tuple) – The size of the slice. Only constant value is allowed.

Returns

Tensor.

Examples

>>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
>>>                         [[3, 3, 3], [4, 4, 4]],
>>>                         [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
>>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3))
class mindspore.ops.SmoothL1Loss(*args, **kwargs)[source]

Computes smooth L1 loss, a robust L1 loss.

SmoothL1Loss is a Loss similar to MSELoss but less sensitive to outliers as described in the Fast R-CNN by Ross Girshick.

Note

Sets input prediction as X, input target as Y, output as loss. Then,

\[\text{SmoothL1Loss} = \begin{cases} \frac{0.5 x^{2}}{\text{beta}}, &if \left |x \right | < \text{beta} \cr \left |x \right|-0.5 \text{beta}, &\text{otherwise}\end{cases}\]
Parameters

beta (float) – A parameter used to control the point where the function will change from quadratic to linear. Default: 1.0.

Inputs:
  • prediction (Tensor) - Predict data. Data type must be float16 or float32.

  • target (Tensor) - Ground truth data, with the same type and shape as prediction.

Outputs:

Tensor, with the same type and shape as prediction.

Examples

>>> loss = P.SmoothL1Loss()
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> loss(input_data, target_data)
[0, 0, 0.5]
class mindspore.ops.Softmax(*args, **kwargs)[source]

Softmax operation.

Applies the Softmax operation to the input tensor on the specified axis. Suppose a slice in the given aixs \(x\), then for each element \(x_i\), the Softmax function is shown as follows:

\[\text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},\]

where \(N\) is the length of the tensor.

Parameters

axis (Union[int, tuple]) – The axis to perform the Softmax operation. Default: -1.

Inputs:
  • logits (Tensor) - The input of Softmax, with float16 or float32 data type.

Outputs:

Tensor, with the same type and shape as the logits.

Examples

>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> softmax = P.Softmax()
>>> softmax(input_x)
[0.01165623, 0.03168492, 0.08612854, 0.23412167, 0.6364086]
class mindspore.ops.SoftmaxCrossEntropyWithLogits(*args, **kwargs)[source]

Gets the softmax cross-entropy value between logits and labels with one-hot encoding.

Note

Sets input logits as X, input label as Y, output as loss. Then,

\[p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}\]
\[loss_{ij} = -\sum_j{Y_{ij} * ln(p_{ij})}\]
Inputs:
  • logits (Tensor) - Input logits, with shape \((N, C)\). Data type must be float16 or float32.

  • labels (Tensor) - Ground truth labels, with shape \((N, C)\), has the same data type with logits.

Outputs:

Tuple of 2 tensors, the loss shape is (N,), and the dlogits with the same shape as logits.

Examples

>>> logits = Tensor([[2, 4, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
>>> labels = Tensor([[0, 0, 0, 0, 1], [0, 0, 0, 1, 0]], mindspore.float32)
>>> softmax_cross = P.SoftmaxCrossEntropyWithLogits()
>>> loss, backprop = softmax_cross(logits, labels)
([0.5899297, 0.52374405], [[0.02760027, 0.20393994, 0.01015357, 0.20393994, -0.44563377],
[0.08015892, 0.02948882, 0.08015892, -0.4077012, 0.21789455]])
class mindspore.ops.Softplus(*args, **kwargs)[source]

Softplus activation function.

Softplus is a smooth approximation to the ReLU function. The function is shown as follows:

\[\text{output} = \log(1 + \exp(\text{input_x})),\]
Inputs:
  • input_x (Tensor) - The input tensor whose data type must be float.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> softplus = P.Softplus()
>>> softplus(input_x)
[1.3132615, 2.126928, 3.0485873, 4.01815, 5.0067153]
class mindspore.ops.Softsign(*args, **kwargs)[source]

Softsign activation function.

The function is shown as follows:

\[\text{output} = \frac{\text{input_x}}{1 + \left| \text{input_x} \right|},\]
Inputs:
  • input_x (Tensor) - The input tensor whose data type must be float16 or float32.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
>>> softsign = P.Softsign()
>>> softsign(input_x)
[0. -0.5 0.6666667 0.9677419 -0.9677419]
class mindspore.ops.Sort(*args, **kwargs)[source]

Sorts the elements of the input tensor along a given dimension in ascending order by value.

Parameters
  • axis (int) – The dimension to sort along. Default: -1.

  • descending (bool) – Controls the sorting order. If descending is True then the elements are sorted in descending order by value. Default: False.

Inputs:
  • x (Tensor) - The input to sort, with float16 or float32 data type.

Outputs:
  • y1 (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.

  • y2 (Tensor) - The indices of the elements in the original input tensor. Data type is int32.

Examples

>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
>>> sort = P.Sort()
>>> sort(x)
>>> ([[1.0, 2.0, 8.0], [3.0, 5.0, 9.0], [4.0, 6.0 ,7.0]],
     [[2, 1, 0], [2, 0, 1], [0, 1, 2]])
class mindspore.ops.SpaceToBatch(*args, **kwargs)[source]

Divides spatial dimensions into blocks and combine the block size with the original batch.

This operation will divide spatial dimensions (H, W) into blocks with block_size, the output tensor’s H and W dimension is the corresponding number of blocks after division. The output tensor’s batch dimension is the product of the original batch and the square of block_size. Before division, the spatial dimensions of the input are zero padded according to paddings if necessary.

Parameters
  • block_size (int) – The block size of dividing blocks with value greater than 2.

  • paddings (list) – The padding values for H and W dimension, containing 2 subtraction lists. Each subtraction list contains 2 integer value. All values must be greater than 0. paddings[i] specifies the paddings for the spatial dimension i, which corresponds to the input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible by block_size.

Inputs:
  • input_x (Tensor) - The input tensor. It must be a 4-D tensor.

Outputs:

Tensor, the output tensor with the same data type as input. Assume input shape is \((n, c, h, w)\) with \(block\_size\) and \(paddings\). The shape of the output tensor will be \((n', c', h', w')\), where

\(n' = n*(block\_size*block\_size)\)

\(c' = c\)

\(h' = (h+paddings[0][0]+paddings[0][1])//block\_size\)

\(w' = (w+paddings[1][0]+paddings[1][1])//block\_size\)

Examples

>>> block_size = 2
>>> paddings = [[0, 0], [0, 0]]
>>> space_to_batch = P.SpaceToBatch(block_size, paddings)
>>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
>>> space_to_batch(input_x)
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
class mindspore.ops.SpaceToBatchND(*args, **kwargs)[source]

Divides spatial dimensions into blocks and combine the block size with the original batch.

This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor’s H and W dimension is the corresponding number of blocks after division. The output tensor’s batch dimension is the product of the original batch and the product of block_shape. Before division, the spatial dimensions of the input are zero padded according to paddings if necessary.

Parameters
  • block_shape (Union[list(int), tuple(int)]) – The block shape of dividing block with all value greater than 1. The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2.

  • paddings (list) – The padding values for H and W dimension, containing 2 subtraction list. Each contains 2 integer value. All values must be greater than 0. paddings[i] specifies the paddings for the spatial dimension i, which corresponds to the input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible by block_shape[i].

Inputs:
  • input_x (Tensor) - The input tensor. It must be a 4-D tensor.

Outputs:

Tensor, the output tensor with the same data type as input. Assume input shape is \((n, c, h, w)\) with \(block\_shape\) and \(padddings\). The shape of the output tensor will be \((n', c', h', w')\), where

\(n' = n*(block\_shape[0]*block\_shape[1])\)

\(c' = c\)

\(h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]\)

\(w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]\)

Examples

>>> block_shape = [2, 2]
>>> paddings = [[0, 0], [0, 0]]
>>> space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings)
>>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
>>> space_to_batch_nd(input_x)
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
class mindspore.ops.SpaceToDepth(*args, **kwargs)[source]

Rearranges blocks of spatial data into depth.

The output tensor’s height dimension is \(height / block\_size\).

The output tensor’s weight dimension is \(weight / block\_size\).

The depth of output tensor is \(block\_size * block\_size * input\_depth\).

The input tensor’s height and width must be divisible by block_size. The data format is “NCHW”.

Parameters

block_size (int) – The block size used to divide spatial data. It must be >= 2.

Inputs:
  • x (Tensor) - The target tensor.

Outputs:

Tensor, the same data type as x. It must be a 4-D tensor.

Examples

>>> x = Tensor(np.random.rand(1,3,2,2), mindspore.float32)
>>> block_size = 2
>>> op = P.SpaceToDepth(block_size)
>>> output = op(x)
>>> output.asnumpy().shape == (1,12,1,1)
class mindspore.ops.SparseApplyAdagrad(*args, **kwargs)[source]

Updates relevant entries according to the adagrad scheme.

\[accum += grad * grad\]
\[var -= lr * grad * (1 / sqrt(accum))\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • lr (float) – Learning rate.

  • update_slots (bool) – If True, accum will be updated. Default: True.

  • use_locking (bool) – If true, the var and accumulation tensors will be protected from being updated. Default: False.

Inputs:
  • var (Parameter) - Variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - Accumulation to be updated. The shape and data type must be the same as var.

  • grad (Tensor) - Gradient. The shape must be the same as var’s shape except the first dimension. Gradients has the same data type as var.

  • indices (Tensor) - A vector of indices into the first dimension of var and accum. The shape of indices must be the same as grad in first dimension, the type must be int32.

Outputs:

Tuple of 2 tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_adagrad = P.SparseApplyAdagrad(lr=1e-8)
>>>         self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> indices = Tensor([0, 1, 2], mstype.int32)
>>> result = net(grad, indices)
class mindspore.ops.SparseApplyAdagradV2(*args, **kwargs)[source]

Updates relevant entries according to the adagrad scheme.

\[accum += grad * grad\]
\[var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • lr (float) – Learning rate.

  • epsilon (float) – A small value added for numerical stability.

  • use_locking (bool) – If True, the var and accum tensors will be protected from being updated. Default: False.

  • update_slots (bool) – If True, the computation logic will be different to False. Default: True.

Inputs:
  • var (Parameter) - Variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - Accumulation to be updated. The shape and data type must be the same as var.

  • grad (Tensor) - Gradient. The shape must be the same as var’s shape except the first dimension. Gradients has the same data type as var.

  • indices (Tensor) - A vector of indices into the first dimension of var and accum. The shape of indices must be the same as grad in first dimension, the type must be int32.

Outputs:

Tuple of 2 tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6)
>>>         self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
>>>
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> indices = Tensor([0, 1, 2], mstype.int32)
>>> result = net(grad, indices)
class mindspore.ops.SparseApplyFtrl(*args, **kwargs)[source]

Updates relevant entries according to the FTRL-proximal scheme.

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • lr (float) – The learning rate value, must be positive.

  • l1 (float) – l1 regularization strength, must be greater than or equal to zero.

  • l2 (float) – l2 regularization strength, must be greater than or equal to zero.

  • lr_power (float) – Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.

  • use_locking (bool) – Use locks for updating operation if true . Default: False.

Inputs:
  • var (Parameter) - The variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - The accumulation to be updated, must be same data type and shape as var.

  • linear (Parameter) - the linear coefficient to be updated, must be the same data type and shape as var.

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

  • indices (Tensor) - A vector of indices in the first dimension of var and accum. The shape of indices must be the same as grad in the first dimension. The type must be int32.

Outputs:
  • var (Tensor) - Tensor, has the same shape and data type as var.

  • accum (Tensor) - Tensor, has the same shape and data type as accum.

  • linear (Tensor) - Tensor, has the same shape and data type as linear.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>>     def __init__(self):
>>>         super(SparseApplyFtrlNet, self).__init__()
>>>         self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
>>>
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>>         return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones([3]), mindspore.int32)
>>> output = net(grad, indices)
class mindspore.ops.SparseApplyFtrlV2(*args, **kwargs)[source]

Updates relevant entries according to the FTRL-proximal scheme.

All of inputs except indices comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters
  • lr (float) – The learning rate value, must be positive.

  • l1 (float) – l1 regularization strength, must be greater than or equal to zero.

  • l2 (float) – l2 regularization strength, must be greater than or equal to zero.

  • l2_shrinkage (float) – L2 shrinkage regularization.

  • lr_power (float) – Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.

  • use_locking (bool) – If True, the var and accumulation tensors will be protected from being updated. Default: False.

Inputs:
  • var (Parameter) - The variable to be updated. The data type must be float16 or float32.

  • accum (Parameter) - The accumulation to be updated, must be same data type and shape as var.

  • linear (Parameter) - the linear coefficient to be updated, must be same data type and shape as var.

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

  • indices (Tensor) - A vector of indices in the first dimension of var and accum. The shape of indices must be the same as grad in the first dimension. The type must be int32.

Outputs:

Tuple of 3 Tensor, the updated parameters.

  • var (Tensor) - Tensor, has the same shape and data type as var.

  • accum (Tensor) - Tensor, has the same shape and data type as accum.

  • linear (Tensor) - Tensor, has the same shape and data type as linear.

Examples

>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlV2Net(nn.Cell):
>>>     def __init__(self):
>>>         super(SparseApplyFtrlV2Net, self).__init__()
>>>         self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0,
                                                            l2_shrinkage=0.0, lr_power=-0.5)
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
>>>
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices)
>>>         return out
>>>
>>> net = SparseApplyFtrlV2Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones([3]), mindspore.int32)
>>> output = net(grad, indices)
class mindspore.ops.SparseApplyProximalAdagrad(*args, **kwargs)[source]

Updates relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, an additional index tensor is input.

\[accum += grad * grad\]
\[\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}\]
\[var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)\]

Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, lower priority data type will be converted to relatively highest priority data type. RuntimeError exception will be thrown when the data type conversion of Parameter is required.

Parameters

use_locking (bool) – If true, the var and accum tensors will be protected from being updated. Default: False.

Inputs:
  • var (Parameter) - Variable tensor to be updated. The data type must be float16 or float32.

  • accum (Parameter) - Variable tensor to be updated, has the same dtype as var.

  • lr (Union[Number, Tensor]) - The learning rate value, must be a float number or a scalar tensor with float16 or float32 data type.

  • l1 (Union[Number, Tensor]) - l1 regularization strength, must be a float number or a scalar tensor with float16 or float32 data type.

  • l2 (Union[Number, Tensor]) - l2 regularization strength, must be a float number or a scalar tensor with float16 or float32 data type..

  • grad (Tensor) - A tensor of the same type as var, for the gradient.

  • indices (Tensor) - A vector of indices in the first dimension of var and accum.

Outputs:

Tuple of 2 tensors, the updated parameters.

  • var (Tensor) - The same shape and data type as var.

  • accum (Tensor) - The same shape and data type as accum.

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
>>>         self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>>         self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>>         self.lr = 0.01
>>>         self.l1 = 0.0
>>>         self.l2 = 0.0
>>>     def construct(self, grad, indices):
>>>         out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
                                                     self.l2, grad, indices)
>>>         return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones((3,), np.int32))
>>> output = net(grad, indices)
class mindspore.ops.SparseGatherV2(*args, **kwargs)[source]

Returns a slice of input tensor based on the specified indices and axis.

Inputs:
  • input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The original Tensor.

  • input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor, must be in the range [0, input_param.shape[axis]).

  • axis (int) - Specifies the dimension index to gather indices.

Outputs:

Tensor, the shape of tensor is \((z_1, z_2, ..., z_N)\).

Examples

>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.SparseGatherV2()(input_params, input_indices, axis)
class mindspore.ops.SparseSoftmaxCrossEntropyWithLogits(*args, **kwargs)[source]

Computes the softmax cross-entropy value between logits and sparse encoding labels.

Note

Sets input logits as X, input label as Y, output as loss. Then,

\[p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}\]
\[loss_{ij} = \begin{cases} -ln(p_{ij}), &j = y_i \cr -ln(1 - p_{ij}), & j \neq y_i \end{cases}\]
\[loss = \sum_{ij} loss_{ij}\]
Parameters

is_grad (bool) – If true, this operation returns the computed gradient. Default: False.

Inputs:
  • logits (Tensor) - Input logits, with shape \((N, C)\). Data type must be float16 or float32.

  • labels (Tensor) - Ground truth labels, with shape \((N)\). Data type must be int32 or int64.

Outputs:

Tensor, if is_grad is False, the output tensor is the value of loss which is a scalar tensor; if is_grad is True, the output tensor is the gradient of input with the same shape as logits.

Examples

Please refer to the usage in nn.SoftmaxCrossEntropyWithLogits source code.

class mindspore.ops.SparseToDense(*args, **kwargs)[source]

Converts a sparse representation into a dense tensor.

Inputs:
  • indices (Tensor) - The indices of sparse representation.

  • values (Tensor) - Values corresponding to each row of indices.

  • dense_shape (tuple) - An int tuple which specifies the shape of dense tensor.

Returns

Tensor, the shape of tensor is dense_shape.

Examples

>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> dense_shape = (3, 4)
>>> out = P.SparseToDense()(indices, values, dense_shape)
class mindspore.ops.Split(*args, **kwargs)[source]

Splits input tensor into output_num of tensors along the given axis and output numbers.

Parameters
  • axis (int) – Index of the split position. Default: 0.

  • output_num (int) – The number of output tensors. Default: 1.

Raises

ValueError – If axis is out of the range [-len(input_x.shape), len(input_x.shape)), or if the output_num is less than or equal to 0, or if the dimension which to split cannot be evenly divided by output_num.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

tuple[Tensor], the shape of each output tensor is the same, which is \((y_1, y_2, ..., y_S)\).

Examples

>>> split = P.Split(1, 2)
>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
>>> output = split(x)
class mindspore.ops.Sqrt(*args, **kwargs)[source]

Returns square root of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor whose dtype is number.

Outputs:

Tensor, has the same shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 4.0, 9.0]), mindspore.float32)
>>> sqrt = P.Sqrt()
>>> sqrt(input_x)
[1.0, 2.0, 3.0]
class mindspore.ops.Square(*args, **kwargs)[source]

Returns square of a tensor element-wise.

Inputs:
  • input_x (Tensor) - The input tensor whose dtype is number.

Outputs:

Tensor, has the same shape and dtype as the input_x.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> square = P.Square()
>>> square(input_x)
[1.0, 4.0, 9.0]
class mindspore.ops.SquareSumAll(*args, **kwargs)[source]

Returns square sum all of a tensor element-wise

Inputs:
  • input_x1 (Tensor) - The input tensor. The data type must be float16 or float32.

  • input_x2 (Tensor) - The input tensor has the same type and shape as the input_x1.

Note

SquareSumAll only supports float16 and float32 data type.

Outputs:
  • output_y1 (Tensor) - The same type as the input_x1.

  • output_y2 (Tensor) - The same type as the input_x1.

Examples

>>> input_x1 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> square_sum_all = P.SquareSumAll()
>>> square_sum_all(input_x1, input_x2)
class mindspore.ops.SquaredDifference(*args, **kwargs)[source]

Subtracts the second input tensor from the first input tensor element-wise and returns square of it.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is float16, float32, int32 or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor or a tensor whose data type isfloat16, float32, int32 or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> input_y = Tensor(np.array([2.0, 4.0, 6.0]), mindspore.float32)
>>> squared_difference = P.SquaredDifference()
>>> squared_difference(input_x, input_y)
[1.0, 4.0, 9.0]
class mindspore.ops.Squeeze(*args, **kwargs)[source]

Returns a tensor with the same type but dimensions of 1 are removed based on axis.

Note

The dimension index starts at 0 and must be in the range [-input.dim(), input.dim()).

Raises

ValueError – If the corresponding dimension of the specified axis does not equal to 1.

Parameters

axis (Union[int, tuple(int)]) – Specifies the dimension indexes of shape to be removed, which will remove all the dimensions that are equal to 1. If specified, it must be int32 or int64. Default: (), an empty tuple.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor, the shape of tensor is \((x_1, x_2, ..., x_S)\).

Examples

>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> squeeze = P.Squeeze(2)
>>> output = squeeze(input_tensor)
class mindspore.ops.StandardLaplace(*args, **kwargs)[source]

Generates random numbers according to the Laplace random number distribution (mean=0, lambda=1). It is defined as:

\[\text{f}(x;0,1) = \frac{1}{2}\exp(-|x|),\]
Parameters
  • seed (int) – Random seed. Default: 0.

  • seed2 (int) – Random seed2. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

Outputs:

Tensor. The shape that the input ‘shape’ denotes. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> stdlaplace = P.StandardLaplace(seed=2)
>>> output = stdlaplace(shape)
class mindspore.ops.StandardNormal(*args, **kwargs)[source]

Generates random numbers according to the standard Normal (or Gaussian) random number distribution.

Parameters
  • seed (int) – Random seed, must be non-negative. Default: 0.

  • seed2 (int) – Random seed2, must be non-negative. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

Outputs:

Tensor. The shape is the same as the input shape. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> stdnormal = P.StandardNormal(seed=2)
>>> output = stdnormal(shape)
class mindspore.ops.StridedSlice(*args, **kwargs)[source]

Extracts a strided slice of a tensor.

Given an input tensor, this operation inserts a dimension of length 1 at the dimension. This operation extracts a fragment of size (end-begin)/stride from the given ‘input_tensor’. Starting from the begining position, the fragment continues adding stride to the index until all dimensions are not less than the ending position.

Note

The stride may be negative value, which causes reverse slicing. The shape of begin, end and strides must be the same.

Parameters
  • begin_mask (int) – Starting index of the slice. Default: 0.

  • end_mask (int) – Ending index of the slice. Default: 0.

  • ellipsis_mask (int) – An int mask. Default: 0.

  • new_axis_mask (int) – An int mask. Default: 0.

  • shrink_axis_mask (int) – An int mask. Default: 0.

Inputs:
  • input_x (Tensor) - The input Tensor.

  • begin (tuple[int]) - A tuple which represents the location where to start. Only constant value is allowed.

  • end (tuple[int]) - A tuple or which represents the maximum location where to end. Only constant value is allowed.

  • strides (tuple[int]) - A tuple which represents the stride is continuously added before reaching the maximum location. Only constant value is allowed.

Outputs:

Tensor. The output is explained by following example.

  • In the 0th dimension, begin is 1, end is 2, and strides is 1, because \(1+1=2\geq2\), the interval is \([1,2)\). Thus, return the element with \(index = 1\) in 0th dimension, i.e., [[3, 3, 3], [4, 4, 4]].

  • In the 1st dimension, similarly, the interval is \([0,1)\). Based on the return value of the 0th dimension, return the element with \(index = 0\), i.e., [3, 3, 3].

  • In the 2nd dimension, similarly, the interval is \([0,3)\). Based on the return value of the 1st dimension, return the element with \(index = 0,1,2\), i.e., [3, 3, 3].

  • Finally, the output is [3, 3, 3].

Examples
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
>>>                   [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
>>> slice = P.StridedSlice()
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
>>> output.shape
(1, 1, 3)
>>> output
[[[3, 3, 3]]]
class mindspore.ops.Sub(*args, **kwargs)[source]

Subtracts the second input tensor from the first input tensor element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor, or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32)
>>> sub = P.Sub()
>>> sub(input_x, input_y)
[-3, -3, -3]
class mindspore.ops.TBERegOp(op_name)[source]

Class for TBE op info register.

async_flag(async_flag)[source]

Define the calculation efficiency of the operator, whether the asynchronous calculation is supported.

Parameters

async_flag (bool) – Value of async flag. Default: false.

attr(name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs)[source]

Register TBE op attribute information.

Parameters
  • name (str) – Name of the attribute. Default: None.

  • param_type (str) – Param type of the attribute. Default: None.

  • value_type (str) – Type of the attribute. Default: None.

  • value (str) – Value of the attribute. Default: None.

  • default_value (str) – Default value of attribute. Default: None.

  • kwargs (dict) – Other information of the attribute.

binfile_name(binfile_name)[source]

Set the binary file name of the operator, it is optional.

Parameters

binfile_name (str) – The binary file name of the operator.

compute_cost(compute_cost)[source]

Define the calculation efficiency of operator, which refers to the value of the cost model in the tiling module.

Parameters

compute_cost (int) – Value of compute cost. Default: 10.

dynamic_format(dynamic_format)[source]

Whether the operator supports dynamic selection of format and dtype or not.

Parameters

dynamic_format (bool) – Value of dynamic format. Default: false.

input(index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs)[source]

Register TBE op input information.

Parameters
  • index (int) – Order of the input. Default: None.

  • name (str) – Name of the input. Default: None.

  • need_compile (bool) – Whether the input needs to be compiled or not. Default: None.

  • param_type (str) – Type of the input. Default: None.

  • shape (str) – Shape of the input. Default: None.

  • kwargs (dict) – Other information of the input.

kernel_name(kernel_name)[source]

The name of operator kernel.

Parameters

kernel_name (str) – Name of operator kernel.

op_pattern(pattern=None)[source]

The behavior type of opeator, such as broadcast, reduce and so on.

Parameters

pattern (str) – Value of op pattern.

output(index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs)[source]

Register TBE op output information.

Parameters
  • index (int) – Order of the output. Default: None.

  • name (str) – Name of the output. Default: None.

  • need_compile (bool) – Whether the output needs to be compiled or not. Default: None.

  • param_type (str) – Type of the output. Default: None.

  • shape (str) – Shape of the output. Default: None.

  • kwargs (dict) – Other information of the output.

partial_flag(partial_flag)[source]

Define the calculation efficiency of operator, whether the partial calculation is supported.

Parameters

partial_flag (bool) – Value of partial flag. Default: true.

reshape_type(reshape_type)[source]

Reshape type of operator.

Parameters

reshape_type (str) – Value of reshape type.

class mindspore.ops.Tan(*args, **kwargs)[source]

Computes tangent of input_x element-wise.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). Data type must be float16, float32 or int32.

Outputs:

Tensor, has the same shape as input_x.

Examples

>>> tan = P.Tan()
>>> input_x = Tensor(np.array([-1.0, 0.0, 1.0]), mindspore.float32)
>>> output = tan(input_x)
class mindspore.ops.Tanh(*args, **kwargs)[source]

Tanh activation function.

Computes hyperbolic tangent of input element-wise. The Tanh function is defined as:

\[tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},\]

where \(x_i\) is an element of the input Tensor.

Inputs:
  • input_x (Tensor) - The input of Tanh.

Outputs:

Tensor, with the same type and shape as the input_x.

Examples

>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> tanh = P.Tanh()
>>> tanh(input_x)
[0.7615941, 0.9640276, 0.9950548, 0.9993293, 0.99990916]
class mindspore.ops.TensorAdd(*args, **kwargs)[source]

Adds two input tensors element-wise.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor, or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> add = P.TensorAdd()
>>> input_x = Tensor(np.array([1,2,3]).astype(np.float32))
>>> input_y = Tensor(np.array([4,5,6]).astype(np.float32))
>>> add(input_x, input_y)
[5,7,9]
class mindspore.ops.TensorScatterUpdate(*args, **kwargs)[source]

Updates tensor value using given values, along with the input indices.

Inputs:
  • input_x (Tensor) - The target tensor. The dimension of input_x must be equal to indices.shape[-1].

  • indices (Tensor) - The index of input tensor whose data type is int32.

  • update (Tensor) - The tensor to update the input tensor, has the same type as input, and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].

Outputs:

Tensor, has the same shape and type as input_x.

Examples

>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
[[1.0, 0.3, 3.6],
 [0.4, 2.2, -3.2]]
class mindspore.ops.TensorSummary(*args, **kwargs)[source]

Outputs a tensor to a protocol buffer through a tensor summary operator.

Inputs:
  • name (str) - The name of the input variable.

  • value (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.

Examples

>>> class SummaryDemo(nn.Cell):
>>>     def __init__(self,):
>>>         super(SummaryDemo, self).__init__()
>>>         self.summary = P.TensorSummary()
>>>         self.add = P.TensorAdd()
>>>
>>>     def construct(self, x, y):
>>>         x = self.add(x, y)
>>>         name = "x"
>>>         self.summary(name, x)
>>>         return x
class mindspore.ops.Tile(*args, **kwargs)[source]

Replicates a tensor with given multiples times.

Creates a new tensor by replicating input multiples times. The dimension of output tensor is the larger of the input tensor dimension and the length of multiples.

Inputs:
  • input_x (Tensor) - 1-D or higher Tensor. Set the shape of input tensor as \((x_1, x_2, ..., x_S)\).

  • multiples (tuple[int]) - The input tuple is constructed by multiple integers, i.e., \((y_1, y_2, ..., y_S)\). The length of multiples cannot be smaller than the length of the shape of input_x. Only constant value is allowed.

Outputs:

Tensor, has the same data type as the input_x.

  • If the length of multiples is the same as the length of shape of input_x, then the shape of their corresponding positions can be multiplied, and the shape of Outputs is \((x_1*y_1, x_2*y_2, ..., x_S*y_R)\).

  • If the length of multiples is larger than the length of shape of input_x, fill in multiple 1 in the length of the shape of input_x until their lengths are consistent. Such as set the shape of input_x as \((1, ..., x_1, x_2, ..., x_S)\), then the shape of their corresponding positions can be multiplied, and the shape of Outputs is \((1*y_1, ..., x_S*y_R)\).

Examples

>>> tile = P.Tile()
>>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
>>> multiples = (2, 3)
>>> result = tile(input_x, multiples)
[[1.  2.  1.  2.  1.  2.]
 [3.  4.  3.  4.  3.  4.]
 [1.  2.  1.  2.  1.  2.]
 [3.  4.  3.  4.  3.  4.]]
class mindspore.ops.TopK(*args, **kwargs)[source]

Finds values and indices of the k largest entries along the last dimension.

Parameters

sorted (bool) – If true, the obtained elements will be sorted by the values in descending order. Default: False.

Inputs:
  • input_x (Tensor) - Input to be computed, data type must be float16, float32 or int32.

  • k (int) - The number of top elements to be computed along the last dimension, constant input is needed.

Outputs:

Tuple of 2 tensors, the values and the indices.

  • values (Tensor) - The k largest elements in each slice of the last dimensional.

  • indices (Tensor) - The indices of values within the last dimension of input.

Examples

>>> topk = P.TopK(sorted=True)
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = topk(input_x, k)
>>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16)
>>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32)
class mindspore.ops.Transpose(*args, **kwargs)[source]

Permutes the dimensions of input tensor according to input permutation.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

  • input_perm (tuple[int]) - The permutation to be converted. The input tuple is constructed by multiple indexes. The length of input_perm and the shape of input_x must be the same. Only constant value is allowed.

Outputs:

Tensor, the type of output tensor is the same as input_x and the shape of output tensor is decided by the shape of input_x and the value of input_perm.

Examples

>>> input_tensor = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
>>> perm = (0, 2, 1)
>>> transpose = P.Transpose()
>>> output = transpose(input_tensor, perm)
class mindspore.ops.TruncateDiv(*args, **kwargs)[source]

Divide the first input tensor by the second input tensor element-wise for integer types, negative numbers will round fractional quantities towards zero.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor, or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> truncate_div = P.TruncateDiv()
>>> truncate_div(input_x, input_y)
[0, 1, 0]
class mindspore.ops.TruncateMod(*args, **kwargs)[source]

Returns element-wise remainder of division.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is number or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor, or a tensor whose data type is number or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> truncate_mod = P.TruncateMod()
>>> truncate_mod(input_x, input_y)
[2, 1, -1]
class mindspore.ops.TruncatedNormal(*args, **kwargs)[source]

Returns a tensor of the specified shape filled with truncated normal values.

The generated values follow a normal distribution.

Parameters
  • seed (int) – A integer number used to create random seed. Default: 0.

  • dtype (mindspore.dtype) – Data type. Default: mindspore.float32.

Inputs:
  • shape (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.

Outputs:

Tensor, the dat type of output tensor is the same as attribute dtype.

Examples

>>> shape = (1, 2, 3)
>>> truncated_normal = P.TruncatedNormal()
>>> output = truncated_normal(shape)
class mindspore.ops.TupleToArray(*args, **kwargs)[source]

Converts a tuple to a tensor.

If the type of the first number in the tuple is integer, the data type of the output tensor is int. Otherwise, the data type of the output tensor is float.

Inputs:
  • input_x (tuple) - A tuple of numbers. These numbers have the same type. Only constant value is allowed.

Outputs:

Tensor, if the input tuple contains N numbers, then the shape of the output tensor is (N,).

Examples

>>> type = P.TupleToArray()((1,2,3))
class mindspore.ops.UniformInt(*args, **kwargs)[source]

Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is, distributed according to the discrete probability function:

\[\text{P}(i|a,b) = \frac{1}{b-a+1},\]

Note

The number in tensor minval must be strictly less than maxval at any position after broadcasting.

Parameters
  • seed (int) – Random seed, must be non-negative. Default: 0.

  • seed2 (int) – Random seed2, must be non-negative. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

  • minval (Tensor) - The distribution parameter, a. It defines the minimum possibly generated value, with int32 data type. Only one number is supported.

  • maxval (Tensor) - The distribution parameter, b. It defines the maximum possibly generated value, with int32 data type. Only one number is supported.

Outputs:

Tensor. The shape is the same as the input ‘shape’, and the data type is int32.

Examples

>>> shape = (4, 16)
>>> minval = Tensor(1, mstype.int32)
>>> maxval = Tensor(5, mstype.int32)
>>> uniform_int = P.UniformInt(seed=10)
>>> output = uniform_int(shape, minval, maxval)
class mindspore.ops.UniformReal(*args, **kwargs)[source]

Produces random floating-point values i, uniformly distributed to the interval [0, 1).

Parameters
  • seed (int) – Random seed, must be non-negative. Default: 0.

  • seed2 (int) – Random seed2, must be non-negative. Default: 0.

Inputs:
  • shape (tuple) - The shape of random tensor to be generated. Only constant value is allowed.

Outputs:

Tensor. The shape that the input ‘shape’ denotes. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> uniformreal = P.UniformReal(seed=2)
>>> output = uniformreal(shape)
class mindspore.ops.Unique(*args, **kwargs)[source]

Returns the unique elements of input tensor and also return a tensor containing the index of each value of input tensor corresponding to the output unique tensor.

Inputs:
  • x (Tensor) - The input tensor.

Outputs:

Tuple, containing Tensor objects (y, idx), y is a tensor has the same type as x, idx is a tensor containing indices of elements in the input coressponding to the output tensor.

Examples

>>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
class mindspore.ops.Unpack(*args, **kwargs)[source]

Unpacks tensor in specified axis.

Unpacks a tensor of rank R along axis dimension, output tensors will have rank (R-1).

Given a tensor of shape \((x_1, x_2, ..., x_R)\). If \(0 \le axis\), the shape of tensor in output is \((x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)\).

This is the opposite of pack.

Parameters

axis (int) – Dimension along which to pack. Default: 0. Negative values wrap around. The range is [-R, R).

Inputs:
  • input_x (Tensor) - The shape is \((x_1, x_2, ..., x_R)\). A tensor to be unpacked and the rank of the tensor must be greater than 0.

Outputs:

A tuple of tensors, the shape of each objects is the same.

Raises

ValueError – If axis is out of the range [-len(input_x.shape), len(input_x.shape)).

Examples

>>> unpack = P.Unpack()
>>> input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
>>> output = unpack(input_x)
([1, 1, 1, 1], [2, 2, 2, 2])
class mindspore.ops.UnsortedSegmentMin(*args, **kwargs)[source]

Computes the minimum along segments of a tensor.

Inputs:
  • input_x (Tensor) - The shape is \((x_1, x_2, ..., x_R)\). The data type must be float16, float32 or int32.

  • segment_ids (Tensor) - A 1-D tensor whose shape is \((x_1)\), the value must be >= 0. The data type must be int32.

  • num_segments (int) - The value spcifies the number of distinct segment_ids.

Outputs:

Tensor, set the number of num_segments as N, the shape is \((N, x_2, ..., x_R)\).

Examples

>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
>>> num_segments = 2
>>> unsorted_segment_min = P.UnsortedSegmentMin()
>>> unsorted_segment_min(input_x, segment_ids, num_segments)
[[1., 2., 3.], [4., 2., 1.]]
class mindspore.ops.UnsortedSegmentProd(*args, **kwargs)[source]

Computes the product along segments of a tensor.

Inputs:
  • input_x (Tensor) - The shape is \((x_1, x_2, ..., x_R)\). With float16, float32 or int32 data type.

  • segment_ids (Tensor) - A 1-D tensor whose shape is \((x_1)\), the value must be >= 0. Data type must be int32.

  • num_segments (int) - The value spcifies the number of distinct segment_ids, must be greater than 0.

Outputs:

Tensor, set the number of num_segments as N, the shape is \((N, x_2, ..., x_R)\).

Examples

>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32))
>>> num_segments = 2
>>> unsorted_segment_prod = P.UnsortedSegmentProd()
>>> unsorted_segment_prod(input_x, segment_ids, num_segments)
[[4., 4., 3.], [4., 5., 6.]]
class mindspore.ops.UnsortedSegmentSum(*args, **kwargs)[source]

Computes the sum along segments of a tensor.

Calculates a tensor such that \(\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]\), where \(j\) is a tuple describing the index of element in data. segment_ids selects which elements in data to sum up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value range.

If the sum of the given segment_ids \(i\) is empty, then \(\text{output}[i] = 0\). If the given segment_ids is negative, the value will be ignored. ‘num_segments’ must be equal to the number of different segment_ids.

Inputs:
  • input_x (Tensor) - The shape is \((x_1, x_2, ..., x_R)\).

  • segment_ids (Tensor) - Set the shape as \((x_1, x_2, ..., x_N)\), where 0 < N <= R. Type must be int.

  • num_segments (int) - Set \(z\) as num_segments.

Outputs:

Tensor, the shape is \((z, x_{N+1}, ..., x_R)\).

Examples

>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
[3, 3, 4, 0]
class mindspore.ops.Xdivy(*args, **kwargs)[source]

Divide the first input tensor by the second input tensor element-wise. Returns zero when x is zero.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor whose data type is float16, float32 or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input is a tensor, or a tensor whose data type is float16, float32 or bool.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.float32)
>>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32)
>>> xdivy = P.Xdivy()
>>> xdivy(input_x, input_y)
[1.0, 2.0, -0.5]
class mindspore.ops.Xlogy(*args, **kwargs)[source]

Computes first input tensor multiplied by the logarithm of second input tensor element-wise. Returns zero when x is zero.

Inputs of input_x and input_y comply with the implicit type conversion rules to make the data types consistent. The inputs must be two tensors or one tensor and one scalar. When the inputs are two tensors, dtypes of them cannot be both bool, and the shapes of them could be broadcast. When the inputs are one tensor and one scalar, the scalar could only be a constant.

Inputs:
  • input_x (Union[Tensor, Number, bool]) - The first input is a number or a bool or a tensor whose data type is float16, float32 or bool.

  • input_y (Union[Tensor, Number, bool]) - The second input is a number or a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. The value must be positive.

Outputs:

Tensor, the shape is the same as the one after broadcasting, and the data type is the one with higher precision or higher digits among the two inputs.

Examples

>>> input_x = Tensor(np.array([-5, 0, 4]), mindspore.float32)
>>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32)
>>> xlogy = P.Xlogy()
>>> xlogy(input_x, input_y)
[-3.465736, 0.0, 2.7725887]
class mindspore.ops.ZerosLike(*args, **kwargs)[source]

Creates a new tensor. All elements value are 0.

Returns a tensor of zeros with the same shape and data type as the input tensor.

Inputs:
  • input_x (Tensor) - Input tensor.

Outputs:

Tensor, has the same shape and data type as input_x but filled with zeros.

Examples

>>> zeroslike = P.ZerosLike()
>>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
>>> output = zeroslike(x)
[[0.0, 0.0],
 [0.0, 0.0]]
mindspore.ops.add_flags(fn=None, **flags)[source]

A decorator that adds a flag to the function.

Note

Only supports bool value.

Parameters
  • fn (Function) – Function or cell to add flag. Default: None.

  • flags (dict) – Flags use kwargs. Default: None.

Returns

Function, the function with added flags.

Examples

>>> add_flags(net, predit=True)
mindspore.ops.clip_by_value(x, clip_value_min, clip_value_max)[source]

Clips tensor values to a specified min and max.

Limits the value of \(x\) to a range, whose lower limit is ‘clip_value_min’ and upper limit is ‘clip_value_max’.

Note

‘clip_value_min’ needs to be less than or equal to ‘clip_value_max’.

Parameters
  • x (Tensor) – Input data.

  • clip_value_min (Tensor) – The minimum value.

  • clip_value_max (Tensor) – The maximum value.

Returns

Tensor, a clipped Tensor.

mindspore.ops.constexpr(fn=None, get_instance=True, name=None)[source]

Make a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function to compute constant value using the constants in the constructor.

Parameters
  • fn (function) – A fn use as the infer_value of the output operator.

  • get_instance (bool) – If true, return the instance of operator, otherwise return the operator class.

  • name (str) – Defines the operator name. If name is None, use the function name as op name.

Examples

>>> a = (1, 2)
>>> # make an operator to calculate tuple len
>>> @constexpr
>>> def tuple_len(x):
>>>     return len(x)
>>> assert tuple_len(a) == 2
>>>
>>> # make a operator class to calculate tuple len
>>> @constexpr(get_instance=False, name="TupleLen")
>>> def tuple_len_class(x):
>>>     return len(x)
>>> assert tuple_len_class()(a) == 2
mindspore.ops.core(fn=None, **flags)[source]

A decorator that adds a flag to the function.

By default, the function is marked as True, enabling to use this decorator to set flag to a graph.

Parameters
  • fn (Function) – Function to add flag. Default: None.

  • flags (dict) – The following flags can be set core, which indicates that this is a core function or other flag. Default: None.

mindspore.ops.gamma(shape, alpha, beta, seed=None)[source]

Generates random numbers according to the Gamma random number distribution.

Parameters
  • shape (tuple) – The shape of random tensor to be generated.

  • alpha (Tensor) – The alpha α distribution parameter. It should be greater than 0 with float32 data type.

  • beta (Tensor) – The beta β distribution parameter. It should be greater than 0 with float32 data type.

  • seed (int) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.

Returns

Tensor. The shape should be equal to the broadcasted shape between the input “shape” and shapes of alpha and beta. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> alpha = Tensor(1.0, mstype.float32)
>>> beta = Tensor(1.0, mstype.float32)
>>> output = C.gamma(shape, alpha, beta, seed=5)
mindspore.ops.get_vm_impl_fn(prim)[source]

Get the virtual implementation function by a primitive object or primitive name.

Parameters

prim (Union[Primitive, str]) – primitive object or name for operator register.

Returns

function, vm function

mindspore.ops.laplace(shape, mean, lambda_param, seed=None)[source]

Generates random numbers according to the Laplace random number distribution. It is defined as:

\[\text{f}(x;μ,λ) = \frac{1}{2λ}\exp(-\frac{|x-μ|}{λ}),\]
Parameters
  • shape (tuple) – The shape of random tensor to be generated.

  • mean (Tensor) – The mean μ distribution parameter, which specifies the location of the peak. With float32 data type.

  • lambda_param (Tensor) – The parameter used for controling the variance of this random distribution. The variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type.

  • seed (int) – Seed is used as entropy source for Random number engines generating pseudo-random numbers. Default: None, which will be treated as 0.

Returns

Tensor. The shape should be the broadcasted shape of Input “shape” and shapes of mean and lambda_param. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> lambda_param = Tensor(1.0, mstype.float32)
>>> output = C.laplace(shape, mean, lambda_param, seed=5)
mindspore.ops.multinomial(inputs, num_sample, replacement=True, seed=0)[source]

Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of the input tensor.

Note

The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

Parameters
  • inputs (Tensor) – The input tensor containing probabilities, must be 1 or 2 dimensions, with float32 data type.

  • num_sample (int) – Number of samples to draw.

  • replacement (bool, optional) – Whether to draw with replacement or not, default True.

  • seed (int, optional) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers, must be non-negative. Default: 0.

Outputs:

Tensor, has the same rows with input. The number of sampled indices of each row is num_samples. The dtype is float32.

Examples

>>> input = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = C.multinomial(input, 2, True)
mindspore.ops.normal(shape, mean, stddev, seed=None)[source]

Generates random numbers according to the Normal (or Gaussian) random number distribution.

Parameters
  • shape (tuple) – The shape of random tensor to be generated.

  • mean (Tensor) – The mean μ distribution parameter, which specifies the location of the peak. with float32 data type.

  • stddev (Tensor) – The deviation σ distribution parameter. It should be greater than 0. with float32 data type.

  • seed (int) – Seed is used as entropy source for the Random number engines to generate pseudo-random numbers. must be non-negative. Default: None, which will be treated as 0.

Returns

Tensor. The shape should be equal to the broadcasted shape between the input shape and shapes of mean and stddev. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> output = C.normal(shape, mean, stddev, seed=5)
mindspore.ops.op_info_register(op_info)[source]

A decorator which is used to register an operator.

Note

‘op_info’ should represent the operator information by string with json format. The ‘op_info’ will be added into oplib.

Parameters

op_info (str or dict) – operator information in json format.

Returns

Function, returns a decorator for op info register.

mindspore.ops.poisson(shape, mean, seed=None)[source]

Generates random numbers according to the Poisson random number distribution.

Parameters
  • shape (tuple) – The shape of random tensor to be generated.

  • mean (Tensor) – The mean μ distribution parameter. It should be greater than 0 with float32 data type.

  • seed (int) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers and must be non-negative. Default: None, which will be treated as 0.

Returns

Tensor. The shape should be equal to the broadcasted shape between the input “shape” and shapes of mean. The dtype is float32.

Examples

>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> output = C.poisson(shape, mean, seed=5)
mindspore.ops.prim_attr_register(fn)[source]

Primitive attributes register.

Register the decorator of the built-in operator primitive ‘__init__’. The function will add all the parameters of ‘__init__’ as operator attributes.

Parameters

fn (function) – __init__ function of primitive.

Returns

function, original function.

mindspore.ops.uniform(shape, minval, maxval, seed=None, dtype=mindspore.float32)[source]

Generates random numbers according to the Uniform random number distribution.

Note

The number in tensor minval should be strictly less than maxval at any position after broadcasting.

Parameters
  • shape (tuple) – The shape of random tensor to be generated.

  • minval (Tensor) – The distribution parameter a. It defines the minimum possible generated value, with int32 or float32 data type. If dtype is int32, only one number is allowed.

  • maxval (Tensor) – The distribution parameter b. It defines the maximum possible generated value, with int32 or float32 data type. If dtype is int32, only one number is allowed.

  • seed (int) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.

  • dtype (mindspore.dtype) – type of the Uniform distribution. If it is int32, it generates numbers from discrete uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only supports these two data types. Default: mstype.float32.

Returns

Tensor. The shape should be equal to the broadcasted shape between the input shape and shapes of minval and maxval. The dtype is designated as the input dtype.

Examples

>>> For discrete uniform distribution, only one number is allowed for both minval and maxval:
>>> shape = (4, 2)
>>> minval = Tensor(1, mstype.int32)
>>> maxval = Tensor(2, mstype.int32)
>>> output = C.uniform(shape, minval, maxval, seed=5)
>>>
>>> For continuous uniform distribution, minval and maxval can be multi-dimentional:
>>> shape = (4, 2)
>>> minval = Tensor([1.0, 2.0], mstype.float32)
>>> maxval = Tensor([4.0, 5.0], mstype.float32)
>>> output = C.uniform(shape, minval, maxval, seed=5)