mindspore.ops.embedding

View Source On AtomGit
mindspore.ops.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False)[source]

Retrieve the word embeddings in weight using indices specified in input.

Warning

On Ascend, the behavior is unpredictable when the value of input is invalid.

Parameters:
  • input (Tensor) – The indices used to look up in the weight. The data type must be mindspore.int32 or mindspore.int64, and the value should be in range [0, weight.shape[0]).

  • weight (Union[Parameter, Tensor]) – The matrix where to lookup from. The shape must be 2D.

  • padding_idx (int, optional) – If the value is not None, the corresponding row of weight will not be updated in training. The value should be in range [-weight.shape[0], weight.shape[0]) if it's not None. Default None.

  • max_norm (float, optional) – If not None, firstly get the p-norm result of the weight specified by input where p is specified by norm_type; if the result is larger than max_norm, update the weight with \(\frac{max\_norm}{result+1e^{-7}}\) in-place. Default None.

  • norm_type (float, optional) – Indicates the value of p in p-norm. Default 2.0.

  • scale_grad_by_freq (bool, optional) – If True, the gradients will be scaled by the inverse of frequency of the index in input. Default False.

Returns:

Tensor, with the same data type as weight, and the shape is \((*input.shape, weight.shape[1])\).

Raises:
  • ValueError – If padding_idx is out of valid range.

  • ValueError – If the shape of weight is invalid.

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, ops
>>> input = Tensor([[1, 0, 1, 1], [0, 0, 1, 0]])
>>> weight = Parameter(np.random.randn(3, 3).astype(np.float32))
>>> output = ops.embedding(input, weight, max_norm=0.4)
>>> print(output)
[[[ 5.49015924  3.47811311 -1.89771220],
  [ 2.09307984 -2.24846993  3.40124398],
  [ 5.49015924  3.47811311 -1.89771220],
  [ 5.49015924  3.47811311 -1.89771220]],
 [[ 2.09307984 -2.24846993  3.40124398],
  [ 2.09307984 -2.24846993  3.40124398],
  [ 5.49015924  3.47811311 -1.89771220],
  [ 2.09307984 -2.24846993  3.40124398]]]