Definition of Symbolic PDE Based on MindFlow

DownloadNotebookDownloadCodeViewSource

Partial differential equations (PDEs) play an important role in engineering application since most of the physics governing natural or man-made complex systems are described by PDEs. Here, the definition of PDE based on MindFlow with sympy and the model construction based on Physics-Informed Neural Networks (PINNs) are introduced. Using mindflow.pde.PDEWithLoss, one can describe the partial differential equation in symbolic form, and calculates the loss of all equations easily. This makes the equation simple, easy to understand and expand. One can inherit mindflow.pde.PDEWithLoss to realize user-defined partial differential equations.

Example of PDE: Navier-Stokes equation

The Navier-Stokes equation, referred to as N-S equation, is a classical partial differential equation in the field of fluid mechanics. In the case of viscous incompressibility, the dimensionless N-S equation has the following form:

\[\frac{\partial u}{\partial x} + \frac{\partial v}{\partial y} = 0\]
\[\frac{\partial u} {\partial t} + u \frac{\partial u}{\partial x} + v \frac{\partial u}{\partial y} + \frac{\partial p}{\partial x} - \frac{1} {Re} (\frac{\partial^2u}{\partial x^2} + \frac{\partial^2u}{\partial y^2}) = 0\]
\[\frac{\partial v} {\partial t} + u \frac{\partial v}{\partial x} + v \frac{\partial v}{\partial y} + \frac{\partial p}{\partial y} - \frac{1} {Re} (\frac{\partial^2v}{\partial x^2} + \frac{\partial^2v}{\partial y^2}) = 0\]

where Re stands for Reynolds number.

Symbol Declaration

[1]:
from sympy import symbols, Function

x, y, t = symbols('x y t')
u = Function('u')(x, y, t)
v = Function('v')(x, y, t)
p = Function('p')(x, y, t)

# independent variables
in_vars = [x, y, t]
print(in_vars)

# dependent variables
out_vars = [u, v, p]
print(out_vars)
[x, y, t]
[u(x, y, t), v(x, y, t), p(x, y, t)]

PDE in Symbolic Forms

Define PDEs using the symbol declaration above.

[2]:
import numpy as np
from sympy import diff

Governing Equations

[3]:
# Consider Reynolds number is 100
re = 100
number = np.float32(1.0 / re)

# X Momemtum
momentum_x = u.diff(t) + u * u.diff(x) + v * u.diff(y) + p.diff(x) - number*(diff(u, (x, 2)) + diff(u, (y, 2)))
print(momentum_x)

# Y Momemtum
momentum_y = v.diff(t) + u * v.diff(x) + v * v.diff(y) + p.diff(y) - number*(diff(v, (x, 2)) + diff(v, (y, 2)))
print(momentum_y)

# continuty
continuty = u.diff(x) + v.diff(y)
print(continuty)
u(x, y, t)*Derivative(u(x, y, t), x) + v(x, y, t)*Derivative(u(x, y, t), y) + Derivative(p(x, y, t), x) + Derivative(u(x, y, t), t) - 0.00999999977648258*Derivative(u(x, y, t), (x, 2)) - 0.00999999977648258*Derivative(u(x, y, t), (y, 2))
u(x, y, t)*Derivative(v(x, y, t), x) + v(x, y, t)*Derivative(v(x, y, t), y) + Derivative(p(x, y, t), y) + Derivative(v(x, y, t), t) - 0.00999999977648258*Derivative(v(x, y, t), (x, 2)) - 0.00999999977648258*Derivative(v(x, y, t), (y, 2))
Derivative(u(x, y, t), x) + Derivative(v(x, y, t), y)

Boundary Condition

[4]:
bc_u = u
print(bc_u)

bc_v = v
print(bc_v)
u(x, y, t)
v(x, y, t)

Initial Condition

[5]:
ic_u = u
print(bc_u)

ic_v = v
print(ic_v)

ic_p = p
print(ic_p)
u(x, y, t)
v(x, y, t)
p(x, y, t)

Example of Problem Modeling

The following CylinderFlow defines the 2d unsteady flow passing over a cylinder problem. Specifically, it includes 3 parts defined above: governing equation, initial condition and boundary conditions.

The NavierStokes base class, the input and output variables, and the governing equations are defined.

If you want to define other governing equations, you can overwrite the NavierStokes base class.

[6]:
from mindspore import nn
from mindflow.pde import PDEWithLoss


