mindscience.common.WaveletTransformLoss
- class mindscience.common.WaveletTransformLoss(wave_level=2, regroup=False)[source]
The multi-level wavelet transformation losses.
- Parameters
- Inputs:
input (tuple(Tensor, Tensor)) - Tuple of Tensors. Tensor of shape \((B*H*W/(P*P), P*P*C)\), where B denotes the batch size, H, W denotes the height and the width of the image respectively, P denotes the patch size, C denotes the feature channels.
- Outputs:
output (Tensor) - Losses for multi-level wavelet transformation.
Examples
>>> import numpy as np >>> from mindscience.common import WaveletTransformLoss >>> import mindspore >>> from mindspore import Tensor >>> net = WaveletTransformLoss(wave_level=2) >>> input1 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> input2 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> output = net((input1, input2)) >>> print(output) 2.0794415