Skip to content

Commit 572c983

Browse files
mehdiataeiCopilot
andauthored
Immersed Boundary Method Implementation with Examples and USD Support (#115)
* IBM * Added windtunnel example * Added profiling * Fixed IBM force and added visualization * Added sphere with accurate drag calc * Separated the usd modules into utils * fused two kernels now ibm is only 3(or 4) kernels only * Improved convergence computation for ibm * renamed wind_turbine * Added moving airfoil example * final cleanups * Update xlb/operator/postprocess/grid_to_point.py Co-authored-by: Copilot <[email protected]> * Update xlb/operator/stepper/ibm_stepper.py Co-authored-by: Copilot <[email protected]> * Update examples/ibm/windtunnel_ibm.py Co-authored-by: Copilot <[email protected]> * Update examples/ibm/sphere_ibm.py Co-authored-by: Copilot <[email protected]> * ruff format * removed max_unroll * removed redundant folder * Added docstring * added ref and comments --------- Co-authored-by: Copilot <[email protected]>
1 parent f7e1b9d commit 572c983

File tree

18 files changed

+3414
-4
lines changed

18 files changed

+3414
-4
lines changed

.gitignore

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,12 @@ checkpoints/*
154154
# Ignore Python packaging build directories
155155
dist/
156156
build/
157-
*.egg-info/
157+
*.egg-info/
158+
159+
# USD files
160+
*.usd
161+
*.usda
162+
*.usdc
163+
*.usd.gz
164+
*.usd.zip
165+
*.usd.bz2

examples/ibm/airfoil_ibm.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import numpy as np
2+
import trimesh
3+
import jax.numpy as jnp
4+
import matplotlib.pyplot as plt
5+
import warp as wp
6+
import xlb
7+
from xlb.compute_backend import ComputeBackend
8+
from xlb.precision_policy import PrecisionPolicy
9+
from xlb.operator.stepper import IBMStepper
10+
from xlb.operator.boundary_condition import FullwayBounceBackBC, RegularizedBC, ExtrapolationOutflowBC
11+
from xlb.operator.macroscopic import Macroscopic
12+
from xlb.helper.ibm_helper import prepare_immersed_boundary
13+
from xlb.grid import grid_factory
14+
from xlb.utils import save_image
15+
16+
17+
def generate_naca_profile(chord_length, thickness_ratio, n_points=400):
18+
x = np.linspace(0.0, chord_length, n_points)
19+
x_c = x / chord_length
20+
coeffs = np.array([0.2969, -0.1260, -0.3516, 0.2843, -0.1015], dtype=np.float64)
21+
powers = np.array([0.5, 1.0, 2.0, 3.0, 4.0], dtype=np.float64)
22+
terms = np.stack([x_c**p for p in powers], axis=0)
23+
thickness = 5.0 * thickness_ratio * chord_length * np.tensordot(coeffs, terms, axes=1)
24+
upper = np.stack([x, thickness], axis=1)
25+
lower = np.stack([x[::-1], -thickness[::-1]], axis=1)
26+
profile = np.vstack([upper, lower[1:-1]])
27+
profile[:, 0] -= chord_length * 0.5
28+
return profile
29+
30+
31+
def extrude_profile_to_mesh(profile, span_length):
32+
lower_z = -0.5 * span_length
33+
upper_z = 0.5 * span_length
34+
lower = np.concatenate([profile, np.full((profile.shape[0], 1), lower_z)], axis=1)
35+
upper = np.concatenate([profile, np.full((profile.shape[0], 1), upper_z)], axis=1)
36+
vertices = np.vstack([lower, upper])
37+
faces = []
38+
n = profile.shape[0]
39+
for i in range(1, n - 1):
40+
faces.append([0, i + 1, i])
41+
top_offset = n
42+
for i in range(1, n - 1):
43+
faces.append([top_offset, top_offset + i, top_offset + i + 1])
44+
for i in range(n):
45+
j = (i + 1) % n
46+
faces.append([i, j, top_offset + j])
47+
faces.append([i, top_offset + j, top_offset + i])
48+
return trimesh.Trimesh(vertices=vertices, faces=np.array(faces, dtype=np.int64), process=False)
49+
50+
51+
def create_airfoil_mesh(chord_length, thickness_ratio, span_length, n_points=400):
52+
profile = generate_naca_profile(chord_length, thickness_ratio, n_points)
53+
mesh = extrude_profile_to_mesh(profile, span_length)
54+
return mesh
55+
56+
57+
def define_boundary_indices(grid, velocity_set):
58+
box = grid.bounding_box_indices()
59+
box_no_edge = grid.bounding_box_indices(remove_edges=True)
60+
inlet = box_no_edge["left"]
61+
outlet = box_no_edge["right"]
62+
walls = [box["front"][i] + box["back"][i] + box["top"][i] + box["bottom"][i] for i in range(velocity_set.d)]
63+
walls = np.unique(np.array(walls), axis=-1).tolist()
64+
return inlet, outlet, walls
65+
66+
67+
def bc_profile(precision_policy, grid_shape, u_max):
68+
dtype = precision_policy.store_precision.wp_dtype
69+
u_max_d = dtype(u_max)
70+
71+
@wp.func
72+
def bc_profile_warp(index: wp.vec3i):
73+
return wp.vec(dtype(u_max_d), length=1)
74+
75+
return bc_profile_warp
76+
77+
78+
def setup_boundary_conditions(grid, velocity_set, precision_policy, grid_shape, u_max):
79+
inlet, outlet, walls = define_boundary_indices(grid, velocity_set)
80+
bc_inlet = RegularizedBC("velocity", indices=inlet, profile=bc_profile(precision_policy, grid_shape, u_max))
81+
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
82+
bc_walls = FullwayBounceBackBC(indices=walls)
83+
return [bc_walls, bc_inlet, bc_outlet]
84+
85+
86+
def setup_stepper(grid, boundary_conditions, ibm_max_iterations=2, ibm_tolerance=1e-5, ibm_relaxation=1.0):
87+
return IBMStepper(
88+
grid=grid,
89+
boundary_conditions=boundary_conditions,
90+
collision_type="KBC",
91+
ibm_max_iterations=ibm_max_iterations,
92+
ibm_tolerance=ibm_tolerance,
93+
ibm_relaxation=ibm_relaxation,
94+
)
95+
96+
97+
def calculate_force_coefficients(lag_forces, areas_np, reference_velocity, reference_area):
98+
forces_np = lag_forces.numpy()
99+
weighted = forces_np * areas_np[:, None]
100+
total_force = -np.sum(weighted, axis=0)
101+
dynamic_pressure = 0.5 * reference_velocity**2
102+
denom = dynamic_pressure * reference_area if dynamic_pressure * reference_area != 0.0 else 1.0
103+
cd = total_force[0] / denom
104+
cl = total_force[1] / denom
105+
return cd, cl, total_force
106+
107+
108+
def post_process(
109+
step,
110+
post_process_interval,
111+
f_current,
112+
precision_policy,
113+
grid_shape,
114+
lag_forces,
115+
cd_values,
116+
cl_values,
117+
reference_velocity,
118+
reference_area,
119+
areas_np,
120+
):
121+
if not isinstance(f_current, jnp.ndarray):
122+
f_jax = wp.to_jax(f_current)
123+
else:
124+
f_jax = f_current
125+
macro_jax = Macroscopic(
126+
compute_backend=ComputeBackend.JAX,
127+
precision_policy=precision_policy,
128+
velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
129+
)
130+
rho, u = macro_jax(f_jax)
131+
u = u[:, 1:-1, 1:-1, 1:-1]
132+
fields = {
133+
"u_magnitude": np.sqrt(u[0] ** 2.0 + u[1] ** 2.0 + u[2] ** 2.0),
134+
"u_x": u[0],
135+
"u_y": u[1],
136+
"u_z": u[2],
137+
}
138+
slice_idz = grid_shape[2] // 2
139+
save_image(fields["u_magnitude"][:, :, slice_idz], timestep=step)
140+
cd, cl, total_force = calculate_force_coefficients(lag_forces, areas_np, reference_velocity, reference_area)
141+
cd_values.append((step, float(cd)))
142+
cl_values.append((step, float(cl)))
143+
if step % post_process_interval == 0:
144+
window = 10
145+
if len(cd_values) >= window:
146+
avg_cd = float(np.mean([v for _, v in cd_values[-window:]]))
147+
avg_cl = float(np.mean([v for _, v in cl_values[-window:]]))
148+
else:
149+
avg_cd = float(np.mean([v for _, v in cd_values]))
150+
avg_cl = float(np.mean([v for _, v in cl_values]))
151+
print(
152+
f"Step {step}: Cd = {cd:.6f}, Cl = {cl:.6f}, Cd(avg{window}) = {avg_cd:.6f}, Cl(avg{window}) = {avg_cl:.6f}, "
153+
f"Fx = {total_force[0]:.6f}, Fy = {total_force[1]:.6f}"
154+
)
155+
156+
157+
def save_force_coefficients(cd_values, cl_values, filename):
158+
with open(filename, "w") as f:
159+
f.write("timestep,cd,cl\n")
160+
for (timestep_cd, cd), (_, cl) in zip(cd_values, cl_values):
161+
f.write(f"{timestep_cd},{cd},{cl}\n")
162+
timesteps = [t for t, _ in cd_values]
163+
cds = [cd for _, cd in cd_values]
164+
cls = [cl for _, cl in cl_values]
165+
plt.figure(figsize=(10, 6))
166+
plt.plot(timesteps, cds, "r-", label="Cd")
167+
plt.plot(timesteps, cls, "b-", label="Cl")
168+
plt.grid(True, linestyle="--", alpha=0.7)
169+
plt.xlabel("Timestep")
170+
plt.ylabel("Coefficient")
171+
plt.title("Airfoil Force Coefficients")
172+
plt.legend()
173+
plt.tight_layout()
174+
plt.savefig("airfoil_force_coefficients.png", dpi=150)
175+
plt.close()
176+
177+
178+
@wp.kernel
179+
def update_airfoil_pose(
180+
step: int,
181+
total_steps: int,
182+
start_angle: float,
183+
total_rotation: float,
184+
origin: wp.vec3,
185+
base_vertices: wp.array(dtype=wp.vec3),
186+
vertices: wp.array(dtype=wp.vec3),
187+
velocities: wp.array(dtype=wp.vec3),
188+
):
189+
idx = wp.tid()
190+
total_span = wp.float32(total_steps - 1)
191+
progress = wp.float32(0.0)
192+
if total_span > 0.0:
193+
progress = wp.float32(step) / total_span
194+
if progress > 1.0:
195+
progress = wp.float32(1.0)
196+
start_angle_f = wp.float32(start_angle)
197+
total_rotation_f = wp.float32(total_rotation)
198+
angle = start_angle_f + total_rotation_f * progress
199+
c = wp.cos(angle)
200+
s = wp.sin(angle)
201+
base = base_vertices[idx] - origin
202+
rotated = wp.vec3(
203+
c * base[0] - s * base[1],
204+
s * base[0] + c * base[1],
205+
base[2],
206+
)
207+
vertices[idx] = rotated + origin
208+
angular_rate = wp.float32(0.0)
209+
if total_span > 0.0:
210+
angular_rate = total_rotation_f / total_span
211+
velocities[idx] = wp.vec3(
212+
-angular_rate * rotated[1],
213+
angular_rate * rotated[0],
214+
0.0,
215+
)
216+
217+
218+
if __name__ == "__main__":
219+
chord_length = 60.0 * 1.3
220+
span_length = 50.0 * 1.3
221+
thickness_ratio = 0.12
222+
upstream = int(2 * chord_length)
223+
downstream = int(4 * chord_length)
224+
ly = int(3.0 * chord_length)
225+
lz = int(2.0 * span_length)
226+
lx = upstream + downstream + int(chord_length)
227+
grid_shape = (lx, ly, lz)
228+
u_max = 0.05
229+
Re = 20000
230+
start_angle_deg = 0.0
231+
total_rotation_deg = -45.0
232+
start_angle_rad = np.deg2rad(start_angle_deg)
233+
total_rotation_rad = np.deg2rad(total_rotation_deg)
234+
num_steps = 30000
235+
post_process_interval = 100
236+
print_interval = 100
237+
ibm_max_iterations = 1
238+
ibm_tolerance = 1e-5
239+
ibm_relaxation = 0.5
240+
compute_backend = ComputeBackend.WARP
241+
precision_policy = PrecisionPolicy.FP32FP32
242+
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
243+
xlb.init(velocity_set=velocity_set, default_backend=compute_backend, default_precision_policy=precision_policy)
244+
grid = grid_factory(grid_shape, compute_backend=compute_backend)
245+
print("Airfoil IBM Simulation Configuration:")
246+
print(f" Grid size: {grid_shape}")
247+
print(f" Chord length: {chord_length}")
248+
print(f" Span length: {span_length}")
249+
print(f" Thickness ratio: {thickness_ratio}")
250+
print(f" Inlet velocity: {u_max}")
251+
print(f" Reynolds number: {Re}")
252+
print(f" Start angle: {start_angle_deg}")
253+
print(f" Total rotation: {total_rotation_deg}")
254+
print(f" Max steps: {num_steps}")
255+
print(f" IBM max iterations: {ibm_max_iterations}")
256+
print(f" IBM tolerance: {ibm_tolerance}")
257+
print(f" IBM relaxation: {ibm_relaxation}")
258+
airfoil_mesh = create_airfoil_mesh(chord_length, thickness_ratio, span_length)
259+
airfoil_center = np.array([float(upstream + 0.6 * chord_length), grid_shape[1] * 0.5, grid_shape[2] * 0.5], dtype=np.float64)
260+
translation = airfoil_center - airfoil_mesh.centroid
261+
airfoil_mesh.apply_translation(translation)
262+
vertices_wp, areas_wp, faces_np = prepare_immersed_boundary(airfoil_mesh, max_lbm_length=max(chord_length, span_length))
263+
vertices_np = vertices_wp.numpy()
264+
base_vertices_wp = wp.array(vertices_np, dtype=wp.vec3)
265+
vertices_wp = wp.array(vertices_np, dtype=wp.vec3)
266+
areas_np = areas_wp.numpy()
267+
leading_edge_x = float(np.min(vertices_np[:, 0]))
268+
rotation_center_y = float(np.mean(vertices_np[:, 1]))
269+
rotation_center_z = float(np.mean(vertices_np[:, 2]))
270+
rotation_origin = np.array(
271+
[
272+
leading_edge_x + 0.1 * chord_length,
273+
rotation_center_y,
274+
rotation_center_z,
275+
],
276+
dtype=np.float64,
277+
)
278+
origin_wp = wp.vec3(float(rotation_origin[0]), float(rotation_origin[1]), float(rotation_origin[2]))
279+
reference_area = chord_length * span_length
280+
bc_list = setup_boundary_conditions(grid, velocity_set, precision_policy, grid_shape, u_max)
281+
stepper = setup_stepper(grid, bc_list, ibm_max_iterations, ibm_tolerance, ibm_relaxation)
282+
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
283+
velocities_wp = wp.zeros(shape=vertices_wp.shape[0], dtype=wp.vec3)
284+
device = vertices_wp.device
285+
wp.launch(
286+
kernel=update_airfoil_pose,
287+
dim=vertices_wp.shape[0],
288+
inputs=[
289+
0,
290+
num_steps,
291+
start_angle_rad,
292+
total_rotation_rad,
293+
origin_wp,
294+
base_vertices_wp,
295+
vertices_wp,
296+
velocities_wp,
297+
],
298+
device=device,
299+
)
300+
cd_values = []
301+
cl_values = []
302+
visc = u_max * chord_length / Re
303+
omega = 1.0 / (3.0 * visc + 0.5)
304+
print(f" Omega: {omega}")
305+
try:
306+
for i in range(num_steps):
307+
f_0, f_1, lag_forces = stepper(
308+
f_0,
309+
f_1,
310+
vertices_wp,
311+
areas_wp,
312+
velocities_wp,
313+
bc_mask,
314+
missing_mask,
315+
omega,
316+
i,
317+
)
318+
f_0, f_1 = f_1, f_0
319+
if print_interval > 0 and i % print_interval == 0:
320+
print(f"Step {i}/{num_steps} completed")
321+
if i % post_process_interval == 0 or i == num_steps - 1:
322+
post_process(
323+
i,
324+
post_process_interval,
325+
f_0,
326+
precision_policy,
327+
grid_shape,
328+
lag_forces,
329+
cd_values,
330+
cl_values,
331+
u_max,
332+
reference_area,
333+
areas_np,
334+
)
335+
next_step = i + 1
336+
if next_step < num_steps:
337+
wp.launch(
338+
kernel=update_airfoil_pose,
339+
dim=vertices_wp.shape[0],
340+
inputs=[
341+
next_step,
342+
num_steps,
343+
start_angle_rad,
344+
total_rotation_rad,
345+
origin_wp,
346+
base_vertices_wp,
347+
vertices_wp,
348+
velocities_wp,
349+
],
350+
device=device,
351+
)
352+
except KeyboardInterrupt:
353+
print("Simulation interrupted by user.")
354+
if cd_values and cl_values:
355+
save_force_coefficients(cd_values, cl_values, "airfoil_force_coefficients.csv")
356+
print("Force coefficient data saved to airfoil_force_coefficients.csv")
357+
raise
358+
if cd_values and cl_values:
359+
save_force_coefficients(cd_values, cl_values, "airfoil_force_coefficients.csv")
360+
print("Force coefficient data saved to airfoil_force_coefficients.csv")
361+
print(f"Final Cd (avg last 10): {np.mean([cd for _, cd in cd_values[-10:]]):.6f}")
362+
print(f"Final Cl (avg last 10): {np.mean([cl for _, cl in cl_values[-10:]]):.6f}")
363+
print("Simulation finished.")

0 commit comments

Comments
 (0)