mindscience.models.neural_operator.FNOBlocks

class mindscience.models.neural_operator.FNOBlocks(in_channels, out_channels, n_modes, resolutions, act='gelu', add_residual=False, dft_compute_dtype=mstype.float32, fno_compute_dtype=mstype.float16)[source]

The FNOBlock, which usually accompanied by a Lifting Layer ahead and a Projection Layer behind, is a part of Fourier Neural Operator. It contains a Fourier Layer and a FNO Skip Layer. The details can be found in Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS.

Parameters
  • in_channels (int) – The number of channels in the input space.

  • out_channels (int) – The number of channels in the output space.

  • n_modes (Union[int, list(int)]) – The number of modes reserved after linear transformation in Fourier Layer.

  • resolutions (Union[int, list(int)]) – The resolutions of the input tensor.

  • act (Union[str, class], optional) – The activation function, could be either str or class. Default: "gelu".

  • add_residual (bool, optional) – Whether to add residual in FNOBlock or not. Default: False.

  • dft_compute_dtype (dtype.Number, optional) – The computation type of DFT in SpectralConvDft. Default: mstype.float32.

  • fno_compute_dtype (dtype.Number, optional) – The computation type of MLP in fno skip. Should be mstype.float32 or mstype.float16. mstype.float32 is recommended for the GPU backend, mstype.float16 is recommended for the Ascend backend. Default: mstype.float16.

Inputs:
  • x (Tensor) - Tensor of shape \((batch\_size, in\_channels, resolution)\).

Outputs:
  • output (Tensor) -Tensor of shape \((batch\_size, out\_channels, resolution)\).

Raises

ValueError – If the dimension of n_modes is not equal to that of resolutions.

Examples

>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.common.dtype as mstype
>>> from mindscience.models.neural_operator.fno import FNOBlocks
>>> data = Tensor(np.ones([2, 3, 128, 128]), mstype.float32)
>>> net = FNOBlocks(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128])
>>> out = net(data)
>>> print(data.shape, out.shape)
(2, 3, 128, 128) (2, 3, 128, 128)