Skip to content

Commit d42afe5

Browse files
authored
[FEATURE] Add Fem fixed constraint for implicit solver (Genesis-Embodied-AI#1562)
1 parent 5e0ce67 commit d42afe5

File tree

5 files changed

+317
-117
lines changed

5 files changed

+317
-117
lines changed

genesis/engine/couplers/sap_coupler.py

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,12 @@ def compute_fem_surface_tet_aabb(self, i_step: ti.i32):
498498
aabbs = ti.static(self.fem_surface_tet_aabb.aabbs)
499499
for i_b, i_se in ti.ndrange(self.fem_solver._B, self.fem_solver.n_surface_elements):
500500
i_e = self.fem_solver.surface_elements[i_se]
501-
i_v = self.fem_solver.elements_i[i_e].el2v
501+
i_vs = self.fem_solver.elements_i[i_e].el2v
502502

503503
aabbs[i_b, i_se].min.fill(np.inf)
504504
aabbs[i_b, i_se].max.fill(-np.inf)
505505
for i in ti.static(range(4)):
506-
pos_v = self.fem_solver.elements_v[i_step, i_v[i], i_b].pos
506+
pos_v = self.fem_solver.elements_v[i_step, i_vs[i], i_b].pos
507507
aabbs[i_b, i_se].min = ti.min(aabbs[i_b, i_se].min, pos_v)
508508
aabbs[i_b, i_se].max = ti.max(aabbs[i_b, i_se].max, pos_v)
509509

@@ -729,18 +729,16 @@ def compute_rigid_pcg_matrix_vector_product(self):
729729
)
730730

