Skip to content

Commit 6e116ad

Browse files
committed
Motion planning now works
1 parent 93a958a commit 6e116ad

File tree

1 file changed

+135
-95
lines changed

1 file changed

+135
-95
lines changed

examples/example_a1_walking/chain_dynamics_motion_planning.ipynb

Lines changed: 135 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -32,41 +32,16 @@
3232
"metadata": {},
3333
"outputs": [],
3434
"source": [
35-
"from pathlib import Path\n",
36-
"import sys\n",
3735
"import numpy as np\n",
3836
"import pandas as pd\n",
3937
"import matplotlib.pyplot as plt\n",
4038
"\n",
41-
"\n",
42-
"def find_repo_root(start: Path) -> Path:\n",
43-
" for candidate in [start, *start.parents]:\n",
44-
" if (candidate / 'CMakeLists.txt').exists() and (candidate / 'examples' / 'example_a1_walking').exists():\n",
45-
" return candidate\n",
46-
" raise RuntimeError('Could not locate GTDynamics repository root from current directory.')\n",
47-
"\n",
48-
"\n",
49-
"repo_root = find_repo_root(Path.cwd().resolve())\n",
50-
"gtd_build_python = repo_root / 'build' / 'python'\n",
51-
"gtsam_build_python = repo_root.parent / 'gtsam' / 'build' / 'python'\n",
52-
"\n",
53-
"if gtsam_build_python.exists() and str(gtsam_build_python) not in sys.path:\n",
54-
" sys.path.insert(0, str(gtsam_build_python))\n",
55-
"if str(gtd_build_python) not in sys.path:\n",
56-
" sys.path.insert(0, str(gtd_build_python))\n",
57-
"\n",
5839
"import gtsam\n",
5940
"import gtdynamics as gtd\n",
41+
"import roboplot # type: ignore\n",
6042
"\n",
61-
"print('repo_root :', repo_root)\n",
62-
"print('gtsam module :', gtsam.__file__)\n",
63-
"print('gtdynamics module:', gtd.__file__)\n",
64-
"\n",
65-
"if gtsam_build_python.exists() and str(gtsam_build_python) not in gtsam.__file__:\n",
66-
" raise RuntimeError(\n",
67-
" 'Loaded a non-local gtsam package. Put '\n",
68-
" f\"{gtsam_build_python} first on PYTHONPATH before running this notebook.\"\n",
69-
" )"
43+
"CONTACT_IN_COM = gtsam.Point3(0.0, 0.0, -0.11)\n",
44+
"GROUND_HEIGHT = 1.6"
7045
]
7146
},
7247
{
@@ -91,17 +66,17 @@
9166
"source": [
9267
"def build_walk_trajectory(robot: gtd.Robot, repeat: int = 1) -> gtd.Trajectory:\n",
9368
" # Match getTrajectory() from C++: RRFL -> stationary -> RLFR -> stationary.\n",
94-
" rlfr = [robot.link('RL_lower'), robot.link('FR_lower')]\n",
95-
" rrfl = [robot.link('RR_lower'), robot.link('FL_lower')]\n",
69+
" rlfr = [robot.link(\"RL_lower\"), robot.link(\"FR_lower\")]\n",
70+
" rrfl = [robot.link(\"RR_lower\"), robot.link(\"FL_lower\")]\n",
9671
" all_feet = rlfr + rrfl\n",
9772
"\n",
98-
" contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)\n",
99-
" stationary = gtd.FootContactConstraintSpec(all_feet, contact_in_com)\n",
100-
" rlfr_state = gtd.FootContactConstraintSpec(rlfr, contact_in_com)\n",
101-
" rrfl_state = gtd.FootContactConstraintSpec(rrfl, contact_in_com)\n",
73+
" stationary = gtd.FootContactConstraintSpec(all_feet, CONTACT_IN_COM)\n",
74+
" rlfr_state = gtd.FootContactConstraintSpec(rlfr, CONTACT_IN_COM)\n",
75+
" rrfl_state = gtd.FootContactConstraintSpec(rrfl, CONTACT_IN_COM)\n",
10276
"\n",
103-
" states = [rrfl_state, stationary, rlfr_state, stationary]\n",
104-
" phase_lengths = [25, 5, 25, 5]\n",
77+
" N = 5\n",
78+
" states = [rrfl_state, stationary, rlfr_state, stationary] * N\n",
79+
" phase_lengths = [25, 5, 25, 5] * N\n",
10580
"\n",
10681
" walk_cycle = gtd.WalkCycle(states, phase_lengths)\n",
10782
" return gtd.Trajectory(walk_cycle, repeat)"
@@ -134,6 +109,16 @@
134109
"outputs": [],
135110
"source": [
136111
"robot = gtd.CreateRobotFromFile(gtd.URDF_PATH + \"/a1/a1.urdf\", \"a1\")\n",
112+
"# robot.print()"
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": null,
118+
"id": "4d012b64",
119+
"metadata": {},
120+
"outputs": [],
121+
"source": [
137122
"\n",
138123
"gravity = np.array([0.0, 0.0, -9.8])\n",
139124
"mu = 1.0\n",
@@ -144,18 +129,9 @@
144129
"trajectory = build_walk_trajectory(robot, repeat=1)\n",
145130
"collocation = gtd.CollocationScheme.Euler\n",
146131
"\n",
147-
"graph = trajectory.multiPhaseFactorGraph(robot, graph_builder, collocation, mu)\n",
148-
"print('Graph has', graph.size(), 'factors and', graph.keys().size(), 'variables.')"
149-
]
150-
},
151-
{
152-
"cell_type": "code",
153-
"execution_count": null,
154-
"id": "60f38e64",
155-
"metadata": {},
156-
"outputs": [],
157-
"source": [
158-
"graph.print(\"Graph:\\n\", gtd.GTDKeyFormatter)"
132+
"graph = trajectory.multiPhaseFactorGraph(robot, graph_builder, collocation, mu, GROUND_HEIGHT)\n",
133+
"print('Graph has', graph.size(), 'factors and', graph.keys().size(), 'variables.')\n",
134+
"# graph.print(\"Graph:\\n\", gtd.GTDKeyFormatter)"
159135
]
160136
},
161137
{
@@ -165,47 +141,42 @@
165141
"metadata": {},
166142
"outputs": [],
167143
"source": [
168-
"\n",
169-
"ground_height = 1.0\n",
170144
"step = gtsam.Point3(0.25, 0.0, 0.0)\n",
171145
"objectives = trajectory.contactPointObjectives(\n",
172-
" robot,\n",
173-
" gtsam.noiseModel.Isotropic.Sigma(3, 1e-6),\n",
174-
" step,\n",
175-
" ground_height,\n",
146+
" robot, gtsam.noiseModel.Isotropic.Sigma(3, 1e-6), step, GROUND_HEIGHT\n",
176147
")\n",
177148
"\n",
178149
"K = trajectory.getEndTimeStep(trajectory.numPhases() - 1)\n",
179150
"\n",
180-
"for link in robot.links():\n",
181-
" i = link.id()\n",
182-
" if i == 0:\n",
183-
" objectives.push_back(\n",
184-
" gtd.LinkObjectives(i, 0)\n",
185-
" .pose(link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3))\n",
186-
" .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)),\n",
187-
" )\n",
188-
" if i in (3, 6, 9, 12):\n",
189-
" objectives.push_back(\n",
190-
" gtd.LinkObjectives(i, 0).pose(\n",
191-
" link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)\n",
192-
" ),\n",
193-
" )\n",
151+
"# for link in robot.links():\n",
152+
"# i = link.id()\n",
153+
"# if i == 0:\n",
154+
"# objectives.push_back(\n",
155+
"# gtd.LinkObjectives(i, 0)\n",
156+
"# .pose(link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3))\n",
157+
"# .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)),\n",
158+
"# )\n",
159+
"# if i in (3, 6, 9, 12):\n",
160+
"# objectives.push_back(\n",
161+
"# gtd.LinkObjectives(i, 0).pose(\n",
162+
"# link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)\n",
163+
"# ),\n",
164+
"# )\n",
194165
"\n",
195166
"rest_model = gtsam.noiseModel.Isotropic.Sigma(1, 1e-3)\n",
196167
"objectives.push_back(gtd.JointsAtRestObjectives(robot, rest_model, rest_model, 0))\n",
197168
"objectives.push_back(gtd.JointsAtRestObjectives(robot, rest_model, rest_model, K))\n",
198169
"\n",
199-
"trunk = robot.link(\"trunk\")\n",
200-
"for k in range(K + 1):\n",
201-
" objectives.push_back(\n",
202-
" gtd.LinkObjectives(trunk.id(), k)\n",
203-
" .pose(\n",
204-
" gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(0.0, 0.0, 0.4)),\n",
205-
" gtsam.noiseModel.Isotropic.Sigma(6, 1e-2),\n",
206-
" )\n",
207-
" .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 5e-2)),\n",
208-
" )\n",
170+
"# trunk = robot.link(\"trunk\")\n",
171+
"# for k in range(K + 1):\n",
172+
"# objectives.push_back(\n",
173+
"# gtd.LinkObjectives(trunk.id(), k)\n",
174+
"# .pose(\n",
175+
"# gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(0.0, 0.0, GROUND_HEIGHT + 0.4)),\n",
176+
"# gtsam.noiseModel.Isotropic.Sigma(6, 1e-3),\n",
177+
"# )\n",
178+
"# .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 5e-2)),\n",
179+
"# )\n",
209180
"\n",
210181
"desired_dt = 1.0 / 20.0\n",
211182
"trajectory.addIntegrationTimeFactors(objectives, desired_dt, 1e-30)\n",
@@ -260,6 +231,7 @@
260231
"init_values = trajectory.multiPhaseInitialValues(robot, initializer, gaussian_noise, desired_dt)\n",
261232
"\n",
262233
"print('Initial values:', init_values.size())\n",
234+
"print('Final objective error:', graph.error(init_values))\n",
263235
"\n",
264236
"params = gtsam.LevenbergMarquardtParams()\n",
265237
"params.setlambdaInitial(1e10)\n",
@@ -309,9 +281,9 @@
309281
" j = joint.id()\n",
310282
" name = joint.name()\n",
311283
" row[name] = read_or_nan(result, gtd.JointAngleKey(j, k))\n",
312-
" row[f'{name}.1'] = read_or_nan(result, gtd.JointVelKey(j, k))\n",
313-
" row[f'{name}.2'] = read_or_nan(result, gtd.JointAccelKey(j, k))\n",
314-
" row[f'{name}.3'] = read_or_nan(result, gtd.TorqueKey(j, k))\n",
284+
" # row[f'{name}.1'] = read_or_nan(result, gtd.JointVelKey(j, k))\n",
285+
" # row[f'{name}.2'] = read_or_nan(result, gtd.JointAccelKey(j, k))\n",
286+
" # row[f'{name}.3'] = read_or_nan(result, gtd.TorqueKey(j, k))\n",
315287
"\n",
316288
" rows.append(row)\n",
317289
"\n",
@@ -339,10 +311,11 @@
339311
"metadata": {},
340312
"outputs": [],
341313
"source": [
342-
"out_dir = repo_root / 'build' / 'examples' / 'example_a1_walking'\n",
314+
"from pathlib import Path\n",
315+
"out_dir = Path(gtd.DATA_PATH) / 'example_a1_walking'\n",
343316
"out_dir.mkdir(parents=True, exist_ok=True)\n",
344317
"\n",
345-
"csv_table = out_dir / 'a1_traj_chain_dynamics_graph_python.csv'\n",
318+
"csv_table = out_dir / 'traj_df.csv'\n",
346319
"traj_df.to_csv(csv_table, index=False)\n",
347320
"print('Wrote table CSV:', csv_table)\n",
348321
"\n",
@@ -362,21 +335,53 @@
362335
" if not export_values.exists(pk):\n",
363336
" export_values.insert(pk, desired_dt)\n",
364337
"\n",
365-
"native_name = 'a1_traj_CDG_python.csv'\n",
366-
"trajectory.writeToFile(robot, native_name, export_values)\n",
367-
"print('Wrote native GTDynamics CSV in current working directory:', native_name)\n"
338+
"native_name = out_dir / 'a1_traj_CDG.csv'\n",
339+
"trajectory.writeToFile(robot, str(native_name), export_values)\n",
340+
"print('Wrote native GTDynamics CSV to:', native_name)\n"
368341
]
369342
},
370343
{
371344
"cell_type": "markdown",
372345
"id": "7ff53c69",
373346
"metadata": {},
374347
"source": [
375-
"## 7) Plot joint trajectories and contact-point heights\n",
348+
"## 7) Plot base pose, joint trajectories, and contact-point heights\n",
349+
"\n",
350+
"We use `roboplot` to visualize the trunk/base pose trajectory (position + orientation frames) over time.\n",
376351
"\n",
377352
"The second figure shows foot contact-point heights (world Z) for all four feet, which helps check alternating stance/swing behavior."
378353
]
379354
},
355+
{
356+
"cell_type": "code",
357+
"execution_count": null,
358+
"id": "6e0f8d47",
359+
"metadata": {},
360+
"outputs": [],
361+
"source": [
362+
"trunk_id = robot.link('trunk').id()\n",
363+
"base_poses = gtsam.Values()\n",
364+
"for k in range(K + 1):\n",
365+
" base_poses.insert(k, gtd.Pose(result, trunk_id, k))\n",
366+
"roboplot.plot_trajectory(base_poses, fignum=1, scale=0.08, show=True);"
367+
]
368+
},
369+
{
370+
"cell_type": "code",
371+
"execution_count": null,
372+
"id": "a9429bc0",
373+
"metadata": {},
374+
"outputs": [],
375+
"source": [
376+
"fig, ax = plt.subplots(nrows=3,figsize=(12, 4))\n",
377+
"times = traj_df['t'].to_numpy()\n",
378+
"ax[0].plot(times, traj_df['trunk_x'], label='Base', linewidth=1.5)\n",
379+
"ax[1].plot(times, traj_df['trunk_y'], label='Base', linewidth=1.5)\n",
380+
"ax[2].plot(times, traj_df['trunk_z'], label='Base', linewidth=1.5)\n",
381+
"plt.tight_layout()\n",
382+
"plt.show()"
383+
]
384+
},
380385
{
381386
"cell_type": "code",
382387
"execution_count": null,
@@ -396,7 +401,7 @@
396401
" for joint in joint_order:\n",
397402
" col = f'{leg}_{joint}_joint{suffix}'\n",
398403
" if col in traj_df.columns:\n",
399-
" ax.plot(traj_df['t'], traj_df[col], label=joint, color=colors[joint], linewidth=1.4)\n",
404+
" ax.plot(times, traj_df[col], label=joint, color=colors[joint], linewidth=1.4)\n",
400405
" ax.grid(alpha=0.3)\n",
401406
" ax.set_ylabel(f'{leg} {y_label}')\n",
402407
" ax.legend(loc='upper right', ncol=3)\n",
@@ -407,23 +412,58 @@
407412
"\n",
408413
"\n",
409414
"plot_joint_group('', 'Joint Angles (ChainDynamicsGraph walk cycle)', 'q')\n",
410-
"plot_joint_group('.1', 'Joint Velocities (ChainDynamicsGraph walk cycle)', 'qdot')\n",
415+
"# plot_joint_group('.3', 'Joint Torques (ChainDynamicsGraph walk cycle)', 'tau')\n"
416+
]
417+
},
418+
{
419+
"cell_type": "code",
420+
"execution_count": null,
421+
"id": "d2638128",
422+
"metadata": {},
423+
"outputs": [],
424+
"source": [
425+
"foot_links = ['RL_lower', 'RR_lower']\n",
411426
"\n",
412-
"contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)\n",
413-
"foot_links = ['FL_lower', 'FR_lower', 'RL_lower', 'RR_lower']\n",
427+
"foot_z = {}\n",
428+
"for name in foot_links:\n",
429+
" cp = gtd.PointOnLink(robot.link(name).shared(), CONTACT_IN_COM)\n",
430+
" foot_z[name] = np.array([cp.predict(result, k)[2] for k in range(K + 1)])\n",
431+
"\n",
432+
"fig, ax = plt.subplots(figsize=(12, 4))\n",
433+
"ax.plot(times, np.array(traj_df['trunk_z']), label='Base', linewidth=1.5)\n",
434+
"for name in foot_links:\n",
435+
" ax.plot(times, foot_z[name], label=name, linewidth=1.5)\n",
436+
"ax.grid(alpha=0.3)\n",
437+
"ax.set_xlabel('time [s]')\n",
438+
"ax.set_ylabel('contact point z [m]')\n",
439+
"ax.set_title('REAR feet contact-point heights in world frame')\n",
440+
"ax.legend(ncol=4)\n",
441+
"plt.tight_layout()\n",
442+
"plt.show()"
443+
]
444+
},
445+
{
446+
"cell_type": "code",
447+
"execution_count": null,
448+
"id": "6896ee55",
449+
"metadata": {},
450+
"outputs": [],
451+
"source": [
452+
"foot_links = ['FL_lower', 'FR_lower']\n",
414453
"\n",
415-
"foot_z = {'t': traj_df['t'].to_numpy()}\n",
454+
"foot_z = {}\n",
416455
"for name in foot_links:\n",
417-
" cp = gtd.PointOnLink(robot.link(name).shared(), contact_in_com)\n",
456+
" cp = gtd.PointOnLink(robot.link(name).shared(), CONTACT_IN_COM)\n",
418457
" foot_z[name] = np.array([cp.predict(result, k)[2] for k in range(K + 1)])\n",
419458
"\n",
420459
"fig, ax = plt.subplots(figsize=(12, 4))\n",
460+
"ax.plot(times, np.array(traj_df['trunk_z']), label='Base', linewidth=1.5)\n",
421461
"for name in foot_links:\n",
422-
" ax.plot(foot_z['t'], foot_z[name], label=name, linewidth=1.5)\n",
462+
" ax.plot(times, foot_z[name], label=name, linewidth=1.5)\n",
423463
"ax.grid(alpha=0.3)\n",
424464
"ax.set_xlabel('time [s]')\n",
425465
"ax.set_ylabel('contact point z [m]')\n",
426-
"ax.set_title('Foot contact-point heights in world frame')\n",
466+
"ax.set_title('FRONT feet contact-point heights in world frame')\n",
427467
"ax.legend(ncol=4)\n",
428468
"plt.tight_layout()\n",
429469
"plt.show()"

0 commit comments

Comments
 (0)