mindspore.dataset.transforms.Mask

View Source On Gitee
class mindspore.dataset.transforms.Mask(operator, constant, dtype=mstype.bool_)[source]

Mask content of the input tensor with the given predicate. Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.

Parameters
  • operator (Relational) – relational operators, it can be Relational.EQ, Relational.NE, Relational.LT, Relational.GT, Relational.LE, Relational.GE, take Relational.EQ as example, EQ refers to equal.

  • constant (Union[str, int, float, bool]) – Constant to be compared to.

  • dtype (mindspore.dtype, optional) – Type of the generated mask. Default: mstype.bool_.

Raises
  • TypeErroroperator is not of type Relational.

  • TypeErrorconstant is not of type string int, float or bool.

  • TypeErrordtype is not of type mindspore.dtype.

Supported Platforms:

CPU

Examples

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms as transforms
>>> from mindspore.dataset.transforms import Relational
>>>
>>> # Use the transform in dataset pipeline mode
>>> # Data before
>>> # |  col   |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------+
>>> data = [[1, 2, 3]]
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["col"])
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms.Mask(Relational.EQ, 2))
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["col"].shape, item["col"].dtype)
(3,) bool
>>> # Data after
>>> # |       col         |
>>> # +--------------------+
>>> # | [False,True,False] |
>>> # +--------------------+
>>>
>>> # Use the transform in eager mode
>>> data = [1, 2, 3]
>>> output = transforms.Mask(Relational.EQ, 2)(data)
>>> print(output.shape, output.dtype)
(3,) bool