1+ """Solve Navier-Stokes equations with the Crank-Nicolson method.
2+
3+ Adapted from:
4+ https://github.com/zongyi-li/fourier_neural_operator/blob/master/data_generation/navier_stokes/ns_2d.py
5+ """
6+
17import math
28from enum import Enum
39
1016class Force (str , Enum ):
1117 li = 'li'
1218 random = 'random'
19+ none = 'none'
20+ kolmogorov = 'kolmogorov'
1321
1422
15- # w0: initial vorticity
16- # f: forcing term
17- #visc: viscosity (1/Re)
18- # T: final time
19- # delta_t: internal time-step for solve (descrease if blow-up)
20- # record_steps: number of in-time snapshots to record
2123def solve_navier_stokes_2d (w0 , visc , T , delta_t , record_steps , cycles = None ,
2224 scaling = None , t_scaling = None , force = Force .li ,
2325 varying_force = False ):
26+ """Solve Navier-Stokes equations in 2D using Crank-Nicolson method.
27+
28+ Parameters
29+ ----------
30+ w0 : torch.Tensor
31+ Initial vorticity field.
32+
33+ visc : float
34+ Viscosity (1/Re).
35+
36+ T : float
37+ Final time.
38+
39+ delta_t : float
40+ Internal time-step for solve (descrease if blow-up).
41+
42+ record_steps : int
43+ Number of in-time snapshots to record.
44+
45+ """
2446 seed = np .random .randint (1 , 1000000000 )
2547
2648 # Grid size - must be power of 2
@@ -42,12 +64,21 @@ def solve_navier_stokes_2d(w0, visc, T, delta_t, record_steps, cycles=None,
4264 X , Y = torch .meshgrid (ft , ft , indexing = 'ij' )
4365 f = 0.1 * (torch .sin (2 * math .pi * (X + Y )) +
4466 torch .cos (2 * math .pi * (X + Y )))
67+ elif force == Force .kolmogorov :
68+ ft = torch .linspace (0 , 2 * np .pi , N + 1 , device = w0 .device )
69+ ft = ft [0 :- 1 ]
70+ X , Y = torch .meshgrid (ft , ft , indexing = 'ij' )
71+ f = - 4 * torch .cos (4 * Y )
4572 elif force == Force .random and not varying_force :
4673 f = get_random_force (
4774 w0 .shape [0 ], N , w0 .device , cycles , scaling , 0 , 0 , seed )
75+ else :
76+ f = None
4877
4978 # Forcing to Fourier space
50- if not varying_force :
79+ if force == Force .none :
80+ f_h = 0
81+ elif not varying_force :
5182 f_h = torch .fft .fftn (f , dim = [- 2 , - 1 ], norm = 'backward' )
5283
5384 # If same forcing for the whole batch
@@ -131,7 +162,9 @@ def solve_navier_stokes_2d(w0, visc, T, delta_t, record_steps, cycles=None,
131162 # Dealias
132163 F_h *= dealias
133164
134- if varying_force :
165+ if force == Force .none :
166+ f_h = 0
167+ elif varying_force :
135168 f = get_random_force (w0 .shape [0 ], N , w0 .device , cycles ,
136169 scaling , t , t_scaling , seed )
137170 f_h = torch .fft .fftn (f , dim = [- 2 , - 1 ], norm = 'backward' )
@@ -161,7 +194,10 @@ def solve_navier_stokes_2d(w0, visc, T, delta_t, record_steps, cycles=None,
161194 if varying_force :
162195 f = fs
163196
164- return sol .cpu ().numpy (), f .cpu ().numpy ()
197+ if force != Force .none :
198+ f = f .cpu ().numpy ()
199+
200+ return sol .cpu ().numpy (), f
165201
166202
167203def get_random_force (b , s , device , cycles , scaling , t , t_scaling , seed ):
0 commit comments