731731
@ti.func
732-
def compute_elastic_products(self, i_b, i_e, B, s, i_v0, i_v1, i_v2, i_v3, src):
732+
def compute_elastic_products(self, i_b, i_e, S, i_vs, src):
733733
p9 = ti.Vector.zero(gs.ti_float, 9)
734-
for i in ti.static(range(3)):
735-
p9[i * 3 : i * 3 + 3] = (
736-
B[0, i] * src[i_b, i_v0] + B[1, i] * src[i_b, i_v1] + B[2, i] * src[i_b, i_v2] + s[i] * src[i_b, i_v3]
737-
)
734+
for i, j in ti.static(ti.ndrange(3, 4)):
735+
p9[i * 3 : i * 3 + 3] = p9[i * 3 : i * 3 + 3] + (S[j, i] * src[i_b, i_vs[j]])
736+
738737
H9_p9 = ti.Vector.zero(gs.ti_float, 9)
739-
for i in ti.static(range(3)):
740-
H9_p9[i * 3 : i * 3 + 3] = (
741-
self.fem_solver.elements_el_hessian[i_b, i, 0, i_e] @ p9[0:3]
742-
+ self.fem_solver.elements_el_hessian[i_b, i, 1, i_e] @ p9[3:6]
743-
+ self.fem_solver.elements_el_hessian[i_b, i, 2, i_e] @ p9[6:9]
738+
739+
for i, j in ti.static(ti.ndrange(3, 3)):
740+
H9_p9[i * 3 : i * 3 + 3] = H9_p9[i * 3 : i * 3 + 3] + (
741+
self.fem_solver.elements_el_hessian[i_b, i, j, i_e] @ p9[j * 3 : j * 3 + 3]
744742
)
745743
return p9, H9_p9
746744

@@ -767,16 +765,21 @@ def compute_fem_matrix_vector_product(self, src, dst, active):
767765
continue
768766
V_dt2 = self.fem_solver.elements_i[i_e].V * dt2
769767
B = self.fem_solver.elements_i[i_e].B
770-
s = -B[0, :] - B[1, :] - B[2, :] # s is the negative sum of B rows
771-
i_v0, i_v1, i_v2, i_v3 = self.fem_solver.elements_i[i_e].el2v
768+
S = ti.Matrix.zero(gs.ti_float, 4, 3)
769+
S[:3, :] = B
770+
S[3, :] = -B[0, :] - B[1, :] - B[2, :]
771+
i_vs = self.fem_solver.elements_i[i_e].el2v
772+
773+
if ti.static(self.fem_solver._enable_vertex_constraints):
774+
for i in ti.static(range(4)):
775+
if self.fem_solver.vertex_constraints.is_constrained[i_vs[i], i_b]:
776+
S[i, :] = ti.Vector.zero(gs.ti_float, 3)
772777

773-
_, new_p9 = self.compute_elastic_products(i_b, i_e, B, s, i_v0, i_v1, i_v2, i_v3, src)
778+
_, new_p9 = self.compute_elastic_products(i_b, i_e, S, i_vs, src)
774779
# atomic
775780
scale = V_dt2 * damping_beta_factor
776-
dst[i_b, i_v0] += (B[0, 0] * new_p9[0:3] + B[0, 1] * new_p9[3:6] + B[0, 2] * new_p9[6:9]) * scale
777-
dst[i_b, i_v1] += (B[1, 0] * new_p9[0:3] + B[1, 1] * new_p9[3:6] + B[1, 2] * new_p9[6:9]) * scale
778-
dst[i_b, i_v2] += (B[2, 0] * new_p9[0:3] + B[2, 1] * new_p9[3:6] + B[2, 2] * new_p9[6:9]) * scale
779-
dst[i_b, i_v3] += (s[0] * new_p9[0:3] + s[1] * new_p9[3:6] + s[2] * new_p9[6:9]) * scale
781+
for i in ti.static(range(4)):
782+
dst[i_b, i_vs[i]] += (S[i, 0] * new_p9[0:3] + S[i, 1] * new_p9[3:6] + S[i, 2] * new_p9[6:9]) * scale
780783

781784
def init_pcg_solve(self):
782785
self.init_pcg_state()
@@ -1101,10 +1104,17 @@ def compute_fem_energy(self, i_step: ti.i32, energy: ti.template()):
11011104

11021105
V_dt2 = self.fem_solver.elements_i[i_e].V * dt2
11031106
B = self.fem_solver.elements_i[i_e].B
1104-
s = -B[0, :] - B[1, :] - B[2, :] # s is the negative sum of B rows
1105-
i_v0, i_v1, i_v2, i_v3 = self.fem_solver.elements_i[i_e].el2v
1107+
S = ti.Matrix.zero(gs.ti_float, 4, 3)
1108+
S[:3, :] = B
1109+
S[3, :] = -B[0, :] - B[1, :] - B[2, :]
1110+
i_vs = self.fem_solver.elements_i[i_e].el2v
1111+
1112+
if ti.static(self.fem_solver._enable_vertex_constraints):
1113+
for i in ti.static(range(4)):
1114+
if self.fem_solver.vertex_constraints.is_constrained[i_vs[i], i_b]:
1115+
S[i, :] = ti.Vector.zero(gs.ti_float, 3)
11061116

1107-
p9, H9_p9 = self.compute_elastic_products(i_b, i_e, B, s, i_v0, i_v1, i_v2, i_v3, self.fem_state_v.v_diff)
1117+
p9, H9_p9 = self.compute_elastic_products(i_b, i_e, S, i_vs, self.fem_state_v.v_diff)
11081118
energy[i_b] += 0.5 * p9.dot(H9_p9) * damping_beta_factor * V_dt2
11091119

11101120
@ti.func
@@ -1991,15 +2001,23 @@ def add_Jt_x(self, y, i_p, x):
19912001
i_g = self.contact_pairs[i_p].geom_idx
19922002
for i in ti.static(range(4)):
19932003
i_v = self.fem_solver.elements_i[i_g].el2v[i]
1994-
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] * x
2004+
if ti.static(self.fem_solver._enable_vertex_constraints):
2005+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2006+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] * x
2007+
else:
2008+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] * x
19952009

19962010
@ti.func
19972011
def add_Jt_A_J_diag3x3(self, y, i_p, A):
19982012
i_b = self.contact_pairs[i_p].batch_idx
19992013
i_g = self.contact_pairs[i_p].geom_idx
20002014
for i in ti.static(range(4)):
20012015
i_v = self.fem_solver.elements_i[i_g].el2v[i]
2002-
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] ** 2 * A
2016+
if ti.static(self.fem_solver._enable_vertex_constraints):
2017+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2018+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] ** 2 * A
2019+
else:
2020+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric[i] ** 2 * A
20032021

