mindspore.dataset.vision.py_transforms.LinearTransformation

class mindspore.dataset.vision.py_transforms.LinearTransformation(transformation_matrix, mean_vector)[source]

Transform the input numpy.ndarray image with a given square transformation matrix and a mean vector. It will first flatten the input image and subtract the mean vector from it, then compute the dot product with the transformation matrix, finally reshape it back to its original shape.

Parameters
  • transformation_matrix (numpy.ndarray) – A square transformation matrix in shape of (D, D), where \(D = C \times H \times W\).

  • mean_vector (numpy.ndarray) – A mean vector in shape of (D,), where \(D = C \times H \times W\).

Raises
  • TypeError – If transformation_matrix is not of type numpy.ndarray.

  • TypeError – If mean_vector is not of type numpy.ndarray.

Supported Platforms:

CPU

Examples

>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> import numpy as np
>>> height, width = 32, 32
>>> dim = 3 * height * width
>>> transformation_matrix = np.ones([dim, dim])
>>> mean_vector = np.zeros(dim)
>>> transforms_list = Compose([py_vision.Decode(),
...                            py_vision.Resize((height,width)),
...                            py_vision.ToTensor(),
...                            py_vision.LinearTransformation(transformation_matrix, mean_vector)])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
...                                                 input_columns="image")