@@ -498,12 +498,12 @@ def compute_fem_surface_tet_aabb(self, i_step: ti.i32):
498498 aabbs = ti .static (self .fem_surface_tet_aabb .aabbs )
499499 for i_b , i_se in ti .ndrange (self .fem_solver ._B , self .fem_solver .n_surface_elements ):
500500 i_e = self .fem_solver .surface_elements [i_se ]
501- i_v = self .fem_solver .elements_i [i_e ].el2v
501+ i_vs = self .fem_solver .elements_i [i_e ].el2v
502502
503503 aabbs [i_b , i_se ].min .fill (np .inf )
504504 aabbs [i_b , i_se ].max .fill (- np .inf )
505505 for i in ti .static (range (4 )):
506- pos_v = self .fem_solver .elements_v [i_step , i_v [i ], i_b ].pos
506+ pos_v = self .fem_solver .elements_v [i_step , i_vs [i ], i_b ].pos
507507 aabbs [i_b , i_se ].min = ti .min (aabbs [i_b , i_se ].min , pos_v )
508508 aabbs [i_b , i_se ].max = ti .max (aabbs [i_b , i_se ].max , pos_v )
509509
@@ -729,18 +729,16 @@ def compute_rigid_pcg_matrix_vector_product(self):
729729 )
730730
731731 @ti .func
732- def compute_elastic_products (self , i_b , i_e , B , s , i_v0 , i_v1 , i_v2 , i_v3 , src ):
732+ def compute_elastic_products (self , i_b , i_e , S , i_vs , src ):
733733 p9 = ti .Vector .zero (gs .ti_float , 9 )
734- for i in ti .static (range (3 )):
735- p9 [i * 3 : i * 3 + 3 ] = (
736- B [0 , i ] * src [i_b , i_v0 ] + B [1 , i ] * src [i_b , i_v1 ] + B [2 , i ] * src [i_b , i_v2 ] + s [i ] * src [i_b , i_v3 ]
737- )
734+ for i , j in ti .static (ti .ndrange (3 , 4 )):
735+ p9 [i * 3 : i * 3 + 3 ] = p9 [i * 3 : i * 3 + 3 ] + (S [j , i ] * src [i_b , i_vs [j ]])
736+
738737 H9_p9 = ti .Vector .zero (gs .ti_float , 9 )
739- for i in ti .static (range (3 )):
740- H9_p9 [i * 3 : i * 3 + 3 ] = (
741- self .fem_solver .elements_el_hessian [i_b , i , 0 , i_e ] @ p9 [0 :3 ]
742- + self .fem_solver .elements_el_hessian [i_b , i , 1 , i_e ] @ p9 [3 :6 ]
743- + self .fem_solver .elements_el_hessian [i_b , i , 2 , i_e ] @ p9 [6 :9 ]
738+
739+ for i , j in ti .static (ti .ndrange (3 , 3 )):
740+ H9_p9 [i * 3 : i * 3 + 3 ] = H9_p9 [i * 3 : i * 3 + 3 ] + (
741+ self .fem_solver .elements_el_hessian [i_b , i , j , i_e ] @ p9 [j * 3 : j * 3 + 3 ]
744742 )
745743 return p9 , H9_p9
746744
@@ -767,16 +765,21 @@ def compute_fem_matrix_vector_product(self, src, dst, active):
767765 continue
768766 V_dt2 = self .fem_solver .elements_i [i_e ].V * dt2
769767 B = self .fem_solver .elements_i [i_e ].B
770- s = - B [0 , :] - B [1 , :] - B [2 , :] # s is the negative sum of B rows
771- i_v0 , i_v1 , i_v2 , i_v3 = self .fem_solver .elements_i [i_e ].el2v
768+ S = ti .Matrix .zero (gs .ti_float , 4 , 3 )
769+ S [:3 , :] = B
770+ S [3 , :] = - B [0 , :] - B [1 , :] - B [2 , :]
771+ i_vs = self .fem_solver .elements_i [i_e ].el2v
772+
773+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
774+ for i in ti .static (range (4 )):
775+ if self .fem_solver .vertex_constraints .is_constrained [i_vs [i ], i_b ]:
776+ S [i , :] = ti .Vector .zero (gs .ti_float , 3 )
772777
773- _ , new_p9 = self .compute_elastic_products (i_b , i_e , B , s , i_v0 , i_v1 , i_v2 , i_v3 , src )
778+ _ , new_p9 = self .compute_elastic_products (i_b , i_e , S , i_vs , src )
774779 # atomic
775780 scale = V_dt2 * damping_beta_factor
776- dst [i_b , i_v0 ] += (B [0 , 0 ] * new_p9 [0 :3 ] + B [0 , 1 ] * new_p9 [3 :6 ] + B [0 , 2 ] * new_p9 [6 :9 ]) * scale
777- dst [i_b , i_v1 ] += (B [1 , 0 ] * new_p9 [0 :3 ] + B [1 , 1 ] * new_p9 [3 :6 ] + B [1 , 2 ] * new_p9 [6 :9 ]) * scale
778- dst [i_b , i_v2 ] += (B [2 , 0 ] * new_p9 [0 :3 ] + B [2 , 1 ] * new_p9 [3 :6 ] + B [2 , 2 ] * new_p9 [6 :9 ]) * scale
779- dst [i_b , i_v3 ] += (s [0 ] * new_p9 [0 :3 ] + s [1 ] * new_p9 [3 :6 ] + s [2 ] * new_p9 [6 :9 ]) * scale
781+ for i in ti .static (range (4 )):
782+ dst [i_b , i_vs [i ]] += (S [i , 0 ] * new_p9 [0 :3 ] + S [i , 1 ] * new_p9 [3 :6 ] + S [i , 2 ] * new_p9 [6 :9 ]) * scale
780783
781784 def init_pcg_solve (self ):
782785 self .init_pcg_state ()
@@ -1101,10 +1104,17 @@ def compute_fem_energy(self, i_step: ti.i32, energy: ti.template()):
11011104
11021105 V_dt2 = self .fem_solver .elements_i [i_e ].V * dt2
11031106 B = self .fem_solver .elements_i [i_e ].B
1104- s = - B [0 , :] - B [1 , :] - B [2 , :] # s is the negative sum of B rows
1105- i_v0 , i_v1 , i_v2 , i_v3 = self .fem_solver .elements_i [i_e ].el2v
1107+ S = ti .Matrix .zero (gs .ti_float , 4 , 3 )
1108+ S [:3 , :] = B
1109+ S [3 , :] = - B [0 , :] - B [1 , :] - B [2 , :]
1110+ i_vs = self .fem_solver .elements_i [i_e ].el2v
1111+
1112+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
1113+ for i in ti .static (range (4 )):
1114+ if self .fem_solver .vertex_constraints .is_constrained [i_vs [i ], i_b ]:
1115+ S [i , :] = ti .Vector .zero (gs .ti_float , 3 )
11061116
1107- p9 , H9_p9 = self .compute_elastic_products (i_b , i_e , B , s , i_v0 , i_v1 , i_v2 , i_v3 , self .fem_state_v .v_diff )
1117+ p9 , H9_p9 = self .compute_elastic_products (i_b , i_e , S , i_vs , self .fem_state_v .v_diff )
11081118 energy [i_b ] += 0.5 * p9 .dot (H9_p9 ) * damping_beta_factor * V_dt2
11091119
11101120 @ti .func
@@ -1991,15 +2001,23 @@ def add_Jt_x(self, y, i_p, x):
19912001 i_g = self .contact_pairs [i_p ].geom_idx
19922002 for i in ti .static (range (4 )):
19932003 i_v = self .fem_solver .elements_i [i_g ].el2v [i ]
1994- y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] * x
2004+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2005+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2006+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] * x
2007+ else :
2008+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] * x
19952009
19962010 @ti .func
19972011 def add_Jt_A_J_diag3x3 (self , y , i_p , A ):
19982012 i_b = self .contact_pairs [i_p ].batch_idx
19992013 i_g = self .contact_pairs [i_p ].geom_idx
20002014 for i in ti .static (range (4 )):
20012015 i_v = self .fem_solver .elements_i [i_g ].el2v [i ]
2002- y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] ** 2 * A
2016+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2017+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2018+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] ** 2 * A
2019+ else :
2020+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric [i ] ** 2 * A
20032021
20042022 @ti .func
20052023 def compute_delassus (self , i_p ):
@@ -2302,10 +2320,18 @@ def add_Jt_x(self, y, i_p, x):
23022320 x_ = world @ x
23032321 for i in ti .static (range (4 )):
23042322 i_v = self .fem_solver .elements_i [i_g0 ].el2v [i ]
2305- y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] * x_
2323+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2324+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2325+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] * x_
2326+ else :
2327+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] * x_
23062328 for i in ti .static (range (4 )):
23072329 i_v = self .fem_solver .elements_i [i_g1 ].el2v [i ]
2308- y [i_b , i_v ] -= self .contact_pairs [i_p ].barycentric1 [i ] * x_
2330+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2331+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2332+ y [i_b , i_v ] -= self .contact_pairs [i_p ].barycentric1 [i ] * x_
2333+ else :
2334+ y [i_b , i_v ] -= self .contact_pairs [i_p ].barycentric1 [i ] * x_
23092335
23102336 @ti .func
23112337 def add_Jt_A_J_diag3x3 (self , y , i_p , A ):
@@ -2318,10 +2344,18 @@ def add_Jt_A_J_diag3x3(self, y, i_p, A):
23182344 B_ = world @ A @ world .transpose ()
23192345 for i in ti .static (range (4 )):
23202346 i_v = self .fem_solver .elements_i [i_g0 ].el2v [i ]
2321- y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] ** 2 * B_
2347+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2348+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2349+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] ** 2 * B_
2350+ else :
2351+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric0 [i ] ** 2 * B_
23222352 for i in ti .static (range (4 )):
23232353 i_v = self .fem_solver .elements_i [i_g1 ].el2v [i ]
2324- y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric1 [i ] ** 2 * B_
2354+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2355+ if not self .fem_solver .vertex_constraints .is_constrained [i_v , i_b ]:
2356+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric1 [i ] ** 2 * B_
2357+ else :
2358+ y [i_b , i_v ] += self .contact_pairs [i_p ].barycentric1 [i ] ** 2 * B_
23252359
23262360 @ti .func
23272361 def compute_delassus (self , i_p ):
@@ -2408,13 +2442,21 @@ def compute_Jx(self, i_p, x):
24082442 def add_Jt_x (self , y , i_p , x ):
24092443 i_b = self .contact_pairs [i_p ].batch_idx
24102444 i_g = self .contact_pairs [i_p ].geom_idx
2411- y [i_b , i_g ] += x
2445+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2446+ if not self .fem_solver .vertex_constraints .is_constrained [i_g , i_b ]:
2447+ y [i_b , i_g ] += x
2448+ else :
2449+ y [i_b , i_g ] += x
24122450
24132451 @ti .func
24142452 def add_Jt_A_J_diag3x3 (self , y , i_p , A ):
24152453 i_b = self .contact_pairs [i_p ].batch_idx
24162454 i_g = self .contact_pairs [i_p ].geom_idx
2417- y [i_b , i_g ] += A
2455+ if ti .static (self .fem_solver ._enable_vertex_constraints ):
2456+ if not self .fem_solver .vertex_constraints .is_constrained [i_g , i_b ]:
2457+ y [i_b , i_g ] += A
2458+ else :
2459+ y [i_b , i_g ] += A
24182460
24192461 @ti .func
24202462 def compute_delassus (self , i_p ):
0 commit comments