mindflow.pde.UnsteadyFlowWithLoss

查看源文件
class mindflow.pde.UnsteadyFlowWithLoss(model, t_in=1, t_out=1, loss_fn='mse', data_format='NTCHW')[源代码]

基于数据驱动的非定常流体问题求解的基类。

参数:
  • model (mindspore.nn.Cell) - 用于训练的网络模型。

  • t_in (int) - 初始步长。默认值: 1

  • t_out (int) - 输出步长。 默认值: 1

  • loss_fn (Union[str, Cell]) - 损失函数。默认值: 'mse'

  • data_format (str) - 数据格式。默认值: 'NTCHW'

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore
>>> from mindflow.pde import UnsteadyFlowWithLoss
>>> from mindflow.cell import FNO2D
>>> from mindflow.core import RelativeRMSELoss
...
>>> model = FNO2D(in_channels=1, out_channels=1, resolution=64, modes=12)
>>> problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format='NHWTC')
>>> inputs = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32)
>>> label = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32)
>>> loss = problem.get_loss(inputs, label)
>>> print(loss)
31.999998
get_loss(inputs, labels)[源代码]

计算训练或测试模型的损失。

参数:
  • inputs (Tensor) - 模型输入数据。

  • labels (Tensor) - 样本真实值。

返回:

float,损失值。

step(inputs)[源代码]

支持单步或多步训练。

参数:
  • inputs (Tensor) - 输入数据,数据格式为'NTCHW'或'MHWTC'。

返回:

List(Tensor),格式为'NTCHW'或'MHWTC'的数据。