mindspore.ops.UniformCandidateSampler

View Source On Gitee
class mindspore.ops.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[source]

Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.

Refer to mindspore.ops.uniform_candidate_sampler() for more details.

Parameters
  • num_true (int) – The number of target classes in each training example.

  • num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.

  • unique (bool) – Whether all sampled classes in a batch are unique.

  • range_max (int) – The number of possible classes, must be non-negative.

  • seed (int, optional) – Used for random number generation, must be non-negative. If seed has a value of 0, the seed will be replaced with a randomly generated value. Default: 0 .

  • remove_accidental_hits (bool, optional) – Whether accidental hit is removed. Accidental hit is when one of the true classes matches one of the sample classes. Set True to remove which accidentally sampling the true class as sample class. Default: False .

Inputs:
  • true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of \((batch\_size, num\_true)\).

Outputs:
  • sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. Shape: \((num\_sampled, )\).

  • true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. Shape: \((batch\_size, num\_true)\).

  • sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. Shape: \((num\_sampled, )\).

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64)))
>>> print(output1.shape)
(3,)
>>> print(output2.shape)
(5, 1)
>>> print(output3.shape)
(3,)