Skip to content

Commit 05ef698

Browse files
committed
update export for 16 dim command
1 parent 796218f commit 05ef698

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

  • scripts/reinforcement_learning/rsl_rl

scripts/reinforcement_learning/rsl_rl/export.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# scripts/reinforcement_learning/rsl_rl/export.py
22
#
33
# Export a trained rsl_rl policy as Kinfer binary.
4+
#
5+
# MODIFIED: This export script now accepts 16-dimensional command vectors from the deployment
6+
# platform but only uses the first 3 dimensions (lin_vel_x, lin_vel_y, ang_vel_z) that the
7+
# original model was trained on. The remaining 13 dimensions are ignored.
8+
#
49
# Example usage: python scripts/reinforcement_learning/rsl_rl/export.py --task=Isaac-Velocity-Rough-Kbot-v0 --checkpoint ~/Github/IsaacLab/logs/rsl_rl/kbot_rough/[path_to_checkpoint].pt
510
# Or you can omit the checkpoint arg and it will use the latest checkpoint in the logs/rsl_rl/agent_name/ directory
611
import argparse
@@ -237,7 +242,12 @@ def main():
237242
command_tensor = torch.cat([command_manager.get_command(name) for name in command_term_names], dim=-1)
238243
command_tensor = command_tensor.to("cpu").flatten()
239244

240-
NUM_COMMANDS = command_tensor.shape[0]
245+
# Original model expects 3D commands (lin_vel_x, lin_vel_y, ang_vel_z)
246+
# But deployment platform provides 16D commands - we only use first 3
247+
MODEL_NUM_COMMANDS = command_tensor.shape[0] # This is 3 for the original model
248+
PLATFORM_NUM_COMMANDS = 16 # Platform provides 16D commands
249+
250+
print(f"[INFO] Model expects {MODEL_NUM_COMMANDS} command dimensions, but deployment platform provides {PLATFORM_NUM_COMMANDS} dimensions.")
241251

242252
def construct_obs_rnn(
243253
projected_gravity: torch.Tensor,
@@ -247,12 +257,18 @@ def construct_obs_rnn(
247257
gyroscope: torch.Tensor,
248258
carry: torch.Tensor,
249259
) -> torch.Tensor:
260+
# Extract only the first 3 dimensions from the 16D command vector
261+
# [0] x linear velocity [m/s]
262+
# [1] y linear velocity [m/s]
263+
# [2] z angular velocity [rad/s]
264+
model_command = command[:MODEL_NUM_COMMANDS]
265+
250266
offset_joint_angles = joint_angles - _INIT_JOINT_POS
251267
scaled_projected_gravity = projected_gravity / 9.81
252268
obs = torch.cat(
253269
(
254270
scaled_projected_gravity,
255-
command,
271+
model_command,
256272
offset_joint_angles,
257273
joint_angular_velocities,
258274
gyroscope,
@@ -301,7 +317,7 @@ def _init_fn() -> torch.Tensor:
301317
torch.zeros(3),
302318
torch.zeros(NUM_JOINTS),
303319
torch.zeros(NUM_JOINTS),
304-
torch.zeros(NUM_COMMANDS),
320+
torch.zeros(PLATFORM_NUM_COMMANDS), # Use 16D command vector
305321
torch.zeros(3),
306322
torch.zeros(*CARRY_SHAPE),
307323
)
@@ -312,7 +328,7 @@ def _init_fn() -> torch.Tensor:
312328
joint_names = list(env.unwrapped.scene["robot"].data.joint_names)
313329
metadata = PyModelMetadata(
314330
joint_names=joint_names,
315-
num_commands=NUM_COMMANDS,
331+
num_commands=PLATFORM_NUM_COMMANDS, # Use 16D command vector size
316332
carry_size=list(CARRY_SHAPE),
317333
)
318334

0 commit comments

Comments
 (0)