mindspore.dataset.vision.RandAugment

View Source On Gitee
class mindspore.dataset.vision.RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31, interpolation=Inter.NEAREST, fill_value=0)[source]

Apply RandAugment data augmentation method on the input image.

Refer to RandAugment: Learning Augmentation Strategies from Data .

Only support 3-channel RGB image.

Parameters
  • num_ops (int, optional) – Number of augmentation transformations to apply sequentially. Default: 2.

  • magnitude (int, optional) – Magnitude for all the transformations, must be smaller than num_magnitude_bins. Default: 9.

  • num_magnitude_bins (int, optional) – The number of different magnitude values, must be no less than 2. Default: 31.

  • interpolation (Inter, optional) – Image interpolation method defined by Inter . Default: Inter.NEAREST.

  • fill_value (Union[int, tuple[int, int, int]], optional) – Pixel fill value for the area outside the transformed image, must be in range of [0, 255]. Default: 0. If int is provided, pad all RGB channels with this value. If tuple[int, int, int] is provided, pad R, G, B channels respectively.

Raises
  • TypeError – If num_ops is not of type int.

  • ValueError – If num_ops is negative.

  • TypeError – If magnitude is not of type int.

  • ValueError – If magnitude is not positive.

  • TypeError – If num_magnitude_bins is not of type int.

  • ValueError – If num_magnitude_bins is less than 2.

  • TypeError – If interpolation not of type Inter .

  • TypeError – If fill_value is not of type int or tuple[int, int, int].

  • ValueError – If fill_value is not in range of [0, 255].

  • RuntimeError – If shape of the input image is not <H, W, C>.

Supported Platforms:

CPU

Examples

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.vision as vision
>>> from mindspore.dataset.vision import Inter
>>>
>>> # Use the transform in dataset pipeline mode
>>> data = np.random.randint(0, 255, size=(1, 100, 100, 3)).astype(np.uint8)
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["image"])
>>> transforms_list = [vision.RandAugment()]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms_list, input_columns=["image"])
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["image"].shape, item["image"].dtype)
...     break
(100, 100, 3) uint8
>>>
>>> # Use the transform in eager mode
>>> data = np.random.randint(0, 255, size=(100, 100, 3)).astype(np.uint8)
>>> output = vision.RandAugment(interpolation=Inter.BILINEAR, fill_value=255)(data)
>>> print(output.shape, output.dtype)
(100, 100, 3) uint8
Tutorial Examples: