11from typing import Literal , TYPE_CHECKING
2+ from dataclasses import dataclass
23
34import numpy as np
45import torch
1314from genesis .engine .entities import AvatarEntity , DroneEntity , RigidEntity
1415from genesis .engine .states .solvers import RigidSolverState
1516from genesis .styles import colors , formats
17+ import genesis .engine .solvers .rigid .array_class as array_class
1618
1719from ..base_solver import Solver
1820from .collider_decomp import Collider
@@ -65,6 +67,11 @@ class RigidSolver(Solver):
6567 # ------------------------------------------------------------------------------------
6668 # --------------------------------- Initialization -----------------------------------
6769 # ------------------------------------------------------------------------------------
70+ @dataclass (frozen = True )
71+ class StaticRigidSimConfig :
72+ # store static arguments here
73+ para_level : int = 0
74+ use_hibernation : bool = False
6875
6976 def __init__ (self , scene : "Scene" , sim : "Simulator" , options : RigidOptions ) -> None :
7077 super ().__init__ (scene , sim , options )
@@ -213,6 +220,19 @@ def build(self):
213220
214221 self .n_equalities_candidate = max (1 , self .n_equalities + self ._options .max_dynamic_constraints )
215222
223+ self ._static_rigid_sim_config = self .StaticRigidSimConfig (
224+ para_level = self .sim ._para_level ,
225+ use_hibernation = getattr (self , "_use_hibernation" , False ),
226+ )
227+ # when the migration is finished, we will remove the about two lines
228+ # and initizlize the awake_dofs and n_awake_dofs in _rigid_global_info directly
229+ self ._rigid_global_info = array_class .RigidGlobalInfo (
230+ n_dofs = self .n_dofs_ ,
231+ n_entities = self .n_entities_ ,
232+ n_geoms = self .n_geoms_ ,
233+ f_batch = self ._batch_shape ,
234+ )
235+
216236 if self .is_active ():
217237 self ._init_mass_mat ()
218238 self ._init_dof_fields ()
@@ -408,8 +428,11 @@ def _init_mass_mat(self):
408428
409429 def _init_dof_fields (self ):
410430 if self ._use_hibernation :
411- self .n_awake_dofs = ti .field (dtype = gs .ti_int , shape = self ._B )
412- self .awake_dofs = ti .field (dtype = gs .ti_int , shape = self ._batch_shape (self .n_dofs_ ))
431+ # we are going to move n_awake_dofs and awake_dofs to _rigid_global_info completely after migration.
432+ # But right now, other kernels are still using self.n_awake_dofs and self.awake_dofs
433+ # so we need to keep them in self for now.
434+ self .n_awake_dofs = self ._rigid_global_info .n_awake_dofs
435+ self .awake_dofs = self ._rigid_global_info .awake_dofs
413436
414437 struct_dof_info = ti .types .struct (
415438 stiffness = gs .ti_float ,
@@ -472,14 +495,19 @@ def _init_dof_fields(self):
472495 dofs_kp = np .concatenate ([joint .dofs_kp for joint in joints ], dtype = gs .np_float ),
473496 dofs_kv = np .concatenate ([joint .dofs_kv for joint in joints ], dtype = gs .np_float ),
474497 dofs_force_range = np .concatenate ([joint .dofs_force_range for joint in joints ], dtype = gs .np_float ),
498+ dofs_info = self .dofs_info ,
499+ dofs_state = self .dofs_state ,
500+ rigid_global_info = self ._rigid_global_info ,
501+ static_rigid_sim_config = self ._static_rigid_sim_config ,
475502 )
476503
477504 # just in case
478505 self .dofs_state .force .fill (0 )
479506
480507 @ti .kernel
481508 def _kernel_init_dof_fields (
482- self ,
509+ self_unused ,
510+ # input np array
483511 dofs_motion_ang : ti .types .ndarray (),
484512 dofs_motion_vel : ti .types .ndarray (),
485513 dofs_limit : ti .types .ndarray (),
@@ -490,38 +518,46 @@ def _kernel_init_dof_fields(
490518 dofs_kp : ti .types .ndarray (),
491519 dofs_kv : ti .types .ndarray (),
492520 dofs_force_range : ti .types .ndarray (),
521+ # taichi variables
522+ dofs_info : array_class .DofsInfo ,
523+ dofs_state : array_class .DofsState ,
524+ # we will use RigidGlobalInfo as typing after Hugh adds array_struct feature to taichi
525+ rigid_global_info : ti .template (),
526+ static_rigid_sim_config : ti .template (),
493527 ):
494- for I in ti .grouped (self .dofs_info ):
528+ n_dofs = dofs_state .shape [0 ]
529+ _B = dofs_state .shape [1 ]
530+ for I in ti .grouped (dofs_info ):
495531 i = I [0 ] # batching (if any) will be the second dim
496532
497533 for j in ti .static (range (3 )):
498- self . dofs_info [I ].motion_ang [j ] = dofs_motion_ang [i , j ]
499- self . dofs_info [I ].motion_vel [j ] = dofs_motion_vel [i , j ]
534+ dofs_info [I ].motion_ang [j ] = dofs_motion_ang [i , j ]
535+ dofs_info [I ].motion_vel [j ] = dofs_motion_vel [i , j ]
500536
501537 for j in ti .static (range (2 )):
502- self . dofs_info [I ].limit [j ] = dofs_limit [i , j ]
503- self . dofs_info [I ].force_range [j ] = dofs_force_range [i , j ]
504-
505- self . dofs_info [I ].armature = dofs_armature [i ]
506- self . dofs_info [I ].invweight = dofs_invweight [i ]
507- self . dofs_info [I ].stiffness = dofs_stiffness [i ]
508- self . dofs_info [I ].damping = dofs_damping [i ]
509- self . dofs_info [I ].kp = dofs_kp [i ]
510- self . dofs_info [I ].kv = dofs_kv [i ]
511-
512- ti .loop_config (serialize = self . _para_level < gs .PARA_LEVEL .PARTIAL )
513- for i , b in ti .ndrange (self . n_dofs , self . _B ):
514- self . dofs_state [i , b ].ctrl_mode = gs .CTRL_MODE .FORCE
515-
516- if ti .static (self . _use_hibernation ):
517- ti .loop_config (serialize = self . _para_level < gs .PARA_LEVEL .PARTIAL )
518- for i , b in ti .ndrange (self . n_dofs , self . _B ):
519- self . dofs_state [i , b ].hibernated = False
520- self .awake_dofs [i , b ] = i
521-
522- ti .loop_config (serialize = self . _para_level < gs .PARA_LEVEL .PARTIAL )
523- for b in range (self . _B ):
524- self .n_awake_dofs [b ] = self . n_dofs
538+ dofs_info [I ].limit [j ] = dofs_limit [i , j ]
539+ dofs_info [I ].force_range [j ] = dofs_force_range [i , j ]
540+
541+ dofs_info [I ].armature = dofs_armature [i ]
542+ dofs_info [I ].invweight = dofs_invweight [i ]
543+ dofs_info [I ].stiffness = dofs_stiffness [i ]
544+ dofs_info [I ].damping = dofs_damping [i ]
545+ dofs_info [I ].kp = dofs_kp [i ]
546+ dofs_info [I ].kv = dofs_kv [i ]
547+
548+ ti .loop_config (serialize = static_rigid_sim_config . para_level < gs .PARA_LEVEL .PARTIAL )
549+ for i , b in ti .ndrange (n_dofs , _B ):
550+ dofs_state [i , b ].ctrl_mode = gs .CTRL_MODE .FORCE
551+
552+ if ti .static (static_rigid_sim_config . use_hibernation ):
553+ ti .loop_config (serialize = static_rigid_sim_config . para_level < gs .PARA_LEVEL .PARTIAL )
554+ for i , b in ti .ndrange (n_dofs , _B ):
555+ dofs_state [i , b ].hibernated = False
556+ rigid_global_info .awake_dofs [i , b ] = i
557+
558+ ti .loop_config (serialize = static_rigid_sim_config . para_level < gs .PARA_LEVEL .PARTIAL )
559+ for b in range (_B ):
560+ rigid_global_info .n_awake_dofs [b ] = n_dofs
525561
526562 def _init_link_fields (self ):
527563 if self ._use_hibernation :
0 commit comments