Skip to content

Commit 7e9713d

Browse files
budzianowskiWT-MM
andauthored
prepare reference motion for amp (#63)
* g1 obs noise scales * n * update randomizers * setup works again * fixed elbows * save debugging setup * fix noise, change back to accel * more noise * clean qpos logic * add orientation logic * add terrain setup * push updated rnn * format * run env works again * training works as well * format * add back gyro * wip, remove old reference logic * mypy lint * update jax * debug velocity rewrad * proper input shapes * add rnn version * add rnn setup * save functional setup * add literal typin --------- Co-authored-by: Wesley Maa <[email protected]>
1 parent ce12fd9 commit 7e9713d

File tree

9 files changed

+1380
-178
lines changed

9 files changed

+1380
-178
lines changed

ksim_kbot/common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
slice_update,
2020
update_data_field,
2121
)
22+
from ksim.utils.priors import (
23+
MotionReferenceData,
24+
get_local_xpos,
25+
)
2226
from mujoco import mjx
2327

2428

@@ -491,3 +495,47 @@ def _apply_random_angular_velocity_push(
491495
def get_initial_event_state(self, rng: PRNGKeyArray) -> Array:
492496
minval, maxval = self.interval_range
493497
return jax.random.uniform(rng, (), minval=minval, maxval=maxval)
498+
499+
500+
@attrs.define(frozen=True, kw_only=True)
501+
class ReferenceQposObservation(ksim.Observation):
502+
"""Observation for the reference joint positions."""
503+
504+
reference_motion_data: MotionReferenceData
505+
speed: float = attrs.field(default=1.0)
506+
507+
def observe(self, state: ksim.ObservationInput, curriculum_level: Array, rng: PRNGKeyArray) -> Array:
508+
physics_state = state.physics_state
509+
effective_time = physics_state.data.time * self.speed
510+
reference_qpos_at_time = self.reference_motion_data.get_qpos_at_time(effective_time)
511+
return reference_qpos_at_time[..., 7:]
512+
513+
514+
@attrs.define(frozen=True, kw_only=True)
515+
class ReferenceLocalXposObservation(ksim.Observation):
516+
"""Observation for the reference local cartesian positions of tracked bodies."""
517+
518+
reference_motion_data: MotionReferenceData
519+
tracked_body_ids: tuple[int, ...]
520+
521+
def observe(self, state: ksim.ObservationInput, curriculum_level: Array, rng: PRNGKeyArray) -> Array:
522+
physics_state = state.physics_state
523+
target_pos_dict = self.reference_motion_data.get_cartesian_pose_at_time(physics_state.data.time)
524+
target_pos_list = [target_pos_dict[body_id] for body_id in self.tracked_body_ids]
525+
return jnp.concatenate(target_pos_list, axis=-1)
526+
527+
528+
@attrs.define(frozen=True, kw_only=True)
529+
class TrackedLocalXposObservation(ksim.Observation):
530+
"""Observation for the current local cartesian positions of tracked bodies."""
531+
532+
tracked_body_ids: tuple[int, ...]
533+
mj_base_id: int
534+
535+
def observe(self, state: ksim.ObservationInput, curriculum_level: Array, rng: PRNGKeyArray) -> Array:
536+
physics_state = state.physics_state
537+
tracked_positions_list: list[Array] = []
538+
for body_id in self.tracked_body_ids:
539+
body_pos = get_local_xpos(physics_state.data.xpos, body_id, self.mj_base_id)
540+
tracked_positions_list.append(jnp.array(body_pos))
541+
return jnp.concatenate(tracked_positions_list, axis=-1)

ksim_kbot/rewards.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
If some logic will become more general, we can move it to ksim or xax.
44
"""
55

6-
from typing import Self
6+
from typing import Literal, Self
77

88
import attrs
99
import jax.numpy as jnp
@@ -434,3 +434,40 @@ def gait_phase(
434434
stance = xax.cubic_bezier_interpolation(jnp.array(0), swing_height, 2 * x)
435435
swing = xax.cubic_bezier_interpolation(swing_height, jnp.array(0), 2 * x - 1)
436436
return jnp.where(x <= 0.5, stance, swing)
437+
438+
439+
@attrs.define(frozen=True)
440+
class TargetLinearVelocityReward(ksim.Reward):
441+
"""Reward for forward motion."""
442+
443+
index: Literal["x", "y", "z"] = attrs.field(default="x")
444+
target_vel: float = attrs.field(default=0.0)
445+
norm: xax.NormType = attrs.field(default="l1")
446+
monotonic_fn: Literal["exp", "inv"] = attrs.field(default="inv")
447+
temp: float = attrs.field(default=1.0)
448+
449+
def get_reward(self, trajectory: ksim.Trajectory) -> Array:
450+
vel = trajectory.qvel[..., ksim.cartesian_index_to_dim(self.index)]
451+
error = xax.get_norm(vel - self.target_vel, self.norm)
452+
return ksim.norm_to_reward(error, temp=self.temp, monotonic_fn=self.monotonic_fn)
453+
454+
def get_name(self) -> str:
455+
return f"{self.index}_{super().get_name()}"
456+
457+
458+
@attrs.define(frozen=True, kw_only=True)
459+
class TargetHeightReward(ksim.Reward):
460+
"""Reward for reaching a target height."""
461+
462+
target_height: float = attrs.field(default=1.0)
463+
norm: xax.NormType = attrs.field(default="l1")
464+
temp: float = attrs.field(default=1.0)
465+
monotonic_fn: Literal["exp", "inv"] = attrs.field(default="inv")
466+
467+
def get_reward(self, trajectory: ksim.Trajectory) -> Array:
468+
qpos = trajectory.qpos
469+
error = qpos[..., 2] - self.target_height
470+
reward_value = ksim.norm_to_reward(
471+
xax.get_norm(error, self.norm), temp=self.temp, monotonic_fn=self.monotonic_fn
472+
)
473+
return reward_value

ksim_kbot/walking/walking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def sample_action(
530530
# To run training, use the following command:
531531
# python -m ksim_kbot.walking.walking
532532
# To visualize the environment, use the following command:
533-
# python -m ksim_kbot.walking.walking run_environment=True
533+
# python -m ksim_kbot.walking.walking run_model_viewer=True
534534
# On MacOS or other devices with less memory, you can change the number
535535
# of environments and batch size to reduce memory usage. Here's an example
536536
# from the command line:

ksim_kbot/walking/walking_joystick.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ def __init__(
6565
self,
6666
key: PRNGKeyArray,
6767
*,
68+
num_inputs: int,
6869
min_std: float,
6970
max_std: float,
7071
var_scale: float,
7172
mean_scale: float,
7273
) -> None:
7374
self.mlp = eqx.nn.MLP(
74-
in_size=NUM_INPUTS,
75+
in_size=num_inputs,
7576
out_size=NUM_OUTPUTS * 2,
7677
width_size=256,
7778
depth=5,
@@ -133,9 +134,9 @@ class KbotCritic(eqx.Module):
133134

134135
mlp: eqx.nn.MLP
135136

136-
def __init__(self, key: PRNGKeyArray) -> None:
137+
def __init__(self, key: PRNGKeyArray, *, num_inputs: int) -> None:
137138
self.mlp = eqx.nn.MLP(
138-
in_size=NUM_CRITIC_INPUTS,
139+
in_size=num_inputs,
139140
out_size=1, # Always output a single critic value.
140141
width_size=256,
141142
depth=5,
@@ -193,16 +194,30 @@ def forward(
193194
class KbotModel(eqx.Module):
194195
actor: KbotActor
195196
critic: KbotCritic
197+
num_inputs: int = eqx.static_field()
198+
num_critic_inputs: int = eqx.static_field()
196199

197-
def __init__(self, key: PRNGKeyArray) -> None:
200+
def __init__(
201+
self,
202+
key: PRNGKeyArray,
203+
*,
204+
num_inputs: int,
205+
num_critic_inputs: int,
206+
) -> None:
207+
self.num_inputs = num_inputs
208+
self.num_critic_inputs = num_critic_inputs
198209
self.actor = KbotActor(
199210
key,
211+
num_inputs=num_inputs,
200212
min_std=0.01,
201213
max_std=1.0,
202214
var_scale=1.0,
203215
mean_scale=1.0,
204216
)
205-
self.critic = KbotCritic(key)
217+
self.critic = KbotCritic(
218+
key,
219+
num_inputs=num_critic_inputs,
220+
)
206221

207222

208223
@dataclass
@@ -559,7 +574,11 @@ def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termin
559574
return [common.GVecTermination.create(physics_model, sensor_name="upvector_origin")]
560575

561576
def get_model(self, key: PRNGKeyArray) -> KbotModel:
562-
return KbotModel(key)
577+
return KbotModel(
578+
key,
579+
num_inputs=NUM_INPUTS,
580+
num_critic_inputs=NUM_CRITIC_INPUTS,
581+
)
563582

564583
def get_initial_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
565584
return None, None
@@ -724,11 +743,13 @@ def on_after_checkpoint_save(self, ckpt_path: Path, state: xax.State) -> xax.Sta
724743
if self.config.only_save_most_recent
725744
else ckpt_path.parent / f"tf_model_{state.num_steps}"
726745
)
746+
727747
export(
728748
model_fn,
729749
input_shapes, # type: ignore [arg-type]
730750
tf_path,
731751
)
752+
732753
return state
733754

734755

@@ -737,7 +758,7 @@ def on_after_checkpoint_save(self, ckpt_path: Path, state: xax.State) -> xax.Sta
737758
# To run training, use the following command:
738759
# python -m ksim_kbot.walking.walking_joystick disable_multiprocessing=True
739760
# To visualize the environment, use the following command:
740-
# python -m ksim_kbot.walking.walking_joystick run_environment=True \
761+
# python -m ksim_kbot.walking.walking_joystick run_model_viewer=True \
741762
# run_environment_num_seconds=1 \
742763
# run_environment_save_path=videos/test.mp4
743764
KbotWalkingTask.launch(
@@ -772,6 +793,5 @@ def on_after_checkpoint_save(self, ckpt_path: Path, state: xax.State) -> xax.Sta
772793
gait_freq_upper=1.5,
773794
reward_clip_min=0.0,
774795
reward_clip_max=1000.0,
775-
stand_still_threshold=0.0, # no stand still reward
776796
),
777797
)

ksim_kbot/walking/walking_joystick_rnn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import mujoco
1515
import xax
1616
from jaxtyping import Array, PRNGKeyArray
17+
from mujoco import mjx
1718
from mujoco_scenes.mjcf import load_mjmodel
1819
from xax.nn.export import export
1920

@@ -300,11 +301,11 @@ def get_mujoco_model(self) -> mujoco.MjModel:
300301
mj_model = load_mjmodel(mjcf_path, scene=self.config.terrain_type)
301302

302303
# NOTE: test the difference
303-
# mj_model.opt.timestep = jnp.array(self.config.dt)
304-
# mj_model.opt.iterations = 6
305-
# mj_model.opt.ls_iterations = 6
306-
# mj_model.opt.disableflags = mjx.DisableBit.EULERDAMP
307-
# mj_model.opt.solver = mjx.SolverType.CG
304+
mj_model.opt.timestep = jnp.array(self.config.dt)
305+
mj_model.opt.iterations = 6
306+
mj_model.opt.ls_iterations = 6
307+
mj_model.opt.disableflags = mjx.DisableBit.EULERDAMP
308+
mj_model.opt.solver = mjx.SolverType.CG
308309

309310
return mj_model
310311

@@ -534,7 +535,7 @@ def on_after_checkpoint_save(self, ckpt_path: Path, state: xax.State) -> xax.Sta
534535
# To run training, use the following command:
535536
# python -m ksim_kbot.walking.walking_joystick_rnn
536537
# To visualize the environment, use the following command:
537-
# python -m ksim_kbot.walking.walking_joystick_rnn run_environment=True
538+
# python -m ksim_kbot.walking.walking_joystick_rnn run_model_viewer=True
538539
KbotWalkingJoystickRNNTask.launch(
539540
KbotWalkingJoystickRNNTaskConfig(
540541
num_envs=4096,

0 commit comments

Comments
 (0)