mindspore.ops.uniform_candidate_sampler

View Source On Gitee
mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[source]

Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. If unique=True, candidates are drawn without replacement, else unique=False with replacement.

Parameters
  • true_classes (Tensor) – A Tensor. The target classes with a Tensor shape of \((batch\_size, num\_true)\) .

  • num_true (int) – The number of target classes in each training example.

  • num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.

  • unique (bool) – Whether all sampled classes in a batch are unique.

  • range_max (int) – The number of possible classes, must be positive.

  • seed (int) – Used for random number generation, must be non-negative. If seed has a value of 0, the seed will be replaced with a randomly generated value. Default: 0 .

  • remove_accidental_hits (bool) – Whether accidental hit is removed. Accidental hit is when one of the true classes matches one of the sample classes. Set True to remove which accidentally sampling the true class as sample class. Default: False .

Returns

  • sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. shape: \((num\_sampled, )\) .

  • true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. shape: \((batch\_size, num\_true)\) .

  • sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. shape: \((num\_sampled, )\) .

Raises
  • TypeError – If neither num_true nor num_sampled is an int.

  • TypeError – If neither unique nor remove_accidental_hits is a bool.

  • TypeError – If neither range_max nor seed is an int.

  • TypeError – If true_classes is not a Tensor.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64))
>>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
>>> print(output1.shape)
(3,)
>>> print(output2.shape)
(5, 1)
>>> print(output3.shape)
(3,)