|
32 | 32 | "metadata": {}, |
33 | 33 | "outputs": [], |
34 | 34 | "source": [ |
35 | | - "from pathlib import Path\n", |
36 | | - "import sys\n", |
37 | 35 | "import numpy as np\n", |
38 | 36 | "import pandas as pd\n", |
39 | 37 | "import matplotlib.pyplot as plt\n", |
40 | 38 | "\n", |
41 | | - "\n", |
42 | | - "def find_repo_root(start: Path) -> Path:\n", |
43 | | - " for candidate in [start, *start.parents]:\n", |
44 | | - " if (candidate / 'CMakeLists.txt').exists() and (candidate / 'examples' / 'example_a1_walking').exists():\n", |
45 | | - " return candidate\n", |
46 | | - " raise RuntimeError('Could not locate GTDynamics repository root from current directory.')\n", |
47 | | - "\n", |
48 | | - "\n", |
49 | | - "repo_root = find_repo_root(Path.cwd().resolve())\n", |
50 | | - "gtd_build_python = repo_root / 'build' / 'python'\n", |
51 | | - "gtsam_build_python = repo_root.parent / 'gtsam' / 'build' / 'python'\n", |
52 | | - "\n", |
53 | | - "if gtsam_build_python.exists() and str(gtsam_build_python) not in sys.path:\n", |
54 | | - " sys.path.insert(0, str(gtsam_build_python))\n", |
55 | | - "if str(gtd_build_python) not in sys.path:\n", |
56 | | - " sys.path.insert(0, str(gtd_build_python))\n", |
57 | | - "\n", |
58 | 39 | "import gtsam\n", |
59 | 40 | "import gtdynamics as gtd\n", |
| 41 | + "import roboplot # type: ignore\n", |
60 | 42 | "\n", |
61 | | - "print('repo_root :', repo_root)\n", |
62 | | - "print('gtsam module :', gtsam.__file__)\n", |
63 | | - "print('gtdynamics module:', gtd.__file__)\n", |
64 | | - "\n", |
65 | | - "if gtsam_build_python.exists() and str(gtsam_build_python) not in gtsam.__file__:\n", |
66 | | - " raise RuntimeError(\n", |
67 | | - " 'Loaded a non-local gtsam package. Put '\n", |
68 | | - " f\"{gtsam_build_python} first on PYTHONPATH before running this notebook.\"\n", |
69 | | - " )" |
| 43 | + "CONTACT_IN_COM = gtsam.Point3(0.0, 0.0, -0.11)\n", |
| 44 | + "GROUND_HEIGHT = 1.6" |
70 | 45 | ] |
71 | 46 | }, |
72 | 47 | { |
|
91 | 66 | "source": [ |
92 | 67 | "def build_walk_trajectory(robot: gtd.Robot, repeat: int = 1) -> gtd.Trajectory:\n", |
93 | 68 | " # Match getTrajectory() from C++: RRFL -> stationary -> RLFR -> stationary.\n", |
94 | | - " rlfr = [robot.link('RL_lower'), robot.link('FR_lower')]\n", |
95 | | - " rrfl = [robot.link('RR_lower'), robot.link('FL_lower')]\n", |
| 69 | + " rlfr = [robot.link(\"RL_lower\"), robot.link(\"FR_lower\")]\n", |
| 70 | + " rrfl = [robot.link(\"RR_lower\"), robot.link(\"FL_lower\")]\n", |
96 | 71 | " all_feet = rlfr + rrfl\n", |
97 | 72 | "\n", |
98 | | - " contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)\n", |
99 | | - " stationary = gtd.FootContactConstraintSpec(all_feet, contact_in_com)\n", |
100 | | - " rlfr_state = gtd.FootContactConstraintSpec(rlfr, contact_in_com)\n", |
101 | | - " rrfl_state = gtd.FootContactConstraintSpec(rrfl, contact_in_com)\n", |
| 73 | + " stationary = gtd.FootContactConstraintSpec(all_feet, CONTACT_IN_COM)\n", |
| 74 | + " rlfr_state = gtd.FootContactConstraintSpec(rlfr, CONTACT_IN_COM)\n", |
| 75 | + " rrfl_state = gtd.FootContactConstraintSpec(rrfl, CONTACT_IN_COM)\n", |
102 | 76 | "\n", |
103 | | - " states = [rrfl_state, stationary, rlfr_state, stationary]\n", |
104 | | - " phase_lengths = [25, 5, 25, 5]\n", |
| 77 | + " N = 5\n", |
| 78 | + " states = [rrfl_state, stationary, rlfr_state, stationary] * N\n", |
| 79 | + " phase_lengths = [25, 5, 25, 5] * N\n", |
105 | 80 | "\n", |
106 | 81 | " walk_cycle = gtd.WalkCycle(states, phase_lengths)\n", |
107 | 82 | " return gtd.Trajectory(walk_cycle, repeat)" |
|
134 | 109 | "outputs": [], |
135 | 110 | "source": [ |
136 | 111 | "robot = gtd.CreateRobotFromFile(gtd.URDF_PATH + \"/a1/a1.urdf\", \"a1\")\n", |
| 112 | + "# robot.print()" |
| 113 | + ] |
| 114 | + }, |
| 115 | + { |
| 116 | + "cell_type": "code", |
| 117 | + "execution_count": null, |
| 118 | + "id": "4d012b64", |
| 119 | + "metadata": {}, |
| 120 | + "outputs": [], |
| 121 | + "source": [ |
137 | 122 | "\n", |
138 | 123 | "gravity = np.array([0.0, 0.0, -9.8])\n", |
139 | 124 | "mu = 1.0\n", |
|
144 | 129 | "trajectory = build_walk_trajectory(robot, repeat=1)\n", |
145 | 130 | "collocation = gtd.CollocationScheme.Euler\n", |
146 | 131 | "\n", |
147 | | - "graph = trajectory.multiPhaseFactorGraph(robot, graph_builder, collocation, mu)\n", |
148 | | - "print('Graph has', graph.size(), 'factors and', graph.keys().size(), 'variables.')" |
149 | | - ] |
150 | | - }, |
151 | | - { |
152 | | - "cell_type": "code", |
153 | | - "execution_count": null, |
154 | | - "id": "60f38e64", |
155 | | - "metadata": {}, |
156 | | - "outputs": [], |
157 | | - "source": [ |
158 | | - "graph.print(\"Graph:\\n\", gtd.GTDKeyFormatter)" |
| 132 | + "graph = trajectory.multiPhaseFactorGraph(robot, graph_builder, collocation, mu, GROUND_HEIGHT)\n", |
| 133 | + "print('Graph has', graph.size(), 'factors and', graph.keys().size(), 'variables.')\n", |
| 134 | + "# graph.print(\"Graph:\\n\", gtd.GTDKeyFormatter)" |
159 | 135 | ] |
160 | 136 | }, |
161 | 137 | { |
|
165 | 141 | "metadata": {}, |
166 | 142 | "outputs": [], |
167 | 143 | "source": [ |
168 | | - "\n", |
169 | | - "ground_height = 1.0\n", |
170 | 144 | "step = gtsam.Point3(0.25, 0.0, 0.0)\n", |
171 | 145 | "objectives = trajectory.contactPointObjectives(\n", |
172 | | - " robot,\n", |
173 | | - " gtsam.noiseModel.Isotropic.Sigma(3, 1e-6),\n", |
174 | | - " step,\n", |
175 | | - " ground_height,\n", |
| 146 | + " robot, gtsam.noiseModel.Isotropic.Sigma(3, 1e-6), step, GROUND_HEIGHT\n", |
176 | 147 | ")\n", |
177 | 148 | "\n", |
178 | 149 | "K = trajectory.getEndTimeStep(trajectory.numPhases() - 1)\n", |
179 | 150 | "\n", |
180 | | - "for link in robot.links():\n", |
181 | | - " i = link.id()\n", |
182 | | - " if i == 0:\n", |
183 | | - " objectives.push_back(\n", |
184 | | - " gtd.LinkObjectives(i, 0)\n", |
185 | | - " .pose(link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3))\n", |
186 | | - " .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)),\n", |
187 | | - " )\n", |
188 | | - " if i in (3, 6, 9, 12):\n", |
189 | | - " objectives.push_back(\n", |
190 | | - " gtd.LinkObjectives(i, 0).pose(\n", |
191 | | - " link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)\n", |
192 | | - " ),\n", |
193 | | - " )\n", |
| 151 | + "# for link in robot.links():\n", |
| 152 | + "# i = link.id()\n", |
| 153 | + "# if i == 0:\n", |
| 154 | + "# objectives.push_back(\n", |
| 155 | + "# gtd.LinkObjectives(i, 0)\n", |
| 156 | + "# .pose(link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3))\n", |
| 157 | + "# .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)),\n", |
| 158 | + "# )\n", |
| 159 | + "# if i in (3, 6, 9, 12):\n", |
| 160 | + "# objectives.push_back(\n", |
| 161 | + "# gtd.LinkObjectives(i, 0).pose(\n", |
| 162 | + "# link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)\n", |
| 163 | + "# ),\n", |
| 164 | + "# )\n", |
194 | 165 | "\n", |
195 | 166 | "rest_model = gtsam.noiseModel.Isotropic.Sigma(1, 1e-3)\n", |
196 | 167 | "objectives.push_back(gtd.JointsAtRestObjectives(robot, rest_model, rest_model, 0))\n", |
197 | 168 | "objectives.push_back(gtd.JointsAtRestObjectives(robot, rest_model, rest_model, K))\n", |
198 | 169 | "\n", |
199 | | - "trunk = robot.link(\"trunk\")\n", |
200 | | - "for k in range(K + 1):\n", |
201 | | - " objectives.push_back(\n", |
202 | | - " gtd.LinkObjectives(trunk.id(), k)\n", |
203 | | - " .pose(\n", |
204 | | - " gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(0.0, 0.0, 0.4)),\n", |
205 | | - " gtsam.noiseModel.Isotropic.Sigma(6, 1e-2),\n", |
206 | | - " )\n", |
207 | | - " .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 5e-2)),\n", |
208 | | - " )\n", |
| 170 | + "# trunk = robot.link(\"trunk\")\n", |
| 171 | + "# for k in range(K + 1):\n", |
| 172 | + "# objectives.push_back(\n", |
| 173 | + "# gtd.LinkObjectives(trunk.id(), k)\n", |
| 174 | + "# .pose(\n", |
| 175 | + "# gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(0.0, 0.0, GROUND_HEIGHT + 0.4)),\n", |
| 176 | + "# gtsam.noiseModel.Isotropic.Sigma(6, 1e-3),\n", |
| 177 | + "# )\n", |
| 178 | + "# .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 5e-2)),\n", |
| 179 | + "# )\n", |
209 | 180 | "\n", |
210 | 181 | "desired_dt = 1.0 / 20.0\n", |
211 | 182 | "trajectory.addIntegrationTimeFactors(objectives, desired_dt, 1e-30)\n", |
|
260 | 231 | "init_values = trajectory.multiPhaseInitialValues(robot, initializer, gaussian_noise, desired_dt)\n", |
261 | 232 | "\n", |
262 | 233 | "print('Initial values:', init_values.size())\n", |
| 234 | + "print('Final objective error:', graph.error(init_values))\n", |
263 | 235 | "\n", |
264 | 236 | "params = gtsam.LevenbergMarquardtParams()\n", |
265 | 237 | "params.setlambdaInitial(1e10)\n", |
|
309 | 281 | " j = joint.id()\n", |
310 | 282 | " name = joint.name()\n", |
311 | 283 | " row[name] = read_or_nan(result, gtd.JointAngleKey(j, k))\n", |
312 | | - " row[f'{name}.1'] = read_or_nan(result, gtd.JointVelKey(j, k))\n", |
313 | | - " row[f'{name}.2'] = read_or_nan(result, gtd.JointAccelKey(j, k))\n", |
314 | | - " row[f'{name}.3'] = read_or_nan(result, gtd.TorqueKey(j, k))\n", |
| 284 | + " # row[f'{name}.1'] = read_or_nan(result, gtd.JointVelKey(j, k))\n", |
| 285 | + " # row[f'{name}.2'] = read_or_nan(result, gtd.JointAccelKey(j, k))\n", |
| 286 | + " # row[f'{name}.3'] = read_or_nan(result, gtd.TorqueKey(j, k))\n", |
315 | 287 | "\n", |
316 | 288 | " rows.append(row)\n", |
317 | 289 | "\n", |
|
339 | 311 | "metadata": {}, |
340 | 312 | "outputs": [], |
341 | 313 | "source": [ |
342 | | - "out_dir = repo_root / 'build' / 'examples' / 'example_a1_walking'\n", |
| 314 | + "from pathlib import Path\n", |
| 315 | + "out_dir = Path(gtd.DATA_PATH) / 'example_a1_walking'\n", |
343 | 316 | "out_dir.mkdir(parents=True, exist_ok=True)\n", |
344 | 317 | "\n", |
345 | | - "csv_table = out_dir / 'a1_traj_chain_dynamics_graph_python.csv'\n", |
| 318 | + "csv_table = out_dir / 'traj_df.csv'\n", |
346 | 319 | "traj_df.to_csv(csv_table, index=False)\n", |
347 | 320 | "print('Wrote table CSV:', csv_table)\n", |
348 | 321 | "\n", |
|
362 | 335 | " if not export_values.exists(pk):\n", |
363 | 336 | " export_values.insert(pk, desired_dt)\n", |
364 | 337 | "\n", |
365 | | - "native_name = 'a1_traj_CDG_python.csv'\n", |
366 | | - "trajectory.writeToFile(robot, native_name, export_values)\n", |
367 | | - "print('Wrote native GTDynamics CSV in current working directory:', native_name)\n" |
| 338 | + "native_name = out_dir / 'a1_traj_CDG.csv'\n", |
| 339 | + "trajectory.writeToFile(robot, str(native_name), export_values)\n", |
| 340 | + "print('Wrote native GTDynamics CSV to:', native_name)\n" |
368 | 341 | ] |
369 | 342 | }, |
370 | 343 | { |
371 | 344 | "cell_type": "markdown", |
372 | 345 | "id": "7ff53c69", |
373 | 346 | "metadata": {}, |
374 | 347 | "source": [ |
375 | | - "## 7) Plot joint trajectories and contact-point heights\n", |
| 348 | + "## 7) Plot base pose, joint trajectories, and contact-point heights\n", |
| 349 | + "\n", |
| 350 | + "We use `roboplot` to visualize the trunk/base pose trajectory (position + orientation frames) over time.\n", |
376 | 351 | "\n", |
377 | 352 | "The second figure shows foot contact-point heights (world Z) for all four feet, which helps check alternating stance/swing behavior." |
378 | 353 | ] |
379 | 354 | }, |
| 355 | + { |
| 356 | + "cell_type": "code", |
| 357 | + "execution_count": null, |
| 358 | + "id": "6e0f8d47", |
| 359 | + "metadata": {}, |
| 360 | + "outputs": [], |
| 361 | + "source": [ |
| 362 | + "trunk_id = robot.link('trunk').id()\n", |
| 363 | + "base_poses = gtsam.Values()\n", |
| 364 | + "for k in range(K + 1):\n", |
| 365 | + " base_poses.insert(k, gtd.Pose(result, trunk_id, k))\n", |
| 366 | + "roboplot.plot_trajectory(base_poses, fignum=1, scale=0.08, show=True);" |
| 367 | + ] |
| 368 | + }, |
| 369 | + { |
| 370 | + "cell_type": "code", |
| 371 | + "execution_count": null, |
| 372 | + "id": "a9429bc0", |
| 373 | + "metadata": {}, |
| 374 | + "outputs": [], |
| 375 | + "source": [ |
| 376 | + "fig, ax = plt.subplots(nrows=3,figsize=(12, 4))\n", |
| 377 | + "times = traj_df['t'].to_numpy()\n", |
| 378 | + "ax[0].plot(times, traj_df['trunk_x'], label='Base', linewidth=1.5)\n", |
| 379 | + "ax[1].plot(times, traj_df['trunk_y'], label='Base', linewidth=1.5)\n", |
| 380 | + "ax[2].plot(times, traj_df['trunk_z'], label='Base', linewidth=1.5)\n", |
| 381 | + "plt.tight_layout()\n", |
| 382 | + "plt.show()" |
| 383 | + ] |
| 384 | + }, |
380 | 385 | { |
381 | 386 | "cell_type": "code", |
382 | 387 | "execution_count": null, |
|
396 | 401 | " for joint in joint_order:\n", |
397 | 402 | " col = f'{leg}_{joint}_joint{suffix}'\n", |
398 | 403 | " if col in traj_df.columns:\n", |
399 | | - " ax.plot(traj_df['t'], traj_df[col], label=joint, color=colors[joint], linewidth=1.4)\n", |
| 404 | + " ax.plot(times, traj_df[col], label=joint, color=colors[joint], linewidth=1.4)\n", |
400 | 405 | " ax.grid(alpha=0.3)\n", |
401 | 406 | " ax.set_ylabel(f'{leg} {y_label}')\n", |
402 | 407 | " ax.legend(loc='upper right', ncol=3)\n", |
|
407 | 412 | "\n", |
408 | 413 | "\n", |
409 | 414 | "plot_joint_group('', 'Joint Angles (ChainDynamicsGraph walk cycle)', 'q')\n", |
410 | | - "plot_joint_group('.1', 'Joint Velocities (ChainDynamicsGraph walk cycle)', 'qdot')\n", |
| 415 | + "# plot_joint_group('.3', 'Joint Torques (ChainDynamicsGraph walk cycle)', 'tau')\n" |
| 416 | + ] |
| 417 | + }, |
| 418 | + { |
| 419 | + "cell_type": "code", |
| 420 | + "execution_count": null, |
| 421 | + "id": "d2638128", |
| 422 | + "metadata": {}, |
| 423 | + "outputs": [], |
| 424 | + "source": [ |
| 425 | + "foot_links = ['RL_lower', 'RR_lower']\n", |
411 | 426 | "\n", |
412 | | - "contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)\n", |
413 | | - "foot_links = ['FL_lower', 'FR_lower', 'RL_lower', 'RR_lower']\n", |
| 427 | + "foot_z = {}\n", |
| 428 | + "for name in foot_links:\n", |
| 429 | + " cp = gtd.PointOnLink(robot.link(name).shared(), CONTACT_IN_COM)\n", |
| 430 | + " foot_z[name] = np.array([cp.predict(result, k)[2] for k in range(K + 1)])\n", |
| 431 | + "\n", |
| 432 | + "fig, ax = plt.subplots(figsize=(12, 4))\n", |
| 433 | + "ax.plot(times, np.array(traj_df['trunk_z']), label='Base', linewidth=1.5)\n", |
| 434 | + "for name in foot_links:\n", |
| 435 | + " ax.plot(times, foot_z[name], label=name, linewidth=1.5)\n", |
| 436 | + "ax.grid(alpha=0.3)\n", |
| 437 | + "ax.set_xlabel('time [s]')\n", |
| 438 | + "ax.set_ylabel('contact point z [m]')\n", |
| 439 | + "ax.set_title('REAR feet contact-point heights in world frame')\n", |
| 440 | + "ax.legend(ncol=4)\n", |
| 441 | + "plt.tight_layout()\n", |
| 442 | + "plt.show()" |
| 443 | + ] |
| 444 | + }, |
| 445 | + { |
| 446 | + "cell_type": "code", |
| 447 | + "execution_count": null, |
| 448 | + "id": "6896ee55", |
| 449 | + "metadata": {}, |
| 450 | + "outputs": [], |
| 451 | + "source": [ |
| 452 | + "foot_links = ['FL_lower', 'FR_lower']\n", |
414 | 453 | "\n", |
415 | | - "foot_z = {'t': traj_df['t'].to_numpy()}\n", |
| 454 | + "foot_z = {}\n", |
416 | 455 | "for name in foot_links:\n", |
417 | | - " cp = gtd.PointOnLink(robot.link(name).shared(), contact_in_com)\n", |
| 456 | + " cp = gtd.PointOnLink(robot.link(name).shared(), CONTACT_IN_COM)\n", |
418 | 457 | " foot_z[name] = np.array([cp.predict(result, k)[2] for k in range(K + 1)])\n", |
419 | 458 | "\n", |
420 | 459 | "fig, ax = plt.subplots(figsize=(12, 4))\n", |
| 460 | + "ax.plot(times, np.array(traj_df['trunk_z']), label='Base', linewidth=1.5)\n", |
421 | 461 | "for name in foot_links:\n", |
422 | | - " ax.plot(foot_z['t'], foot_z[name], label=name, linewidth=1.5)\n", |
| 462 | + " ax.plot(times, foot_z[name], label=name, linewidth=1.5)\n", |
423 | 463 | "ax.grid(alpha=0.3)\n", |
424 | 464 | "ax.set_xlabel('time [s]')\n", |
425 | 465 | "ax.set_ylabel('contact point z [m]')\n", |
426 | | - "ax.set_title('Foot contact-point heights in world frame')\n", |
| 466 | + "ax.set_title('FRONT feet contact-point heights in world frame')\n", |
427 | 467 | "ax.legend(ncol=4)\n", |
428 | 468 | "plt.tight_layout()\n", |
429 | 469 | "plt.show()" |
|
0 commit comments