|
176 | 176 | " params[\"site_xpos\"] = True\n",
|
177 | 177 | " out_batched = mjx.Data(**params)\n",
|
178 | 178 | "\n",
|
179 |
| - " qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = (\n", |
180 |
| - " jax_mjwarp_step(d.ctrl, d.qpos, d.qvel, d.qacc_warmstart)\n", |
| 179 | + " qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = jax_mjwarp_step(\n", |
| 180 | + " d.ctrl, d.qpos, d.qvel, d.qacc_warmstart\n", |
181 | 181 | " )\n",
|
182 | 182 | " d = d.replace(\n",
|
183 | 183 | " qpos=qpos,\n",
|
|
355 | 355 | " obs = self._get_obs(data, state.info, contact)\n",
|
356 | 356 | " done = self._get_termination(data)\n",
|
357 | 357 | "\n",
|
358 |
| - " rewards = self._get_reward(\n", |
359 |
| - " data, action, state.info, state.metrics, done, first_contact, contact\n", |
360 |
| - " )\n", |
| 358 | + " rewards = self._get_reward(data, action, state.info, state.metrics, done, first_contact, contact)\n", |
361 | 359 | " rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}\n",
|
362 | 360 | " reward = sum(rewards.values()) * self.dt\n",
|
363 | 361 | "\n",
|
|
390 | 388 | " fall_termination = data.xpos[self._head_body_id, 2] < 1.0\n",
|
391 | 389 | " return fall_termination\n",
|
392 | 390 | "\n",
|
393 |
| - " def _get_obs(\n", |
394 |
| - " self, data: mjx.Data, info: dict[str, Any], contact: jax.Array\n", |
395 |
| - " ) -> mjx_env.Observation:\n", |
| 391 | + " def _get_obs(self, data: mjx.Data, info: dict[str, Any], contact: jax.Array) -> mjx_env.Observation:\n", |
396 | 392 | " cos = jp.cos(info[\"phase\"])\n",
|
397 | 393 | " sin = jp.sin(info[\"phase\"])\n",
|
398 | 394 | " phase = jp.concatenate([cos, sin])\n",
|
|
424 | 420 | " del metrics # Unused.\n",
|
425 | 421 | " return {\n",
|
426 | 422 | " # Tracking rewards.\n",
|
427 |
| - " \"tracking_lin_vel\": self._reward_tracking_lin_vel(\n", |
428 |
| - " info[\"command\"], self._get_global_linvel(data, self._torso_id)\n", |
429 |
| - " ),\n", |
430 |
| - " \"tracking_ang_vel\": self._reward_tracking_ang_vel(\n", |
431 |
| - " info[\"command\"], self._get_global_angvel(data, self._torso_id)\n", |
432 |
| - " ),\n", |
| 423 | + " \"tracking_lin_vel\": self._reward_tracking_lin_vel(info[\"command\"], self._get_global_linvel(data, self._torso_id)),\n", |
| 424 | + " \"tracking_ang_vel\": self._reward_tracking_ang_vel(info[\"command\"], self._get_global_angvel(data, self._torso_id)),\n", |
433 | 425 | " # Base-related rewards.\n",
|
434 |
| - " \"ang_vel_xy\": self._cost_ang_vel_xy(\n", |
435 |
| - " self._get_global_angvel(data, self._torso_id)\n", |
436 |
| - " ),\n", |
| 426 | + " \"ang_vel_xy\": self._cost_ang_vel_xy(self._get_global_angvel(data, self._torso_id)),\n", |
437 | 427 | " \"orientation\": self._cost_orientation(self._get_z_frame(data, self._torso_id)),\n",
|
438 | 428 | " # Energy related rewards.\n",
|
439 |
| - " \"action_rate\": self._cost_action_rate(\n", |
440 |
| - " action, info[\"last_act\"], info[\"last_last_act\"]\n", |
441 |
| - " ),\n", |
| 429 | + " \"action_rate\": self._cost_action_rate(action, info[\"last_act\"], info[\"last_last_act\"]),\n", |
442 | 430 | " # Feet related rewards.\n",
|
443 | 431 | " \"feet_slip\": self._cost_feet_slip(data, contact, info),\n",
|
444 |
| - " \"feet_air_time\": self._reward_feet_air_time(\n", |
445 |
| - " info[\"feet_air_time\"], first_contact, info[\"command\"]\n", |
446 |
| - " ),\n", |
| 432 | + " \"feet_air_time\": self._reward_feet_air_time(info[\"feet_air_time\"], first_contact, info[\"command\"]),\n", |
447 | 433 | " \"feet_phase\": self._reward_feet_phase(\n",
|
448 | 434 | " data,\n",
|
449 | 435 | " info[\"phase\"],\n",
|
450 | 436 | " self._config.reward_config.max_foot_height,\n",
|
451 | 437 | " info[\"command\"],\n",
|
452 | 438 | " ),\n",
|
453 | 439 | " # Pose related rewards.\n",
|
454 |
| - " \"joint_deviation_hip\": self._cost_joint_deviation_hip(\n", |
455 |
| - " data.qpos[7:], info[\"command\"]\n", |
456 |
| - " ),\n", |
| 440 | + " \"joint_deviation_hip\": self._cost_joint_deviation_hip(data.qpos[7:], info[\"command\"]),\n", |
457 | 441 | " \"joint_deviation_knee\": self._cost_joint_deviation_knee(data.qpos[7:]),\n",
|
458 | 442 | " \"dof_pos_limits\": self._cost_joint_pos_limits(data.qpos[7:]),\n",
|
459 | 443 | " \"pose\": self._cost_pose(data.qpos[7:]),\n",
|
|
504 | 488 | "\n",
|
505 | 489 | " # Energy related rewards.\n",
|
506 | 490 | "\n",
|
507 |
| - " def _cost_action_rate(\n", |
508 |
| - " self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array\n", |
509 |
| - " ) -> jax.Array:\n", |
| 491 | + " def _cost_action_rate(self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array) -> jax.Array:\n", |
510 | 492 | " del last_last_act # Unused.\n",
|
511 | 493 | " return jp.sum(jp.square(act - last_act))\n",
|
512 | 494 | "\n",
|
513 | 495 | " # Feet related rewards.\n",
|
514 | 496 | "\n",
|
515 |
| - " def _cost_feet_slip(\n", |
516 |
| - " self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]\n", |
517 |
| - " ) -> jax.Array:\n", |
| 497 | + " def _cost_feet_slip(self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]) -> jax.Array:\n", |
518 | 498 | " del info # Unused.\n",
|
519 | 499 | " body_vel = self._get_global_linvel(data, self._torso_id)[:2]\n",
|
520 | 500 | " reward = jp.sum(jp.linalg.norm(body_vel, axis=-1) * contact)\n",
|
|
534 | 514 | " reward = jp.sum(air_time)\n",
|
535 | 515 | " return reward\n",
|
536 | 516 | "\n",
|
537 |
| - " def get_rz(\n", |
538 |
| - " phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08\n", |
539 |
| - " ) -> jax.Array:\n", |
| 517 | + " def get_rz(phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08) -> jax.Array:\n", |
540 | 518 | " def cubic_bezier_interpolation(y_start, y_end, x):\n",
|
541 | 519 | " y_diff = y_end - y_start\n",
|
542 | 520 | " bezier = x**3 + 3 * (x**2 * (1 - x))\n",
|
|
556 | 534 | " ) -> jax.Array:\n",
|
557 | 535 | " # Reward for tracking the desired foot height.\n",
|
558 | 536 | " foot_pos = data.site_xpos[self._feet_site_id]\n",
|
559 |
| - " foot_pos = jp.array(\n", |
560 |
| - " [jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)]\n", |
561 |
| - " )\n", |
| 537 | + " foot_pos = jp.array([jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)])\n", |
562 | 538 | " foot_z = foot_pos[..., -1]\n",
|
563 | 539 | " rz = Joystick.get_rz(phase, swing_height=foot_height)\n",
|
564 | 540 | " error = jp.sum(jp.square(foot_z - rz))\n",
|
|
606 | 582 | " def sample_command(self, rng: jax.Array) -> jax.Array:\n",
|
607 | 583 | " rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)\n",
|
608 | 584 | "\n",
|
609 |
| - " lin_vel_x = jax.random.uniform(\n", |
610 |
| - " rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1]\n", |
611 |
| - " )\n", |
612 |
| - " lin_vel_y = jax.random.uniform(\n", |
613 |
| - " rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1]\n", |
614 |
| - " )\n", |
| 585 | + " lin_vel_x = jax.random.uniform(rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1])\n", |
| 586 | + " lin_vel_y = jax.random.uniform(rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1])\n", |
615 | 587 | " ang_vel_yaw = jax.random.uniform(\n",
|
616 | 588 | " rng3,\n",
|
617 | 589 | " minval=self._config.ang_vel_yaw[0],\n",
|
|
704 | 676 | "network_factory = ppo_networks.make_ppo_networks\n",
|
705 | 677 | "if \"network_factory\" in ppo_params:\n",
|
706 | 678 | " del ppo_training_params[\"network_factory\"]\n",
|
707 |
| - " network_factory = functools.partial(\n", |
708 |
| - " ppo_networks.make_ppo_networks, **ppo_params.network_factory\n", |
709 |
| - " )\n", |
| 679 | + " network_factory = functools.partial(ppo_networks.make_ppo_networks, **ppo_params.network_factory)\n", |
710 | 680 | "\n",
|
711 | 681 | "train_fn = functools.partial(\n",
|
712 | 682 | " ppo.train,\n",
|
|
0 commit comments