Skip to content

Commit 9b8ddaf

Browse files
authored
Adding Kovasznay flow doc (#1099)
1 parent 7a96829 commit 9b8ddaf

File tree

3 files changed

+175
-2
lines changed

3 files changed

+175
-2
lines changed

docs/demos/pinn_forward.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ Time-dependent PDEs
4949
pinn_forward/burgers.rar
5050
pinn_forward/allen.cahn
5151
pinn_forward/klein.gordon
52+
pinn_forward/Kovasznay.flow
5253

5354
- `Beltrami flow <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/Beltrami_flow.py>`_
54-
- `Kovasznay flow <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/Kovasznay_flow.py>`_
5555
- `Wave propagation with spatio-temporal multi-scale Fourier feature architecture <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/wave_1d.py>`_
5656
- `Schrodinger equation <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/Schrodinger.ipynb>`_
5757

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
Kovasznay flow
2+
================
3+
4+
Problem setup
5+
--------------
6+
7+
We will solve the Kovasznay flow equation on :math:`\Omega = [0, 1]^2`:
8+
9+
.. math:: u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y}= -\frac{\partial p}{\partial x} + \frac{1}{Re}(\frac{\partial^2u}{\partial x^2} + \frac{\partial^2u}{\partial y^2}),
10+
11+
.. math:: u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y}= -\frac{\partial p}{\partial y} + \frac{1}{Re}(\frac{\partial^2v}{\partial x^2} + \frac{\partial^2v}{\partial y^2}),
12+
13+
with the Dirichlet boundary conditions
14+
15+
.. math:: u(x,y)=0, \qquad (x,y)\in \partial \Omega
16+
17+
The reference solution is :math:`u = 1 - e^{\lambda x} \cos(2\pi y)`, :math:`v = \frac{\lambda}{2\pi}e^{\lambda x} \sin(2\pi x)`, :math:`p =\frac{1}{2}(1 - e^{2\lambda x})`, where :math:`\lambda = \frac{1}{2\nu}-\sqrt{\frac{1}{4\nu^2}+4\pi^2}`.
18+
19+
Implementation
20+
--------------
21+
22+
This description goes through the implementation of a solver for the above described Kovasznay flow step-by-step.
23+
24+
First, the DeepXDE and Numpy modules are imported:
25+
26+
.. code-block:: python
27+
28+
import deepxde as dde
29+
import numpy as np
30+
31+
We begin by defining the parameters of the equation. :math:`\lambda` is defined as l below.
32+
33+
.. code-block:: python
34+
35+
Re = 20
36+
nu = 1 / Re
37+
l = 1 / (2 * nu) - np.sqrt(1 / (4 * nu ** 2) + 4 * np.pi ** 2)
38+
39+
Next, we express the PDE residual of the Kovasznay flow equation in terms of the x-direction, y-direction and continuity equations.
40+
41+
.. code-block:: python
42+
43+
def pde(x, u):
44+
u_vel, v_vel, p = u[:, 0:1], u[:, 1:2], u[:, 2:]
45+
u_vel_x = dde.grad.jacobian(u, x, i=0, j=0)
46+
u_vel_y = dde.grad.jacobian(u, x, i=0, j=1)
47+
u_vel_xx = dde.grad.hessian(u, x, component=0, i=0, j=0)
48+
u_vel_yy = dde.grad.hessian(u, x, component=0, i=1, j=1)
49+
50+
v_vel_x = dde.grad.jacobian(u, x, i=1, j=0)
51+
v_vel_y = dde.grad.jacobian(u, x, i=1, j=1)
52+
v_vel_xx = dde.grad.hessian(u, x, component=1, i=0, j=0)
53+
v_vel_yy = dde.grad.hessian(u, x, component=1, i=1, j=1)
54+
55+
p_x = dde.grad.jacobian(u, x, i=2, j=0)
56+
p_y = dde.grad.jacobian(u, x, i=2, j=1)
57+
58+
momentum_x = (
59+
u_vel * u_vel_x + v_vel * u_vel_y + p_x - 1 / Re * (u_vel_xx + u_vel_yy)
60+
)
61+
momentum_y = (
62+
u_vel * v_vel_x + v_vel * v_vel_y + p_y - 1 / Re * (v_vel_xx + v_vel_yy)
63+
)
64+
continuity = u_vel_x + v_vel_y
65+
66+
return [momentum_x, momentum_y, continuity]
67+
68+
The first argument to ``pde`` is the network input, i.e. the x and y coordinates. The second argument is the network output ``u`` which is comprised of the 3 different output solutions i.e., velocity u, velocity v, and pressure p.
69+
70+
Next, the exact solution of the Kovasznay flow is introduced
71+
72+
.. code-block:: python
73+
74+
def u_func(x):
75+
return 1 - np.exp(l * x[:, 0:1]) * np.cos(2 * np.pi * x[:, 1:2])
76+
77+
def v_func(x):
78+
return l / (2 * np.pi) * np.exp(l * x[:, 0:1]) * np.sin(2 * np.pi * x[:, 1:2])
79+
80+
def p_func(x):
81+
return 1 / 2 * (1 - np.exp(2 * l * x[:, 0:1]))
82+
83+
Next, we consider the boundary condition. ``on_boundary`` is chosen here to use the whole boundary of the computational domain as the boundary condition. We include ``on_boundary`` as the BCs in the ``DirichletBC`` function of DeepXDE.
84+
85+
.. code-block:: python
86+
87+
def boundary_outflow(x, on_boundary):
88+
return on_boundary and np.isclose(x[0], 1)
89+
90+
spatial_domain = dde.geometry.Rectangle(xmin=[-0.5, -0.5], xmax=[1, 1.5])
91+
92+
boundary_condition_u = dde.icbc.DirichletBC(
93+
spatial_domain, u_func, lambda _, on_boundary: on_boundary, component=0
94+
)
95+
boundary_condition_v = dde.icbc.DirichletBC(
96+
spatial_domain, v_func, lambda _, on_boundary: on_boundary, component=1
97+
)
98+
boundary_condition_right_p = dde.icbc.DirichletBC(
99+
spatial_domain, p_func, boundary_outflow, component=2
100+
)
101+
102+
103+
Now, we have specified the geometry, PDE residual, and boundary condition. We then define the ``PDE`` problem as
104+
105+
.. code-block:: python
106+
107+
data = dde.data.PDE(
108+
spatial_domain,
109+
pde,
110+
[boundary_condition_u, boundary_condition_v, boundary_condition_right_p],
111+
num_domain=2601,
112+
num_boundary=400,
113+
num_test=100000,
114+
)
115+
116+
The training residual points imside the domain is 2601, and the number of training points sampled on the boundary is 400. 100000 test points were used in the ``PDE``.
117+
118+
Next, we choose the network. We use a fully connected neural network of 4 hidden layers, 3 outputs and width 50
119+
120+
.. code-block:: python
121+
122+
net = dde.nn.FNN([2] + 4 * [50] + [3], "tanh", "Glorot normal")
123+
124+
The PDE and the network have now been defined. Next, we build a ``Model`` and choose the optimizer and learning rate.
125+
126+
.. code-block:: python
127+
128+
model = dde.Model(data, net)
129+
130+
model.compile("adam", lr=1e-3)
131+
model.train(iterations=30000)
132+
model.compile("L-BFGS")
133+
losshistory, train_state = model.train()
134+
135+
We then train the model for 30000 iterations. After we train the network using ``Adam``, we continue to train the network using L-BFGS to achieve a smaller loss.
136+
137+
The next step is to define a spatial domain with the same number of random points 100000 and use the model created to predict the output.
138+
139+
.. code-block:: python
140+
141+
X = spatial_domain.random_points(100000)
142+
output = model.predict(X)
143+
u_pred = output[:, 0]
144+
v_pred = output[:, 1]
145+
p_pred = output[:, 2]
146+
147+
.. code-block:: python
148+
149+
u_exact = u_func(X).reshape(-1)
150+
v_exact = v_func(X).reshape(-1)
151+
p_exact = p_func(X).reshape(-1)
152+
153+
Next, we compare the predicted output to the exact output and calculate the loss.
154+
155+
.. code-block:: python
156+
157+
f = model.predict(X, operator=pde)
158+
159+
l2_difference_u = dde.metrics.l2_relative_error(u_exact, u_pred)
160+
l2_difference_v = dde.metrics.l2_relative_error(v_exact, v_pred)
161+
l2_difference_p = dde.metrics.l2_relative_error(p_exact, p_pred)
162+
residual = np.mean(np.absolute(f))
163+
164+
print("Mean residual:", residual)
165+
print("L2 relative error in u:", l2_difference_u)
166+
print("L2 relative error in v:", l2_difference_v)
167+
print("L2 relative error in p:", l2_difference_p)
168+
169+
Complete code
170+
--------------
171+
172+
.. literalinclude:: ../../../examples/pinn_forward/Kovasznay_flow.py
173+
:language: python

examples/pinn_forward/Kovasznay_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def boundary_outflow(x, on_boundary):
6262
spatial_domain, p_func, boundary_outflow, component=2
6363
)
6464

65-
data = dde.data.TimePDE(
65+
data = dde.data.PDE(
6666
spatial_domain,
6767
pde,
6868
[boundary_condition_u, boundary_condition_v, boundary_condition_right_p],

0 commit comments

Comments
 (0)