mindflow.pde.flow_with_loss 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""flow with loss"""
from mindspore import nn, ops, jit_class
from ..core import get_loss_metric
from ..utils.check_func import check_param_type

[文档]class FlowWithLoss: """ Base class of user-defined data-driven flow prediction problems. Args: model (mindspore.nn.Cell): A training or test model. loss_fn (Union[str, Cell]): Loss function. Default: ``"mse"``. Raises: TypeError: If `modle` or `loss_fn` is not mindspore.nn.Cell. NotImplementedError: If the member function `get_loss` is not implemented. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, model, loss_fn='mse'): self.model = model self.loss_fn = get_loss_metric(loss_fn) if isinstance(loss_fn, str) else loss_fn check_param_type(model, "model", data_type=nn.Cell, exclude_type=bool) check_param_type(self.loss_fn, "loss_fn", data_type=nn.Cell, exclude_type=bool)
[文档] def get_loss(self, inputs, labels): """ Compute the loss of the model. Args: inputs (Tensor): The input data of model. labels (Tensor): True values of the samples. """ raise NotImplementedError
[文档]@jit_class class SteadyFlowWithLoss(FlowWithLoss): """ Base class of user-defined steady data-driven problems. Args: model (mindspore.nn.Cell): A training or test model. loss_fn (Union[str, Cell]): Loss function. Default: ``"mse"``. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindspore import Tensor, nn >>> import mindspore >>> from mindflow.pde import SteadyFlowWithLoss >>> from mindflow.core import RelativeRMSELoss ... >>> class Net(nn.Cell): ... def __init__(self, num_class=10, num_channel=1): ... super(Net, self).__init__() ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') ... self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones') ... self.fc2 = nn.Dense(120, 84, weight_init='ones') ... self.fc3 = nn.Dense(84, num_class, weight_init='ones') ... self.relu = nn.ReLU() ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) ... self.flatten = nn.Flatten() ... ... def construct(self, x): ... x = self.max_pool2d(self.relu(self.conv1(x))) ... x = self.max_pool2d(self.relu(self.conv2(x))) ... x = self.flatten(x) ... x = self.relu(self.fc1(x)) ... x = self.relu(self.fc2(x)) ... x = self.fc3(x) ... return x ... >>> model = Net() >>> problem = SteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss()) ... >>> inputs = Tensor(np.random.randn(32, 1, 32, 32), mindspore.float32) >>> label = Tensor(np.random.randn(32, 10), mindspore.float32) >>> loss = problem.get_loss(inputs, label) >>> print(loss) 680855.1 """
[文档] def get_loss(self, inputs, labels): """ Compute the loss of training or test model. Args: inputs (Tensor): The input data of model. labels (Tensor): True values of the samples. Returns: float, loss value. """ pred = self.model(inputs) loss = self.loss_fn(pred, labels) return loss
[文档]@jit_class class UnsteadyFlowWithLoss(FlowWithLoss): """ Base class of unsteady user-defined data-driven problems. Args: model (mindspore.nn.Cell): A training or test model. t_in (int): Initial time steps. Default: ``1``. t_out (int): Output time steps. Default: ``1``. loss_fn (Union[str, Cell]): Loss function. Default: ``"mse"``. data_format (str): Data format. Default: ``"NTCHW"``. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindspore import Tensor >>> import mindspore >>> from mindflow.pde import UnsteadyFlowWithLoss >>> from mindflow.cell import FNO2D >>> from mindflow.core import RelativeRMSELoss ... >>> model = FNO2D(in_channels=1, out_channels=1, resolution=64, modes=12) >>> problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format='NHWTC') >>> inputs = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32) >>> label = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32) >>> loss = problem.get_loss(inputs, label) >>> print(loss) 31.999998 """ def __init__(self, model, t_in=1, t_out=1, loss_fn='mse', data_format='NTCHW'): super(UnsteadyFlowWithLoss, self).__init__(model, loss_fn) self.t_in = t_in self.t_out = t_out check_param_type(t_in, "t_in", data_type=int, exclude_type=bool) check_param_type(t_out, "t_out", data_type=int, exclude_type=bool) self.data_format = data_format
[文档] def step(self, inputs): """ Support single or multiple time steps training. Args: inputs (Tensor): Input dataset with data format is "NTCHW" or "NHWTC". Returns: List(Tensor), Dataset with data format is "NTCHW" or "NHWTC". """ # change inputs dimension: [bs, t_in, c, x1, x2, ...] -> [bs, t_out, c, x1, x2, ...] pred_list = [] for _ in range(self.t_out): inp = self._flatten(inputs) pred = self.model(inp) if self.data_format == 'NTCHW': pred = pred.expand_dims(axis=1) pred_list.append(pred) if self.t_in > 1: inputs = ops.concat([inputs[:, 1:, ...], pred], axis=1) else: inputs = pred if self.data_format == 'NHWTC': pred = pred.expand_dims(axis=-2) pred_list.append(pred) if self.t_in > 1: inputs = ops.concat([inputs[..., 1:, :], pred], axis=-2) else: inputs = pred if self.data_format == 'NTCHW': pred_list = ops.concat(pred_list, axis=1) if self.data_format == 'NHWTC': pred_list = ops.concat(pred_list, axis=-2) return pred_list
[文档] def get_loss(self, inputs, labels): """ Compute the loss of training or test model. Args: inputs (Tensor): Dataset with data format is "NTCHW" or "NHWTC". labels (Tensor): True values of the samples. Returns: float, loss value. """ # the dimension of inputs: [bs, t_in, c, x1, x2, ...] # the dimension of labels [bs, t_out, c, x1, x2, ...] pred = self.step(inputs) if self.data_format == 'NTCHW': pred = pred[:, -1, ...] labels = labels[:, -1, ...] if self.data_format == 'NHWTC': pred = pred[..., -1, :] labels = labels[..., -1, :] return self.loss_fn(pred, labels)
def _flatten(self, inputs): """ flatten """ # [bs, t_in, c, x1, x2, ...] -> [bs, t_in*c, x1, x2, ...] dim = len(inputs.shape) - 3 if self.data_format == 'NTCHW': inputs = ops.transpose(inputs, tuple([0] + list(range(3, dim + 3)) + [1, 2])) inp_shape = list(inputs.shape) inp_shape = inp_shape[:-2] inp_shape.append(-1) inputs = ops.reshape(inputs, tuple(inp_shape)) if self.data_format == 'NTCHW': inputs = ops.transpose(inputs, tuple([0, dim + 1] + list(range(1, dim + 1)))) return inputs