class NavierStokes(PDEWithLoss):
    def __init__(self, model, re=100, loss_fn=nn.MSELoss()):
        self.number = np.float32(1.0 / re)
        self.x, self.y, self.t = symbols('x y t')
        self.u = Function('u')(self.x, self.y, self.t)
        self.v = Function('v')(self.x, self.y, self.t)
        self.p = Function('p')(self.x, self.y, self.t)
        self.in_vars = [self.x, self.y, self.t]
        self.out_vars = [self.u, self.v, self.p]
        super(NavierStokes, self).__init__(model, self.in_vars, self.out_vars)
        self.loss_fn = loss_fn

    def pde(self):
        momentum_x = self.u.diff(self.t) + self.u * self.u.diff(self.x) + self.v * self.u.diff(self.y) +\
            self.p.diff(self.x) - self.number * (diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)))
        momentum_y = self.v.diff(self.t) + self.u * self.v.diff(self.x) + self.v * self.v.diff(self.y) +\
            self.p.diff(self.y) - self.number * (diff(self.v, (self.x, 2)) + diff(self.v, (self.y, 2)))
        continuty = self.u.diff(self.x) + self.v.diff(self.y)

        equations = {"momentum_x": momentum_x, "momentum_y": momentum_y, "continuty": continuty}
        return equations

Next, we use the NavierStokes base class to define initial and boundary conditions, as well as loss functions.

[7]:
from mindspore import nn, ops, Tensor
from mindspore import dtype as mstype
from mindflow.pde import PDEWithLoss, NavierStokes, sympy_to_mindspore


class CylinderFlow(NavierStokes):
    def __init__(self, model, re=100, loss_fn=nn.MSELoss()):
        super(CylinderFlow, self).__init__(model, re=re, loss_fn=loss_fn)
        self.ic_nodes = sympy_to_mindspore(self.ic(), self.in_vars, self.out_vars)
        self.bc_nodes = sympy_to_mindspore(self.bc(), self.in_vars, self.out_vars)

    def bc(self):
        bc_u = self.u
        bc_v = self.v
        equations = {"bc_u": bc_u, "bc_v": bc_v}
        return equations

    def ic(self):
        ic_u = self.u
        ic_v = self.v
        ic_p = self.p
        equations = {"ic_u": ic_u, "ic_v": ic_v, "ic_p": ic_p}
        return equations

    def get_loss(self, pde_data, bc_data, bc_label, ic_data, ic_label):
        pde_res = self.parse_node(self.pde_nodes, inputs=pde_data)
        pde_residual = ops.Concat(1)(pde_res)
        pde_loss = self.loss_fn(pde_residual, Tensor(np.array([0.0]).astype(np.float32), mstype.float32))

        ic_res = self.parse_node(self.ic_nodes, inputs=ic_data)
        ic_residual = ops.Concat(1)(ic_res)
        ic_loss = self.loss_fn(ic_residual, ic_label)

        bc_res = self.parse_node(self.bc_nodes, inputs=bc_data)
        bc_residual = ops.Concat(1)(bc_res)
        bc_loss = self.loss_fn(bc_residual, bc_label)

        return pde_loss + ic_loss + bc_loss

For details, see PINNs-based solution for flow past a cylinder.

Example of Neumann Boundary Condition Definition

In mathematics, Neumann boundary conditions are also called “the second kind of boundary conditions” of ordinary differential equations or partial differential equations. The Neumann boundary condition specifies the differential at the boundary of the solution of the differential equation.

The following Poisson2D problem defines Dirichlet boundary conditions(bc_outer) and Neumann boundary conditions(bc_inner).

[8]:
import sympy
from mindflow.pde import Poisson

class Poisson2D(Poisson):
    def __init__(self, model, loss_fn=nn.MSELoss()):
        super(Poisson2D, self).__init__(model, loss_fn=loss_fn)
        self.bc_outer_nodes = sympy_to_mindspore(self.bc_outer(), self.in_vars, self.out_vars)
        self.bc_inner_nodes = sympy_to_mindspore(self.bc_inner(), self.in_vars, self.out_vars)

    def bc_outer(self):
        bc_outer_eq = self.u
        equations = {"bc_outer": bc_outer_eq}
        return equations

    def bc_inner(self):
        bc_inner_eq = sympy.Derivative(self.u, self.normal) - 0.5
        equations = {"bc_inner": bc_inner_eq}
        return equations

    def get_loss(self, pde_data, bc_outer_data, bc_inner_data, bc_inner_normal):
        pde_res = self.parse_node(self.pde_nodes, inputs=pde_data)
        pde_loss = self.loss_fn(pde_res[0], Tensor(np.array([0.0]), mstype.float32))

        bc_inner_res = self.parse_node(self.bc_inner_nodes, inputs=bc_inner_data, norm=bc_inner_normal)
        bc_inner_loss = self.loss_fn(bc_inner_res[0], Tensor(np.array([0.0]), mstype.float32))

        bc_outer_res = self.parse_node(self.bc_outer_nodes, inputs=bc_outer_data)
        bc_outer_loss = self.loss_fn(bc_outer_res[0], Tensor(np.array([0.0]), mstype.float32))

        return pde_loss + bc_inner_loss + bc_outer_loss

For details, see Solve Poisson’s Equation on a Ring based on PINNS.