Skip to content

Commit 43b5c62

Browse files
committed
fixups
1 parent 166a767 commit 43b5c62

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

train.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
559559
ksim.FloatVectorCommand(
560560
ranges=((-0.5, 0.5),),
561561
switch_prob=0.005,
562-
unique_name="target_yaw_rate",
563562
zero_prob=0.2,
564563
),
565564
]
@@ -575,7 +574,6 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
575574
in_robot_frame=True,
576575
scale=0.5,
577576
norm="l1",
578-
unique_identifier="l1_vel",
579577
),
580578
ksim.LinearVelocityTrackingReward(
581579
linvel_obs_name="base_linear_velocity_observation",
@@ -588,14 +586,13 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
588586
),
589587
ksim.AngularVelocityTrackingReward(
590588
index=("z",),
591-
command_name="target_yaw_rate_float_vector_command",
589+
command_name="float_vector_command",
592590
scale=0.5,
593591
norm="l1",
594-
unique_identifier="l1_angvel",
595592
),
596593
ksim.AngularVelocityTrackingReward(
597594
index=("z",),
598-
command_name="target_yaw_rate_float_vector_command",
595+
command_name="float_vector_command",
599596
scale=2.0,
600597
norm="l2",
601598
),
@@ -665,7 +662,7 @@ def run_actor(
665662

666663
# target_velocity_2 = commands["target_velocity_float_vector_command"]
667664
target_velocity_2 = commands["linear_velocity_command"]
668-
target_yaw_rate_1 = commands["target_yaw_rate_float_vector_command"]
665+
target_yaw_rate_1 = commands["float_vector_command"]
669666

670667
obs = [
671668
joint_pos_n, # NUM_JOINTS
@@ -704,7 +701,7 @@ def run_critic(
704701

705702
# target_velocity_2 = commands["target_velocity_float_vector_command"]
706703
target_velocity_2 = commands["linear_velocity_command"]
707-
target_yaw_rate_1 = commands["target_yaw_rate_float_vector_command"]
704+
target_yaw_rate_1 = commands["float_vector_command"]
708705

709706
obs_n = jnp.concatenate(
710707
[

0 commit comments

Comments
 (0)