@@ -2296,53 +2296,48 @@ def update_verts_for_geom(self, i_g):
22962296 self .fixed_verts_state ,
22972297 )
22982298
2299- @gs .assert_built
23002299 def get_weld_constraints (self , as_tensor : bool = True , to_torch : bool = True ):
2301- n_welds = tuple (self .constraint_solver .ti_n_equalities .to_numpy ())
2302- n_envs = len (n_welds )
2303- n_welds_max = max (n_welds ) if n_welds else 0
2304- out_size = n_welds_max * n_envs
2300+ n_eqs = tuple (self .constraint_solver .constraint_state .ti_n_equalities .to_numpy ())
2301+ n_envs = len (n_eqs )
2302+ n_max = max (n_eqs ) if n_eqs else 0
2303+
2304+ if as_tensor :
2305+ out_size = n_envs * n_max
2306+ splits = None
2307+ else :
2308+ cumsum = np .cumsum (n_eqs , dtype = np .int32 )
2309+ splits = list (cumsum [:- 1 ])
2310+ out_size = int (cumsum [- 1 ]) if n_envs else 0
23052311
23062312 if to_torch :
23072313 buf = torch .full ((out_size , 3 ), - 1 , dtype = gs .tc_int , device = gs .device )
23082314 else :
23092315 buf = np .full ((out_size , 3 ), - 1 , dtype = np .int32 )
23102316
2311- if n_welds_max > 0 :
2312- self ._kernel_collect_welds (buf )
2313-
2314- if to_torch :
2315- buf_view = buf .view (n_envs , n_welds_max , 3 )
2316- else :
2317- buf_view = buf .reshape (n_envs , n_welds_max , 3 )
2318- env_idx = buf_view [..., 0 ]
2319- obj_a = buf_view [..., 1 ]
2320- obj_b = buf_view [..., 2 ]
2317+ if n_max > 0 :
2318+ kernel_collect_welds (
2319+ as_tensor ,
2320+ buf ,
2321+ self .constraint_solver .constraint_state ,
2322+ self .equalities_info ,
2323+ self ._static_rigid_sim_config ,
2324+ )
23212325
23222326 if as_tensor :
2327+ if n_envs > 0 :
2328+ buf = buf .view (n_envs , n_max , 3 ) if to_torch else buf .reshape (n_envs , n_max , 3 )
2329+ env_idx , obj_a , obj_b = buf [..., 0 ], buf [..., 1 ], buf [..., 2 ]
23232330 return {"env" : env_idx , "obj_a" : obj_a , "obj_b" : obj_b }
2324- result_a = []
2325- result_b = []
2326- for e , count in enumerate (n_welds ):
2327- result_a .append (obj_a [e , :count ].copy ())
2328- result_b .append (obj_b [e , :count ].copy ())
2331+
2332+ if n_envs == 0 :
2333+ return {"obj_a" : [], "obj_b" : []}
2334+
2335+ parts = torch .split (buf , n_eqs ) if to_torch else np .split (buf , splits )
2336+ a = [p [:, 1 ] for p in parts ]
2337+ b = [p [:, 2 ] for p in parts ]
23292338 if n_envs == 1 :
2330- return {"obj_a" : result_a [0 ], "obj_b" : result_b [0 ]}
2331- return {"obj_a" : result_a , "obj_b" : result_b }
2332-
2333- @ti .kernel
2334- def _kernel_collect_welds (self , buf : ti .types .ndarray ()):
2335- for env in range (self .n_envs ):
2336- base = env * self .n_equalities_candidate
2337- out = 0
2338- n_eq = self .constraint_solver .ti_n_equalities [env ]
2339- for j in range (n_eq ):
2340- rec = self .equalities_info [j , env ]
2341- if rec .eq_type == gs .EQUALITY_TYPE .WELD and out < self .n_equalities_candidate :
2342- buf [base + out , 0 ] = env
2343- buf [base + out , 1 ] = rec .eq_obj1id
2344- buf [base + out , 2 ] = rec .eq_obj2id
2345- out += 1
2339+ return {"obj_a" : a [0 ], "obj_b" : b [0 ]}
2340+ return {"obj_a" : a , "obj_b" : b }
23462341
23472342 # ------------------------------------------------------------------------------------
23482343 # ----------------------------------- properties -------------------------------------
@@ -6765,3 +6760,38 @@ def kernel_delete_weld_constraint(
67656760 constraint_state .ti_n_equalities [i_b ] - 1 , i_b
67666761 ]
67676762 constraint_state .ti_n_equalities [i_b ] = constraint_state .ti_n_equalities [i_b ] - 1
6763+
6764+
6765+ @ti .kernel
6766+ def kernel_collect_welds (
6767+ is_padded : ti .template (),
6768+ buf : ti .types .ndarray (),
6769+ constraint_state : array_class .ConstraintState ,
6770+ equalities_info : array_class .EqualitiesInfo ,
6771+ static_rigid_sim_config : ti .template (),
6772+ ):
6773+ B = constraint_state .ti_n_equalities .shape [0 ]
6774+ max_eq = 0
6775+ for e in range (B ):
6776+ n = constraint_state .ti_n_equalities [e ]
6777+ if n > max_eq :
6778+ max_eq = n
6779+
6780+ ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
6781+ for e in range (B ):
6782+ base = 0
6783+ if ti .static (is_padded ):
6784+ base = e * max_eq
6785+ else :
6786+ for pe in range (e ):
6787+ base += constraint_state .ti_n_equalities [pe ]
6788+
6789+ out = 0
6790+ n = constraint_state .ti_n_equalities [e ]
6791+ for i in range (n ):
6792+ if equalities_info .eq_type [i , e ] == gs .EQUALITY_TYPE .WELD and out < max_eq :
6793+ idx = base + out
6794+ buf [idx , 0 ] = e
6795+ buf [idx , 1 ] = equalities_info .eq_obj1id [i , e ]
6796+ buf [idx , 2 ] = equalities_info .eq_obj2id [i , e ]
6797+ out += 1
0 commit comments