mindscience.common.WaveletTransformLoss

class mindscience.common.WaveletTransformLoss(wave_level=2, regroup=False)[source]

The multi-level wavelet transformation losses.

Parameters
  • wave_level (int, optional) – The number of the wavelet transformation levels, should be positive integer. Default: 2.

  • regroup (bool, optional) – The regroup error combination form of the wavelet transformation losses. Default: "False".

Inputs:
  • input (tuple(Tensor, Tensor)) - Tuple of Tensors. Tensor of shape \((B*H*W/(P*P), P*P*C)\), where B denotes the batch size, H, W denotes the height and the width of the image respectively, P denotes the patch size, C denotes the feature channels.

Outputs:
  • output (Tensor) - Losses for multi-level wavelet transformation.

Examples

>>> import numpy as np
>>> from mindscience.common import WaveletTransformLoss
>>> import mindspore
>>> from mindspore import Tensor
>>> net = WaveletTransformLoss(wave_level=2)
>>> input1 = Tensor(np.ones((32, 288, 768)), mstype.float32)
>>> input2 = Tensor(np.ones((32, 288, 768)), mstype.float32)
>>> output = net((input1, input2))
>>> print(output)
2.0794415