|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +VLA-Lab Open-Loop Evaluation for Isaac GR00T |
| 4 | +
|
| 5 | +使用 VLA-Lab 的统一评估框架对 GR00T 模型进行开环评估。 |
| 6 | +通过 GR00TAdapter 将 GR00T Policy 适配到 VLA-Lab 接口, |
| 7 | +并使用 LeRobotDatasetLoader 加载 GR00T 格式的 LeRobot v2 数据集。 |
| 8 | +
|
| 9 | +用法: |
| 10 | + # 方式 1: 直接运行(使用下方默认配置) |
| 11 | + python examples/eval_groot_openloop.py |
| 12 | +
|
| 13 | + # 方式 2: 通过命令行参数 |
| 14 | + python examples/eval_groot_openloop.py \ |
| 15 | + --model_path /path/to/checkpoint \ |
| 16 | + --dataset_path /path/to/dataset \ |
| 17 | + --embodiment_tag NEW_EMBODIMENT \ |
| 18 | + --traj_ids 0 1 2 \ |
| 19 | + --action_horizon 16 \ |
| 20 | + --max_steps 300 \ |
| 21 | + --save_plots_dir outputs/groot_eval/ |
| 22 | + |
| 23 | + |
| 24 | + python ../VLA-Lab/examples/eval_groot_openloop.py \ |
| 25 | + --model_path ckpts/GR00T-N1.6-3B_assembly_things/checkpoint-100000 \ |
| 26 | + --dataset_path /data1/vla-data/processed/GR00T/000204_assembly_things \ |
| 27 | + --embodiment_tag FRANKA \ |
| 28 | + --modality_config_path examples/assembly_things/assembly_things_config.py \ |
| 29 | + --traj_ids 10 310 670 \ |
| 30 | + --action_horizon 8 \ |
| 31 | + --max_steps 300 \ |
| 32 | + --save_plots_dir outputs/vlalab_eval/ \ |
| 33 | + --device cuda:0 |
| 34 | +""" |
| 35 | + |
| 36 | +import argparse |
| 37 | +import logging |
| 38 | +import sys |
| 39 | +from pathlib import Path |
| 40 | + |
| 41 | +import numpy as np |
| 42 | + |
| 43 | +# ============================================================================= |
| 44 | +# 默认配置 - 根据你的环境修改 |
| 45 | +# ============================================================================= |
| 46 | + |
| 47 | +# 数据集路径(LeRobot v2 格式) |
| 48 | +DEFAULT_DATASET_PATH = "demo_data/cube_to_bowl_5" |
| 49 | + |
| 50 | +# 模型 checkpoint 路径 |
| 51 | +DEFAULT_MODEL_PATH = None # None 表示使用 PolicyClient 连接远程推理服务 |
| 52 | + |
| 53 | +# Embodiment tag |
| 54 | +DEFAULT_EMBODIMENT_TAG = "NEW_EMBODIMENT" |
| 55 | + |
| 56 | +# 评估参数 |
| 57 | +DEFAULT_TRAJ_IDS = [0] |
| 58 | +DEFAULT_ACTION_HORIZON = 16 |
| 59 | +DEFAULT_MAX_STEPS = 200 |
| 60 | + |
| 61 | +# 输出路径 |
| 62 | +DEFAULT_SAVE_PLOTS_DIR = "outputs/groot_eval/" |
| 63 | + |
| 64 | + |
| 65 | +def parse_args(): |
| 66 | + parser = argparse.ArgumentParser( |
| 67 | + description="VLA-Lab Open-Loop Evaluation for Isaac GR00T" |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--model_path", type=str, default=DEFAULT_MODEL_PATH, |
| 71 | + help="模型 checkpoint 路径 (None=使用 PolicyClient)" |
| 72 | + ) |
| 73 | + parser.add_argument( |
| 74 | + "--dataset_path", type=str, default=DEFAULT_DATASET_PATH, |
| 75 | + help="LeRobot v2 格式数据集路径" |
| 76 | + ) |
| 77 | + parser.add_argument( |
| 78 | + "--embodiment_tag", type=str, default=DEFAULT_EMBODIMENT_TAG, |
| 79 | + help="Embodiment tag (如 NEW_EMBODIMENT, FRANKA, GR1 等)" |
| 80 | + ) |
| 81 | + parser.add_argument( |
| 82 | + "--modality_config_path", type=str, default=None, |
| 83 | + help="自定义 modality 配置 Python 文件路径 (可选)" |
| 84 | + ) |
| 85 | + parser.add_argument( |
| 86 | + "--host", type=str, default="127.0.0.1", |
| 87 | + help="PolicyClient 远程推理服务地址" |
| 88 | + ) |
| 89 | + parser.add_argument( |
| 90 | + "--port", type=int, default=5555, |
| 91 | + help="PolicyClient 远程推理服务端口" |
| 92 | + ) |
| 93 | + parser.add_argument( |
| 94 | + "--traj_ids", type=int, nargs="+", default=DEFAULT_TRAJ_IDS, |
| 95 | + help="要评估的轨迹 ID 列表" |
| 96 | + ) |
| 97 | + parser.add_argument( |
| 98 | + "--action_horizon", type=int, default=DEFAULT_ACTION_HORIZON, |
| 99 | + help="动作预测步长" |
| 100 | + ) |
| 101 | + parser.add_argument( |
| 102 | + "--max_steps", type=int, default=DEFAULT_MAX_STEPS, |
| 103 | + help="每条轨迹最大评估步数" |
| 104 | + ) |
| 105 | + parser.add_argument( |
| 106 | + "--save_plots_dir", type=str, default=DEFAULT_SAVE_PLOTS_DIR, |
| 107 | + help="保存评估图表的目录" |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "--device", type=str, default="cuda:0", |
| 111 | + help="推理设备" |
| 112 | + ) |
| 113 | + return parser.parse_args() |
| 114 | + |
| 115 | + |
| 116 | +def _resolve_embodiment_tag(tag_str: str): |
| 117 | + """将字符串解析为 EmbodimentTag,支持枚举名(FRANKA)和枚举值(franka)""" |
| 118 | + from gr00t.data.embodiment_tags import EmbodimentTag |
| 119 | + # 先按枚举名查找 (如 FRANKA, NEW_EMBODIMENT) |
| 120 | + try: |
| 121 | + return EmbodimentTag[tag_str] |
| 122 | + except KeyError: |
| 123 | + pass |
| 124 | + # 再按枚举值查找 (如 franka, new_embodiment) |
| 125 | + try: |
| 126 | + return EmbodimentTag(tag_str) |
| 127 | + except ValueError: |
| 128 | + pass |
| 129 | + # 最后尝试大小写变换 |
| 130 | + try: |
| 131 | + return EmbodimentTag[tag_str.upper()] |
| 132 | + except KeyError: |
| 133 | + pass |
| 134 | + raise ValueError( |
| 135 | + f"'{tag_str}' 不是有效的 EmbodimentTag。" |
| 136 | + f"可用值: {[e.name for e in EmbodimentTag]}" |
| 137 | + ) |
| 138 | + |
| 139 | + |
| 140 | +def load_groot_policy(args): |
| 141 | + """加载 GR00T Policy(本地模型或远程客户端)""" |
| 142 | + embodiment_tag = _resolve_embodiment_tag(args.embodiment_tag) |
| 143 | + |
| 144 | + if args.model_path is not None: |
| 145 | + import torch |
| 146 | + from gr00t.policy.gr00t_policy import Gr00tPolicy |
| 147 | + |
| 148 | + device = args.device if torch.cuda.is_available() else "cpu" |
| 149 | + logging.info(f"加载本地 GR00T Policy: {args.model_path}") |
| 150 | + logging.info(f" Embodiment: {embodiment_tag}") |
| 151 | + logging.info(f" Device: {device}") |
| 152 | + |
| 153 | + policy = Gr00tPolicy( |
| 154 | + embodiment_tag=embodiment_tag, |
| 155 | + model_path=args.model_path, |
| 156 | + device=device, |
| 157 | + ) |
| 158 | + else: |
| 159 | + from gr00t.policy.server_client import PolicyClient |
| 160 | + |
| 161 | + logging.info(f"连接远程 GR00T Policy: {args.host}:{args.port}") |
| 162 | + policy = PolicyClient(host=args.host, port=args.port) |
| 163 | + |
| 164 | + if hasattr(policy, "ping") and not policy.ping(): |
| 165 | + raise RuntimeError( |
| 166 | + f"无法连接到推理服务 {args.host}:{args.port}\n" |
| 167 | + "请先启动服务: uv run python gr00t/eval/run_gr00t_server.py ..." |
| 168 | + ) |
| 169 | + |
| 170 | + return policy, embodiment_tag |
| 171 | + |
| 172 | + |
| 173 | +def load_modality_config(args, embodiment_tag): |
| 174 | + """加载自定义 modality 配置(如果指定了配置文件)""" |
| 175 | + if args.modality_config_path: |
| 176 | + logging.info(f"加载自定义 modality 配置: {args.modality_config_path}") |
| 177 | + import importlib.util |
| 178 | + spec = importlib.util.spec_from_file_location("modality_config", args.modality_config_path) |
| 179 | + mod = importlib.util.module_from_spec(spec) |
| 180 | + spec.loader.exec_module(mod) |
| 181 | + logging.info("自定义 modality 配置已注册") |
| 182 | + |
| 183 | + |
| 184 | +def main(): |
| 185 | + args = parse_args() |
| 186 | + |
| 187 | + # 配置日志 |
| 188 | + logging.basicConfig( |
| 189 | + level=logging.INFO, |
| 190 | + format="%(asctime)s [%(levelname)s] %(message)s", |
| 191 | + datefmt="%H:%M:%S", |
| 192 | + ) |
| 193 | + |
| 194 | + logging.info("=" * 60) |
| 195 | + logging.info("VLA-Lab Open-Loop Evaluation for Isaac GR00T") |
| 196 | + logging.info("=" * 60) |
| 197 | + |
| 198 | + # Step 1: 加载 modality config(如有自定义) |
| 199 | + embodiment_tag = _resolve_embodiment_tag(args.embodiment_tag) |
| 200 | + load_modality_config(args, embodiment_tag) |
| 201 | + |
| 202 | + # Step 2: 加载 GR00T Policy |
| 203 | + policy, embodiment_tag = load_groot_policy(args) |
| 204 | + modality_config = policy.get_modality_config() |
| 205 | + logging.info(f"Policy modality config:\n{modality_config}") |
| 206 | + |
| 207 | + # Step 3: 使用 VLA-Lab 的 GR00TAdapter 包装 |
| 208 | + from vlalab.eval.adapters.groot_adapter import GR00TAdapter |
| 209 | + adapter = GR00TAdapter(policy, embodiment_tag=args.embodiment_tag) |
| 210 | + vlalab_modality = adapter.get_modality_config() |
| 211 | + |
| 212 | + logging.info(f"VLA-Lab adapter modality config:") |
| 213 | + logging.info(f" State keys: {vlalab_modality.state_keys}") |
| 214 | + logging.info(f" Action keys: {vlalab_modality.action_keys}") |
| 215 | + logging.info(f" Image keys: {vlalab_modality.image_keys}") |
| 216 | + logging.info(f" Language keys: {vlalab_modality.language_keys}") |
| 217 | + logging.info(f" Action horizon: {vlalab_modality.action_horizon}") |
| 218 | + |
| 219 | + # Step 4: 构建 LeRobot 数据加载器 |
| 220 | + from vlalab.eval.lerobot_loader import LeRobotDatasetLoader |
| 221 | + dataset_loader = LeRobotDatasetLoader( |
| 222 | + dataset_path=args.dataset_path, |
| 223 | + modality_configs=modality_config, |
| 224 | + embodiment_tag=embodiment_tag, |
| 225 | + ) |
| 226 | + logging.info(f"数据集轨迹数: {len(dataset_loader)}") |
| 227 | + |
| 228 | + # Step 5: 构建评估器并运行 |
| 229 | + from vlalab.eval.open_loop_eval import ( |
| 230 | + OpenLoopEvaluator, |
| 231 | + EvalConfig, |
| 232 | + evaluate_trajectory, |
| 233 | + plot_trajectory_results, |
| 234 | + ) |
| 235 | + |
| 236 | + evaluator = OpenLoopEvaluator( |
| 237 | + policy=adapter, |
| 238 | + dataset_path=args.dataset_path, |
| 239 | + dataset_format="lerobot", |
| 240 | + dataset_loader=dataset_loader, |
| 241 | + ) |
| 242 | + |
| 243 | + # 获取任务描述 |
| 244 | + task_desc = dataset_loader.get_task_description(args.traj_ids[0]) |
| 245 | + logging.info(f"任务描述: {task_desc}") |
| 246 | + |
| 247 | + logging.info(f"\n开始评估 {len(args.traj_ids)} 条轨迹...") |
| 248 | + logging.info(f" Trajectory IDs: {args.traj_ids}") |
| 249 | + logging.info(f" Action Horizon: {args.action_horizon}") |
| 250 | + logging.info(f" Max Steps: {args.max_steps}") |
| 251 | + |
| 252 | + results = evaluator.evaluate( |
| 253 | + traj_ids=args.traj_ids, |
| 254 | + max_steps=args.max_steps, |
| 255 | + action_horizon=args.action_horizon, |
| 256 | + save_plots_dir=args.save_plots_dir, |
| 257 | + task_description=task_desc, |
| 258 | + ) |
| 259 | + |
| 260 | + # 打印结果 |
| 261 | + logging.info("\n" + "=" * 60) |
| 262 | + logging.info("评估结果") |
| 263 | + logging.info("=" * 60) |
| 264 | + logging.info(f"评估轨迹数: {results['num_trajectories']}") |
| 265 | + |
| 266 | + if "avg_mse" in results: |
| 267 | + logging.info(f"平均 MSE: {results['avg_mse']:.6f}") |
| 268 | + logging.info(f"平均 MAE: {results['avg_mae']:.6f}") |
| 269 | + |
| 270 | + for r in results["results"]: |
| 271 | + logging.info( |
| 272 | + f" 轨迹 {r['trajectory_id']}: " |
| 273 | + f"MSE={r['mse']:.6f}, MAE={r['mae']:.6f}, " |
| 274 | + f"步数={r['num_steps']}" |
| 275 | + ) |
| 276 | + |
| 277 | + if args.save_plots_dir: |
| 278 | + logging.info(f"\n图表保存到: {args.save_plots_dir}") |
| 279 | + |
| 280 | + # 保存 JSON 结果 |
| 281 | + results_path = Path(args.save_plots_dir) / "results.json" |
| 282 | + evaluator.evaluate_and_save( |
| 283 | + output_path=str(results_path), |
| 284 | + traj_ids=args.traj_ids, |
| 285 | + max_steps=args.max_steps, |
| 286 | + action_horizon=args.action_horizon, |
| 287 | + task_description=task_desc, |
| 288 | + ) |
| 289 | + logging.info(f"结果 JSON 保存到: {results_path}") |
| 290 | + |
| 291 | + logging.info("\n✅ 评估完成!") |
| 292 | + return results |
| 293 | + |
| 294 | + |
| 295 | +if __name__ == "__main__": |
| 296 | + main() |
0 commit comments