20042022
@ti.func
20052023
def compute_delassus(self, i_p):
@@ -2302,10 +2320,18 @@ def add_Jt_x(self, y, i_p, x):
23022320
x_ = world @ x
23032321
for i in ti.static(range(4)):
23042322
i_v = self.fem_solver.elements_i[i_g0].el2v[i]
2305-
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] * x_
2323+
if ti.static(self.fem_solver._enable_vertex_constraints):
2324+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2325+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] * x_
2326+
else:
2327+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] * x_
23062328
for i in ti.static(range(4)):
23072329
i_v = self.fem_solver.elements_i[i_g1].el2v[i]
2308-
y[i_b, i_v] -= self.contact_pairs[i_p].barycentric1[i] * x_
2330+
if ti.static(self.fem_solver._enable_vertex_constraints):
2331+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2332+
y[i_b, i_v] -= self.contact_pairs[i_p].barycentric1[i] * x_
2333+
else:
2334+
y[i_b, i_v] -= self.contact_pairs[i_p].barycentric1[i] * x_
23092335

23102336
@ti.func
23112337
def add_Jt_A_J_diag3x3(self, y, i_p, A):
@@ -2318,10 +2344,18 @@ def add_Jt_A_J_diag3x3(self, y, i_p, A):
23182344
B_ = world @ A @ world.transpose()
23192345
for i in ti.static(range(4)):
23202346
i_v = self.fem_solver.elements_i[i_g0].el2v[i]
2321-
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] ** 2 * B_
2347+
if ti.static(self.fem_solver._enable_vertex_constraints):
2348+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2349+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] ** 2 * B_
2350+
else:
2351+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric0[i] ** 2 * B_
23222352
for i in ti.static(range(4)):
23232353
i_v = self.fem_solver.elements_i[i_g1].el2v[i]
2324-
y[i_b, i_v] += self.contact_pairs[i_p].barycentric1[i] ** 2 * B_
2354+
if ti.static(self.fem_solver._enable_vertex_constraints):
2355+
if not self.fem_solver.vertex_constraints.is_constrained[i_v, i_b]:
2356+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric1[i] ** 2 * B_
2357+
else:
2358+
y[i_b, i_v] += self.contact_pairs[i_p].barycentric1[i] ** 2 * B_
23252359

23262360
@ti.func
23272361
def compute_delassus(self, i_p):
@@ -2408,13 +2442,21 @@ def compute_Jx(self, i_p, x):
24082442
def add_Jt_x(self, y, i_p, x):
24092443
i_b = self.contact_pairs[i_p].batch_idx
24102444
i_g = self.contact_pairs[i_p].geom_idx
2411-
y[i_b, i_g] += x
2445+
if ti.static(self.fem_solver._enable_vertex_constraints):
2446+
if not self.fem_solver.vertex_constraints.is_constrained[i_g, i_b]:
2447+
y[i_b, i_g] += x
2448+
else:
2449+
y[i_b, i_g] += x
24122450

24132451
@ti.func
24142452
def add_Jt_A_J_diag3x3(self, y, i_p, A):
24152453
i_b = self.contact_pairs[i_p].batch_idx
24162454
i_g = self.contact_pairs[i_p].geom_idx
2417-
y[i_b, i_g] += A
2455+
if ti.static(self.fem_solver._enable_vertex_constraints):
2456+
if not self.fem_solver.vertex_constraints.is_constrained[i_g, i_b]:
2457+
y[i_b, i_g] += A
2458+
else:
2459+
y[i_b, i_g] += A
24182460

24192461
@ti.func
24202462
def compute_delassus(self, i_p):

genesis/engine/entities/fem_entity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,9 @@ def set_vertex_constraints(
803803
List of environment indices to apply the constraints to. If None, applies to all environments.
804804
"""
805805
if self._solver._use_implicit_solver:
806-
gs.logger.warning("Ignoring vertex constraint; unsupported with FEM implicit solver.")
807-
return
806+
if not self._solver._enable_vertex_constraints:
807+
gs.logger.warning("Ignoring vertex constraint; FEM implicit solver needs to enable vertex constraints.")
808+
return
808809

809810
if not self._solver._constraints_initialized:
810811
self._solver.init_constraints()

0 commit comments

Comments
 (0)