"""Operators for random."""

from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register

[docs]class RandomChoiceWithMask(PrimitiveWithInfer): """ Generates a random samply as index tensor with a mask tensor from a given tensor. The input must be a tensor of rank >= 2, the first dimension specify the number of sample. The index tensor and the mask tensor have the same and fixed shape. The index tensor denotes the index of the nonzero sample, while the mask tensor denotes which element in the index tensor are valid. Args: count (int): Number of items expected to get. Default: 256. seed (int): Random seed. seed2 (int): Random seed2. Inputs: - **input_x** (Tensor) - The input tensor. Outputs: Tuple, two tensors, the first one is the index tensor and the other one is the mask tensor. Examples: >>> rnd_choice_mask = RandomChoiceWithMask() >>> input_x = Tensor(np.ones(shape=[240000, 4]), ms.bool_) >>> output_y, output_mask = rnd_choice_mask(input_x) """ @prim_attr_register def __init__(self, count=256, seed=0, seed2=0): """Init RandomChoiceWithMask""" validator.check_type("count", count, [int]) validator.check_integer("count", count, 0, Rel.GT) validator.check_type('seed', seed, [int]) validator.check_type('seed2', seed2, [int]) def infer_shape(self, x_shape): validator.check_shape_length("input_x shape", len(x_shape), 1, Rel.GE) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): validator.check_typename('x_dtype', x_dtype, [mstype.bool_]) return (mstype.int32, mstype.bool_)