mindspore.dataset.vision.AutoAugment
- class mindspore.dataset.vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, interpolation=Inter.NEAREST, fill_value=0)[source]
Apply AutoAugment data augmentation method based on AutoAugment: Learning Augmentation Strategies from Data . This operation works only with 3-channel RGB images.
- Parameters
policy (AutoAugmentPolicy, optional) –
AutoAugment policies learned on different datasets. Default:
AutoAugmentPolicy.IMAGENET. It can beAutoAugmentPolicy.IMAGENET,AutoAugmentPolicy.CIFAR10,AutoAugmentPolicy.SVHN. Randomly apply 2 operations from a candidate set. See auto augmentation details in AutoAugmentPolicy.AutoAugmentPolicy.IMAGENET, means to apply AutoAugment learned on ImageNet dataset.AutoAugmentPolicy.CIFAR10, means to apply AutoAugment learned on Cifar10 dataset.AutoAugmentPolicy.SVHN, means to apply AutoAugment learned on SVHN dataset.
interpolation (Inter, optional) – Image interpolation method defined by
Inter. Default:Inter.NEAREST.fill_value (Union[int, tuple[int]], optional) – Pixel fill value for the area outside the transformed image. It can be an int or a 3-tuple. If it is a 3-tuple, it is used to fill R, G, B channels respectively. If it is an integer, it is used for all RGB channels. The fill_value values must be in range [0, 255]. Default:
0.
- Raises
TypeError – If policy is not of type
mindspore.dataset.vision.AutoAugmentPolicy.TypeError – If fill_value is not an integer or a tuple of length 3.
RuntimeError – If given tensor shape is not <H, W, C>.
- Supported Platforms:
CPU
Examples
>>> import numpy as np >>> import mindspore.dataset as ds >>> import mindspore.dataset.vision as vision >>> from mindspore.dataset.vision import AutoAugmentPolicy, Inter >>> >>> # Use the transform in dataset pipeline mode >>> transforms_list = [vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, ... interpolation=Inter.NEAREST, ... fill_value=0)] >>> data = np.random.randint(0, 255, size=(1, 100, 100, 3)).astype(np.uint8) >>> numpy_slices_dataset = ds.NumpySlicesDataset(data, ["image"]) >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms_list, input_columns=["image"]) >>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... print(item["image"].shape, item["image"].dtype) ... break (100, 100, 3) uint8 >>> >>> # Use the transform in eager mode >>> data = np.random.randint(0, 255, size=(100, 100, 3)).astype(np.uint8) >>> output = vision.AutoAugment()(data) >>> print(output.shape, output.dtype) (100, 100, 3) uint8
- Tutorial Examples: