-
Notifications
You must be signed in to change notification settings - Fork 93
Description
Hello,
I’m trying to apply periodic boundary conditions in a setup where the velocity and pressure fields are defined on different triangular meshes with different polynomial orders - specifically:
- Velocity: quadratic (P2) Lagrange elements
- Pressure: linear (P1) Lagrange elements
The setup is based on the standard dolfin mixed finite element example, but in my case, the domain is periodic (a square with opposite sides identified) rather than having Dirichlet boundaries.
I’ve referred to the Poisson periodic boundary condition example as a guide. However, since the velocity and pressure live on different meshes (mesh_u and mesh_p), I’m not sure how to correctly implement the periodic boundary conditions in this context.
I attempted to construct separate periodic constraint matrices (P) for the velocity and pressure spaces and then combine them, but this approach fails because the matrices have different shapes.
Could you please advise on the best way to implement periodic boundary conditions for this mixed-order, multi-mesh setup?
Thank you very much for your help and for maintaining this great project.
Python script
"""This example is similar to
https://fenicsproject.org/olddocs/dolfin/1.5.0/python/demo/documented/stokes-taylor-hood/python/documentation.html
but using periodic boundary conditions on the edges
Also, see jax-fem/applications/stokes/fenics.py
"""
import jax
import jax.numpy as np
import jax.flatten_util
import numpy as onp
import os
import meshio
from jax_fem.solver import solver
from jax_fem.generate_mesh import Mesh, get_meshio_cell_type, rectangle_mesh
from jax_fem.utils import save_sol, modify_vtu_file
from jax_fem.problem import Problem
"""
Implement boundary conditions as per: https://github.com/deepmodeling/jax-fem/blob/main/applications/periodic_bc/example.py
"""
def periodic_boundary_conditions(periodic_bc_info, mesh, vec):
"""
Construct the 'P' matrix
Reference: https://fenics2021.com/slides/dokken.pdf
"""
p_node_inds_list_A = []
p_node_inds_list_B = []
p_vec_inds_list = []
location_fns_A, location_fns_B, mappings, vecs = periodic_bc_info
for i in range(len(location_fns_A)):
node_inds_A = np.argwhere(jax.vmap(location_fns_A[i])(mesh.points)).reshape(-1)
node_inds_B = np.argwhere(jax.vmap(location_fns_B[i])(mesh.points)).reshape(-1)
points_set_A = mesh.points[node_inds_A]
points_set_B = mesh.points[node_inds_B]
EPS = 1e-5
node_inds_B_ordered = []
for node_ind in node_inds_A:
point_A = mesh.points[node_ind]
dist = np.linalg.norm(mappings[i](point_A)[None, :] - points_set_B, axis=-1)
node_ind_B_ordered = node_inds_B[np.argwhere(dist < EPS)].reshape(-1)
node_inds_B_ordered.append(node_ind_B_ordered)
node_inds_B_ordered = np.array(node_inds_B_ordered).reshape(-1)
vec_inds = np.ones_like(node_inds_A, dtype=np.int32) * vecs[i]
p_node_inds_list_A.append(node_inds_A)
p_node_inds_list_B.append(node_inds_B_ordered)
p_vec_inds_list.append(vec_inds)
assert len(node_inds_A) == len(node_inds_B_ordered)
# For mutiple variables (e.g, stokes flow, u-p coupling), offset will be nonzero.
offset = 0
inds_A_list = []
inds_B_list = []
for i in range(len(p_node_inds_list_A)):
inds_A_list.append(np.array(p_node_inds_list_A[i] * vec + p_vec_inds_list[i] + offset, dtype=np.int32))
inds_B_list.append(np.array(p_node_inds_list_B[i] * vec + p_vec_inds_list[i] + offset, dtype=np.int32))
inds_A = np.hstack(inds_A_list)
inds_B = np.hstack(inds_B_list)
num_total_nodes = len(mesh.points)
num_total_dofs = num_total_nodes * vec
N = num_total_dofs
M = num_total_dofs - len(inds_B)
# The use of 'reduced_inds_map' seems to be a smart way to construct P_mat
reduced_inds_map = np.ones(num_total_dofs, dtype=np.int32)
reduced_inds_map.at[inds_B].set(-(inds_A + 1))
reduced_inds_map.at[reduced_inds_map == 1].set(np.arange(M))
I = []
J = []
V = []
for i in range(num_total_dofs):
I.append(i)
V.append(1.)
if reduced_inds_map[i] < 0:
J.append(reduced_inds_map[-reduced_inds_map[i] - 1])
else:
J.append(reduced_inds_map[i])
P_mat = scipy.sparse.csr_array((np.array(V), (np.array(I), np.array(J))), shape=(N, M))
return P_mat
class StokesFlow(Problem):
def custom_init(self):
self.fe_u = self.fes[0]
self.fe_p = self.fes[1]
def get_universal_kernel(self):
def universal_kernel(cell_sol_flat, x, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars):
# cell_sol_flat: (num_nodes*vec + ...,)
# cell_sol_list: [(num_nodes, vec), ...]
# x: (num_quads, dim)
# cell_shape_grads: (num_quads, num_nodes + ..., dim)
# cell_JxW: (num_vars, num_quads)
# cell_v_grads_JxW: (num_quads, num_nodes + ..., 1, dim)
cell_sol_list = self.unflatten_fn_dof(cell_sol_flat)
# cell_sol_u: (num_nodes_u, vec), cell_sol_p: (num_nodes, vec)
cell_sol_u, cell_sol_p = cell_sol_list
cell_shape_grads_list = [cell_shape_grads[:, self.num_nodes_cumsum[i]: self.num_nodes_cumsum[i+1], :]
for i in range(self.num_vars)]
cell_shape_grads_u, cell_shape_grads_p = cell_shape_grads_list
cell_v_grads_JxW_list = [cell_v_grads_JxW[:, self.num_nodes_cumsum[i]: self.num_nodes_cumsum[i+1], :, :]
for i in range(self.num_vars)]
cell_v_grads_JxW_u, cell_v_grads_JxW_p = cell_v_grads_JxW_list
cell_JxW_u, cell_JxW_p = cell_JxW[0], cell_JxW[1]
# Handles the term `inner(grad(u), grad(v)*dx`
# (1, num_nodes_u, vec_u, 1) * (num_quads, num_nodes_u, 1, dim) -> (num_quads, num_nodes_u, vec_u, dim)
u_grads = cell_sol_u[None, :, :, None] * cell_shape_grads_u[:, :, None, :]
u_grads = np.sum(u_grads, axis=1) # (num_quads, vec_u, dim)
# (num_quads, num_nodes_u, vec_u, dim) -> (num_nodes_u, vec_u)
val1 = np.sum(u_grads[:, None, :, :] * cell_v_grads_JxW_u, axis=(0, -1))
# Handles the term `div(v)*p*dx`
# (1, num_nodes_p, vec_p) * (num_quads, num_nodes_p, 1) -> (num_quads, num_nodes_p, vec_p) -> (num_quads, vec_p)
p = np.sum(cell_sol_p[None, :, :] * self.fe_p.shape_vals[:, :, None], axis=1)[:, 0]
# Be careful about this step to find divergence!
# (num_quads, num_nodes_u, 1, dim) -> (num_quads, num_nodes_u, vec_u)
div_v = cell_v_grads_JxW_u[:, :, 0, :]
# (num_quads, 1, 1) * (num_quads, num_nodes_u, vec_u) -> (num_nodes_u, vec_u)
val2 = np.sum(p[:, None, None] * div_v, axis=0)
# Handles the term `q*div(u))*dx`
# (num_quads, vec_u, dim) -> (num_quads, )
div_u = u_grads[:, 0, 0] + u_grads[:, 1, 1]
# (num_quads, 1) * (num_quads, num_nodes_p) * (num_quads, 1) -> (num_nodes_p,) -> (num_nodes_p, vec_p)
val3 = np.sum(div_u[:, None] * self.fe_p.shape_vals * cell_JxW_p[:, None], axis=0)[:, None]
weak_form = [val1 + val2, val3] # [(num_nodes, vec), ...]
return jax.flatten_util.ravel_pytree(weak_form)[0]
return universal_kernel
def configure_Dirichlet_BC_for_dolphin(self):
"""FEniCS dolfin example has interior boundaries that can't be directly imported
Here, we manually find the boundaries containing the 'dolphin' contour
"""
cells_u = self.fe_u.cells
points_u = self.fe_u.points
v, c = onp.unique(cells_u, return_counts=True)
boundary_mid_nodes = v[c==1]
def ind_map(ind):
assert ind == 3 or ind == 4 or ind == 5, f"Wrong face ind!"
if ind == 3:
return 0, 1
if ind == 4:
return 1, 2
if ind == 5:
return 2, 0
boundary_inds = []
for cell in cells_u:
for c in cell:
if c in boundary_mid_nodes:
pos_ind = onp.argwhere(cell==c)[0, 0]
node1, node2 = ind_map(pos_ind)
boundary_inds += [c, cell[node1], cell[node2]]
boundary_inds = onp.array(list(set(boundary_inds)))
valid_inds = onp.argwhere((points_u[boundary_inds][:, 0] < 1 - 1e-5) &
(points_u[boundary_inds][:, 0] > 1e-5) ).reshape(-1)
boundary_inds = boundary_inds[valid_inds]
vec_inds_1 = onp.zeros_like(boundary_inds, dtype=onp.int32)
vec_inds_2 = onp.ones_like(boundary_inds, dtype=onp.int32)
values = onp.zeros_like(boundary_inds, dtype=onp.float32)
self.fes[0].node_inds_list += [boundary_inds]*2
self.fes[0].vec_inds_list += [vec_inds_1, vec_inds_2]
self.fes[0].vals_list += [values]*2
# A little program to find orientation of 3 points
# Coplied from https://www.geeksforgeeks.org/orientation-3-ordered-points/
class Point:
# to store the x and y coordinates of a point
def __init__(self, x, y):
self.x = x
self.y = y
def orientation(p1, p2, p3):
# To find the orientation of an ordered triplet (p1,p2,p3) function returns the following values:
# 0 : Collinear points
# 1 : Clockwise points
# 2 : Counterclockwise
val = (float(p2.y - p1.y) * (p3.x - p2.x)) - (float(p2.x - p1.x) * (p3.y - p2.y))
if (val > 0):
# Clockwise orientation
return 1
elif (val < 0):
# Counterclockwise orientation
return 2
else:
# Collinear orientation
return 0
def transform_cells(cells, points, ele_type):
"""FEniCS triangular mesh is not always counter-clockwise. We need to fix it.
"""
new_cells = []
for cell in cells:
pts = points[cell[:3]]
p1 = Point(pts[0, 0], pts[0, 1])
p2 = Point(pts[1, 0], pts[1, 1])
p3 = Point(pts[2, 0], pts[2, 1])
o = orientation(p1, p2, p3)
if (o == 0):
print(f"Linear")
print(f"Can't be linear, somethign wrong!")
exit()
elif (o == 1):
# print(f"Clockwise")
if ele_type == 'TRI3':
new_celll = cell[[0, 2, 1]]
elif ele_type == 'TRI6':
new_celll = cell[[0, 2, 1, 5, 4, 3]]
else:
print(f"Wrong element type, can't be transformed")
exit()
new_cells.append(new_celll)
else:
# print(f"CounterClockwise")
new_cells.append(cell)
return onp.stack(new_cells)
def problem():
input_dir = os.path.join(os.path.dirname(__file__), 'input')
output_dir = os.path.join(os.path.dirname(__file__), 'output')
# First run `python -m applications.stokes.fenics` to generate these numpy files
ele_type_u = 'TRI6'
points_u = onp.load(os.path.join(input_dir, f'numpy/points_u.npy'))
cells_u = onp.load(os.path.join(input_dir, f'numpy/cells_u.npy'))
cells_u = transform_cells(cells_u, points_u, ele_type_u)
mesh_u = Mesh(points_u, cells_u)
ele_type_p = 'TRI3'
points_p = onp.load(os.path.join(input_dir, f'numpy/points_p.npy'))
cells_p = onp.load(os.path.join(input_dir, f'numpy/cells_p.npy'))
cells_p = transform_cells(cells_p, points_p, ele_type_p)
mesh_p = Mesh(points_p, cells_p)
"""
Periodic and Dirichlet boundary conditions
"""
# Boundary selector functions (use the same convention as fenics)
def left(point):
return np.isclose(point[0], 0.0, atol=1e-5)
def right(point):
return np.isclose(point[0], 1.0, atol=1e-5)
def bottom(point):
return np.isclose(point[1], 0.0, atol=1e-5)
def top(point):
return np.isclose(point[1], 1.0, atol=1e-5)
# Mapping functions: given a point on A, return its corresponding point on B
def mapping_x(point_A):
# map a point at x=0 to x=1 or vice-versa. We follow the pattern used in the Poisson example:
# mapping function takes a NumPy array and returns another NumPy array.
return point_A + np.array([1.0, 0.0])
def mapping_y(point_A):
# map a point at y=0 -> y=1
return point_A + np.array([0.0, 1.0])
# -------------------------
# Create periodic info for velocity mesh (vec = 2)
# -------------------------
# We need to apply periodic identification for both velocity components (u_x and u_y).
# The helper `periodic_boundary_conditions` expects lists: location_fns_A, location_fns_B, mappings, vecs.
# To apply mapping_x to both components, include the mapping twice with vecs [0, 1],
# and likewise for mapping_y.
location_fns_A_u = [left, left, bottom, bottom] # left->right (for comp 0 and 1), bottom->top (for comp 0 and 1)
location_fns_B_u = [right, right, top, top]
mappings_u = [mapping_x, mapping_x, mapping_y, mapping_y]
vecs_u = [0, 1, 0, 1] # for each mapping entry, which component (0 or 1) to constrain
# Build P_mat for velocity field
P_mat_u = periodic_boundary_conditions([location_fns_A_u, location_fns_B_u, mappings_u, vecs_u], mesh_u, vec=2)
# -------------------------
# Create periodic info for pressure mesh (vec = 1)
# -------------------------
# Pressure is scalar; map left->right and bottom->top once each.
location_fns_A_p = [left, bottom]
location_fns_B_p = [right, top]
mappings_p = [mapping_x, mapping_y]
vecs_p = [0, 0] # pressure has only component 0
# Build P_mat for pressure field
P_mat_p = periodic_boundary_conditions([location_fns_A_p, location_fns_B_p, mappings_p, vecs_p], mesh_p, vec=1)
# -------------------------
# Combine into block-diagonal periodic operator for the mixed problem
# -------------------------
# P_mat_u is (N_u, M_u) where N_u = num_nodes_u * 2, P_mat_p is (N_p, M_p)
# The full system unknown ordering is (all u DOFs, then all p DOFs).
# So make a block-diagonal P_mat for the full system:
P_mat_full = scipy.sparse.block_diag((P_mat_u, P_mat_p), format='csr')
problem = StokesFlow([mesh_u, mesh_p], vec=[2, 1], dim=2, ele_type=[ele_type_u, ele_type_p], gauss_order=[2, 2],
dirichlet_bc_info=[dirichlet_bc_info1, dirichlet_bc_info2])
problem.configure_Dirichlet_BC_for_dolphin()
problem.P_mat = P_mat_full
# Preconditioning is very important for a problem like this. See discussions:
# https://fenicsproject.discourse.group/t/steady-stokes-equation-3d-dolfinx/9709/4
# https://fenicsproject.org/olddocs/dolfin/2019.1.0/python/demos/stokes-iterative/demo_stokes-iterative.py.html
# Here, we choose 'ksp_type' to be 'tfqmr' and 'pc_type' to be 'lu'
# But see a variety of other choices in PETSc:
# https://www.mcs.anl.gov/petsc/petsc4py-current/docs/apiref/index.html
sol_list = solver(problem, solver_options={'petsc_solver': {'ksp_type': 'tfqmr', 'pc_type': 'lu'}})
# Alternatively, you may use the UMFPACK solver
# sol_list = solver(problem, solver_options={'umfpack_solver': {}})
u, p = sol_list
print(f"Max u = {onp.max(u)}, Min u = {onp.min(u)}")
print(f"Max p = {onp.max(p)}, Min p = {onp.min(p)}")
vtk_path_u = os.path.join(output_dir, f'vtk/jax-fem_velocity.vtu')
vtk_path_p = os.path.join(output_dir, f'vtk/jax-fem_pressure.vtu')
sol_to_save = np.hstack((sol_list[0], np.zeros((len(sol_list[0]), 1))))
save_sol(problem.fes[0], sol_to_save, vtk_path_u)
save_sol(problem.fes[1], sol_list[1], vtk_path_p)
if __name__ == "__main__":
problem()Please note that you will need the same input folder as for the Stokes example