mindscience.common.WaveletTransformLoss
- class mindscience.common.WaveletTransformLoss(wave_level=2, regroup=False)[源代码]
多层小波变换损失函数。
- 参数:
wave_level (int, 可选) - 小波变换级数,应为正整数。默认值:
2。regroup (bool, 可选) - 小波变换损失中误差的重新组合方式。默认值:
False。
- 输入:
input (tuple(Tensor, Tensor)) - Tensor 构成的元组。Tensor 的形状为 \((B*H*W/(P*P), P*P*C)\),其中 B 表示批次大小,H、W 分别表示图像的高度和宽度,P 表示 patch 大小,C 表示特征通道。
- 输出:
output (Tensor) - 小波变换损失函数输出。
样例:
>>> import numpy as np >>> from mindscience.common import WaveletTransformLoss >>> import mindspore >>> from mindspore import Tensor >>> net = WaveletTransformLoss(wave_level=2) >>> input1 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> input2 = Tensor(np.ones((32, 288, 768)), mstype.float32) >>> output = net((input1, input2)) >>> print(output) 2.0794415