diff --git a/scripts/check_dataset_integrity.py b/scripts/check_dataset_integrity.py deleted file mode 100644 index 581d5e57..00000000 --- a/scripts/check_dataset_integrity.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2025 The VLA-Arena Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A script to check if any demonstration dataset does not have the exact number of demonstration trajectories""" - -from pathlib import Path - -import h5py -import numpy as np - -from vla_arena.vla_arena import get_vla_arena_path - - -error_datasets = [] -for demo_file_name in Path(get_vla_arena_path('datasets')).rglob('*hdf5'): - - demo_file = h5py.File(demo_file_name) - - count = 0 - for key in demo_file['data'].keys(): - if 'demo' in key: - count += 1 - - if count == 50: - traj_lengths = [] - action_min = np.inf - action_max = -np.inf - for demo_name in demo_file['data'].keys(): - traj_lengths.append( - demo_file[f'data/{demo_name}/actions'].shape[0] - ) - traj_lengths = np.array(traj_lengths) - print( - f'[info] dataset {demo_file_name} is in tact, test passed \u2714' - ) - print(np.mean(traj_lengths), ' +- ', np.std(traj_lengths)) - if demo_file['data'].attrs['tag'] == 'vla_arena-v1': - print('Version correct') - - print('=========================================') - - else: - print('[error] !!!') - error_datasets.append(demo_file_name) - -if len(error_datasets) > 0: - print('[error] The following datasets are corrupted:') - for dataset in error_datasets: - print(dataset) diff --git a/scripts/evaluate_policy.py b/scripts/evaluate_policy.py deleted file mode 100644 index fd05ca5d..00000000 --- a/scripts/evaluate_policy.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright 2025 The VLA-Arena Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os - -from vla_arena.evaluation.evaluator import Evaluator - -# from vla_arena.evaluation.policy import OpenVLAOFT -# from vla_arena.evaluation.policy import OpenPI -# from vla_arena.evaluation.policy import SmolVLA -from vla_arena.evaluation.policy import PolicyRegistry - - -os.environ['MUJOCO_GL'] = 'egl' - - -def parse_levels(levels_str): - """ - Parse level string to support various formats: - - Single level: "0" -> [0] - - Range: "0-2" -> [0, 1, 2] - - List: "0,2" -> [0, 2] - """ - if levels_str is None: - return None - - levels = [] - parts = levels_str.split(',') - - for part in parts: - part = part.strip() - if '-' in part: - # Handle range - start, end = part.split('-') - start, end = int(start.strip()), int(end.strip()) - levels.extend(list(range(start, end + 1))) - else: - # Handle single level - levels.append(int(part)) - - # Remove duplicates and sort - levels = sorted(list(set(levels))) - return levels - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--task_suite', - default=None, - type=str, - choices=[ - 'safety_dynamic_obstacles', - 'safety_hazard_avoidance', - 'safety_object_state_preservation', - 'safety_risk_aware_grasping', - 'safety_static_obstacles', - 'robustness_dynamic_distractors', - 'robustness_static_distractors', - 'robustness_visual_variations', - 'generalization_language_variations', - 'generalization_object_preposition_combinations', - 'generalization_task_workflows', - 'generalization_unseen_objects', - 'long_horizon', - # "libero_10", - # "libero_90", - # "libero_spatial", - # "libero_object", - # "libero_goal", - ], - help='The evaluation track to run', - ) - - # Modified: Support both single level and multiple levels - parser.add_argument( - '--task_level', - default='0', - type=str, - help='Task level(s) to evaluate. Supports: single (0), range (0-2), list (0,2,4), or mixed (0-2,5)', - ) - - parser.add_argument( - '--n-episode', - default=1, - type=int, - help='The number of episodes to evaluate for each task', - ) - parser.add_argument( - '--policy', - default='openvla', - type=str, - choices=PolicyRegistry.list_policies(), - help='The policy to evaluate', - ) - parser.add_argument( - '--model_ckpt', default=None, help='The base model checkpoint path' - ) - parser.add_argument( - '--save-dir', - default='logs', - help='The directory to save the evaluation results', - ) - parser.add_argument( - '--visualization', - action='store_true', - default=False, - help='Whether to visualize the episodes', - ) - parser.add_argument( - '--metrics', - nargs='+', - default=['success_rate', 'cumulative_cost', 'safe_success_rate'], - choices=[ - 'success_rate', - 'cumulative_cost', - 'safe_success_rate', - 'episode_length', - ], - help='The metrics to evaluate', - ) - parser.add_argument( - '--host', - default='localhost', - type=str, - help='The host to the remote server', - ) - parser.add_argument( - '--port', default=5555, type=int, help='The port to the remote server' - ) - parser.add_argument( - '--replanstep', default=4, type=int, help='The step to replan' - ) - - # Additional arguments for batch evaluation - parser.add_argument( - '--parallel', - action='store_true', - default=False, - help='Whether to run level evaluations in parallel (experimental)', - ) - parser.add_argument( - '--episode_config', - default=None, - type=str, - help='Path to episode configuration file', - ) - - args = parser.parse_args() - return args - - -def print_evaluation_plan(args, task_levels): - """Print the evaluation plan before starting""" - print('\n' + '=' * 70) - print('EVALUATION PLAN') - print('=' * 70) - print(f'Task Suite: {args.task_suite}') - print(f'Levels to evaluate: {task_levels}') - print(f'Episodes per task: {args.n_episode}') - print(f'Policy: {args.policy}') - print(f'Metrics: {args.metrics}') - print(f'Visualization: {args.visualization}') - print(f'Save directory: {args.save_dir}') - print('=' * 70 + '\n') - - # Calculate total evaluation scope - num_levels = len(task_levels) - # This is approximate - actual number depends on the suite - estimated_tasks_per_level = 10 # You might want to get this from the suite - total_episodes = num_levels * estimated_tasks_per_level * args.n_episode - - print(f'Estimated total episodes: ~{total_episodes}') - print('Press Ctrl+C to cancel, or wait to continue...\n') - - import time - - time.sleep(3) # Give user time to cancel if needed - - -def evaluate(args): - """Main evaluation function with multi-level support""" - - # Parse task levels - task_levels = parse_levels(args.task_level) - if not task_levels: - raise ValueError('No valid task levels specified!') - - # Load episode configuration if provided - episode_config = None - if args.episode_config: - with open(args.episode_config) as f: - episode_config = json.load(f) - - # Set up save directory - if args.task_suite is not None: - args.save_dir = os.path.join(args.save_dir, args.task_suite) - - if not args.task_suite: - raise ValueError('No tasks specified! Please provide --task_suite') - - # Print evaluation plan - print_evaluation_plan(args, task_levels) - - print(f'Tasks to evaluate: {args.task_suite}') - print(f'Levels to evaluate: {task_levels}') - print(f'Number of episodes per task: {args.n_episode}') - - # Create evaluator with multiple levels support - evaluator = Evaluator( - task_suite=args.task_suite, - task_levels=task_levels, # Pass list of levels - n_episodes=args.n_episode, - episode_config=episode_config, - max_substeps=1, # repeat step in simulation - save_dir=args.save_dir, - visualization=args.visualization, - metrics=args.metrics, - ) - if args.policy not in PolicyRegistry.list_policies(): - raise ValueError( - f"Policy '{args.policy}' is not registered. Available policies: {PolicyRegistry.list_policies()}", - ) - if args.policy != 'openpi': - policy = PolicyRegistry.get( - args.policy, - model_ckpt=args.model_ckpt if args.model_ckpt else None, - ) - else: - policy = PolicyRegistry.get( - args.policy, host=args.host, port=args.port - ) - - # Run evaluation - results = evaluator.evaluate(policy) - - # Print quick summary of results - print('\n' + '=' * 70) - print('EVALUATION COMPLETED SUCCESSFULLY') - print('=' * 70) - - if isinstance(results, dict): - # If single level, results is a dict of task metrics - if len(task_levels) == 1: - print(f'\nLevel {task_levels[0]} Results:') - for task_name, metrics in results.items(): - print(f' {task_name}:') - if 'success_rate' in metrics: - print(f" Success Rate: {metrics['success_rate']:.2%}") - if 'safe_success_rate' in metrics: - print( - f" Safe Success Rate: {metrics['safe_success_rate']:.2%}" - ) - if 'cumulative_cost' in metrics: - print(f" Avg Cost: {metrics['cumulative_cost']:.2f}") - else: - # Multiple levels, results is dict of level -> task metrics - for level, level_results in results.items(): - print(f'\nLevel {level} Results:') - success_rates = [] - for task_name, metrics in level_results.items(): - if 'success_rate' in metrics: - success_rates.append(metrics['success_rate']) - if success_rates: - avg_success = sum(success_rates) / len(success_rates) - print(f' Average Success Rate: {avg_success:.2%}') - - print(f'\nDetailed results saved to: {evaluator.save_dir}') - - # except KeyboardInterrupt: - # print("\n\nEvaluation interrupted by user.") - # print("Partial results may have been saved.") - # except Exception as e: - # print(f"\n\nEvaluation failed with error: {e}") - # import traceback - # traceback.print_exc() - # raise - - -def main(): - """Entry point with better error handling""" - args = get_args() - - # Validate arguments - if not args.task_suite: - print('Error: --task_suite is required!') - print( - 'Available options: static_obstacles, preposition_generalization' - ) - return 1 - - try: - evaluate(args) - return 0 - except Exception: - import traceback - - traceback.print_exc() - return 1 - - -if __name__ == '__main__': - import sys - - sys.exit(main()) diff --git a/scripts/evaluate_policy.sh b/scripts/evaluate_policy.sh deleted file mode 100644 index 8a4d19c3..00000000 --- a/scripts/evaluate_policy.sh +++ /dev/null @@ -1,201 +0,0 @@ -#!/bin/bash -# ============================================================================ -# VLA-Arena Unified Evaluation Script -# ============================================================================ -# Instructions: -# 1. Copy this script: cp scripts/evaluate_policy.sh my_evaluation.sh -# 2. Edit the configuration section below -# 3. Run: bash my_evaluation.sh -# ============================================================================ - -# ================================ -# CONFIGURATION SECTION - Edit these variables for your evaluation -# ================================ - -# Model Configuration -export CUDA_VISIBLE_DEVICES=0 - -POLICY="openvla" # Options: openvla, random (more coming soon) -MODEL_CKPT="path/to/model/checkpoint" # Path to model checkpoint - -# Task Configuration -TASK_SUITE="safety_static_obstacles" # Options: -TASK_LEVEL=0 # Difficulty level: 0 (easy), 1 (medium), 2 (hard) -N_EPISODES=1 # Number of episodes per task - -# Evaluation Settings -VISUALIZATION=true # Set to true to save evaluation videos -METRICS="success_rate" # Metrics to compute - -# Output Configuration -SAVE_DIR="logs/evaluation_$(date +%Y%m%d_%H%M%S)" # Output directory (auto-timestamped) - -# ================================ -# END OF CONFIGURATION SECTION -# ================================ - -# Color codes for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# Function to print colored output -print_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -print_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -print_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -print_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -# Validation -validate_config() { - local valid=true - - if [ "$valid" = false ]; then - print_error "Configuration validation failed. Please check your settings." - exit 1 - fi -} - -# Print configuration summary -print_config() { - echo "" - echo "╔══════════════════════════════════════════════════════════════╗" - echo "║ VLA-Arena Evaluation Configuration ║" - echo "╠══════════════════════════════════════════════════════════════╣" - printf "║ %-20s : %-39s ║\n" "Policy" "$POLICY" - if [[ "$POLICY" != "random" ]]; then - # Truncate long paths for display - local display_model=$(basename "$MODEL_CKPT") - printf "║ %-20s : %-39s ║\n" "Model" "...$display_model" - fi - printf "║ %-20s : %-39s ║\n" "Task Suite" "$TASK_SUITE" - printf "║ %-20s : %-39s ║\n" "Task Level" "Level $TASK_LEVEL" - printf "║ %-20s : %-39s ║\n" "Episodes per Task" "$N_EPISODES" - printf "║ %-20s : %-39s ║\n" "Device" "$DEVICE" - printf "║ %-20s : %-39s ║\n" "Visualization" "$VISUALIZATION" - printf "║ %-20s : %-39s ║\n" "Save Directory" "$(basename $SAVE_DIR)" - echo "╚══════════════════════════════════════════════════════════════╝" - echo "" -} - -# Main execution -main() { - # Validate configuration - validate_config - - # Print configuration - print_config - - # Ask for confirmation - read -p "Do you want to proceed with this configuration? [(y)/n]: " -n 1 -r - echo - if [[ $REPLY =~ ^[Nn]$ ]]; then - print_warning "Evaluation cancelled by user" - exit 0 - fi - - # Build command - CMD="python scripts/evaluate_policy.py" - CMD="$CMD --task_suite $TASK_SUITE" - CMD="$CMD --task_level $TASK_LEVEL" - CMD="$CMD --n-episode $N_EPISODES" - CMD="$CMD --policy $POLICY" - CMD="$CMD --save-dir $SAVE_DIR" - CMD="$CMD --metrics $METRICS" - - # Add model checkpoint if not random policy - if [[ "$POLICY" != "random" ]]; then - CMD="$CMD --model_ckpt $MODEL_CKPT" - fi - - # Add visualization flag if enabled - if [[ "$VISUALIZATION" == "true" ]]; then - CMD="$CMD --visualization" - fi - - # Create save directory - mkdir -p "$SAVE_DIR" - - # Save configuration to file - cat > "$SAVE_DIR/evaluation_config.txt" < 5 else ''}", - ) - - # Create output directory - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Create output file and copy data - with h5py.File(output_file, 'w') as f_out: - # Create data group - data_group_out = f_out.create_group('data') - - # Copy all attributes from data group - for key, value in data_group.attrs.items(): - data_group_out.attrs[key] = value - - # Copy selected demos - total_samples = 0 - for i, demo_name in enumerate(selected_demos): - # Create new demo group (renumbered) - new_demo_name = f'demo_{i}' - demo_group_out = data_group_out.create_group(new_demo_name) - - # Copy all data from demo group - demo_group_in = data_group[demo_name] - copy_hdf5_group(demo_group_in, demo_group_out) - - # Accumulate sample count - if 'num_samples' in demo_group_in.attrs: - total_samples += demo_group_in.attrs['num_samples'] - elif 'obs' in demo_group_in: - # If no num_samples attribute, try to infer from obs - obs_group = demo_group_in['obs'] - # Find any dataset to infer length - for key in obs_group.keys(): - if isinstance(obs_group[key], h5py.Dataset): - total_samples += len(obs_group[key]) - break - - # Update statistics - if 'num_demos' in data_group_out.attrs: - data_group_out.attrs['num_demos'] = num_samples - if 'total' in data_group_out.attrs: - data_group_out.attrs['total'] = total_samples - - print(f' Output file: {output_file}') - print(f' Retained demos: {num_samples}') - print(f' Total samples: {total_samples}') - - return True - - except Exception as e: - print(f'Error processing file {input_file}: {e}') - import traceback - - traceback.print_exc() - return False - - -def main(): - parser = argparse.ArgumentParser( - description='Randomly sample a certain proportion of data from HDF5 files and create new HDF5 files', - ) - parser.add_argument('--input-file', type=str, help='Input HDF5 file path') - parser.add_argument( - '--output-file', - type=str, - default=None, - help='Output HDF5 file path (default: add _sampled suffix to input filename)', - ) - parser.add_argument( - '--ratio', - type=float, - required=True, - help='Sampling ratio (0.0 - 1.0), e.g., 0.5 means sample 50%%', - ) - parser.add_argument( - '--seed', - type=int, - default=None, - help='Random seed for reproducibility', - ) - parser.add_argument( - '--input-dir', - type=str, - default=None, - help='Input directory, batch process all HDF5 files in the directory', - ) - parser.add_argument( - '--output-dir', - type=str, - default=None, - help='Output directory, used together with --input-dir', - ) - parser.add_argument( - '--pattern', - type=str, - default='*.hdf5', - help='Filename pattern (default: *.hdf5)', - ) - parser.add_argument( - '--not-recursive', - action='store_true', - help='Do not recursively search subdirectories', - ) - - args = parser.parse_args() - - # Validate sampling ratio - if args.ratio < 0.0 or args.ratio > 1.0: - print('Error: Sampling ratio must be between 0.0 and 1.0') - return - - # Batch processing mode - if args.input_dir: - if not args.output_dir: - print( - 'Error: --output-dir must be specified when using --input-dir' - ) - return - - input_dir = Path(args.input_dir) - output_dir = Path(args.output_dir) - - # Find all HDF5 files - if args.not_recursive: - demo_files = list(input_dir.glob(args.pattern)) - else: - demo_files = list(input_dir.rglob(args.pattern)) - - if not demo_files: - print( - f'No files matching {args.pattern} found in {args.input_dir}' - ) - return - - print(f'Found {len(demo_files)} files to process\n') - - success_count = 0 - for demo_file in demo_files: - # Generate output file path - relative_path = demo_file.relative_to(input_dir) - output_file = output_dir / relative_path - - # If output filename is same as input, add suffix - if output_file == demo_file: - output_file = ( - output_file.parent - / f'{output_file.stem}_sampled{output_file.suffix}' - ) - - output_file.parent.mkdir(parents=True, exist_ok=True) - - if sample_hdf5_file( - str(demo_file), str(output_file), args.ratio, args.seed - ): - success_count += 1 - print() - - print( - f'Processing complete: {success_count}/{len(demo_files)} files succeeded' - ) - - # Single file processing mode - else: - if not args.input_file: - print('Error: Must specify --input-file or --input-dir') - return - - # Determine output file path - if args.output_file: - output_file = args.output_file - else: - input_path = Path(args.input_file) - output_file = str( - input_path.parent - / f'{input_path.stem}_sampled{input_path.suffix}' - ) - - success = sample_hdf5_file( - args.input_file, output_file, args.ratio, args.seed - ) - if success: - print('\nProcessing complete!') - else: - print('\nProcessing failed!') - - -if __name__ == '__main__': - main() diff --git a/scripts/replace_prismatic_imports.py b/scripts/replace_prismatic_imports.py deleted file mode 100644 index 0aecb29a..00000000 --- a/scripts/replace_prismatic_imports.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 The VLA-Arena Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility to rewrite `prismatic.*` imports under an OpenVLA-aware namespace.""" - -from __future__ import annotations - -import argparse -import pathlib -import textwrap -from collections.abc import Iterable - - -OLD_PREFIX = 'prismatic.' -NEW_PREFIX = 'vla_arena.models.univla.prismatic.' - - -def find_files(base_dir: pathlib.Path) -> Iterable[pathlib.Path]: - for path in base_dir.rglob('*'): - if path.is_file(): - yield path - - -def rewrite_file(path: pathlib.Path, dry_run: bool) -> bool: - try: - data = path.read_text(encoding='utf-8') - except UnicodeDecodeError: - return False - - updated = data.replace(OLD_PREFIX, NEW_PREFIX) - if updated == data: - return False - - if dry_run: - print(f'[dry-run] would rewrite {path}') - return True - - path.write_text(updated, encoding='utf-8') - print(f'rewrote {path}') - return True - - -def main() -> None: - parser = argparse.ArgumentParser( - description=textwrap.dedent( - """ - Walks a directory tree and rewrites occurrences of `prismatic.` to - `vla_arena.models.openvla.prismatic.` so import statements stay correct. - """, - ), - ) - parser.add_argument('path', type=pathlib.Path, help='Folder to process') - parser.add_argument( - '--dry-run', - action='store_true', - help='Only print files that would be changed', - ) - args = parser.parse_args() - - processed = 0 - for file_path in find_files(args.path): - if rewrite_file(file_path, dry_run=args.dry_run): - processed += 1 - - print( - ( - f'{processed} files updated' - if not args.dry_run - else f'{processed} files would be updated' - ), - ) - - -if __name__ == '__main__': - main() diff --git a/vla_arena/models/openpi/src/openpi/training/config.py b/vla_arena/models/openpi/src/openpi/training/config.py index 2b7eda0d..e6267a28 100644 --- a/vla_arena/models/openpi/src/openpi/training/config.py +++ b/vla_arena/models/openpi/src/openpi/training/config.py @@ -923,6 +923,61 @@ def __post_init__(self) -> None: pytorch_weight_path='/path/to/your/pytorch_weight_path', num_train_steps=30_000, ), + TrainConfig( + name='pi05_vla_arena', + model=pi0_config.Pi0Config( + pi05=True, action_horizon=10, discrete_state_input=False + ), + data=LeRobotLiberoDataConfig( + repo_id='VLA_Arena_L0_L_lerobot_openpi/VLA_Arena', + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=256, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=0.999, + weight_loader=weight_loaders.CheckpointWeightLoader( + 'gs://openpi-assets/checkpoints/pi05_base/params' + ), + pytorch_weight_path='/path/to/your/pytorch_weight_path', + num_train_steps=30_000, + ), + TrainConfig( + name='pi05_vla_arena_low_mem_finetune', + model=pi0_config.Pi0Config( + pi05=True, + action_horizon=10, + discrete_state_input=False, + paligemma_variant='gemma_2b_lora', + action_expert_variant='gemma_300m_lora', + ), + data=LeRobotLiberoDataConfig( + repo_id='VLA_Arena_L0_L_lerobot_openpi/VLA_Arena', + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + weight_loader=weight_loaders.CheckpointWeightLoader( + os.getenv( + 'OPENPI_VLA_ARENA_CHECKPOINT_PATH', + 'gs://openpi-assets/checkpoints/pi05_base/params', + ) + ), + num_train_steps=60_000, + freeze_filter=pi0_config.Pi0Config( + pi05=True, + action_horizon=10, + discrete_state_input=False, + paligemma_variant='gemma_2b_lora', + action_expert_variant='gemma_300m_lora', + ).get_freeze_filter(), + ema_decay=None, + ), # # Fine-tuning Aloha configs. #