Skip to content

Commit 587455c

Browse files
committed
add groot open-loop eval
1 parent 0ec48f6 commit 587455c

6 files changed

Lines changed: 1214 additions & 478 deletions

File tree

examples/eval_groot_openloop.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#!/usr/bin/env python3
2+
"""
3+
VLA-Lab Open-Loop Evaluation for Isaac GR00T
4+
5+
使用 VLA-Lab 的统一评估框架对 GR00T 模型进行开环评估。
6+
通过 GR00TAdapter 将 GR00T Policy 适配到 VLA-Lab 接口,
7+
并使用 LeRobotDatasetLoader 加载 GR00T 格式的 LeRobot v2 数据集。
8+
9+
用法:
10+
# 方式 1: 直接运行(使用下方默认配置)
11+
python examples/eval_groot_openloop.py
12+
13+
# 方式 2: 通过命令行参数
14+
python examples/eval_groot_openloop.py \
15+
--model_path /path/to/checkpoint \
16+
--dataset_path /path/to/dataset \
17+
--embodiment_tag NEW_EMBODIMENT \
18+
--traj_ids 0 1 2 \
19+
--action_horizon 16 \
20+
--max_steps 300 \
21+
--save_plots_dir outputs/groot_eval/
22+
23+
24+
python ../VLA-Lab/examples/eval_groot_openloop.py \
25+
--model_path ckpts/GR00T-N1.6-3B_assembly_things/checkpoint-100000 \
26+
--dataset_path /data1/vla-data/processed/GR00T/000204_assembly_things \
27+
--embodiment_tag FRANKA \
28+
--modality_config_path examples/assembly_things/assembly_things_config.py \
29+
--traj_ids 10 310 670 \
30+
--action_horizon 8 \
31+
--max_steps 300 \
32+
--save_plots_dir outputs/vlalab_eval/ \
33+
--device cuda:0
34+
"""
35+
36+
import argparse
37+
import logging
38+
import sys
39+
from pathlib import Path
40+
41+
import numpy as np
42+
43+
# =============================================================================
44+
# 默认配置 - 根据你的环境修改
45+
# =============================================================================
46+
47+
# 数据集路径(LeRobot v2 格式)
48+
DEFAULT_DATASET_PATH = "demo_data/cube_to_bowl_5"
49+
50+
# 模型 checkpoint 路径
51+
DEFAULT_MODEL_PATH = None # None 表示使用 PolicyClient 连接远程推理服务
52+
53+
# Embodiment tag
54+
DEFAULT_EMBODIMENT_TAG = "NEW_EMBODIMENT"
55+
56+
# 评估参数
57+
DEFAULT_TRAJ_IDS = [0]
58+
DEFAULT_ACTION_HORIZON = 16
59+
DEFAULT_MAX_STEPS = 200
60+
61+
# 输出路径
62+
DEFAULT_SAVE_PLOTS_DIR = "outputs/groot_eval/"
63+
64+
65+
def parse_args():
66+
parser = argparse.ArgumentParser(
67+
description="VLA-Lab Open-Loop Evaluation for Isaac GR00T"
68+
)
69+
parser.add_argument(
70+
"--model_path", type=str, default=DEFAULT_MODEL_PATH,
71+
help="模型 checkpoint 路径 (None=使用 PolicyClient)"
72+
)
73+
parser.add_argument(
74+
"--dataset_path", type=str, default=DEFAULT_DATASET_PATH,
75+
help="LeRobot v2 格式数据集路径"
76+
)
77+
parser.add_argument(
78+
"--embodiment_tag", type=str, default=DEFAULT_EMBODIMENT_TAG,
79+
help="Embodiment tag (如 NEW_EMBODIMENT, FRANKA, GR1 等)"
80+
)
81+
parser.add_argument(
82+
"--modality_config_path", type=str, default=None,
83+
help="自定义 modality 配置 Python 文件路径 (可选)"
84+
)
85+
parser.add_argument(
86+
"--host", type=str, default="127.0.0.1",
87+
help="PolicyClient 远程推理服务地址"
88+
)
89+
parser.add_argument(
90+
"--port", type=int, default=5555,
91+
help="PolicyClient 远程推理服务端口"
92+
)
93+
parser.add_argument(
94+
"--traj_ids", type=int, nargs="+", default=DEFAULT_TRAJ_IDS,
95+
help="要评估的轨迹 ID 列表"
96+
)
97+
parser.add_argument(
98+
"--action_horizon", type=int, default=DEFAULT_ACTION_HORIZON,
99+
help="动作预测步长"
100+
)
101+
parser.add_argument(
102+
"--max_steps", type=int, default=DEFAULT_MAX_STEPS,
103+
help="每条轨迹最大评估步数"
104+
)
105+
parser.add_argument(
106+
"--save_plots_dir", type=str, default=DEFAULT_SAVE_PLOTS_DIR,
107+
help="保存评估图表的目录"
108+
)
109+
parser.add_argument(
110+
"--device", type=str, default="cuda:0",
111+
help="推理设备"
112+
)
113+
return parser.parse_args()
114+
115+
116+
def _resolve_embodiment_tag(tag_str: str):
117+
"""将字符串解析为 EmbodimentTag,支持枚举名(FRANKA)和枚举值(franka)"""
118+
from gr00t.data.embodiment_tags import EmbodimentTag
119+
# 先按枚举名查找 (如 FRANKA, NEW_EMBODIMENT)
120+
try:
121+
return EmbodimentTag[tag_str]
122+
except KeyError:
123+
pass
124+
# 再按枚举值查找 (如 franka, new_embodiment)
125+
try:
126+
return EmbodimentTag(tag_str)
127+
except ValueError:
128+
pass
129+
# 最后尝试大小写变换
130+
try:
131+
return EmbodimentTag[tag_str.upper()]
132+
except KeyError:
133+
pass
134+
raise ValueError(
135+
f"'{tag_str}' 不是有效的 EmbodimentTag。"
136+
f"可用值: {[e.name for e in EmbodimentTag]}"
137+
)
138+
139+
140+
def load_groot_policy(args):
141+
"""加载 GR00T Policy(本地模型或远程客户端)"""
142+
embodiment_tag = _resolve_embodiment_tag(args.embodiment_tag)
143+
144+
if args.model_path is not None:
145+
import torch
146+
from gr00t.policy.gr00t_policy import Gr00tPolicy
147+
148+
device = args.device if torch.cuda.is_available() else "cpu"
149+
logging.info(f"加载本地 GR00T Policy: {args.model_path}")
150+
logging.info(f" Embodiment: {embodiment_tag}")
151+
logging.info(f" Device: {device}")
152+
153+
policy = Gr00tPolicy(
154+
embodiment_tag=embodiment_tag,
155+
model_path=args.model_path,
156+
device=device,
157+
)
158+
else:
159+
from gr00t.policy.server_client import PolicyClient
160+
161+
logging.info(f"连接远程 GR00T Policy: {args.host}:{args.port}")
162+
policy = PolicyClient(host=args.host, port=args.port)
163+
164+
if hasattr(policy, "ping") and not policy.ping():
165+
raise RuntimeError(
166+
f"无法连接到推理服务 {args.host}:{args.port}\n"
167+
"请先启动服务: uv run python gr00t/eval/run_gr00t_server.py ..."
168+
)
169+
170+
return policy, embodiment_tag
171+
172+
173+
def load_modality_config(args, embodiment_tag):
174+
"""加载自定义 modality 配置(如果指定了配置文件)"""
175+
if args.modality_config_path:
176+
logging.info(f"加载自定义 modality 配置: {args.modality_config_path}")
177+
import importlib.util
178+
spec = importlib.util.spec_from_file_location("modality_config", args.modality_config_path)
179+
mod = importlib.util.module_from_spec(spec)
180+
spec.loader.exec_module(mod)
181+
logging.info("自定义 modality 配置已注册")
182+
183+
184+
def main():
185+
args = parse_args()
186+
187+
# 配置日志
188+
logging.basicConfig(
189+
level=logging.INFO,
190+
format="%(asctime)s [%(levelname)s] %(message)s",
191+
datefmt="%H:%M:%S",
192+
)
193+
194+
logging.info("=" * 60)
195+
logging.info("VLA-Lab Open-Loop Evaluation for Isaac GR00T")
196+
logging.info("=" * 60)
197+
198+
# Step 1: 加载 modality config(如有自定义)
199+
embodiment_tag = _resolve_embodiment_tag(args.embodiment_tag)
200+
load_modality_config(args, embodiment_tag)
201+
202+
# Step 2: 加载 GR00T Policy
203+
policy, embodiment_tag = load_groot_policy(args)
204+
modality_config = policy.get_modality_config()
205+
logging.info(f"Policy modality config:\n{modality_config}")
206+
207+
# Step 3: 使用 VLA-Lab 的 GR00TAdapter 包装
208+
from vlalab.eval.adapters.groot_adapter import GR00TAdapter
209+
adapter = GR00TAdapter(policy, embodiment_tag=args.embodiment_tag)
210+
vlalab_modality = adapter.get_modality_config()
211+
212+
logging.info(f"VLA-Lab adapter modality config:")
213+
logging.info(f" State keys: {vlalab_modality.state_keys}")
214+
logging.info(f" Action keys: {vlalab_modality.action_keys}")
215+
logging.info(f" Image keys: {vlalab_modality.image_keys}")
216+
logging.info(f" Language keys: {vlalab_modality.language_keys}")
217+
logging.info(f" Action horizon: {vlalab_modality.action_horizon}")
218+
219+
# Step 4: 构建 LeRobot 数据加载器
220+
from vlalab.eval.lerobot_loader import LeRobotDatasetLoader
221+
dataset_loader = LeRobotDatasetLoader(
222+
dataset_path=args.dataset_path,
223+
modality_configs=modality_config,
224+
embodiment_tag=embodiment_tag,
225+
)
226+
logging.info(f"数据集轨迹数: {len(dataset_loader)}")
227+
228+
# Step 5: 构建评估器并运行
229+
from vlalab.eval.open_loop_eval import (
230+
OpenLoopEvaluator,
231+
EvalConfig,
232+
evaluate_trajectory,
233+
plot_trajectory_results,
234+
)
235+
236+
evaluator = OpenLoopEvaluator(
237+
policy=adapter,
238+
dataset_path=args.dataset_path,
239+
dataset_format="lerobot",
240+
dataset_loader=dataset_loader,
241+
)
242+
243+
# 获取任务描述
244+
task_desc = dataset_loader.get_task_description(args.traj_ids[0])
245+
logging.info(f"任务描述: {task_desc}")
246+
247+
logging.info(f"\n开始评估 {len(args.traj_ids)} 条轨迹...")
248+
logging.info(f" Trajectory IDs: {args.traj_ids}")
249+
logging.info(f" Action Horizon: {args.action_horizon}")
250+
logging.info(f" Max Steps: {args.max_steps}")
251+
252+
results = evaluator.evaluate(
253+
traj_ids=args.traj_ids,
254+
max_steps=args.max_steps,
255+
action_horizon=args.action_horizon,
256+
save_plots_dir=args.save_plots_dir,
257+
task_description=task_desc,
258+
)
259+
260+
# 打印结果
261+
logging.info("\n" + "=" * 60)
262+
logging.info("评估结果")
263+
logging.info("=" * 60)
264+
logging.info(f"评估轨迹数: {results['num_trajectories']}")
265+
266+
if "avg_mse" in results:
267+
logging.info(f"平均 MSE: {results['avg_mse']:.6f}")
268+
logging.info(f"平均 MAE: {results['avg_mae']:.6f}")
269+
270+
for r in results["results"]:
271+
logging.info(
272+
f" 轨迹 {r['trajectory_id']}: "
273+
f"MSE={r['mse']:.6f}, MAE={r['mae']:.6f}, "
274+
f"步数={r['num_steps']}"
275+
)
276+
277+
if args.save_plots_dir:
278+
logging.info(f"\n图表保存到: {args.save_plots_dir}")
279+
280+
# 保存 JSON 结果
281+
results_path = Path(args.save_plots_dir) / "results.json"
282+
evaluator.evaluate_and_save(
283+
output_path=str(results_path),
284+
traj_ids=args.traj_ids,
285+
max_steps=args.max_steps,
286+
action_horizon=args.action_horizon,
287+
task_description=task_desc,
288+
)
289+
logging.info(f"结果 JSON 保存到: {results_path}")
290+
291+
logging.info("\n✅ 评估完成!")
292+
return results
293+
294+
295+
if __name__ == "__main__":
296+
main()

0 commit comments

Comments
 (0)