Skip to content

Commit 411c1d9

Browse files
authored
[Migration] Initial template for migration (#1361)
1 parent 0037f76 commit 411c1d9

File tree

2 files changed

+82
-29
lines changed

2 files changed

+82
-29
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
from typing import Callable
3+
4+
import taichi as ti
5+
6+
import genesis as gs
7+
8+
# we will use struct for DofsState and DofsInfo after Hugh adds array_struct feature to taichi
9+
DofsState = ti.template()
10+
DofsInfo = ti.template()
11+
12+
13+
@ti.data_oriented
14+
class RigidGlobalInfo:
15+
def __init__(self, n_dofs: int, n_entities: int, n_geoms: int, f_batch: Callable):
16+
self.n_awake_dofs = ti.field(dtype=gs.ti_int, shape=f_batch())
17+
self.awake_dofs = ti.field(dtype=gs.ti_int, shape=f_batch(n_dofs))

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Literal, TYPE_CHECKING
2+
from dataclasses import dataclass
23

34
import numpy as np
45
import torch
@@ -13,6 +14,7 @@
1314
from genesis.engine.entities import AvatarEntity, DroneEntity, RigidEntity
1415
from genesis.engine.states.solvers import RigidSolverState
1516
from genesis.styles import colors, formats
17+
import genesis.engine.solvers.rigid.array_class as array_class
1618

1719
from ..base_solver import Solver
1820
from .collider_decomp import Collider
@@ -65,6 +67,11 @@ class RigidSolver(Solver):
6567
# ------------------------------------------------------------------------------------
6668
# --------------------------------- Initialization -----------------------------------
6769
# ------------------------------------------------------------------------------------
70+
@dataclass(frozen=True)
71+
class StaticRigidSimConfig:
72+
# store static arguments here
73+
para_level: int = 0
74+
use_hibernation: bool = False
6875

6976
def __init__(self, scene: "Scene", sim: "Simulator", options: RigidOptions) -> None:
7077
super().__init__(scene, sim, options)
@@ -213,6 +220,19 @@ def build(self):
213220

214221
self.n_equalities_candidate = max(1, self.n_equalities + self._options.max_dynamic_constraints)
215222

223+
self._static_rigid_sim_config = self.StaticRigidSimConfig(
224+
para_level=self.sim._para_level,
225+
use_hibernation=getattr(self, "_use_hibernation", False),
226+
)
227+
# when the migration is finished, we will remove the about two lines
228+
# and initizlize the awake_dofs and n_awake_dofs in _rigid_global_info directly
229+
self._rigid_global_info = array_class.RigidGlobalInfo(
230+
n_dofs=self.n_dofs_,
231+
n_entities=self.n_entities_,
232+
n_geoms=self.n_geoms_,
233+
f_batch=self._batch_shape,
234+
)
235+
216236
if self.is_active():
217237
self._init_mass_mat()
218238
self._init_dof_fields()
@@ -408,8 +428,11 @@ def _init_mass_mat(self):
408428

409429
def _init_dof_fields(self):
410430
if self._use_hibernation:
411-
self.n_awake_dofs = ti.field(dtype=gs.ti_int, shape=self._B)
412-
self.awake_dofs = ti.field(dtype=gs.ti_int, shape=self._batch_shape(self.n_dofs_))
431+
# we are going to move n_awake_dofs and awake_dofs to _rigid_global_info completely after migration.
432+
# But right now, other kernels are still using self.n_awake_dofs and self.awake_dofs
433+
# so we need to keep them in self for now.
434+
self.n_awake_dofs = self._rigid_global_info.n_awake_dofs
435+
self.awake_dofs = self._rigid_global_info.awake_dofs
413436

414437
struct_dof_info = ti.types.struct(
415438
stiffness=gs.ti_float,
@@ -472,14 +495,19 @@ def _init_dof_fields(self):
472495
dofs_kp=np.concatenate([joint.dofs_kp for joint in joints], dtype=gs.np_float),
473496
dofs_kv=np.concatenate([joint.dofs_kv for joint in joints], dtype=gs.np_float),
474497
dofs_force_range=np.concatenate([joint.dofs_force_range for joint in joints], dtype=gs.np_float),
498+
dofs_info=self.dofs_info,
499+
dofs_state=self.dofs_state,
500+
rigid_global_info=self._rigid_global_info,
501+
static_rigid_sim_config=self._static_rigid_sim_config,
475502
)
476503

477504
# just in case
478505
self.dofs_state.force.fill(0)
479506

480507
@ti.kernel
481508
def _kernel_init_dof_fields(
482-
self,
509+
self_unused,
510+
# input np array
483511
dofs_motion_ang: ti.types.ndarray(),
484512
dofs_motion_vel: ti.types.ndarray(),
485513
dofs_limit: ti.types.ndarray(),
@@ -490,38 +518,46 @@ def _kernel_init_dof_fields(
490518
dofs_kp: ti.types.ndarray(),
491519
dofs_kv: ti.types.ndarray(),
492520
dofs_force_range: ti.types.ndarray(),
521+
# taichi variables
522+
dofs_info: array_class.DofsInfo,
523+
dofs_state: array_class.DofsState,
524+
# we will use RigidGlobalInfo as typing after Hugh adds array_struct feature to taichi
525+
rigid_global_info: ti.template(),
526+
static_rigid_sim_config: ti.template(),
493527
):
494-
for I in ti.grouped(self.dofs_info):
528+
n_dofs = dofs_state.shape[0]
529+
_B = dofs_state.shape[1]
530+
for I in ti.grouped(dofs_info):
495531
i = I[0] # batching (if any) will be the second dim
496532

497533
for j in ti.static(range(3)):
498-
self.dofs_info[I].motion_ang[j] = dofs_motion_ang[i, j]
499-
self.dofs_info[I].motion_vel[j] = dofs_motion_vel[i, j]
534+
dofs_info[I].motion_ang[j] = dofs_motion_ang[i, j]
535+
dofs_info[I].motion_vel[j] = dofs_motion_vel[i, j]
500536

501537
for j in ti.static(range(2)):
502-
self.dofs_info[I].limit[j] = dofs_limit[i, j]
503-
self.dofs_info[I].force_range[j] = dofs_force_range[i, j]
504-
505-
self.dofs_info[I].armature = dofs_armature[i]
506-
self.dofs_info[I].invweight = dofs_invweight[i]
507-
self.dofs_info[I].stiffness = dofs_stiffness[i]
508-
self.dofs_info[I].damping = dofs_damping[i]
509-
self.dofs_info[I].kp = dofs_kp[i]
510-
self.dofs_info[I].kv = dofs_kv[i]
511-
512-
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
513-
for i, b in ti.ndrange(self.n_dofs, self._B):
514-
self.dofs_state[i, b].ctrl_mode = gs.CTRL_MODE.FORCE
515-
516-
if ti.static(self._use_hibernation):
517-
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
518-
for i, b in ti.ndrange(self.n_dofs, self._B):
519-
self.dofs_state[i, b].hibernated = False
520-
self.awake_dofs[i, b] = i
521-
522-
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
523-
for b in range(self._B):
524-
self.n_awake_dofs[b] = self.n_dofs
538+
dofs_info[I].limit[j] = dofs_limit[i, j]
539+
dofs_info[I].force_range[j] = dofs_force_range[i, j]
540+
541+
dofs_info[I].armature = dofs_armature[i]
542+
dofs_info[I].invweight = dofs_invweight[i]
543+
dofs_info[I].stiffness = dofs_stiffness[i]
544+
dofs_info[I].damping = dofs_damping[i]
545+
dofs_info[I].kp = dofs_kp[i]
546+
dofs_info[I].kv = dofs_kv[i]
547+
548+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
549+
for i, b in ti.ndrange(n_dofs, _B):
550+
dofs_state[i, b].ctrl_mode = gs.CTRL_MODE.FORCE
551+
552+
if ti.static(static_rigid_sim_config.use_hibernation):
553+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
554+
for i, b in ti.ndrange(n_dofs, _B):
555+
dofs_state[i, b].hibernated = False
556+
rigid_global_info.awake_dofs[i, b] = i
557+
558+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
559+
for b in range(_B):
560+
rigid_global_info.n_awake_dofs[b] = n_dofs
525561

526562
def _init_link_fields(self):
527563
if self._use_hibernation:

0 commit comments

Comments
 (0)