diff --git a/eval_math_harness.py b/eval_math_harness.py new file mode 100644 index 0000000..1b1a332 --- /dev/null +++ b/eval_math_harness.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +"""Evaluate Qwen/Qwen2.5-1.5B-Instruct on Hendrycks MATH using lm-evaluation-harness. + +Metrics Computed: +----------------- +1. Overall exact_match accuracy (primary) +2. Per-subject accuracy (7 subjects: algebra, counting_and_prob, + geometry, intermediate_algebra, num_theory, prealgebra, precalc) +3. Aggregated by difficulty level (1-5) via post-processing + +Best Practices Used: +-------------------- +1. minerva_math task - Better prompts with \boxed{} extraction (standard for MATH) +2. 4-shot prompting - Standard for MATH benchmark +3. Chain-of-thought - Enabled via task's native format +4. Greedy decoding - temperature=0 for reproducibility +5. max_gen_toks=1024 - Sufficient for reasoning chains +6. vLLM backend - High-throughput inference with tensor parallelism +7. BF16 precision - Optimal for modern GPUs +8. Tensor parallelism - Utilize all available GPUs + +Usage: +------ + python eval_math_harness.py + python eval_math_harness.py --model Qwen/Qwen2.5-7B-Instruct + python eval_math_harness.py --num-fewshot 0 # zero-shot + python eval_math_harness.py --tensor-parallel-size 8 # use 8 GPUs +""" + +import argparse +import json +import logging +import sys +from datetime import datetime +from pathlib import Path + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + + +def get_gpu_count() -> int: + """Get the number of available CUDA GPUs.""" + try: + import torch + + return torch.cuda.device_count() + except ImportError: + # Fallback to nvidia-smi + import subprocess + + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, + text=True, + ) + return len(result.stdout.strip().split("\n")) if result.returncode == 0 else 1 + + +def get_model_num_attention_heads(model_name: str) -> int: + """Get the number of attention heads for a model.""" + try: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + return getattr(config, "num_attention_heads", 32) + except Exception: + return 32 # Default fallback + + +def get_optimal_tensor_parallel_size(model_name: str, max_gpus: int) -> int: + """Calculate optimal tensor parallel size based on model architecture. + + Tensor parallelism requires num_attention_heads to be divisible by TP size. + Returns the largest valid TP size <= max_gpus. + """ + num_heads = get_model_num_attention_heads(model_name) + + # Find the largest divisor of num_heads that is <= max_gpus + for tp_size in range(min(max_gpus, num_heads), 0, -1): + if num_heads % tp_size == 0: + return tp_size + return 1 + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Evaluate model on Hendrycks MATH using lm-eval-harness with vLLM" + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-1.5B-Instruct", + help="Model name or path", + ) + parser.add_argument( + "--num-fewshot", + type=int, + default=4, + help="Number of few-shot examples (default: 4)", + ) + parser.add_argument( + "--batch-size", + type=str, + default="auto", + help="Batch size (default: auto)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("eval_results"), + help="Output directory for results", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use (default: cuda)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of samples per task (for debugging)", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=None, + help="Tensor parallel size for vLLM (default: auto-detect all GPUs)", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=4096, + help="Maximum model context length for vLLM (default: 4096)", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.9, + help="GPU memory utilization for vLLM (default: 0.9)", + ) + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf"], + default="vllm", + help="Inference backend: vllm (faster) or hf (HuggingFace)", + ) + return parser.parse_args() + + +def run_evaluation(args: argparse.Namespace) -> dict: + """Run lm-evaluation-harness on MATH dataset using vLLM or HuggingFace backend.""" + try: + from lm_eval import simple_evaluate + except ImportError: + logger.error("lm-eval not installed. Run: pip install lm-eval") + sys.exit(1) + + # Auto-detect tensor parallel size if not specified + available_gpus = get_gpu_count() + if args.tensor_parallel_size is None: + args.tensor_parallel_size = get_optimal_tensor_parallel_size(args.model, available_gpus) + + logger.info(f"Model: {args.model}") + logger.info(f"Backend: {args.backend}") + logger.info(f"Few-shot: {args.num_fewshot}") + logger.info(f"Batch size: {args.batch_size}") + logger.info(f"Available GPUs: {available_gpus}") + logger.info(f"Tensor parallel size: {args.tensor_parallel_size}") + + # MATH subtasks (all 7 subjects) + # Using hendrycks_math tasks with proper \boxed{} answer extraction + tasks = [ + "hendrycks_math_algebra", + "hendrycks_math_counting_and_prob", + "hendrycks_math_geometry", + "hendrycks_math_intermediate_algebra", + "hendrycks_math_num_theory", + "hendrycks_math_prealgebra", + "hendrycks_math_precalc", + ] + + logger.info(f"Tasks: {tasks}") + + if args.backend == "vllm": + # Use vLLM backend for maximum efficiency with tensor parallelism + try: + from lm_eval.models.vllm_causallms import VLLM + except ImportError: + logger.error("vLLM not installed. Run: pip install vllm") + sys.exit(1) + + logger.info(f"GPU memory utilization: {args.gpu_memory_utilization}") + logger.info(f"Max model length: {args.max_model_len}") + + # vLLM model configuration for maximum throughput + model = VLLM( + pretrained=args.model, + tensor_parallel_size=args.tensor_parallel_size, + dtype="bfloat16", + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + trust_remote_code=True, + # Enable prefix caching for faster few-shot evaluation + enable_prefix_caching=True, + ) + + # Run evaluation with vLLM + logger.info("Starting evaluation with vLLM backend...") + results = simple_evaluate( + model=model, + tasks=tasks, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + limit=args.limit, + # Greedy decoding for reproducibility + gen_kwargs="temperature=0,do_sample=False", + log_samples=True, + ) + else: + # Fall back to HuggingFace backend + from lm_eval.models.huggingface import HFLM + + model_kwargs = { + "pretrained": args.model, + "dtype": "bfloat16", + "device_map": "auto", + "trust_remote_code": True, + "attn_implementation": "flash_attention_2", + } + + try: + model = HFLM(**model_kwargs) + except Exception as e: + logger.warning(f"Flash attention failed ({e}), using default attention") + model_kwargs.pop("attn_implementation", None) + model = HFLM(**model_kwargs) + + logger.info("Starting evaluation with HuggingFace backend...") + results = simple_evaluate( + model=model, + tasks=tasks, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + device=args.device, + limit=args.limit, + gen_kwargs="temperature=0,do_sample=False", + log_samples=True, + ) + + return results + + +def print_results(results: dict, args: argparse.Namespace) -> None: + """Print formatted results.""" + print("\n" + "=" * 70) + print(f"MATH Evaluation Results - {args.model}") + print("=" * 70) + + # Extract per-subject results + subject_results = {} + + for task_name, task_results in results.get("results", {}).items(): + # Extract subject from task name and expand abbreviations + subject = task_name.replace("hendrycks_math_", "") + # Expand abbreviated names for readability + subject = subject.replace("counting_and_prob", "counting_and_probability") + subject = subject.replace("num_theory", "number_theory") + subject = subject.replace("precalc", "precalculus") + + # Get accuracy metric (exact_match or acc) + acc = task_results.get("exact_match,none", task_results.get("acc,none", 0)) + stderr = task_results.get("exact_match_stderr,none", task_results.get("acc_stderr,none", 0)) + + subject_results[subject] = { + "accuracy": acc, + "stderr": stderr, + } + + # Print per-subject results + print("\nPer-Subject Accuracy:") + print("-" * 50) + for subject, data in sorted(subject_results.items()): + acc_pct = data["accuracy"] * 100 + stderr_pct = data["stderr"] * 100 + print(f" {subject:35s} {acc_pct:5.2f}% ± {stderr_pct:.2f}%") + + # Print aggregate (use weighted average based on number of samples per subject) + accs = [d["accuracy"] for d in subject_results.values()] + overall_acc = sum(accs) / len(accs) * 100 if accs else 0 + + print("-" * 50) + print(f" {'OVERALL':35s} {overall_acc:5.2f}%") + print("=" * 70) + + +def save_results(results: dict, args: argparse.Namespace) -> Path: + """Save results to JSON file.""" + args.output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_name = args.model.replace("/", "_") + output_file = args.output_dir / f"math_{model_name}_{timestamp}.json" + + # Add metadata + results["metadata"] = { + "model": args.model, + "num_fewshot": args.num_fewshot, + "batch_size": args.batch_size, + "timestamp": timestamp, + } + + with open(output_file, "w") as f: + json.dump(results, f, indent=2, default=str) + + logger.info(f"Results saved to: {output_file}") + return output_file + + +def main() -> None: + """Main entry point.""" + args = parse_args() + + logger.info("=" * 50) + logger.info("Hendrycks MATH Evaluation") + logger.info("=" * 50) + + # Run evaluation + results = run_evaluation(args) + + # Print results + print_results(results, args) + + # Save results + save_results(results, args) + + logger.info("Evaluation complete!") + + +if __name__ == "__main__": + main() diff --git a/eval_pass_at_k.py b/eval_pass_at_k.py new file mode 100644 index 0000000..f898c64 --- /dev/null +++ b/eval_pass_at_k.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""Compute pass@1, pass@5, pass@10 on MATH using vLLM with multiple samples.""" + +import argparse +import json +import re +from datetime import datetime +from pathlib import Path + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm + + +def pass_at_k(n: int, c: int, k: int) -> float: + """Unbiased pass@k estimator from Codex paper. + + Args: + n: total number of samples + c: number of correct samples + k: k in pass@k + """ + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + +def extract_boxed_answer(text: str) -> str | None: + """Extract answer from \\boxed{...} format.""" + # Find the last \boxed{...} + matches = re.findall(r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}", text) + if matches: + return matches[-1].strip() + return None + + +def normalize_answer(answer: str) -> str: + """Normalize answer for comparison.""" + if answer is None: + return "" + # Remove whitespace and common LaTeX formatting + answer = answer.strip() + answer = answer.replace(" ", "") + answer = answer.replace("\\,", "") + answer = answer.replace("\\!", "") + return answer + + +def is_correct(pred: str, target: str) -> bool: + """Check if prediction matches target.""" + pred_norm = normalize_answer(extract_boxed_answer(pred) or pred) + target_norm = normalize_answer(extract_boxed_answer(target) or target) + return pred_norm == target_norm + + +def build_prompt(problem: str, few_shot_examples: list[dict] = None) -> str: + """Build prompt for MATH problem (zero-shot or few-shot).""" + prompt = "" + if few_shot_examples: + for ex in few_shot_examples: + prompt += f"Problem: {ex['problem']}\nSolution: {ex['solution']}\n\n" + prompt += f"Problem: {problem}\nSolution:" + return prompt + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") + parser.add_argument("--n-samples", type=int, default=5, help="Samples per problem (>= max k)") + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--max-tokens", type=int, default=2048) + parser.add_argument("--num-fewshot", type=int, default=0) + parser.add_argument("--limit", type=int, default=None, help="Limit problems (for testing)") + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--output-dir", type=Path, default=Path("eval_results")) + args = parser.parse_args() + + # Import vLLM + from vllm import LLM, SamplingParams + + print(f"Loading model: {args.model}") + llm = LLM( + model=args.model, + tensor_parallel_size=args.tensor_parallel_size, + dtype="bfloat16", + trust_remote_code=True, + gpu_memory_utilization=0.95, + max_model_len=4096, + max_num_seqs=512, # More concurrent sequences + enable_prefix_caching=True, + ) + + # Load MATH dataset (all 7 subjects) + print("Loading MATH dataset...") + subjects = [ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", + ] + + # Load all subjects + test_data = [] + for subject in subjects: + test_data.extend(load_dataset("EleutherAI/hendrycks_math", subject, split="test")) + + dataset = test_data[: args.limit] if args.limit else test_data + print(f"Loaded {len(dataset)} test problems ({args.num_fewshot}-shot)") + + # Load few-shot examples if needed + few_shot_examples = [] + if args.num_fewshot > 0: + train_data = load_dataset("EleutherAI/hendrycks_math", "algebra", split="train") + few_shot_examples = [train_data[i] for i in range(args.num_fewshot)] + + # Build prompts + prompts = [build_prompt(ex["problem"], few_shot_examples) for ex in dataset] + targets = [ex["solution"] for ex in dataset] + + # Sampling params for multiple samples + sampling_params = SamplingParams( + temperature=args.temperature, + max_tokens=args.max_tokens, + n=args.n_samples, # Generate n samples per prompt + ) + + print(f"Generating {args.n_samples} samples per problem for {len(prompts)} problems...") + outputs = llm.generate(prompts, sampling_params) + + # Evaluate + results = [] + for _idx, (output, target) in enumerate( + tqdm(zip(outputs, targets, strict=False), total=len(outputs)) + ): + # Check each sample + correct_count = sum( + 1 for completion in output.outputs if is_correct(completion.text, target) + ) + results.append( + { + "n_samples": args.n_samples, + "n_correct": correct_count, + } + ) + + # Compute pass@k for k=1,5 + k_values = [1, 5] + pass_at_k_results = {} + + for k in k_values: + if k <= args.n_samples: + scores = [pass_at_k(r["n_samples"], r["n_correct"], k) for r in results] + pass_at_k_results[f"pass@{k}"] = np.mean(scores) * 100 + + # Print results + print("\n" + "=" * 50) + print(f"MATH pass@k Results - {args.model}") + print(f"Temperature: {args.temperature}, Samples: {args.n_samples}") + print("=" * 50) + for k, score in pass_at_k_results.items(): + print(f" {k}: {score:.2f}%") + print("=" * 50) + + # Save results + args.output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = args.output_dir / f"pass_at_k_{timestamp}.json" + + with open(output_file, "w") as f: + json.dump( + { + "model": args.model, + "temperature": args.temperature, + "n_samples": args.n_samples, + "num_problems": len(results), + "results": pass_at_k_results, + "per_problem": results, + }, + f, + indent=2, + ) + + print(f"Results saved to: {output_file}") + + +if __name__ == "__main__": + main() diff --git a/grail/cli/__init__.py b/grail/cli/__init__.py index fe90648..0c91f5d 100644 --- a/grail/cli/__init__.py +++ b/grail/cli/__init__.py @@ -329,6 +329,7 @@ def _register_subcommands() -> None: "grail.cli.mine", "grail.cli.validate", "grail.cli.train", + "grail.cli.parallel_miner", ): module = importlib.import_module(mod_name) register: Callable[[typer.Typer], None] | None = getattr(module, "register", None) diff --git a/grail/cli/mine.py b/grail/cli/mine.py index c71e265..ddb9a01 100644 --- a/grail/cli/mine.py +++ b/grail/cli/mine.py @@ -431,7 +431,10 @@ def package_rollout_data( Returns: Signed dictionary ready to upload for validation """ - rollout_nonce = base_nonce * 10 + rollout_idx + # CRITICAL: Use ROLLOUTS_PER_PROBLEM (16) as multiplier to avoid nonce collisions + # Old formula (base_nonce * 10) caused duplicates when rollout_idx >= 10 + # e.g., problem 14 rollout 10 = 150, problem 15 rollout 0 = 150 (collision!) + rollout_nonce = base_nonce * ROLLOUTS_PER_PROBLEM + rollout_idx # Sign commit binding (tokens, randomness, model, layer, commitments) from ..protocol.signatures import sign_commit_binding @@ -539,6 +542,9 @@ async def generate_rollouts_for_window( monitor: Any | None, use_drand: bool, checkpoint_window: int, + *, + problem_offset: int = 0, + max_problems: int = 0, ) -> list[dict]: """Generate as many GRPO rollouts as safely possible within a window. @@ -559,11 +565,31 @@ async def generate_rollouts_for_window( timers: EMA-based timing estimates for safety. monitor: Optional monitoring client for metrics. use_drand: Whether drand was used in randomness generation. - checkpoint_window: The checkpoint window used for this generation + checkpoint_window: The checkpoint window used for this generation. + problem_offset: Starting problem index for this worker (default: 0). + Used in parallel mining to assign non-overlapping problem ranges. + max_problems: Maximum number of problems to generate (default: 0 = unlimited). + When 0, generates until time runs out. Used in parallel mining. Returns: List of signed rollout data ready for upload. """ + # Read problem offset/max from environment (worker mode support) + # Environment variables take precedence over function args for subprocess isolation + env_problem_offset = int(os.getenv("GRAIL_PROBLEM_OFFSET", str(problem_offset))) + env_max_problems = int(os.getenv("GRAIL_MAX_PROBLEMS", str(max_problems))) + + # Use env values if set, otherwise use function args + effective_offset = env_problem_offset + effective_max = env_max_problems + + if effective_offset > 0 or effective_max > 0: + logger.info( + "Worker mode: problem_offset=%d, max_problems=%s", + effective_offset, + effective_max if effective_max > 0 else "unlimited", + ) + # Window generation state and metrics inferences: list[dict] = [] start_time = time.time() @@ -598,6 +624,14 @@ async def generate_rollouts_for_window( logger.info("Window %s has ended, moving to next window", window_start) break + # Check max_problems limit (for worker mode) + if effective_max > 0 and problem_count >= effective_max: + logger.info( + "Stopping generation: reached max_problems limit (%d)", + effective_max, + ) + break + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block needed_blocks = timers.blocks_needed_for_next_gen() if blocks_remaining <= needed_blocks: @@ -619,9 +653,13 @@ async def generate_rollouts_for_window( problem_count += 1 inference_count += 1 + # Apply problem offset for parallel mining coordination + # Each GPU worker gets a unique range: GPU0=[0-11], GPU1=[12-23], etc. + problem_index = effective_offset + (problem_count - 1) + logger.info( "⚡ Generating GRPO rollouts for problem %s (block %s/%s)...", - problem_count, + problem_index, current_block, window_start + WINDOW_LENGTH - 1, ) @@ -635,7 +673,6 @@ async def generate_rollouts_for_window( ) # Deterministically derive environment seed from miner+window+index - problem_index = max(0, problem_count - 1) seed_int = derive_env_seed(wallet.hotkey.ss58_address, window_block_hash, problem_index) # Use deterministic problem index as rollout_group identifier base_nonce = problem_index diff --git a/grail/cli/multi_miner_aggregator.py b/grail/cli/multi_miner_aggregator.py new file mode 100644 index 0000000..c998af3 --- /dev/null +++ b/grail/cli/multi_miner_aggregator.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +""" +Multi-Miner Aggregator for GRAIL + +Coordinates multiple miners running on the same machine and aggregates +their results into a single window upload to R2. + +Usage: + python -m grail.cli.multi_miner_aggregator \ + --hotkeys miner_1 miner_2 miner_3 miner_4 \ + --aggregation-hotkey aggregator_hotkey \ + --mode watch # or 'batch' +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import bittensor as bt +import typer + +from ..infrastructure.comms import ( + upload_file_chunked, +) +from ..infrastructure.credentials import load_r2_credentials +from ..shared.constants import WINDOW_LENGTH +from . import console + +logger = logging.getLogger("grail.aggregator") + + +# --------------------------------------------------------------------------- # +# Configuration & State # +# --------------------------------------------------------------------------- # + + +@dataclass +class AggregatorConfig: + """Configuration for multi-miner aggregation.""" + + hotkeys: list[str] + aggregation_hotkey: str + cold_wallet: str = "default" + results_dir: Path = Path("/tmp/grail_miner_results") + poll_interval: float = 5.0 # seconds between polls + window_timeout: float = 300.0 # seconds to wait for all miners per window + credentials: Any | None = None + + +class WindowAggregator: + """Aggregates results from multiple miners for a specific window.""" + + def __init__(self, window_start: int, config: AggregatorConfig): + self.window_start = window_start + self.config = config + self.results: dict[str, list[dict]] = {hotkey: [] for hotkey in config.hotkeys} + self.collected_hotkeys: set[str] = set() + self.start_time = time.time() + + async def collect_results(self) -> dict[str, list[dict]]: + """Poll for results from all miners until timeout or all collected.""" + logger.info( + f"🔍 Collecting results for window {self.window_start} " + f"from {len(self.config.hotkeys)} miners..." + ) + + while time.time() - self.start_time < self.config.window_timeout: + # Check each miner's result directory + for hotkey in self.config.hotkeys: + if hotkey in self.collected_hotkeys: + continue # Already collected + + result_file = self._get_result_path(hotkey) + if result_file.exists(): + try: + inferences = await self._load_and_parse(result_file, hotkey) + self.results[hotkey] = inferences + self.collected_hotkeys.add(hotkey) + logger.info(f" ✓ {hotkey}: {len(inferences)} inferences collected") + except Exception as e: + logger.warning(f" ✗ {hotkey}: Failed to load results - {e}") + + # Check if we have all results + if len(self.collected_hotkeys) == len(self.config.hotkeys): + logger.info( + f"✅ All {len(self.config.hotkeys)} miners reported for window " + f"{self.window_start}" + ) + break + + # Log progress + elapsed = time.time() - self.start_time + remaining = self.config.window_timeout - elapsed + pending = len(self.config.hotkeys) - len(self.collected_hotkeys) + if pending > 0: + logger.debug(f" ⏳ Waiting for {pending} miners ({remaining:.0f}s remaining)...") + + await asyncio.sleep(self.config.poll_interval) + + # Log final status + if len(self.collected_hotkeys) < len(self.config.hotkeys): + missing = set(self.config.hotkeys) - self.collected_hotkeys + logger.warning( + f"⚠️ Timeout: Missing results from {missing}. " + f"Uploading partial results ({len(self.collected_hotkeys)}/{len(self.config.hotkeys)})" + ) + + return self.results + + async def aggregate_and_upload(self, wallet: bt.wallet) -> bool: + """Aggregate all collected results and upload to R2.""" + # Flatten all inferences + all_inferences: list[dict] = [] + for inferences in self.results.values(): + all_inferences.extend(inferences) + + if not all_inferences: + logger.warning(f"No inferences to upload for window {self.window_start}") + return False + + # Create window data with aggregation metadata + window_data = { + "wallet": wallet.hotkey.ss58_address, + "window_start": self.window_start, + "window_length": WINDOW_LENGTH, + "inference_count": len(all_inferences), + "inferences": all_inferences, + "timestamp": time.time(), + "aggregated": True, + "miner_count": len(self.collected_hotkeys), + "miner_hotkeys": list(self.collected_hotkeys), + "collection_time_seconds": time.time() - self.start_time, + } + + # Upload to R2 + key = ( + f"grail/windows/aggregated/{wallet.hotkey.ss58_address}-window-{self.window_start}.json" + ) + body = json.dumps(window_data).encode() + + logger.info( + f"📤 Uploading aggregated window {self.window_start} " + f"({len(all_inferences)} inferences from {len(self.collected_hotkeys)} miners)..." + ) + + success = await upload_file_chunked( + key, + body, + credentials=self.config.credentials, + use_write=True, + ) + + if success: + logger.info(f"✅ Successfully uploaded aggregated window {self.window_start} to R2") + # Clean up local result files + await self._cleanup_results() + else: + logger.error(f"❌ Failed to upload aggregated window {self.window_start}") + + return success + + def _get_result_path(self, hotkey: str) -> Path: + """Get path where miner should write results.""" + return self.config.results_dir / f"{hotkey}-window-{self.window_start}.json" + + async def _load_and_parse(self, result_file: Path, hotkey: str) -> list[dict]: + """Load and parse inferences from result file.""" + try: + with open(result_file) as f: + data = json.load(f) + inferences = data.get("inferences", []) + if not isinstance(inferences, list): + raise ValueError(f"Expected list of inferences, got {type(inferences)}") + return inferences + except Exception as e: + logger.debug(f"Failed to parse {result_file}: {e}") + raise + + async def _cleanup_results(self) -> None: + """Remove processed result files.""" + for hotkey in self.collected_hotkeys: + result_file = self._get_result_path(hotkey) + try: + if result_file.exists(): + result_file.unlink() + logger.debug(f"Cleaned up {result_file}") + except Exception as e: + logger.warning(f"Failed to cleanup {result_file}: {e}") + + +class MultiMinerAggregatorService: + """Main service for coordinating multi-miner aggregation.""" + + def __init__(self, config: AggregatorConfig): + self.config = config + self.config.results_dir.mkdir(parents=True, exist_ok=True) + self.stop_event = asyncio.Event() + + async def watch_and_aggregate(self) -> None: + """Watch for window completions and aggregate results.""" + logger.info(f"🚀 Starting multi-miner aggregator for {len(self.config.hotkeys)} miners") + logger.info(f" Miners: {', '.join(self.config.hotkeys)}") + logger.info(f" Results directory: {self.config.results_dir}") + logger.info(f" Poll interval: {self.config.poll_interval}s") + logger.info(f" Window timeout: {self.config.window_timeout}s") + + wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey) + last_window = -1 + + try: + while not self.stop_event.is_set(): + # Get current window + subtensor = bt.subtensor() + current_block = await asyncio.to_thread(subtensor.get_current_block) + current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH + + # New window detected + if current_window > last_window: + logger.info(f"📍 New window detected: {current_window} (block {current_block})") + last_window = current_window + + # Process previous window if we have results + if current_window > WINDOW_LENGTH: + prev_window = current_window - WINDOW_LENGTH + await self._process_window(wallet, prev_window) + + await asyncio.sleep(self.config.poll_interval) + + except KeyboardInterrupt: + logger.info("Stopping aggregator...") + except Exception as e: + logger.error(f"Error in aggregator: {e}", exc_info=True) + raise + + async def _process_window(self, wallet: bt.wallet, window_start: int) -> None: + """Process and upload a specific window.""" + aggregator = WindowAggregator(window_start, self.config) + results = await aggregator.collect_results() + + # Check if we have any results + total_inferences = sum(len(inf) for inf in results.values()) + if total_inferences == 0: + logger.info(f"⊘ No results for window {window_start}, skipping") + return + + # Upload aggregated results + await aggregator.aggregate_and_upload(wallet) + + async def batch_process_window(self, window_start: int) -> bool: + """Process a single window in batch mode.""" + wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey) + await self._process_window(wallet, window_start) + return True + + +# --------------------------------------------------------------------------- # +# CLI Interface # +# --------------------------------------------------------------------------- # + + +def register(app: typer.Typer) -> None: + """Register aggregator command with CLI.""" + app.command("aggregate")(aggregate) + + +def aggregate( + hotkeys: list[str] = typer.Option( + ..., + "--hotkey", + help="Miner hotkeys to aggregate (can specify multiple times)", + ), + aggregation_hotkey: str = typer.Option( + ..., + "--aggregation-hotkey", + help="Hotkey to use for uploading aggregated results", + ), + cold_wallet: str = typer.Option( + "default", + "--cold-wallet", + help="Cold wallet name", + ), + results_dir: str = typer.Option( + "/tmp/grail_miner_results", + "--results-dir", + help="Directory where miners write results", + ), + poll_interval: float = typer.Option( + 5.0, + "--poll-interval", + help="Seconds between polls for new results", + ), + window_timeout: float = typer.Option( + 300.0, + "--window-timeout", + help="Seconds to wait for all miners per window", + ), + mode: str = typer.Option( + "watch", + "--mode", + help="'watch' for continuous monitoring or 'batch' for single window", + ), + window: int | None = typer.Option( + None, + "--window", + help="Window to process (required for batch mode)", + ), +) -> None: + """Aggregate results from multiple miners and upload to R2. + + Example: + python -m grail.cli.multi_miner_aggregator \ + --hotkey miner_1 --hotkey miner_2 --hotkey miner_3 \ + --aggregation-hotkey aggregator \ + --mode watch + + python -m grail.cli.multi_miner_aggregator \ + --hotkey miner_1 --hotkey miner_2 \ + --aggregation-hotkey aggregator \ + --mode batch --window 12345 + """ + try: + # Validate inputs + if not hotkeys: + console.print("[red]Error: At least one --hotkey must be specified[/red]") + raise typer.Exit(code=1) + + if mode not in ("watch", "batch"): + console.print(f"[red]Error: mode must be 'watch' or 'batch', got {mode}[/red]") + raise typer.Exit(code=1) + + if mode == "batch" and window is None: + console.print("[red]Error: --window required for batch mode[/red]") + raise typer.Exit(code=1) + + # Load credentials + try: + credentials = load_r2_credentials() + except Exception as e: + console.print(f"[red]Failed to load R2 credentials: {e}[/red]") + raise typer.Exit(code=1) from None + + # Create config + config = AggregatorConfig( + hotkeys=hotkeys, + aggregation_hotkey=aggregation_hotkey, + cold_wallet=cold_wallet, + results_dir=Path(results_dir), + poll_interval=poll_interval, + window_timeout=window_timeout, + credentials=credentials, + ) + + # Run aggregator + service = MultiMinerAggregatorService(config) + + if mode == "watch": + console.print("[bold green]Starting multi-miner aggregator in watch mode[/bold green]") + asyncio.run(service.watch_and_aggregate()) + else: # batch + console.print(f"[bold green]Processing window {window}[/bold green]") + asyncio.run(service.batch_process_window(window)) + + except KeyboardInterrupt: + console.print("[yellow]Aggregator stopped by user[/yellow]") + raise typer.Exit(code=0) from None + except Exception as e: + logger.error(f"Fatal error: {e}", exc_info=True) + console.print(f"[red]Fatal error: {e}[/red]") + raise typer.Exit(code=1) from None + + +# --------------------------------------------------------------------------- # +# Main Entry Point # +# --------------------------------------------------------------------------- # + + +def main() -> None: + """Main entry point for aggregator CLI.""" + + app = typer.Typer() + register(app) + app() + + +if __name__ == "__main__": + main() diff --git a/grail/cli/multi_miner_config.py b/grail/cli/multi_miner_config.py new file mode 100644 index 0000000..0ead808 --- /dev/null +++ b/grail/cli/multi_miner_config.py @@ -0,0 +1,264 @@ +""" +Multi-Miner Configuration and Helper Utilities + +Provides common configurations and helper functions for running multiple +miners on the same machine with window-based result aggregation. +""" + +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class MinerConfig: + """Configuration for a single miner instance.""" + + hotkey: str + gpu_index: int | None = None + batch_size: int = 2 + safety_blocks: int = 3 + use_drand: bool = True + extra_env: dict[str, str] | None = None + + +@dataclass +class MultiMinerSetup: + """Complete setup for multiple miners.""" + + miners: list[MinerConfig] + cold_wallet: str = "default" + use_aggregator: bool = True + aggregator_hotkey: str | None = None + results_directory: Path = Path("/tmp/grail_miner_results") + poll_interval_seconds: float = 5.0 + window_timeout_seconds: float = 300.0 + + def __post_init__(self) -> None: + if not self.miners: + raise ValueError("At least one miner config is required") + if self.use_aggregator and not self.aggregator_hotkey: + import time + + self.aggregator_hotkey = f"aggregator_{int(time.time())}" + + +class MultiMinerBuilder: + """Builder for creating multi-miner configurations.""" + + @staticmethod + def from_hotkeys( + hotkeys: list[str], + gpus: list[int] | None = None, + batch_size: int = 2, + use_aggregator: bool = True, + ) -> MultiMinerSetup: + """Create multi-miner setup from list of hotkeys and optional GPU assignments. + + Args: + hotkeys: List of miner hotkeys + gpus: Optional list of GPU indices (cycles if fewer than hotkeys) + batch_size: Generation batch size per miner + use_aggregator: Whether to enable result aggregation + + Returns: + MultiMinerSetup ready to launch + """ + if not hotkeys: + raise ValueError("At least one hotkey required") + + # Build miner configs + miners = [] + for i, hotkey in enumerate(hotkeys): + gpu = gpus[i % len(gpus)] if gpus else None + miners.append( + MinerConfig( + hotkey=hotkey, + gpu_index=gpu, + batch_size=batch_size, + ) + ) + + return MultiMinerSetup( + miners=miners, + use_aggregator=use_aggregator, + ) + + @staticmethod + def from_environment() -> MultiMinerSetup: + """Create multi-miner setup from environment variables. + + Environment variables: + GRAIL_MINERS: Comma-separated hotkey list (e.g., "miner_1,miner_2,miner_3") + GRAIL_GPUS: Comma-separated GPU indices (optional, e.g., "0,1,2") + GRAIL_BATCH_SIZE: Generation batch size (default: 2) + GRAIL_USE_AGGREGATOR: "true" or "false" (default: true) + GRAIL_AGGREGATOR_HOTKEY: Aggregator identity (auto-generated if not set) + GRAIL_RESULTS_DIR: Results directory (default: /tmp/grail_miner_results) + + Returns: + MultiMinerSetup from environment configuration + """ + # Parse miners + miners_str = os.getenv("GRAIL_MINERS", "miner_1") + hotkeys = [h.strip() for h in miners_str.split(",") if h.strip()] + + if not hotkeys: + raise ValueError("GRAIL_MINERS environment variable is empty") + + # Parse GPUs (optional) + gpus_str = os.getenv("GRAIL_GPUS", "") + gpus = None + if gpus_str: + gpus = [int(g.strip()) for g in gpus_str.split(",") if g.strip()] + + # Other settings + batch_size = int(os.getenv("GRAIL_BATCH_SIZE", "2")) + use_aggregator = os.getenv("GRAIL_USE_AGGREGATOR", "true").lower() in ( + "true", + "1", + "yes", + ) + aggregator_hotkey = os.getenv("GRAIL_AGGREGATOR_HOTKEY", None) + results_dir = Path(os.getenv("GRAIL_RESULTS_DIR", "/tmp/grail_miner_results")) + + setup = MultiMinerBuilder.from_hotkeys( + hotkeys=hotkeys, + gpus=gpus, + batch_size=batch_size, + use_aggregator=use_aggregator, + ) + + if aggregator_hotkey: + setup.aggregator_hotkey = aggregator_hotkey + + setup.results_directory = results_dir + + return setup + + +class MinerLauncher: + """Helper for launching miner processes with proper environment.""" + + @staticmethod + def get_env_for_miner(config: MinerConfig, cold_wallet: str = "default") -> dict[str, str]: + """Get environment variables for a miner process. + + Args: + config: MinerConfig for this miner + cold_wallet: Cold wallet name + + Returns: + Dictionary of environment variables to set + """ + env = os.environ.copy() + + # Set wallet + env["BT_WALLET_COLD"] = cold_wallet + env["BT_WALLET_HOT"] = config.hotkey + + # Set GPU if specified + if config.gpu_index is not None: + env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_index) + else: + # Remove GPU constraint if not specified + env.pop("CUDA_VISIBLE_DEVICES", None) + + # Set generation parameters + env["GRAIL_GENERATION_BATCH_SIZE"] = str(config.batch_size) + env["GRAIL_MINER_SAFETY_BLOCKS"] = str(config.safety_blocks) + + # Add any extra environment variables + if config.extra_env: + env.update(config.extra_env) + + return env + + @staticmethod + def get_command_for_miner(config: MinerConfig) -> list[str]: + """Get command to launch a miner. + + Args: + config: MinerConfig for this miner + + Returns: + Command as list of strings (suitable for subprocess) + """ + return [ + "python", + "-m", + "grail.cli.mine", + "--use-drand" if config.use_drand else "--no-drand", + ] + + +class AggregatorLauncher: + """Helper for launching aggregator with proper arguments.""" + + @staticmethod + def get_command_for_aggregator(setup: MultiMinerSetup, mode: str = "watch") -> list[str]: + """Get command to launch aggregator. + + Args: + setup: MultiMinerSetup configuration + mode: "watch" or "batch" + + Returns: + Command as list of strings + """ + if not setup.aggregator_hotkey: + raise ValueError("aggregator_hotkey not set") + + cmd = [ + "python", + "-m", + "grail.cli.multi_miner_aggregator", + ] + + # Add miner hotkeys + for miner in setup.miners: + cmd.extend(["--hotkey", miner.hotkey]) + + # Add aggregator settings + cmd.extend( + [ + "--aggregation-hotkey", + setup.aggregator_hotkey, + "--cold-wallet", + setup.cold_wallet, + "--results-dir", + str(setup.results_directory), + "--poll-interval", + str(setup.poll_interval_seconds), + "--window-timeout", + str(setup.window_timeout_seconds), + "--mode", + mode, + ] + ) + + return cmd + + +def print_setup_summary(setup: MultiMinerSetup) -> None: + """Pretty-print the multi-miner setup configuration.""" + print("\n" + "=" * 60) + print("Multi-Miner Setup Configuration") + print("=" * 60) + + print(f"\n📊 Miners: {len(setup.miners)}") + for i, miner in enumerate(setup.miners, 1): + gpu_info = f"GPU {miner.gpu_index}" if miner.gpu_index is not None else "Any GPU" + print(f" {i}. {miner.hotkey:20s} [{gpu_info}] batch_size={miner.batch_size}") + + print(f"\n💼 Wallet: {setup.cold_wallet}") + + if setup.use_aggregator: + print(f"\n🔄 Aggregator: {setup.aggregator_hotkey}") + print(f" Poll interval: {setup.poll_interval_seconds}s") + print(f" Window timeout: {setup.window_timeout_seconds}s") + else: + print("\n🔄 Aggregator: Disabled") + + print(f"\n📁 Results directory: {setup.results_directory}") + print("\n" + "=" * 60 + "\n") diff --git a/grail/cli/parallel_miner.py b/grail/cli/parallel_miner.py new file mode 100644 index 0000000..e424d05 --- /dev/null +++ b/grail/cli/parallel_miner.py @@ -0,0 +1,908 @@ +#!/usr/bin/env python3 +""" +Parallel Multi-GPU Miner for GRAIL + +Coordinates multiple GPU workers to generate rollouts in parallel, with each GPU +handling a distinct range of problem IDs. All results are gathered before a +single upload to maximize throughput while maintaining submission integrity. + +Architecture: + ┌─────────────────────────────────────────────────────────────┐ + │ Coordinator Process │ + │ - Assigns problem ranges: GPU0=[0-11], GPU1=[12-23], ... │ + │ - Spawns N worker processes │ + │ - Gathers results via temp files │ + │ - Single sink_window_inferences() call │ + └──────────────────────────┬──────────────────────────────────┘ + │ + ┌─────────┬─────────┬─┴─────────┬─────────┐ + ▼ ▼ ▼ ▼ ▼ + ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ + │GPU 0 │ │GPU 1 │ │GPU 2 │ ... │GPU N │ │GPU N │ + │P:0-11│ │P:12-23│ │P:24-35│ │ │ │ │ + └──────┘ └──────┘ └──────┘ └──────┘ └──────┘ + +Usage: + python -m grail.cli.parallel_miner --num-gpus 8 --problems-per-gpu 12 +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import multiprocessing as mp +import os +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from queue import Empty +from typing import Any + +import bittensor as bt +import torch +import typer + +from ..infrastructure.credentials import load_r2_credentials +from ..shared.constants import WINDOW_LENGTH +from . import console + +logger = logging.getLogger("grail.parallel_miner") + + +# --------------------------------------------------------------------------- # +# Configuration # +# --------------------------------------------------------------------------- # + + +@dataclass +class GPUWorkerConfig: + """Configuration for a single GPU worker process.""" + + gpu_id: int + problem_offset: int + max_problems: int + results_dir: Path + window_start: int + window_block_hash: str + combined_randomness: str + use_drand: bool + checkpoint_path: str | None + # Wallet names read from environment in worker for subprocess isolation + batch_size: int = 16 # Match single miner's default for optimal performance + safety_blocks: int = 3 + + +@dataclass +class ParallelMinerConfig: + """Configuration for parallel multi-GPU mining.""" + + num_gpus: int = 8 + problems_per_gpu: int = 12 + batch_size: int = 16 # Match single miner's default for optimal performance + safety_blocks: int = 3 + use_drand: bool = True + results_dir: Path = field( + default_factory=lambda: Path(tempfile.mkdtemp(prefix="grail_parallel_")) + ) + worker_timeout: float = 600.0 # 10 minutes max per window + gpu_ids: list[int] | None = None # Specific GPU IDs to use, None = [0, 1, ..., num_gpus-1] + + def get_gpu_ids(self) -> list[int]: + """Return list of GPU IDs to use.""" + if self.gpu_ids is not None: + return self.gpu_ids + return list(range(self.num_gpus)) + + +# --------------------------------------------------------------------------- # +# GPU Worker Process # +# --------------------------------------------------------------------------- # + + +def _gpu_worker_main( + config: GPUWorkerConfig, + result_queue: mp.Queue, +) -> None: + """Main function for GPU worker process. + + This runs in a separate process with CUDA_VISIBLE_DEVICES set to the + assigned GPU. It generates rollouts for a specific problem range and + writes results to a temp file. + + Args: + config: Worker configuration with GPU assignment and problem range + result_queue: Queue to signal completion status back to coordinator + """ + worker_id = f"GPU-{config.gpu_id}" + start_time = time.time() + + # Configure logging for worker process + import logging + + logging.basicConfig( + level=logging.INFO, + format=f"%(asctime)s [{worker_id}] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + ) + worker_logger = logging.getLogger(f"grail.worker.{config.gpu_id}") + + try: + # Set GPU visibility BEFORE any CUDA operations + os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id) + + # Import heavy modules after setting CUDA_VISIBLE_DEVICES + from ..cli.mine import ( + MiningTimers, + package_rollout_data, + ) + from ..environments.factory import create_env + from ..environments.loop import AgentEnvLoop + from ..grail import derive_env_seed + from ..model.provider import get_model, get_tokenizer + from ..shared.constants import ROLLOUTS_PER_PROBLEM + + worker_logger.info( + "Starting worker: problems %d-%d on GPU %d", + config.problem_offset, + config.problem_offset + config.max_problems - 1, + config.gpu_id, + ) + + # Load wallet from environment (same env as coordinator) + coldkey = os.getenv("BT_WALLET_COLD", "default") + hotkey = os.getenv("BT_WALLET_HOT", "default") + wallet = bt.wallet(name=coldkey, hotkey=hotkey) + + # Load model and tokenizer + if config.checkpoint_path: + model = get_model(config.checkpoint_path, device="cuda", eval_mode=True) + tokenizer = get_tokenizer(config.checkpoint_path) + else: + raise RuntimeError("checkpoint_path is required for parallel mining") + + device = model.device + loop = AgentEnvLoop(model, tokenizer, str(device)) + + # Generate rollouts for assigned problem range + inferences: list[dict] = [] + timers = MiningTimers() + + for local_idx in range(config.max_problems): + problem_index = config.problem_offset + local_idx + gen_start = time.time() + + # Derive deterministic seed for this problem + seed_int = derive_env_seed( + wallet.hotkey.ss58_address, + config.window_block_hash, + problem_index, + ) + + worker_logger.debug( + "Generating problem %d (seed=%d)", + problem_index, + seed_int, + ) + + # Generate GRPO rollouts + def _env_factory(): + return create_env() + + grpo_rollouts = loop.run_grpo_group( + _env_factory, + ROLLOUTS_PER_PROBLEM, + config.combined_randomness, + wallet, + batch_size=config.batch_size, + seed=seed_int, + ) + + # Package rollouts with signatures + base_nonce = problem_index + for rollout_idx, rollout in enumerate(grpo_rollouts): + rollout_data = package_rollout_data( + model, + wallet, + rollout, + base_nonce, + rollout_idx, + len(grpo_rollouts), + config.window_start, + config.window_start, # current_block = window_start for parallel + config.window_block_hash, + config.combined_randomness, + config.use_drand, + ) + inferences.append(rollout_data) + + gen_duration = time.time() - gen_start + timers.update_gen_time_ema(gen_duration) + + worker_logger.info( + "Problem %d: %d rollouts in %.2fs", + problem_index, + len(grpo_rollouts), + gen_duration, + ) + + # Write results to temp file + result_file = config.results_dir / f"gpu_{config.gpu_id}_results.json" + result_data = { + "gpu_id": config.gpu_id, + "problem_offset": config.problem_offset, + "max_problems": config.max_problems, + "inference_count": len(inferences), + "inferences": inferences, + "duration_seconds": time.time() - start_time, + } + + with open(result_file, "w") as f: + json.dump(result_data, f) + + worker_logger.info( + "Completed: %d rollouts from %d problems in %.2fs", + len(inferences), + config.max_problems, + time.time() - start_time, + ) + + # Signal success + result_queue.put( + { + "gpu_id": config.gpu_id, + "status": "success", + "inference_count": len(inferences), + "result_file": str(result_file), + "duration": time.time() - start_time, + } + ) + + except Exception as e: + worker_logger.exception("Worker failed: %s", e) + result_queue.put( + { + "gpu_id": config.gpu_id, + "status": "error", + "error": str(e), + "duration": time.time() - start_time, + } + ) + + +# --------------------------------------------------------------------------- # +# Parallel Mining Coordinator # +# --------------------------------------------------------------------------- # + + +class ParallelMiningCoordinator: + """Coordinates parallel rollout generation across multiple GPUs. + + Responsibilities: + - Spawn GPU worker processes with non-overlapping problem ranges + - Monitor worker progress and handle failures + - Gather all results and perform single aggregated upload + - Clean up temp files after successful upload + """ + + def __init__( + self, + config: ParallelMinerConfig, + wallet: bt.wallet, + credentials: Any, + ) -> None: + self.config = config + self.wallet = wallet + self.credentials = credentials + self._workers: list[mp.Process] = [] + # Use spawn context for CUDA-safe queue + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass # Already set + self._ctx = mp.get_context("spawn") + self._result_queue: mp.Queue = self._ctx.Queue() + self._shutdown_requested = False + + async def mine_window( + self, + window_start: int, + window_block_hash: str, + combined_randomness: str, + checkpoint_path: str | None, + ) -> list[dict]: + """Generate rollouts for a window using all GPUs in parallel. + + Args: + window_start: Start block of the mining window + window_block_hash: Block hash at window start + combined_randomness: Combined randomness for proof generation + checkpoint_path: Path to model checkpoint + + Returns: + Combined list of all rollout inferences from all GPUs + """ + gpu_ids = self.config.get_gpu_ids() + total_problems = self.config.problems_per_gpu * len(gpu_ids) + + logger.info( + "🚀 Starting parallel mining: %d GPUs × %d problems = %d total problems", + len(gpu_ids), + self.config.problems_per_gpu, + total_problems, + ) + + # Ensure results directory exists + self.config.results_dir.mkdir(parents=True, exist_ok=True) + + # Create worker configs with non-overlapping problem ranges + worker_configs: list[GPUWorkerConfig] = [] + for idx, gpu_id in enumerate(gpu_ids): + problem_offset = idx * self.config.problems_per_gpu + worker_config = GPUWorkerConfig( + gpu_id=gpu_id, + problem_offset=problem_offset, + max_problems=self.config.problems_per_gpu, + results_dir=self.config.results_dir, + window_start=window_start, + window_block_hash=window_block_hash, + combined_randomness=combined_randomness, + use_drand=self.config.use_drand, + checkpoint_path=checkpoint_path, + batch_size=self.config.batch_size, + safety_blocks=self.config.safety_blocks, + ) + worker_configs.append(worker_config) + + # Spawn worker processes using 'spawn' method for CUDA compatibility + # This ensures each worker gets a fresh CUDA context without conflicts + start_time = time.time() + self._workers = [] + + for worker_config in worker_configs: + # Use spawn context to avoid CUDA context issues + proc = self._ctx.Process( + target=_gpu_worker_main, + args=(worker_config, self._result_queue), + daemon=True, + ) + proc.start() + self._workers.append(proc) + logger.info( + " Started worker PID %d for GPU %d (problems %d-%d)", + proc.pid, + worker_config.gpu_id, + worker_config.problem_offset, + worker_config.problem_offset + worker_config.max_problems - 1, + ) + + # Wait for all workers to complete + results = await self._wait_for_workers(len(gpu_ids)) + + # Gather and combine results - ALL workers must succeed + all_inferences, all_succeeded = await self._gather_results(results, len(gpu_ids)) + + elapsed = time.time() - start_time + + if not all_succeeded: + logger.error("❌ Parallel mining FAILED: Not all GPUs completed successfully") + logger.error( + "Returning empty results to prevent partial upload that would fail validation" + ) + return [] # Return empty to prevent upload + + # Verify expected rollout count + from ..shared.constants import ROLLOUTS_PER_PROBLEM + + expected_rollouts = len(gpu_ids) * self.config.problems_per_gpu * ROLLOUTS_PER_PROBLEM + if len(all_inferences) != expected_rollouts: + logger.error( + "❌ Rollout count mismatch: got %d, expected %d (%d GPUs × %d problems × %d rollouts)", + len(all_inferences), + expected_rollouts, + len(gpu_ids), + self.config.problems_per_gpu, + ROLLOUTS_PER_PROBLEM, + ) + logger.error("Returning empty results to prevent validation failure") + return [] + + logger.info( + "✅ Parallel mining complete: %d rollouts in %.2fs (%.1f rollouts/sec)", + len(all_inferences), + elapsed, + len(all_inferences) / elapsed if elapsed > 0 else 0, + ) + + return all_inferences + + async def _wait_for_workers(self, expected_count: int) -> list[dict]: + """Wait for all worker processes to complete. + + Args: + expected_count: Number of workers expected to complete + + Returns: + List of result dictionaries from each worker + """ + results: list[dict] = [] + deadline = time.time() + self.config.worker_timeout + + while len(results) < expected_count and time.time() < deadline: + try: + # Non-blocking check with timeout + result = await asyncio.to_thread( + self._result_queue.get, + timeout=5.0, + ) + results.append(result) + + if result["status"] == "success": + logger.info( + " GPU %d completed: %d rollouts in %.2fs", + result["gpu_id"], + result["inference_count"], + result["duration"], + ) + else: + logger.error( + " GPU %d failed: %s", + result["gpu_id"], + result.get("error", "unknown error"), + ) + + except Empty: + # Check if any workers have crashed + alive_count = sum(1 for w in self._workers if w.is_alive()) + if alive_count == 0 and len(results) < expected_count: + logger.error("All workers have exited but not all reported results") + break + continue + + # Terminate any remaining workers + for worker in self._workers: + if worker.is_alive(): + logger.warning("Terminating hung worker PID %d", worker.pid) + worker.terminate() + worker.join(timeout=5.0) + + return results + + async def _gather_results( + self, worker_results: list[dict], expected_gpu_count: int + ) -> tuple[list[dict], bool]: + """Gather and combine results from all workers. + + CRITICAL: All workers must succeed for upload to proceed. + Missing any problem ID will cause validator proof failure. + + Args: + worker_results: List of worker result status dictionaries + expected_gpu_count: Number of GPUs that must succeed + + Returns: + Tuple of (combined inferences, all_succeeded) + """ + all_inferences: list[dict] = [] + successful_gpus = 0 + failed_gpus: list[int] = [] + + for result in worker_results: + if result["status"] != "success": + failed_gpus.append(result["gpu_id"]) + logger.error( + "GPU %d FAILED: %s - Cannot upload partial results!", + result["gpu_id"], + result.get("error", "unknown error"), + ) + continue + + result_file = Path(result["result_file"]) + if not result_file.exists(): + failed_gpus.append(result["gpu_id"]) + logger.error( + "GPU %d result file missing: %s - Cannot upload partial results!", + result["gpu_id"], + result_file, + ) + continue + + try: + with open(result_file) as f: + data = json.load(f) + inferences = data.get("inferences", []) + all_inferences.extend(inferences) + successful_gpus += 1 + logger.info( + " GPU %d: %d rollouts collected", + result["gpu_id"], + len(inferences), + ) + + # Clean up temp file + result_file.unlink() + + except Exception as e: + failed_gpus.append(result["gpu_id"]) + logger.error("Failed to read results from GPU %d: %s", result["gpu_id"], e) + + # Check if ALL workers succeeded + all_succeeded = (successful_gpus == expected_gpu_count) and len(failed_gpus) == 0 + + if all_succeeded: + # CRITICAL: Sort inferences by rollout_group (problem_index) then rollout_index + # The validator uses file-order to derive seed: first group in file = group_index 0 + # If we don't sort, a GPU that finishes first could put problem 24 before problem 0, + # causing the validator to derive wrong seeds and fail validation! + all_inferences.sort( + key=lambda x: ( + int(x.get("rollout_group", 0)), # Primary: problem index + int(x.get("rollout_index", 0)), # Secondary: rollout within problem + ) + ) + logger.info( + "✅ All %d GPUs succeeded: %d total rollouts ready for upload (sorted by problem ID)", + successful_gpus, + len(all_inferences), + ) + else: + logger.error( + "❌ INCOMPLETE: Only %d/%d GPUs succeeded. Failed GPUs: %s", + successful_gpus, + expected_gpu_count, + failed_gpus, + ) + logger.error( + "Cannot upload partial results - validator would reject due to missing problem IDs!" + ) + + return all_inferences, all_succeeded + + def cleanup(self) -> None: + """Clean up resources and temp files.""" + # Terminate any remaining workers + for worker in self._workers: + if worker.is_alive(): + worker.terminate() + worker.join(timeout=2.0) + + # Clean up results directory + try: + if self.config.results_dir.exists(): + for f in self.config.results_dir.iterdir(): + f.unlink() + self.config.results_dir.rmdir() + except Exception as e: + logger.debug("Cleanup error (non-fatal): %s", e) + + +# --------------------------------------------------------------------------- # +# CLI Interface # +# --------------------------------------------------------------------------- # + + +async def run_parallel_miner( + config: ParallelMinerConfig, + use_drand: bool = True, +) -> None: + """Main entry point for parallel multi-GPU mining. + + Args: + config: Parallel mining configuration + use_drand: Whether to use drand for randomness + """ + from types import SimpleNamespace + + from ..cli.mine import ( + MiningTimers, + calculate_window_start, + get_conf, + get_window_randomness, + upload_inferences_with_metrics, + ) + from ..infrastructure.chain import GrailChainManager + from ..infrastructure.checkpoints import CheckpointManager, default_checkpoint_cache_root + from ..shared.constants import TRAINER_UID + + # Load configuration + coldkey = get_conf("BT_WALLET_COLD", "default") + hotkey = get_conf("BT_WALLET_HOT", "default") + wallet = bt.wallet(name=coldkey, hotkey=hotkey) + + logger.info("🔑 Parallel Miner hotkey: %s", wallet.hotkey.ss58_address) + logger.info(" GPUs: %d, Problems/GPU: %d", config.num_gpus, config.problems_per_gpu) + + # Load credentials + credentials = load_r2_credentials() + logger.info("✅ Loaded R2 credentials") + + # Initialize async subtensor (grail uses async bittensor wrapper) + from ..infrastructure.network import create_subtensor + + subtensor = await create_subtensor() + netuid = int(get_conf("BT_NETUID", get_conf("NETUID", 200))) + + # Get metagraph using async subtensor + metagraph = await subtensor.metagraph(netuid) + + # Initialize chain manager for credential commitments + chain_config = SimpleNamespace(netuid=netuid) + chain_manager = GrailChainManager(chain_config, wallet, metagraph, subtensor, credentials) + await chain_manager.initialize() + logger.info("✅ Initialized chain manager") + + # Get trainer credentials for checkpoints + trainer_bucket = chain_manager.get_bucket(TRAINER_UID) + checkpoint_credentials = trainer_bucket if trainer_bucket else credentials + + checkpoint_manager = CheckpointManager( + cache_root=default_checkpoint_cache_root(), + credentials=checkpoint_credentials, + keep_limit=2, + ) + + # Create coordinator + coordinator = ParallelMiningCoordinator(config, wallet, credentials) + + # Main mining loop + last_window_start = -1 + timers = MiningTimers() + current_checkpoint_window: int | None = None + checkpoint_path: str | None = None + + try: + while True: + current_block = await subtensor.get_current_block() + window_start = calculate_window_start(current_block) + checkpoint_window = window_start - WINDOW_LENGTH + + if window_start <= last_window_start: + await asyncio.sleep(5) + continue + + # Load checkpoint if needed + if checkpoint_window >= 0 and current_checkpoint_window != checkpoint_window: + logger.info("🔁 Loading checkpoint for window %s", checkpoint_window) + checkpoint_path_obj = await checkpoint_manager.get_checkpoint(checkpoint_window) + if checkpoint_path_obj: + checkpoint_path = str(checkpoint_path_obj) + current_checkpoint_window = checkpoint_window + else: + logger.error("No checkpoint available for window %s", checkpoint_window) + await asyncio.sleep(30) + continue + + if not checkpoint_path: + logger.error("No checkpoint loaded, cannot mine") + await asyncio.sleep(30) + continue + + # Check time budget BEFORE starting parallel mining + # Parallel mode needs more time since all GPUs must complete before upload + current_block = await subtensor.get_current_block() + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block + + # Estimate time needed: rough estimate based on problems and safety margin + # Each GPU needs ~30-60s per problem with batch_size=16, plus upload time + estimated_time_per_problem = 45 # seconds, conservative estimate + estimated_upload_time = 30 # seconds + total_estimated_seconds = ( + config.problems_per_gpu * estimated_time_per_problem + estimated_upload_time + ) + # Convert to blocks (12 seconds per block) + estimated_blocks_needed = (total_estimated_seconds // 12) + config.safety_blocks + + if blocks_remaining < estimated_blocks_needed: + logger.warning( + "⏰ Skipping window %d: only %d blocks remaining, need ~%d blocks for parallel mining", + window_start, + blocks_remaining, + estimated_blocks_needed, + ) + await asyncio.sleep(10) + continue + + # Get window randomness + window_block_hash, combined_randomness = await get_window_randomness( + subtensor, + window_start, + use_drand, + ) + + logger.info( + "🔥 Starting parallel mining for window %d-%d", + window_start, + window_start + WINDOW_LENGTH - 1, + ) + + # Run parallel mining + inferences = await coordinator.mine_window( + window_start, + window_block_hash, + combined_randomness, + checkpoint_path, + ) + + # Upload aggregated results - but first check we have time! + if inferences: + # CRITICAL: Check blocks remaining before upload + current_block = await subtensor.get_current_block() + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block + + if blocks_remaining < config.safety_blocks: + logger.error( + "❌ SKIPPING UPLOAD: Only %d blocks remaining (need %d safety blocks)", + blocks_remaining, + config.safety_blocks, + ) + logger.error( + "Window %d will be missed - workers took too long", + window_start, + ) + # Don't upload late - validator would reject anyway + last_window_start = window_start + continue + + logger.info( + "📤 Uploading %d aggregated rollouts for window %d (%d blocks remaining)", + len(inferences), + window_start, + blocks_remaining, + ) + upload_duration = await upload_inferences_with_metrics( + wallet, + window_start, + inferences, + credentials, + None, # monitor + ) + timers.update_upload_time_ema(upload_duration) + logger.info("✅ Successfully uploaded window %d", window_start) + else: + logger.warning("No inferences generated for window %d", window_start) + + last_window_start = window_start + await checkpoint_manager.cleanup_local(window_start) + + except KeyboardInterrupt: + logger.info("Shutting down parallel miner...") + finally: + coordinator.cleanup() + chain_manager.stop() + + +def register(app: typer.Typer) -> None: + """Register parallel-mine command with CLI.""" + app.command("parallel-mine")(parallel_mine) + + +def parallel_mine( + num_gpus: int = typer.Option( + 8, + "--num-gpus", + "-g", + help="Number of GPUs to use for parallel mining", + ), + problems_per_gpu: int = typer.Option( + 12, + "--problems-per-gpu", + "-p", + help="Minimum number of problems each GPU should generate", + ), + batch_size: int = typer.Option( + 16, + "--batch-size", + "-b", + help="Rollout batch size within each problem (default 16 for optimal A100 performance)", + ), + safety_blocks: int = typer.Option( + 3, + "--safety-blocks", + help="Safety margin blocks before window end", + ), + use_drand: bool = typer.Option( + True, + "--use-drand/--no-drand", + help="Use drand for randomness", + ), + gpu_ids: str = typer.Option( + None, + "--gpu-ids", + help="Comma-separated GPU IDs to use (e.g., '0,1,2,3'). Default: 0 to num_gpus-1", + ), + worker_timeout: float = typer.Option( + 600.0, + "--worker-timeout", + help="Maximum seconds to wait for workers per window", + ), +) -> None: + """Run parallel multi-GPU miner for maximum throughput. + + Spawns multiple worker processes, each on a dedicated GPU, generating + rollouts for non-overlapping problem ranges. Results are aggregated + and uploaded as a single submission per window. + + Example: + grail parallel-mine --num-gpus 8 --problems-per-gpu 12 + + This generates 8 × 12 = 96 problems per window (1,536+ rollouts). + """ + # Validate inputs + if num_gpus < 1: + console.print("[red]Error: --num-gpus must be at least 1[/red]") + raise typer.Exit(code=1) + + if problems_per_gpu < 1: + console.print("[red]Error: --problems-per-gpu must be at least 1[/red]") + raise typer.Exit(code=1) + + if batch_size < 1 or batch_size > 16: + console.print("[red]Error: --batch-size must be between 1 and 16[/red]") + raise typer.Exit(code=1) + + # Parse GPU IDs if provided + parsed_gpu_ids = None + if gpu_ids: + try: + parsed_gpu_ids = [int(x.strip()) for x in gpu_ids.split(",")] + if len(parsed_gpu_ids) != num_gpus: + console.print( + f"[red]Error: --gpu-ids has {len(parsed_gpu_ids)} IDs " + f"but --num-gpus is {num_gpus}[/red]" + ) + raise typer.Exit(code=1) + except ValueError as err: + console.print("[red]Error: --gpu-ids must be comma-separated integers[/red]") + raise typer.Exit(code=1) from err + + # Check GPU availability + available_gpus = torch.cuda.device_count() + if available_gpus < num_gpus: + console.print( + f"[yellow]Warning: Only {available_gpus} GPUs available, " + f"but {num_gpus} requested[/yellow]" + ) + + config = ParallelMinerConfig( + num_gpus=num_gpus, + problems_per_gpu=problems_per_gpu, + batch_size=batch_size, + safety_blocks=safety_blocks, + use_drand=use_drand, + worker_timeout=worker_timeout, + gpu_ids=parsed_gpu_ids, + ) + + total_problems = num_gpus * problems_per_gpu + console.print("[bold green]Starting Parallel Miner[/bold green]") + console.print(f" GPUs: {num_gpus}") + console.print(f" Problems/GPU: {problems_per_gpu}") + console.print(f" Total problems/window: {total_problems}") + console.print(f" Expected rollouts/window: {total_problems * 16}") + + try: + asyncio.run(run_parallel_miner(config, use_drand)) + except KeyboardInterrupt: + console.print("[yellow]Parallel miner stopped by user[/yellow]") + raise typer.Exit(code=0) from None + except Exception as e: + logger.exception("Fatal error in parallel miner") + console.print(f"[red]Fatal error: {e}[/red]") + raise typer.Exit(code=1) from None + + +# --------------------------------------------------------------------------- # +# Main Entry Point # +# --------------------------------------------------------------------------- # + + +def main() -> None: + """Main entry point for parallel miner CLI.""" + app = typer.Typer() + register(app) + app() + + +if __name__ == "__main__": + main() diff --git a/grail/environments/providers.py b/grail/environments/providers.py index 1636a2b..6fc0e40 100644 --- a/grail/environments/providers.py +++ b/grail/environments/providers.py @@ -215,6 +215,44 @@ def _extract_boxed_answer(solution: str) -> str: _MATH_VAL_SEED = 42 +def _extract_boxed_answer(solution: str) -> str: + """Extract answer from \\boxed{...} in solution, handling nested braces.""" + import re + + match = re.search(r"\\boxed\{", solution) + if not match: + return "" + + start = match.end() + depth = 1 + i = start + while i < len(solution) and depth > 0: + if solution[i] == "{": + depth += 1 + elif solution[i] == "}": + depth -= 1 + i += 1 + + return solution[start : i - 1] if depth == 0 else "" + + +# Subsets in EleutherAI/hendrycks_math dataset +_MATH_SUBSETS = ( + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", +) + +# Fixed validation set size (stratified across problem types) +_MATH_VAL_SIZE = 500 +# Seed for deterministic stratified sampling of validation set +_MATH_VAL_SEED = 42 + + class MATHTaskSource(TaskSource): """HF datasets-backed Hendrycks MATH provider with stratified train/val split. diff --git a/grail/model/provider.py b/grail/model/provider.py index c5aa784..2a5def2 100644 --- a/grail/model/provider.py +++ b/grail/model/provider.py @@ -59,6 +59,7 @@ def get_model( use_safetensors: bool = True, eval_mode: bool = True, use_flash_attention: bool = False, + use_sdpa: bool = True, checkpoint_window: int | None = None, ) -> Any: """Load model with consistent configuration. @@ -69,7 +70,9 @@ def get_model( use_safetensors: Whether to prefer safetensors format eval_mode: Whether to set model to eval() mode use_flash_attention: Whether to use Flash Attention 2 (requires flash-attn package). - Only enabled for training, not for evaluation/inference. + Takes priority over SDPA if both are enabled. + use_sdpa: Whether to use PyTorch SDPA (Scaled Dot-Product Attention). + Built into PyTorch 2.0+, provides 10-30% speedup. Default: True. checkpoint_window: Optional checkpoint window number. If not provided, will be extracted from metadata.json or parsed from the path. @@ -111,19 +114,27 @@ def get_model( except (ValueError, IndexError): pass - # Configure attention implementation + # Configure attention implementation (priority: Flash Attention 2 > SDPA > default) attn_implementation = None - if use_flash_attention and device == "cuda": - try: - import flash_attn # noqa: F401 - - attn_implementation = "flash_attention_2" - logger.info("Using Flash Attention 2 for model loading") - except ImportError: - logger.warning( - "flash-attn not installed; falling back to default attention. " - "Install with: uv pip install flash-attn" - ) + if device == "cuda": + if use_flash_attention: + try: + import flash_attn # noqa: F401 + + attn_implementation = "flash_attention_2" + logger.info("Using Flash Attention 2 for model loading") + except ImportError: + logger.warning( + "flash-attn not installed; falling back to SDPA. " + "Install with: uv pip install flash-attn" + ) + if use_sdpa: + attn_implementation = "sdpa" + logger.info("Using PyTorch SDPA (Scaled Dot-Product Attention)") + elif use_sdpa: + # SDPA is built into PyTorch 2.0+ and provides good speedup + attn_implementation = "sdpa" + logger.info("Using PyTorch SDPA (Scaled Dot-Product Attention)") # Load model with optimized attention if available model = AutoModelForCausalLM.from_pretrained( diff --git a/grail/trainer/config.py b/grail/trainer/config.py index 54ac1e4..5e08333 100644 --- a/grail/trainer/config.py +++ b/grail/trainer/config.py @@ -96,7 +96,7 @@ class EvalConfig: enabled: bool = True window_interval: int = 20 - split: str = "val" # dataset-backed envs (e.g., GSM8K) #TODO: should be specified per env + split: str = "val" # Use validation split subset_size: int | None = None # generative envs or capped dataset eval seed_base: int = 2025 batch_size: int = 32 # Conservative for vLLM server: 8 tasks × 5 reps = 40 prompts/batch (prevent queue timeout) diff --git a/research/eval/README.md b/research/eval/README.md new file mode 100644 index 0000000..8325faa --- /dev/null +++ b/research/eval/README.md @@ -0,0 +1,379 @@ +# MATH Benchmark Evaluation + +Evaluate language models on the [Hendrycks MATH](https://github.com/hendrycks/math) dataset using [EleutherAI's lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with vLLM backend. + +## Overview + +This directory contains custom evaluation tasks for reasoning models that use the GRAIL format: +- `` ... `` for chain-of-thought reasoning +- `` ... `` for final answers + +## Prerequisites + +```bash +# Activate the vLLM environment +source /root/grail/tools/vllm-server/.venv/bin/activate +``` + +## Quick Start + +### 1. Base Model (Standard Evaluation) + +Standard 4-shot evaluation without reasoning format: + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096" \ + --tasks hendrycks_math \ + --num_fewshot 4 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --output_path ./results/base_4shot \ + --log_samples +``` + +### 2. Reasoning Model (Custom Template) + +For models trained with the GRAIL reasoning format, use the custom task with chat template: + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/path/to/checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/reasoning_4shot +``` + +## Evaluation Configurations + +### Base Model Configurations + +| Config | Command Flags | Use Case | +|--------|--------------|----------| +| 0-shot | `--num_fewshot 0` | Zero-shot baseline | +| 4-shot | `--num_fewshot 4` | Standard MATH benchmark | + +### Reasoning Model Configurations + +| Config | Command Flags | Use Case | +|--------|--------------|----------| +| 0-shot | `--num_fewshot 0 --apply_chat_template` | Zero-shot with reasoning template | +| 4-shot multiturn | `--num_fewshot 4 --apply_chat_template --fewshot_as_multiturn` | **Recommended** - Few-shot as conversation | + +## Key Arguments + +| Argument | Description | +|----------|-------------| +| `--tasks hendrycks_math` | Standard MATH evaluation (7 subjects) | +| `--tasks hendrycks_math_grail` | Custom GRAIL reasoning format | +| `--include_path` | Path to custom task definitions | +| `--apply_chat_template` | Apply model's chat template | +| `--fewshot_as_multiturn` | Format few-shot examples as multi-turn conversation | +| `--think_end_token` | Token marking end of reasoning (extracts answer after this) | +| `--max_model_len` | Context length (use 8192+ for 4-shot) | +| `--log_samples` | Save per-sample outputs for analysis | + +## Example Commands + +### Evaluate GRAIL Checkpoint (Recommended) + +```bash +cd /root/grail && source tools/vllm-server/.venv/bin/activate + +CUDA_VISIBLE_DEVICES=0 python -m lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path /root/grail/eval_results/grail_checkpoint +``` + +### Evaluate Base Model with Reasoning Template + +First, prepare the base model with custom chat template: + +```bash +# Download and patch the model (one-time setup) +python -c " +from huggingface_hub import snapshot_download +import json + +# Download model +snapshot_download('Qwen/Qwen2.5-1.5B-Instruct', local_dir='./models/Qwen2.5-1.5B-Instruct-reasoning') + +# Patch tokenizer config with reasoning template +with open('./models/Qwen2.5-1.5B-Instruct-reasoning/tokenizer_config.json', 'r') as f: + config = json.load(f) + +config['chat_template'] = \"\"\"{% if messages[0]['role'] == 'system' %}{{ messages[0]['content'] + eos_token }}{% set loop_messages = messages[1:] %}{% else %}{{ 'You are given a problem. +Think about the problem and provide your working out. +Place it between and . +Then, provide your solution between .' + eos_token }}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if message['role'] == 'user' %}{{ message['content'] }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '' }}{% endif %}\"\"\" + +with open('./models/Qwen2.5-1.5B-Instruct-reasoning/tokenizer_config.json', 'w') as f: + json.dump(config, f, indent=2) +" +``` + +Then evaluate: + +```bash +CUDA_VISIBLE_DEVICES=0 python -m lm_eval \ + --model vllm \ + --model_args "pretrained=./models/Qwen2.5-1.5B-Instruct-reasoning,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/base_reasoning_4shot +``` + +### Run in Background + +```bash +CUDA_VISIBLE_DEVICES=0 nohup python -m lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/grail_checkpoint \ + > eval.log 2>&1 & +echo "Started. PID: $!" +``` + +## Benchmark Results + +| Model | Config | Accuracy | +|-------|--------|----------| +| Qwen2.5-1.5B-Instruct | 0-shot standard | 1.90% | +| Qwen2.5-1.5B-Instruct | 4-shot standard | 12.66% | +| Qwen2.5-1.5B-Instruct + reasoning template | 4-shot multiturn | 28.00% | +| grail_final_checkpoint | 4-shot multiturn | **30.34%** | + +## Task Structure + +``` +tasks/hendrycks_math_grail/ +├── _default_template.yaml # Base config with reasoning format +├── hendrycks_math_grail.yaml # Task group definition +├── hendrycks_math_grail_algebra.yaml +├── hendrycks_math_grail_counting_and_prob.yaml +├── hendrycks_math_grail_geometry.yaml +├── hendrycks_math_grail_intermediate_algebra.yaml +├── hendrycks_math_grail_num_theory.yaml +├── hendrycks_math_grail_prealgebra.yaml +├── hendrycks_math_grail_precalc.yaml +└── utils.py # Answer extraction and comparison +``` + +## Reasoning Format + +The custom chat template instructs the model to: + +``` +You are given a problem. +Think about the problem and provide your working out. +Place it between and . +Then, provide your solution between . +``` + +Example output: +``` + +Let me solve this step by step... +The answer is 42. + +42 +``` + +The `think_end_token=` argument tells the evaluator to extract the answer from text **after** this token, effectively using only the `` content for scoring. + +## AIME 2024 Benchmark + +AIME (American Invitational Mathematics Examination) is an extremely challenging competition math benchmark. The dataset contains 30 problems from AIME 2024. + +### Running AIME Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096,max_gen_toks=2048" \ + --tasks aime24 \ + --num_fewshot 0 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False,max_gen_toks=2048" \ + --log_samples \ + --output_path ./results/aime24_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks aime24_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/aime24_grail +``` + +**Note**: AIME is extremely difficult - even 70B+ models typically achieve only 3-10% on AIME. Small models (1.5B) are expected to score near 0%. + +## Pass@k Evaluation + +For sampling-based evaluation with pass@k metrics: + +### Best Practices + +| Parameter | Recommended Value | Notes | +|-----------|------------------|-------| +| `repeats` | 10 (for pass@5), 100 (for pass@100) | Number of samples per problem | +| `temperature` | 0.6 - 0.8 | Higher = more diversity | +| `top_p` | 0.95 | Nucleus sampling | +| `do_sample` | true | Required for sampling | + +### Formula + +pass@k = 1 - C(n-c, k) / C(n, k) + +Where: +- n = total samples generated +- c = number of correct samples +- k = number of samples to consider + +### Example: Pass@5 on MATH + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_pass_at_5 \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --batch_size auto \ + --log_samples \ + --output_path ./results/math_pass_at_5 +``` + +### Key Differences from Greedy Evaluation + +| Greedy (pass@1) | Sampling (pass@k) | +|-----------------|-------------------| +| `temperature=0` | `temperature=0.7` | +| `do_sample=false` | `do_sample=true` | +| `repeats=1` | `repeats=10+` | +| Single deterministic output | Multiple diverse outputs | + +### Custom Pass@k Tasks + +Create a task YAML with: +```yaml +repeats: 10 # Generate 10 samples per problem +generation_kwargs: + do_sample: true + temperature: 0.7 + top_p: 0.95 +metric_list: + - metric: !function utils.aggregate_pass_at_5 + aggregation: mean + higher_is_better: true +``` + +## AMC 2023 Benchmark + +AMC (American Mathematics Competition) is a high school math competition. The AMC 2023 dataset contains 40 problems. + +### Running AMC Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096,max_gen_toks=2048" \ + --tasks amc2023 \ + --include_path /root/grail/research/eval/tasks \ + --num_fewshot 0 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --log_samples \ + --output_path ./results/amc2023_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks amc2023_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/amc2023_grail +``` + +### AMC 2023 Results + +| Model | Config | Accuracy | +|-------|--------|----------| +| Qwen2.5-1.5B-Instruct | 0-shot | 17.5% | +| Qwen2.5-1.5B-Instruct | 4-shot | 17.5% | +| grail_final_checkpoint | reasoning template | 17.5% | + +## GSM8K Benchmark + +GSM8K (Grade School Math 8K) is a dataset of 8.5K high-quality linguistically diverse grade school math word problems. The test set contains 1319 problems. + +### Running GSM8K Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096" \ + --tasks gsm8k \ + --num_fewshot 4 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --log_samples \ + --output_path ./results/gsm8k_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks gsm8k_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/gsm8k_grail +``` diff --git a/research/eval/tasks/_common.py b/research/eval/tasks/_common.py new file mode 100644 index 0000000..bb3cf24 --- /dev/null +++ b/research/eval/tasks/_common.py @@ -0,0 +1,523 @@ +"""Shared utilities for thinking model evaluation tasks. + +This module contains common functions for answer extraction and comparison +used across multiple evaluation tasks (AIME, AMC, GSM8K, MATH, etc.). + +Following DRY principles - extract once, reuse everywhere. +""" + +import re +from collections.abc import Callable + +# ============================================================================= +# Answer Extraction Functions +# ============================================================================= + + +def extract_solution_tag(text: str) -> str | None: + """Extract content from ... tags. + + Args: + text: Model output text + + Returns: + Content inside SOLUTION tags, or None if not found + """ + match = re.search(r"(.*?)", text, re.DOTALL) + if match: + return match.group(1).strip() + return None + + +def extract_dollar_sign_answer(text: str) -> str | None: + """Extract answer from $...$ format (last pair). + + Args: + text: Model output text + + Returns: + Content between last pair of dollar signs, or None if not found + """ + indices = [pos for pos, char in enumerate(text) if char == "$"] + if len(indices) >= 2: + return text[indices[-2] + 1 : indices[-1]] + return None + + +def remove_boxed(s: str) -> str | None: + """Remove \\boxed{} wrapper from string. + + Args: + s: String potentially wrapped in \\boxed{} + + Returns: + Unwrapped content, or original string if no valid wrapper found + """ + if s is None: + return None + + # Handle "\\boxed " format (space after boxed) + if "\\boxed " in s: + left = "\\boxed " + if s[: len(left)] == left: + return s[len(left) :] + + # Handle "\\boxed{...}" format + left = "\\boxed{" + if s[: len(left)] == left and s.endswith("}"): + return s[len(left) : -1] + + return s + + +def last_boxed_only_string(string: str) -> str | None: + """Extract the last \\boxed{} or \\fbox{} content from a string. + + Handles nested braces correctly. + + Args: + string: Text containing potential boxed content + + Returns: + The last boxed expression (including \\boxed{} wrapper), or None + """ + if not string: + return None + + idx = string.rfind("\\boxed") + + # Handle "\\boxed " format + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + # Find matching closing brace + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + + return string[idx : right_brace_idx + 1] + + +def extract_answer_cascade( + text: str, + try_solution_tag: bool = True, + try_boxed: bool = True, + try_dollar: bool = True, +) -> str: + """Extract answer using cascade of methods. + + Tries extraction methods in order until one succeeds: + 1. ... tags (optional) + 2. \\boxed{...} (optional) + 3. $...$ format (optional) + 4. Original text (fallback) + + Args: + text: Model output text + try_solution_tag: Whether to try SOLUTION tag extraction + try_boxed: Whether to try boxed extraction + try_dollar: Whether to try dollar sign extraction + + Returns: + Extracted answer string + """ + if not text: + return "" + + # Try SOLUTION tags (reasoning models) + if try_solution_tag: + result = extract_solution_tag(text) + if result: + return result + + # Try boxed format + if try_boxed: + boxed = last_boxed_only_string(text) + if boxed: + unboxed = remove_boxed(boxed) + if unboxed: + return unboxed + + # Try dollar sign format + if try_dollar: + result = extract_dollar_sign_answer(text) + if result: + return result + + return text.strip() + + +# ============================================================================= +# String Normalization Functions +# ============================================================================= + + +def strip_string_basic(string: str) -> str: + """Basic string normalization for comparison. + + Removes common formatting that doesn't affect mathematical meaning: + - Linebreaks, spaces + - LaTeX commands: \\!, \\left, \\right + - Dollar signs + + Args: + string: String to normalize + + Returns: + Normalized string + """ + if string is None: + return "" + + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("$", "") + string = string.replace("\\$", "") + string = string.replace(" ", "") + + return string + + +def fix_fracs(string: str) -> str: + """Fix fraction formatting (\\frac12 -> \\frac{1}{2}).""" + substrs = string.split("\\frac") + new_str = substrs[0] + + if len(substrs) > 1: + for substr in substrs[1:]: + new_str += "\\frac" + if not substr or substr[0] == "{": + new_str += substr + else: + if len(substr) < 2: + return string + a = substr[0] + b = substr[1] + if b != "{": + post_substr = substr[2:] if len(substr) > 2 else "" + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + post_substr = substr[2:] if len(substr) > 2 else "" + new_str += "{" + a + "}" + b + post_substr + + return new_str + + +def fix_sqrt(string: str) -> str: + """Fix sqrt formatting (\\sqrt2 -> \\sqrt{2}).""" + if "\\sqrt" not in string: + return string + + splits = string.split("\\sqrt") + new_string = splits[0] + + for split in splits[1:]: + if split and split[0] != "{": + new_substr = "\\sqrt{" + split[0] + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + + return new_string + + +def fix_a_slash_b(string: str) -> str: + """Convert simple fractions a/b to \\frac{a}{b}.""" + if len(string.split("/")) != 2: + return string + + a_str, b_str = string.split("/") + try: + a = int(a_str) + b = int(b_str) + if string == f"{a}/{b}": + return "\\frac{" + str(a) + "}{" + str(b) + "}" + except ValueError: + pass + + return string + + +def remove_right_units(string: str) -> str: + """Remove units on the right side (e.g., '5 \\text{ meters}').""" + if "\\text{ " in string: + splits = string.split("\\text{ ") + if len(splits) == 2: + return splits[0] + return string + + +def strip_string_math(string: str) -> str: + """Full math string normalization for MATH benchmark. + + Includes all basic normalization plus: + - tfrac/dfrac -> frac + - Degrees removal + - Units removal + - Fraction normalization + - Leading decimal fixes + + Args: + string: String to normalize + + Returns: + Normalized string + """ + if string is None: + return "" + + # Basic cleanup + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove degrees + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # Remove dollar signs + string = string.replace("\\$", "") + + # Remove units + string = remove_right_units(string) + + # Remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # Fix leading decimals + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # Handle "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # Fix sqrt formatting + string = fix_sqrt(string) + + # Remove spaces + string = string.replace(" ", "") + + # Fix fractions + string = fix_fracs(string) + + # Special case: 0.5 -> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # Convert a/b to \frac{a}{b} + string = fix_a_slash_b(string) + + return string + + +# ============================================================================= +# Number Extraction Functions +# ============================================================================= + + +def extract_integer(s: str) -> int | None: + """Extract integer from string, handling common formats. + + Args: + s: String potentially containing an integer + + Returns: + Extracted integer, or None if not found + """ + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return int(s) + except ValueError: + pass + + # Find integers in the string (return last one) + matches = re.findall(r"-?\d+", s) + if matches: + return int(matches[-1]) + + return None + + +def extract_float(s: str) -> float | None: + """Extract float from string, handling common formats. + + Args: + s: String potentially containing a number + + Returns: + Extracted float, or None if not found + """ + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return float(s) + except ValueError: + pass + + # Find numbers in the string (return last one) + matches = re.findall(r"-?\d+\.?\d*", s) + if matches: + return float(matches[-1]) + + return None + + +# ============================================================================= +# Equivalence Checking Functions +# ============================================================================= + + +def is_equiv_string( + str1: str, + str2: str, + normalizer: Callable[[str], str] = strip_string_basic, +) -> bool: + """Check if two strings are equivalent after normalization. + + Args: + str1: First string + str2: Second string + normalizer: Function to normalize strings before comparison + + Returns: + True if equivalent, False otherwise + """ + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = normalizer(str1) + ss2 = normalizer(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def is_equiv_numeric( + str1: str, + str2: str, + tolerance: float = 0.01, + integer_only: bool = False, +) -> bool: + """Check if two strings represent equivalent numbers. + + Args: + str1: First string + str2: Second string + tolerance: Absolute tolerance for float comparison + integer_only: If True, only compare as integers + + Returns: + True if equivalent, False otherwise + """ + if str1 is None or str2 is None: + return str1 is None and str2 is None + + if integer_only: + int1 = extract_integer(str1) + int2 = extract_integer(str2) + if int1 is not None and int2 is not None: + return int1 == int2 + else: + num1 = extract_float(str1) + num2 = extract_float(str2) + if num1 is not None and num2 is not None: + return abs(num1 - num2) < tolerance + + return False + + +def is_equiv_combined( + str1: str, + str2: str, + normalizer: Callable[[str], str] = strip_string_basic, + try_numeric: bool = True, + tolerance: float = 0.01, + integer_only: bool = False, +) -> bool: + """Check equivalence using both string and numeric comparison. + + Args: + str1: First string + str2: Second string + normalizer: Function to normalize strings + try_numeric: Whether to try numeric comparison + tolerance: Tolerance for float comparison + integer_only: If True, only do integer comparison + + Returns: + True if equivalent by any method + """ + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + # Try string comparison first + ss1 = normalizer(str1) + ss2 = normalizer(str2) + if ss1 == ss2: + return True + + # Try numeric comparison + if try_numeric: + return is_equiv_numeric(ss1, ss2, tolerance, integer_only) + + return False + except Exception: + return str1 == str2 diff --git a/research/eval/tasks/aime24_thinking/aime24_thinking.yaml b/research/eval/tasks/aime24_thinking/aime24_thinking.yaml new file mode 100644 index 0000000..3d42239 --- /dev/null +++ b/research/eval/tasks/aime24_thinking/aime24_thinking.yaml @@ -0,0 +1,26 @@ +tag: + - math_word_problems +task: aime24_thinking +dataset_path: Maxwell-Jia/AIME_2024 +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "{{Problem}}" +doc_to_target: "\n{{Solution}}\n\n{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 4096 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/aime24_thinking/utils.py b/research/eval/tasks/aime24_thinking/utils.py new file mode 100644 index 0000000..1425d0a --- /dev/null +++ b/research/eval/tasks/aime24_thinking/utils.py @@ -0,0 +1,62 @@ +"""AIME 2024 evaluation utilities for thinking models. + +Extracts answers from ... tags and uses robust +integer comparison for AIME answers (which are always 0-999). +""" + +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer. + + AIME answers are always integers from 000-999. + """ + response = results[0] + + # Extract answer using cascade (SOLUTION tag -> boxed -> dollar -> raw) + answer = extract_answer_cascade( + response, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + + # Get target answer + answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None) + if answer_key is None: + return {"exact_match": 0} + + target = str(doc[answer_key]) + + # AIME answers are integers 0-999, use integer comparison + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_aime, + try_numeric=True, + integer_only=True, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _strip_string_aime(string: str) -> str: + """Normalize string for AIME comparison. + + Extends basic normalization with leading zero removal. + """ + string = strip_string_basic(string) + # Remove leading zeros for integer comparison (but keep "0") + string = string.lstrip("0") or "0" + return string diff --git a/research/eval/tasks/amc2023/amc2023.yaml b/research/eval/tasks/amc2023/amc2023.yaml new file mode 100644 index 0000000..6dc6ba2 --- /dev/null +++ b/research/eval/tasks/amc2023/amc2023.yaml @@ -0,0 +1,28 @@ +tag: + - math_word_problems +task: amc2023 +dataset_path: sparkle-reasoning/amc2023 +output_type: generate_until +test_split: test +fewshot_split: test +doc_to_text: "Problem: {{question}}\n\nAnswer: The answer is" +doc_to_target: " ${{answer|int}}$" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Problem:" + - "\n\n" + - "" + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 512 +repeats: 1 +metadata: + version: 1.0 + diff --git a/research/eval/tasks/amc2023/utils.py b/research/eval/tasks/amc2023/utils.py new file mode 100644 index 0000000..2398653 --- /dev/null +++ b/research/eval/tasks/amc2023/utils.py @@ -0,0 +1,62 @@ +"""AMC 2023 evaluation utilities. + +AMC answers are integers (multiple choice A-E corresponds to numeric answers). +Extracts answers from $...$ format, \\boxed{}, or plain numbers. +""" + +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + response = results[0] + + # Extract answer (no SOLUTION tags for non-thinking model) + answer = extract_answer_cascade( + response, + try_solution_tag=False, + try_boxed=True, + try_dollar=True, + ) + + # Get target answer + target = str(doc.get("answer", "")) + + # Compare with numeric fallback + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_amc, + try_numeric=True, + tolerance=0.01, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _strip_string_amc(string: str) -> str: + """Normalize string for AMC comparison. + + Extends basic normalization with float->int conversion. + """ + string = strip_string_basic(string) + + # Handle float formatting (e.g., "27.0" -> "27") + try: + num = float(string) + if num == int(num): + string = str(int(num)) + except ValueError: + pass + + return string diff --git a/research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml b/research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml new file mode 100644 index 0000000..5e93452 --- /dev/null +++ b/research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml @@ -0,0 +1,25 @@ +tag: + - math_word_problems +task: amc2023_thinking +dataset_path: sparkle-reasoning/amc2023 +output_type: generate_until +test_split: test +fewshot_split: test +doc_to_text: "{{question}}" +doc_to_target: "\n{{solution}}\n\n{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 4096 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/amc2023_thinking/utils.py b/research/eval/tasks/amc2023_thinking/utils.py new file mode 100644 index 0000000..e060fb0 --- /dev/null +++ b/research/eval/tasks/amc2023_thinking/utils.py @@ -0,0 +1,62 @@ +"""AMC 2023 evaluation utilities for thinking models. + +Extracts answers from ... tags and uses robust +numeric comparison for AMC answers. +""" + +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + response = results[0] + + # Extract answer using cascade (SOLUTION tag first for thinking models) + answer = extract_answer_cascade( + response, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + + # Get target answer + target = str(doc.get("answer", "")) + + # Compare with numeric fallback + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_amc, + try_numeric=True, + tolerance=0.01, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _strip_string_amc(string: str) -> str: + """Normalize string for AMC comparison. + + Extends basic normalization with float->int conversion. + """ + string = strip_string_basic(string) + + # Handle float formatting (e.g., "27.0" -> "27") + try: + num = float(string) + if num == int(num): + string = str(int(num)) + except ValueError: + pass + + return string diff --git a/research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml b/research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml new file mode 100644 index 0000000..dabcb68 --- /dev/null +++ b/research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml @@ -0,0 +1,26 @@ +tag: + - math_word_problems +task: gsm8k_thinking +dataset_path: gsm8k +dataset_name: main +output_type: generate_until +test_split: test +fewshot_split: train +doc_to_text: "{{question}}" +doc_to_target: !function utils.doc_to_target +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 1024 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/gsm8k_thinking/utils.py b/research/eval/tasks/gsm8k_thinking/utils.py new file mode 100644 index 0000000..7ff1ce8 --- /dev/null +++ b/research/eval/tasks/gsm8k_thinking/utils.py @@ -0,0 +1,95 @@ +"""GSM8K evaluation utilities for thinking models. + +Extracts answers from ... tags and compares with +the ground truth answer (after ####). +""" + +import re +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_solution_tag, + is_equiv_combined, +) + + +def doc_to_target(doc: dict) -> str: + """Convert document to target format for thinking.""" + answer = doc["answer"] + + # Extract final answer after #### + if "####" in answer: + final_answer = answer.split("####")[-1].strip() + reasoning = answer.split("####")[0].strip() + else: + final_answer = answer.strip() + reasoning = "" + + return ( + f"\n{reasoning}\n\n{final_answer}" + ) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + response = results[0] + + # Extract answer from tags first (reasoning models) + answer = extract_solution_tag(response) + if answer is None: + # Fallback: try to extract number from the end + answer = _extract_last_number(response) + + # Get target answer from document + target_answer = doc["answer"] + if "####" in target_answer: + target = target_answer.split("####")[-1].strip() + else: + target = target_answer.strip() + + # Compare answers + is_correct = is_equiv_combined( + answer, + target, + normalizer=_clean_answer, + try_numeric=True, + tolerance=0.001, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _clean_answer(s: str) -> str: + """Clean answer string for GSM8K comparison.""" + if s is None: + return "" + + s = s.strip() + # Remove dollar signs, commas, and common formatting + s = s.replace("$", "").replace(",", "").replace(" ", "") + # Remove trailing period + s = s.rstrip(".") + + return s + + +def _extract_last_number(s: str) -> str: + """Extract the last number from a string.""" + if s is None: + return "" + + # Look for #### pattern first (GSM8K format) + if "####" in s: + return s.split("####")[-1].strip() + + # Find all numbers + matches = re.findall(r"-?\d+(?:,\d{3})*(?:\.\d+)?", s) + if matches: + # Return last number, removing commas + return matches[-1].replace(",", "") + + return s.strip() diff --git a/research/eval/tasks/hendrycks_math_thinking/_default_template.yaml b/research/eval/tasks/hendrycks_math_thinking/_default_template.yaml new file mode 100644 index 0000000..4e94637 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/_default_template.yaml @@ -0,0 +1,21 @@ +dataset_path: EleutherAI/hendrycks_math +process_docs: !function utils.process_docs +output_type: generate_until +training_split: train +test_split: test +doc_to_text: "{{problem}}" +doc_to_target: "\n{{solution}}\n\n{{answer}}" +process_results: !function utils.process_results +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0 + max_gen_toks: 2048 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml new file mode 100644 index 0000000..f5d0509 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml @@ -0,0 +1,15 @@ +group: hendrycks_math_thinking +task: + - hendrycks_math_thinking_algebra + - hendrycks_math_thinking_counting_and_prob + - hendrycks_math_thinking_geometry + - hendrycks_math_thinking_intermediate_algebra + - hendrycks_math_thinking_num_theory + - hendrycks_math_thinking_prealgebra + - hendrycks_math_thinking_precalc +aggregate_metric_list: + - metric: exact_match + aggregation: mean + weight_by_size: true +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml new file mode 100644 index 0000000..5a87439 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_algebra +dataset_name: algebra diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml new file mode 100644 index 0000000..9f54e17 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_counting_and_prob +dataset_name: counting_and_probability diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml new file mode 100644 index 0000000..293f55a --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_geometry +dataset_name: geometry diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml new file mode 100644 index 0000000..7ee5914 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_intermediate_algebra +dataset_name: intermediate_algebra diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml new file mode 100644 index 0000000..b668341 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_num_theory +dataset_name: number_theory diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml new file mode 100644 index 0000000..3c9aebc --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_prealgebra +dataset_name: prealgebra diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml new file mode 100644 index 0000000..827992b --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_thinking_precalc +dataset_name: precalculus diff --git a/research/eval/tasks/hendrycks_math_thinking/utils.py b/research/eval/tasks/hendrycks_math_thinking/utils.py new file mode 100644 index 0000000..d58371b --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/utils.py @@ -0,0 +1,54 @@ +"""Custom utils for thinking model evaluation on MATH. + +Extracts answers from ... tags instead of \\boxed{}. +""" + +import sys +from pathlib import Path + +import datasets + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_solution_tag, + is_equiv_string, + last_boxed_only_string, + remove_boxed, + strip_string_math, +) + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + """Process dataset docs - extract ground truth answer from \\boxed{}.""" + + def _process_doc(doc: dict) -> dict: + boxed = last_boxed_only_string(doc["solution"]) + answer = remove_boxed(boxed) if boxed else "" + return { + "problem": doc["problem"], + "solution": doc["solution"], + "answer": answer, + } + + return dataset.map(_process_doc) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process results - extract answer from tags and compare.""" + # Extract from tags + model_answer = extract_solution_tag(results[0]) + if model_answer is None: + model_answer = results[0].strip() + + # Get ground truth (already extracted from \boxed{} in process_docs) + ground_truth = doc.get("answer") + if ground_truth is None: + boxed = last_boxed_only_string(doc["solution"]) + ground_truth = remove_boxed(boxed) if boxed else "" + + # Compare using full math normalization + is_correct = is_equiv_string(model_answer, ground_truth, strip_string_math) + + return {"exact_match": 1 if is_correct else 0} diff --git a/research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml new file mode 100644 index 0000000..ffcaad7 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml @@ -0,0 +1,33 @@ +# MATH pass@5 evaluation task +# Generates 10 samples per problem with temperature sampling +# Computes pass@1, pass@5 + +group: hendrycks_math_thinking_pass_at_k +task: hendrycks_math_thinking_pass_at_5 +dataset_path: EleutherAI/hendrycks_math +dataset_name: algebra +output_type: generate_until +training_split: train +test_split: test +doc_to_text: "{{problem}}" +doc_to_target: "{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: !function utils.aggregate_pass_at_1 + aggregation: mean + higher_is_better: true + - metric: !function utils.aggregate_pass_at_5 + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: true + temperature: 0.7 + top_p: 0.95 + max_gen_toks: 2048 +repeats: 10 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py new file mode 100644 index 0000000..36dec40 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py @@ -0,0 +1,104 @@ +"""Pass@k utilities for MATH benchmark evaluation. + +Implements the standard pass@k metric: given n samples, compute the probability +that at least one of k random samples is correct. + +Formula: pass@k = 1 - C(n-c, k) / C(n, k) +where n = total samples, c = correct samples, k = samples to consider +""" + +import sys +from math import comb +from pathlib import Path +from typing import Any + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_string, + strip_string_math, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, Any]: + """Process results for a single document (used by lm-eval). + + This function is called for EACH sample. The pass@k aggregation + happens in the aggregation function. + """ + answer = str(doc.get("answer", "")) + + # Check each result + correct_list = [] + for result in results: + extracted = extract_answer_cascade( + result, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + is_correct = 1 if is_equiv_string(extracted, answer, strip_string_math) else 0 + correct_list.append(is_correct) + + return { + "pass": correct_list, + "num_correct": sum(correct_list), + "num_samples": len(correct_list), + } + + +def _pass_at_k(n: int, c: int, k: int) -> float: + """Compute pass@k for a single problem. + + Args: + n: Total number of samples + c: Number of correct samples + k: Number of samples to consider + + Returns: + Probability that at least one of k samples is correct + """ + if n - c < k: + return 1.0 + return 1.0 - comb(n - c, k) / comb(n, k) + + +def aggregate_pass_at_k(results: list[dict], k: int = 5) -> float: + """Aggregate pass@k across all documents. + + Args: + results: List of result dicts from process_results + k: Number of samples to consider for pass@k + + Returns: + Average pass@k score + """ + scores = [] + for r in results: + n = r["num_samples"] + c = r["num_correct"] + if n < k: + score = 1.0 if c > 0 else 0.0 + else: + score = _pass_at_k(n, c, k) + scores.append(score) + + return sum(scores) / len(scores) if scores else 0.0 + + +# Convenience aggregation functions for different k values +def aggregate_pass_at_1(results: list[dict]) -> float: + """Aggregate pass@1 score.""" + return aggregate_pass_at_k(results, k=1) + + +def aggregate_pass_at_5(results: list[dict]) -> float: + """Aggregate pass@5 score.""" + return aggregate_pass_at_k(results, k=5) + + +def aggregate_pass_at_10(results: list[dict]) -> float: + """Aggregate pass@10 score.""" + return aggregate_pass_at_k(results, k=10) diff --git a/research/trl/train_trl_grpo.py b/research/trl/train_trl_grpo.py new file mode 100644 index 0000000..d3f0d04 --- /dev/null +++ b/research/trl/train_trl_grpo.py @@ -0,0 +1,1238 @@ +#!/usr/bin/env python3 +"""TRL GRPO training script with factory pattern for GSM8K and MATH datasets. + +Supports both datasets with exact parity to GRAIL environment implementations: +- GSM8K: Grade school math (7,473 train / 1,319 test) +- MATH: Hendrycks MATH benchmark (7,000 train / 500 val / 5,000 test) + +Usage: + python train_trl_grpo.py --dataset gsm8k + python train_trl_grpo.py --dataset math +""" + +from __future__ import annotations + +import abc +import argparse +import asyncio +import os +import re +import sys +from dataclasses import dataclass +from typing import Any + +import torch +from datasets import Dataset +from dotenv import load_dotenv +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedTokenizer, + TrainerCallback, +) +from trl import GRPOConfig, GRPOTrainer + +# Force unbuffered output for better logging in nohup mode +sys.stdout = open(sys.stdout.fileno(), mode="w", buffering=1) +sys.stderr = open(sys.stderr.fileno(), mode="w", buffering=1) + +# Load environment from .env for WandB +load_dotenv("/root/grail/.env") + +sys.path.append("/root/grail") + +# GRAIL imports - reuse task sources and validation logic (after sys.path.append) +from grail.environments.math_hendrycks_env import _math_answers_equal # noqa: E402 +from grail.environments.providers import GSM8KTaskSource, MATHTaskSource # noqa: E402 +from grail.shared.chat_templates import build_qwen_chat_template # noqa: E402 +from grail.trainer.metrics import KMetricsAggregator, TaskReplicateResult # noqa: E402 + + +# ════════════════════════════════════════════════════════════════════════════ +# HYPERPARAMETERS (from .env GRAIL config - exactly matching grail/trainer/algorithms/grpo.py) +# ════════════════════════════════════════════════════════════════════════════ +@dataclass +class Config: + # ──────────────────────────────────────────────────────────────────────── + # Model Configuration (from GRAIL_TRAIN_MODEL_ID) + # ──────────────────────────────────────────────────────────────────────── + model_id: str = "Qwen/Qwen2.5-1.5B-Instruct" + + # ──────────────────────────────────────────────────────────────────────── + # Training Hyperparameters (from grail/shared/constants.py + env vars) + # These match GRAIL's GRPOAlgorithm config exactly + # ──────────────────────────────────────────────────────────────────────── + # Learning rate (GRAIL_TRAINER_LR, constants.py default: 1e-6) + lr: float = 3e-6 + # Epochs per training iteration (GRAIL_TRAINER_EPOCHS, constants.py default: 1) + epochs: int = 1 + # Batch size per device (GRAIL_TRAINER_BATCH_SIZE, constants.py default: 16) + batch_size: int = 4 + # Gradient accumulation steps (GRAIL_TRAINER_GRAD_ACCUM_STEPS, constants.py default: 8) + # Effective batch = batch_size × grad_accum_steps = 4 × 128 = 512 + grad_accum_steps: int = 128 + # Max sequence length (GRAIL_TRAINER_MAX_LENGTH, constants.py default: 2048) + max_length: int = 2048 + # Gradient clipping threshold (GRAIL_TRAINER_GRAD_CLIP, constants.py default: 0.5) + grad_clip: float = 1.0 + # Warmup steps for LR scheduler (GRAIL_TRAINER_WARMUP_STEPS, constants.py default: 10) + warmup_steps: int = 50 + # Total training windows (GRAIL_TRAINER_TOTAL_WINDOWS) - controls iteration count + # Each optimizer step = 32 groups × 16 rollouts = 512 samples + # total_optimizer_steps calculated below based on total_windows + total_steps: int = 400 + + # ──────────────────────────────────────────────────────────────────────── + # GRPO Loss Configuration (from grail/trainer/algorithms/grpo.py) + # ──────────────────────────────────────────────────────────────────────── + # KL divergence coefficient (GRAIL_TRAINER_KL_COEF, constants.py default: 0.02) + kl_coef: float = 0.0 + # Entropy coefficient for exploration (GRAIL_TRAINER_ENTROPY_COEF, constants.py default: 0.001) + # Note: TRL may not support entropy regularization directly + entropy_coef: float = 0.0005 + # PPO clip epsilon lower bound (TRAINER_PPO_CLIP_EPS, constants.py default: 0.2) + ppo_clip_eps: float = 0.2 + # PPO clip epsilon upper bound - DAPO-style asymmetric clipping + # (TRAINER_PPO_CLIP_EPS_UPPER, constants.py default: 0.28) + ppo_clip_eps_upper: float = 0.28 + # Importance sampling ratio ceiling (GRAIL_TRAINER_IS_RATIO_MAX, constants.py default: 10.0) + # Prevents training instability from extreme ratios + is_ratio_max: float = 2.5 + # Log-ratio clamp for numerical stability (GRAIL_TRAINER_LOGRATIO_CLAMP, constants.py default: 5.0) + # ln(2.5) ≈ 0.916 → aligned with IS_RATIO_MAX + logratio_clamp: float = 0.92 + # Advantage clipping percentile (GRAIL_TRAINER_ADV_CLIP_PERCENTILE, constants.py default: 99.0) + # Note: TRL handles advantage normalization differently + adv_clip_percentile: float = 99.0 + # Group advantage sum tolerance (GRAIL_TRAINER_GROUP_ADV_SUM_TOL, constants.py default: 0.01) + # Note: TRL doesn't use group validation, but kept for reference + group_adv_sum_tol: float = 0.01 + # GRPO loss variant (GRAIL_GRPO_VARIANT, constants.py default: "dapo") + # Options: 'grpo', 'bnpo', 'dapo', 'dr_grpo' + grpo_variant: str = "dapo" + # Importance sampling level (GRAIL_IMPORTANCE_SAMPLING_LEVEL, constants.py default: "sequence") + # Options: 'sequence' (one ratio per sequence), 'token' (per-token ratios) + # Note: TRL uses token-level IS by default when using vLLM + importance_sampling_level: str = "sequence" + + # ──────────────────────────────────────────────────────────────────────── + # GRPO Data Configuration (from grail/shared/constants.py) + # ──────────────────────────────────────────────────────────────────────── + # Groups per optimizer step = effective_batch / rollouts_per_problem = 512 / 16 = 32 + max_groups: int = 32 + # Max completion tokens (GRPO_MAX_COMPLETION_TOKENS, constants.py default: 1024) + max_new_tokens: int = 1024 + # Rollouts per problem (ROLLOUTS_PER_PROBLEM, constants.py: 16) + rollouts_per_problem: int = 16 + + # ──────────────────────────────────────────────────────────────────────── + # Dataset Sampling + # ──────────────────────────────────────────────────────────────────────── + num_train_samples: int | None = None # None = use all training samples + num_eval_samples: int | None = None # None = use all test samples + + # ──────────────────────────────────────────────────────────────────────── + # Generation Parameters + # ──────────────────────────────────────────────────────────────────────── + temperature: float = 0.7 + top_p: float = 0.95 + top_k: int = 50 + + # ──────────────────────────────────────────────────────────────────────── + # Evaluation Configuration + # ──────────────────────────────────────────────────────────────────────── + eval_replicates: int = 5 + report_ks: tuple[int, ...] = (1, 5, 10) + eval_batch_size: int = 128 + eval_num_workers: int = 4 + + +cfg = Config() + +# ════════════════════════════════════════════════════════════════════════════ +# SYSTEM PROMPT & TAGS (shared across datasets) +# ════════════════════════════════════════════════════════════════════════════ +REASONING_START_TOKEN = "start_working_out" +REASONING_END_TOKEN = "end_working_out" +SOLUTION_START_TOKEN = "SOLUTION" +SOLUTION_END_TOKEN = "SOLUTION" + +REASONING_START = f"<{REASONING_START_TOKEN}>" +REASONING_END = f"" +SOLUTION_START = f"<{SOLUTION_START_TOKEN}>" +SOLUTION_END = f"" + +SYSTEM_PROMPT = ( + "You are given a problem.\n" + "Think about the problem and provide your working out.\n" + f"Place it between {REASONING_START} and {REASONING_END}.\n" + f"Then, provide your solution between {SOLUTION_START}{SOLUTION_END}." +) + +QWEN_CHAT_TEMPLATE = build_qwen_chat_template( + system_prompt=SYSTEM_PROMPT, reasoning_start=REASONING_START +) + + +# ════════════════════════════════════════════════════════════════════════════ +# DATASET ADAPTER (Abstract Base + Concrete Implementations) +# ════════════════════════════════════════════════════════════════════════════ +class DatasetAdapter(abc.ABC): + """Abstract base class for dataset adapters. + + Provides unified interface for: + - Loading train/eval datasets + - Parsing gold answers + - Computing rewards + - Determining success threshold + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """Dataset name for logging.""" + ... + + @property + @abc.abstractmethod + def question_field(self) -> str: + """Field name for question/problem text.""" + ... + + @property + @abc.abstractmethod + def answer_field(self) -> str: + """Field name for gold answer.""" + ... + + @property + @abc.abstractmethod + def correctness_weight(self) -> float: + """Weight for correctness component in reward.""" + ... + + @property + @abc.abstractmethod + def success_threshold(self) -> float: + """Reward threshold for success (correctness weight).""" + ... + + @abc.abstractmethod + def load_train_data(self) -> list[dict[str, Any]]: + """Load training data as list of dicts.""" + ... + + @abc.abstractmethod + def load_eval_data(self) -> list[dict[str, Any]]: + """Load evaluation data as list of dicts.""" + ... + + @abc.abstractmethod + def parse_gold_answer(self, raw_answer: str) -> str: + """Extract gold answer from dataset format.""" + ... + + @abc.abstractmethod + def validate_answer(self, predicted: str, gold: str) -> bool: + """Check if predicted answer matches gold.""" + ... + + @abc.abstractmethod + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute total reward for completion.""" + ... + + +# ──────────────────────────────────────────────────────────────────────────── +# GSM8K Adapter +# ──────────────────────────────────────────────────────────────────────────── +class GSM8KAdapter(DatasetAdapter): + """GSM8K dataset adapter using GRAIL's GSM8KTaskSource.""" + + # Regex patterns (from gsm8k_env.py) + _HASH_PATTERN = re.compile(r"####\s*(?P.+)") + _NUMBER_PATTERN = re.compile(r"[-+]?\d+(?:[\.,]\d+)?") + _NUMERIC_ONLY_PATTERN = re.compile(r"^[-+]?[\d.,]+$") + + def __init__(self) -> None: + self._train_source = GSM8KTaskSource(split="train") + self._eval_source = GSM8KTaskSource(split="test") + + @property + def name(self) -> str: + return "gsm8k" + + @property + def question_field(self) -> str: + return "question" + + @property + def answer_field(self) -> str: + return "answer" + + @property + def correctness_weight(self) -> float: + return 0.6 # GSM8K uses 0.6 for correctness + + @property + def success_threshold(self) -> float: + return 0.6 # Success if correctness achieved + + def load_train_data(self) -> list[dict[str, Any]]: + """Load GSM8K training data via task source.""" + self._train_source._ensure_dataset() + assert self._train_source._ds is not None + data = [] + for i in range(len(self._train_source._ds)): + sample = self._train_source._ds[i] + data.append( + { + "question": sample["question"], + "answer": sample["answer"], + } + ) + return data + + def load_eval_data(self) -> list[dict[str, Any]]: + """Load GSM8K test data via task source.""" + self._eval_source._ensure_dataset() + assert self._eval_source._ds is not None + data = [] + for i in range(len(self._eval_source._ds)): + sample = self._eval_source._ds[i] + data.append( + { + "question": sample["question"], + "answer": sample["answer"], + } + ) + return data + + def parse_gold_answer(self, raw_answer: str) -> str: + """Parse GSM8K gold answer from #### format.""" + match = None + for m in self._HASH_PATTERN.finditer(raw_answer or ""): + match = m + if match is not None: + return match.group("ans").strip() + nums = list(self._NUMBER_PATTERN.finditer(raw_answer or "")) + if nums: + return nums[-1].group(0).replace(",", "").strip() + return "" + + def validate_answer(self, predicted: str, gold: str) -> bool: + """Validate GSM8K answer (numeric exact match).""" + pred_norm = re.sub(r"[\s\.]+$", "", predicted.strip().lower()) + gold_norm = re.sub(r"[\s\.]+$", "", gold.strip().lower()) + return pred_norm == gold_norm + + def _parse_completion(self, text: str) -> dict[str, Any]: + """Parse completion for thinking/answer tags.""" + flags = re.DOTALL | re.IGNORECASE + has_thinking = bool( + re.search(rf"<{REASONING_START_TOKEN}>.*?", text, flags) + ) + answer_match = re.search( + rf"<{SOLUTION_START_TOKEN}>\s*(.+?)\s*", text, flags + ) + + answer_text = "" + has_answer = bool(answer_match) + is_numeric_only = False + trailing = 0 + + if answer_match: + inside = answer_match.group(1).strip() + num_match = self._NUMBER_PATTERN.search(inside) + if num_match: + answer_text = num_match.group(0).replace(",", "").strip() + is_numeric_only = bool(self._NUMERIC_ONLY_PATTERN.match(inside.replace(" ", ""))) + trailing = len(text) - answer_match.end() + + return { + "answer_text": answer_text, + "has_thinking": has_thinking, + "has_answer": has_answer, + "is_numeric_only": is_numeric_only, + "trailing": trailing, + } + + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute GSM8K reward (matching GSM8KEnv weights). + + Components: + - Correctness (0.6): exact match + - Strict format (0.15): numeric-only + no trailing + - Thinking (0.1): has thinking block + - Answer (0.1): has answer block + - No trailing (0.05): penalty for trailing text + """ + parsed = self._parse_completion(completion) + gold_parsed = self.parse_gold_answer(gold_answer) + + # Correctness + correctness = 0.6 if self.validate_answer(parsed["answer_text"], gold_parsed) else 0.0 + + # Strict format + strict_format = ( + 0.15 + if (parsed["has_answer"] and parsed["is_numeric_only"] and parsed["trailing"] == 0) + else 0.0 + ) + + # Thinking format + thinking = 0.1 if parsed["has_thinking"] else 0.0 + + # Answer format + answer = 0.1 if parsed["has_answer"] else 0.0 + + # No trailing + no_trailing = 0.05 if parsed["trailing"] == 0 else 0.0 + + return correctness + strict_format + thinking + answer + no_trailing + + +# ──────────────────────────────────────────────────────────────────────────── +# MATH (Hendrycks) Adapter +# ──────────────────────────────────────────────────────────────────────────── +class MATHAdapter(DatasetAdapter): + """MATH dataset adapter using GRAIL's MATHTaskSource. + + Uses exact same validation logic as MATHEnv: + - Multi-strategy comparison (exact, symbolic via sympy, numeric) + - LaTeX normalization + - Stratified train/val split (500 val samples) + """ + + def __init__(self) -> None: + self._train_source = MATHTaskSource(split="train") + self._eval_source = MATHTaskSource(split="val") # Use stratified val split + + @property + def name(self) -> str: + return "math" + + @property + def question_field(self) -> str: + return "question" # Normalized to 'question' for consistency + + @property + def answer_field(self) -> str: + return "answer" + + @property + def correctness_weight(self) -> float: + return 0.7 # MATH uses 0.7 for correctness + + @property + def success_threshold(self) -> float: + return 0.7 # Success if correctness achieved + + def load_train_data(self) -> list[dict[str, Any]]: + """Load MATH training data via task source (7000 samples).""" + self._train_source._ensure_dataset() + assert self._train_source._data is not None + data = [] + for sample in self._train_source._data: + data.append( + { + "question": sample["problem"], # Normalize field name + "answer": sample["answer"], # Pre-extracted from \boxed{} + "solution": sample["solution"], + "level": sample["level"], + "subject": sample["subject"], + } + ) + return data + + def load_eval_data(self) -> list[dict[str, Any]]: + """Load MATH validation data via task source (500 samples, stratified).""" + self._eval_source._ensure_dataset() + assert self._eval_source._data is not None + data = [] + for sample in self._eval_source._data: + data.append( + { + "question": sample["problem"], + "answer": sample["answer"], + "solution": sample["solution"], + "level": sample["level"], + "subject": sample["subject"], + } + ) + return data + + def parse_gold_answer(self, raw_answer: str) -> str: + """For MATH, answer is already extracted from \\boxed{} by TaskSource.""" + return raw_answer + + def validate_answer(self, predicted: str, gold: str) -> bool: + """Validate MATH answer using multi-strategy comparison. + + Uses GRAIL's _math_answers_equal which tries: + 1. Exact match (after LaTeX normalization) + 2. Symbolic equivalence (via sympy) + 3. Numeric comparison (floats) + """ + return _math_answers_equal(predicted, gold) + + def _parse_completion(self, text: str) -> dict[str, Any]: + """Parse completion for thinking/answer tags (MATH-specific).""" + flags = re.DOTALL | re.IGNORECASE + has_thinking = bool( + re.search(rf"<{REASONING_START_TOKEN}>.*?", text, flags) + ) + answer_match = re.search( + rf"<{SOLUTION_START_TOKEN}>\s*(.+?)\s*", text, flags + ) + + answer_text = "" + has_answer = bool(answer_match) + trailing = 0 + + if answer_match: + answer_text = answer_match.group(1).strip() + trailing = len(text) - answer_match.end() + + return { + "answer_text": answer_text, + "has_thinking": has_thinking, + "has_answer": has_answer, + "trailing": trailing, + } + + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute MATH reward (matching MATHEnv weights). + + Components: + - Correctness (0.7): Multi-strategy validation + - Answer format (0.15): Has answer + minimal trailing + - Thinking (0.1): Has thinking block + - No trailing (0.05): Penalty for excessive trailing + """ + parsed = self._parse_completion(completion) + + # Correctness (using multi-strategy validation) + correctness = 0.7 if self.validate_answer(parsed["answer_text"], gold_answer) else 0.0 + + # Answer format (has answer + trailing < 50) + answer_format = 0.15 if (parsed["has_answer"] and parsed["trailing"] < 50) else 0.0 + + # Thinking format + thinking = 0.1 if parsed["has_thinking"] else 0.0 + + # No trailing (stricter check) + no_trailing = 0.05 if parsed["trailing"] == 0 else 0.0 + + return correctness + answer_format + thinking + no_trailing + + +# ════════════════════════════════════════════════════════════════════════════ +# FACTORY FUNCTION +# ════════════════════════════════════════════════════════════════════════════ +def get_dataset_adapter(dataset_name: str) -> DatasetAdapter: + """Factory function to get dataset adapter by name. + + Args: + dataset_name: 'gsm8k' or 'math' + + Returns: + DatasetAdapter instance + + Raises: + ValueError: If dataset_name is not supported + """ + adapters: dict[str, type[DatasetAdapter]] = { + "gsm8k": GSM8KAdapter, + "math": MATHAdapter, + } + + if dataset_name.lower() not in adapters: + raise ValueError(f"Unknown dataset: {dataset_name}. Supported: {list(adapters.keys())}") + + return adapters[dataset_name.lower()]() + + +# ════════════════════════════════════════════════════════════════════════════ +# TRAINING PASS@K TRACKER +# ════════════════════════════════════════════════════════════════════════════ +class TrainingPassAtKTracker: + """Computes and logs pass@k metrics during GRPO training. + + This class wraps the reward computation and tracks pass@k metrics + by grouping completions by their prompts. Uses the same unbiased pass@k + formula as evaluation (KMetricsAggregator from grail.trainer.metrics). + + Usage: + tracker = TrainingPassAtKTracker(adapter, prompt_to_answer) + trainer = GRPOTrainer(..., reward_funcs=tracker, ...) + """ + + # Required by TRL GRPOTrainer for reward function naming + __name__ = "reward_with_pass_at_k" + + def __init__( + self, + adapter: DatasetAdapter, + prompt_to_answer: dict[str, str], + report_ks: tuple[int, ...] = (1, 5, 10), + ) -> None: + """Initialize the tracker. + + Args: + adapter: Dataset adapter for reward computation and success threshold + prompt_to_answer: Mapping from prompt text to gold answer + report_ks: Tuple of k values for pass@k metrics + """ + self._adapter = adapter + self._prompt_to_answer = prompt_to_answer + self._report_ks = report_ks + self._step_count = 0 + + def __call__( + self, + completions: list[str], + prompts: list[str], + **kwargs: Any, + ) -> list[float]: + """Compute rewards and log pass@k metrics. + + This method is called by GRPOTrainer for each batch of completions. + + Args: + completions: List of model completions + prompts: List of corresponding prompts + **kwargs: Additional arguments (gold_answer, metadatas, etc.) + + Returns: + List of reward values for each completion + """ + gold_answers = self._extract_gold_answers(prompts, kwargs) + rewards = self._compute_rewards(completions, gold_answers) + metrics = self._compute_pass_at_k_metrics(prompts, rewards) + self._log_to_wandb(metrics) + self._step_count += 1 + return rewards + + def _extract_gold_answers( + self, + prompts: list[str], + kwargs: dict[str, Any], + ) -> list[str]: + """Extract gold answers from kwargs or prompt mapping.""" + if "gold_answer" in kwargs and kwargs["gold_answer"]: + return kwargs["gold_answer"] + if "metadatas" in kwargs and kwargs["metadatas"]: + return [m.get("gold_answer", "") for m in kwargs["metadatas"]] + return [self._prompt_to_answer.get(p, "") for p in prompts] + + def _compute_rewards( + self, + completions: list[str], + gold_answers: list[str], + ) -> list[float]: + """Compute reward for each completion.""" + return [ + self._adapter.compute_reward(c, g) + for c, g in zip(completions, gold_answers, strict=False) + ] + + def _compute_pass_at_k_metrics( + self, + prompts: list[str], + rewards: list[float], + ) -> dict[str, float]: + """Compute all metrics using KMetricsAggregator (unbiased pass@k formula).""" + from collections import defaultdict + + # Group rewards by prompt + prompt_groups: dict[str, list[float]] = defaultdict(list) + for prompt, reward in zip(prompts, rewards, strict=False): + prompt_groups[prompt].append(reward) + + group_count = len(prompt_groups) + expected_groups = cfg.max_groups + step_index = self._step_count + 1 + print( + "[TrainingPassAtKTracker] " + f"Step {step_index}: grouped {group_count} prompts " + f"(max_groups={expected_groups})" + ) + if group_count != expected_groups: + print( + "[TrainingPassAtKTracker] ⚠️ " + f"group_count ({group_count}) != max_groups ({expected_groups})" + ) + + # Use KMetricsAggregator for metrics computation + aggregator = KMetricsAggregator(report_ks=self._report_ks) + threshold = self._adapter.success_threshold + + for task_id, group_rewards in enumerate(prompt_groups.values()): + successes = [r >= threshold for r in group_rewards] + aggregator.add_group( + task_id=str(task_id), + rewards=group_rewards, + successes=successes, + ) + + return aggregator.summarize() + + def _log_to_wandb(self, metrics: dict[str, float]) -> None: + """Log metrics to WandB.""" + try: + import wandb + + if wandb.run is not None and metrics: + wandb_data = {f"train/{k}": v for k, v in metrics.items()} + wandb.log(wandb_data) + except Exception: + pass # Silently ignore WandB errors + + +# ════════════════════════════════════════════════════════════════════════════ +# DATA PREPARATION +# ════════════════════════════════════════════════════════════════════════════ +def prepare_train_dataset(adapter: DatasetAdapter, tokenizer: PreTrainedTokenizer) -> Dataset: + """Load and format training dataset for TRL GRPO. + + Args: + adapter: Dataset adapter instance + tokenizer: Tokenizer for chat template formatting + + Returns: + HuggingFace Dataset with 'prompt' and 'gold_answer' columns + """ + raw_data = adapter.load_train_data() + + if cfg.num_train_samples is not None: + raw_data = raw_data[: cfg.num_train_samples] + + formatted = [] + for sample in raw_data: + question = sample[adapter.question_field] + prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ], + tokenize=False, + add_generation_prompt=True, + ) + formatted.append( + { + "prompt": prompt, + "gold_answer": sample[adapter.answer_field], + } + ) + + print(f" Training dataset ({adapter.name}): {len(formatted)} samples") + return Dataset.from_list(formatted) + + +def prepare_eval_dataset(adapter: DatasetAdapter) -> tuple[Dataset, list[dict[str, Any]]]: + """Load evaluation dataset. + + Args: + adapter: Dataset adapter instance + + Returns: + Tuple of (HuggingFace Dataset, raw data list for reward computation) + """ + raw_data = adapter.load_eval_data() + + if cfg.num_eval_samples is not None: + raw_data = raw_data[: cfg.num_eval_samples] + + print(f" Eval dataset ({adapter.name}): {len(raw_data)} samples") + return Dataset.from_list(raw_data), raw_data + + +# ════════════════════════════════════════════════════════════════════════════ +# VLLM EVALUATION CALLBACK +# ════════════════════════════════════════════════════════════════════════════ +class VLLMEvalCallback(TrainerCallback): + """Evaluation callback using TRL vLLM server with dataset adapter.""" + + def __init__( + self, + adapter: DatasetAdapter, + eval_data: list[dict[str, Any]], + tokenizer: PreTrainedTokenizer, + vllm_base_url: str, + eval_every_n_steps: int = 40, + ) -> None: + self.adapter = adapter + self.eval_data = eval_data + self.tokenizer = tokenizer + self.eval_every_n = eval_every_n_steps + self.base_url = vllm_base_url.rstrip("/") + self._wandb_configured = False + + print( + f"✓ VLLMEvalCallback initialized: dataset={adapter.name}, " + f"url={vllm_base_url}, eval_every={eval_every_n_steps}" + ) + + def run_and_log(self, step: int, label: str = "VLLM EVAL") -> dict[str, float]: + """Run evaluation and log to WandB.""" + print(f"\n{'=' * 80}") + print(f"[{label}] Step {step}: Starting {self.adapter.name.upper()} evaluation...") + print(f"{'=' * 80}") + + metrics = asyncio.run(self._run_eval()) + + try: + import wandb + + if wandb.run is not None: + # Configure step metric for eval on first call + if not self._wandb_configured: + wandb.define_metric("eval_step") + wandb.define_metric("eval/*", step_metric="eval_step") + self._wandb_configured = True + + # Log eval metrics with 'eval/' prefix and custom step + wandb_data = {"eval_step": step} + wandb_data.update({f"eval/{k}": v for k, v in metrics.items()}) + wandb.log(wandb_data) + except Exception as e: + print(f"⚠️ WandB logging failed: {e}") + + print(f"[{label}] Results: {metrics}") + print(f"{'=' * 80}\n") + return metrics + + def on_step_end(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None: + """Run evaluation every N steps.""" + if state.global_step >= self.eval_every_n and state.global_step % self.eval_every_n == 0: + self.run_and_log(state.global_step) + + async def _run_eval(self) -> dict[str, float]: + """Run evaluation using vLLM chat completions API.""" + import time + + from tqdm import tqdm + + start_time = time.time() + aggregator = KMetricsAggregator(report_ks=cfg.report_ks) + + total_tasks = len(self.eval_data) + batch_size = cfg.eval_batch_size + + with tqdm(total=total_tasks, desc=f"Eval ({self.adapter.name})", unit="task") as pbar: + for batch_start in range(0, total_tasks, batch_size): + batch_end = min(batch_start + batch_size, total_tasks) + batch = self.eval_data[batch_start:batch_end] + + # Get questions using adapter's field name + batch_questions = [s[self.adapter.question_field] for s in batch] + batch_golds = [s[self.adapter.answer_field] for s in batch] + + # Expand: each question gets N replicates + tasks_to_generate = [] + task_metadata = [] + + for idx, question in enumerate(batch_questions): + task_id = f"q{batch_start + idx}" + for rep_idx in range(cfg.eval_replicates): + tasks_to_generate.append(question) + task_metadata.append( + { + "task_id": task_id, + "task_idx": idx, + "replicate_idx": rep_idx, + } + ) + + # Generate completions + completions = await self._generate_batch(tasks_to_generate) + + # Log sample completions + if batch_start == 0: + print("\n ━━━ Sample Completions ━━━") + for i in range(min(3, len(completions))): + question = tasks_to_generate[i] + completion = completions[i] + metadata = task_metadata[i] + gold = batch_golds[metadata["task_idx"]] + reward = self.adapter.compute_reward(completion, gold) + + q_display = question[:150] + "..." if len(question) > 150 else question + c_display = ( + completion[:300] + "..." if len(completion) > 300 else completion + ) + print(f"\n Sample {i + 1}:") + print(f" Question: {q_display}") + print(f" Completion: {c_display}") + print(f" Reward: {reward:.3f} | Gold: {gold[:50]}...") + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━\n") + + # Compute rewards and aggregate + for completion_text, metadata in zip(completions, task_metadata, strict=False): + task_id = metadata["task_id"] + task_idx = metadata["task_idx"] + replicate_idx = metadata["replicate_idx"] + gold = batch_golds[task_idx] + + reward = self.adapter.compute_reward(completion_text, gold) + success = reward >= self.adapter.success_threshold + + aggregator.add( + TaskReplicateResult( + task_id=task_id, + replicate_idx=replicate_idx, + reward=reward, + success=success, + ) + ) + + pbar.update(len(batch_questions)) + + metrics = aggregator.summarize() + elapsed = time.time() - start_time + throughput = (total_tasks * cfg.eval_replicates) / elapsed if elapsed > 0 else 0 + + print( + f" ✓ Evaluated {total_tasks} tasks × {cfg.eval_replicates} reps in {elapsed:.2f}s " + f"({throughput:.1f} completions/sec)" + ) + + return metrics + + async def _generate_batch(self, questions: list[str]) -> list[str]: + """Generate completions using TRL /chat/ endpoint with batching.""" + import asyncio + + import aiohttp + + vllm_batch_size = 64 + total = len(questions) + num_requests = (total + vllm_batch_size - 1) // vllm_batch_size + print(f" Generating {total} completions via {num_requests} batched requests") + + async def generate_batch_request( + session: aiohttp.ClientSession, batch_questions: list[str], start_idx: int + ) -> tuple[int, list[list[int]]]: + max_retries = 3 + base_backoff = 1.0 + + messages = [ + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": q}, + ] + for q in batch_questions + ] + + payload = { + "messages": messages, + "max_tokens": cfg.max_new_tokens, + "temperature": cfg.temperature, + "top_p": cfg.top_p, + "top_k": cfg.top_k, + "repetition_penalty": 1.1, + "n": 1, + } + + for attempt in range(max_retries): + try: + async with session.post( + f"{self.base_url}/chat/", + json=payload, + timeout=aiohttp.ClientTimeout(total=300.0), + ) as response: + if response.status == 200: + data = await response.json() + return (start_idx, data["completion_ids"]) + else: + error_text = await response.text() + raise Exception(f"HTTP {response.status}: {error_text}") + except Exception as e: + if attempt < max_retries - 1: + backoff = base_backoff * (2**attempt) + await asyncio.sleep(backoff) + else: + print(f" ⚠️ Batch {start_idx} failed: {type(e).__name__}") + return (start_idx, [[] for _ in batch_questions]) + return (start_idx, [[] for _ in batch_questions]) + + async with aiohttp.ClientSession() as session: + tasks = [] + for batch_start in range(0, total, vllm_batch_size): + batch_end = min(batch_start + vllm_batch_size, total) + batch_questions = questions[batch_start:batch_end] + tasks.append(generate_batch_request(session, batch_questions, batch_start)) + + results = await asyncio.gather(*tasks, return_exceptions=False) + + all_completion_ids: list[list[int]] = [[] for _ in range(total)] + for start_idx, completion_ids_batch in results: + for offset, comp_ids in enumerate(completion_ids_batch): + all_completion_ids[start_idx + offset] = comp_ids + + completions = [] + for comp_ids in all_completion_ids: + if comp_ids: + completion_text = self.tokenizer.decode(comp_ids, skip_special_tokens=True) + completions.append(completion_text) + else: + completions.append("") + + return completions + + +# ════════════════════════════════════════════════════════════════════════════ +# MAIN TRAINING +# ════════════════════════════════════════════════════════════════════════════ +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="TRL GRPO training with GSM8K or MATH dataset") + parser.add_argument( + "--dataset", + type=str, + default="gsm8k", + choices=["gsm8k", "math"], + help="Dataset to use for training (default: gsm8k)", + ) + parser.add_argument( + "--eval-every", + type=int, + default=40, + help="Run evaluation every N steps (default: 30)", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + print(f"🚀 Starting TRL GRPO training with {args.dataset.upper()} dataset") + print("=" * 80) + + # Print hyperparameter alignment summary + print("\n📋 GRAIL Hyperparameter Alignment Summary:") + print("─" * 80) + print(f" {'Parameter':<40} {'Value':<15} {'GRAIL Env Var'}") + print("─" * 80) + print(f" {'Model ID':<40} {cfg.model_id:<15} GRAIL_TRAIN_MODEL_ID") + print(f" {'Learning Rate':<40} {cfg.lr:<15} GRAIL_TRAINER_LR") + print(f" {'Epochs (per window)':<40} {cfg.epochs:<15} GRAIL_TRAINER_EPOCHS") + print(f" {'Batch Size':<40} {cfg.batch_size:<15} GRAIL_TRAINER_BATCH_SIZE") + print( + f" {'Gradient Accum Steps':<40} {cfg.grad_accum_steps:<15} GRAIL_TRAINER_GRAD_ACCUM_STEPS" + ) + print(f" {'Max Length':<40} {cfg.max_length:<15} GRAIL_TRAINER_MAX_LENGTH") + print(f" {'Max Completion Tokens':<40} {cfg.max_new_tokens:<15} GRPO_MAX_COMPLETION_TOKENS") + print(f" {'Gradient Clip':<40} {cfg.grad_clip:<15} GRAIL_TRAINER_GRAD_CLIP") + print(f" {'Warmup Steps':<40} {cfg.warmup_steps:<15} GRAIL_TRAINER_WARMUP_STEPS") + print(f" {'Total Steps':<40} {cfg.total_steps:<15} GRAIL_TRAINER_TOTAL_STEPS") + print(f" {'KL Coefficient':<40} {cfg.kl_coef:<15} GRAIL_TRAINER_KL_COEF") + print(f" {'Entropy Coefficient':<40} {cfg.entropy_coef:<15} GRAIL_TRAINER_ENTROPY_COEF") + print(f" {'PPO Clip Epsilon':<40} {cfg.ppo_clip_eps:<15} TRAINER_PPO_CLIP_EPS") + print( + f" {'PPO Clip Epsilon Upper':<40} {cfg.ppo_clip_eps_upper:<15} TRAINER_PPO_CLIP_EPS_UPPER" + ) + print(f" {'IS Ratio Max':<40} {cfg.is_ratio_max:<15} GRAIL_TRAINER_IS_RATIO_MAX") + print(f" {'Log-Ratio Clamp':<40} {cfg.logratio_clamp:<15} GRAIL_TRAINER_LOGRATIO_CLAMP") + print(f" {'GRPO Variant':<40} {cfg.grpo_variant:<15} GRAIL_GRPO_VARIANT") + print(f" {'IS Level':<40} {cfg.importance_sampling_level:<15} GRAIL_IMPORTANCE_SAMPLING_LEVEL") + print(f" {'Max Groups':<40} {cfg.max_groups:<15} GRPO_MAX_GROUPS") + print(f" {'Rollouts per Problem':<40} {cfg.rollouts_per_problem:<15} ROLLOUTS_PER_PROBLEM") + print("─" * 80) + + # Get dataset adapter + adapter = get_dataset_adapter(args.dataset) + print("\n📚 Dataset Configuration:") + print(f" Dataset: {adapter.name}") + print(f" Correctness weight: {adapter.correctness_weight}") + print(f" Success threshold: {adapter.success_threshold}") + + # Load model and tokenizer + print("\n📦 Loading model and tokenizer...") + try: + model = AutoModelForCausalLM.from_pretrained( + cfg.model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ) + except (ImportError, RuntimeError) as e: + print(f"⚠️ Flash Attention 2 unavailable ({type(e).__name__}), using default") + model = AutoModelForCausalLM.from_pretrained( + cfg.model_id, + torch_dtype=torch.bfloat16, + ) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + tokenizer.chat_template = QWEN_CHAT_TEMPLATE + + # Prepare datasets + print("\n📊 Preparing datasets...") + train_ds = prepare_train_dataset(adapter, tokenizer) + eval_ds, eval_data = prepare_eval_dataset(adapter) + prompt_to_answer = {row["prompt"]: row["gold_answer"] for row in train_ds} + + # WandB setup + print("\n⚙️ Configuring GRPO trainer...") + import wandb + + wandb_api_key = os.getenv("WANDB_API_KEY") + if wandb_api_key: + wandb.login(key=wandb_api_key) + print(f" ✓ WandB logged in (project: {os.getenv('WANDB_PROJECT', 'grail')})") + + # Calculate max_prompt_length (GRAIL_TRAINER_MAX_LENGTH - GRPO_MAX_COMPLETION_TOKENS) + max_prompt_length = cfg.max_length - cfg.max_new_tokens + + # Calculate training schedule + # Each optimizer step = generation_batch_size = effective_batch = 512 samples + # = 32 groups × 16 rollouts + effective_batch = cfg.batch_size * cfg.grad_accum_steps # 4 × 128 = 512 + groups_per_step = effective_batch // cfg.rollouts_per_problem # 512 / 16 = 32 + total_optimizer_steps = cfg.total_steps # Fixed: maintains original training duration + + print("\n📊 Training Schedule:") + print(f" • Effective batch size: {effective_batch} samples") + print(f" • Groups per optimizer step: {groups_per_step}") + print(f" • Rollouts per group: {cfg.rollouts_per_problem}") + print(f" • Total optimizer steps: {total_optimizer_steps}") + + grpo_config = GRPOConfig( + output_dir=f"./outputs/trl_{adapter.name}_final", + # ───────────────────────────────────────────────────────────────────── + # Learning Rate & Schedule (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + learning_rate=cfg.lr, # GRAIL_TRAINER_LR + warmup_steps=cfg.warmup_steps, # GRAIL_TRAINER_WARMUP_STEPS + lr_scheduler_type="cosine", # Cosine annealing (matches grail/neurons/trainer.py) + # Use max_steps to control iterations (matching GRAIL_TRAINER_TOTAL_WINDOWS) + # num_train_epochs is ignored when max_steps is set + num_train_epochs=cfg.epochs, + max_steps=total_optimizer_steps, # Calculated from total_windows + # ───────────────────────────────────────────────────────────────────── + # Batch Size & Gradient Accumulation (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + per_device_train_batch_size=cfg.batch_size, # GRAIL_TRAINER_BATCH_SIZE + gradient_accumulation_steps=cfg.grad_accum_steps, # GRAIL_TRAINER_GRAD_ACCUM_STEPS + max_grad_norm=cfg.grad_clip, # GRAIL_TRAINER_GRAD_CLIP + # ───────────────────────────────────────────────────────────────────── + # GRPO Loss Configuration (matching grail/trainer/algorithms/grpo.py) + # ───────────────────────────────────────────────────────────────────── + beta=cfg.kl_coef, # GRAIL_TRAINER_KL_COEF (KL divergence coefficient) + epsilon=cfg.ppo_clip_eps, # TRAINER_PPO_CLIP_EPS (lower clip bound) + epsilon_high=cfg.ppo_clip_eps_upper, # TRAINER_PPO_CLIP_EPS_UPPER (DAPO asymmetric) + loss_type=cfg.grpo_variant, # GRAIL_GRPO_VARIANT ("dapo") + # ───────────────────────────────────────────────────────────────────── + # Sequence Length (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + max_prompt_length=max_prompt_length, # max_length - max_completion_tokens + max_completion_length=cfg.max_new_tokens, # GRPO_MAX_COMPLETION_TOKENS + # ───────────────────────────────────────────────────────────────────── + # Importance Sampling Level + # ───────────────────────────────────────────────────────────────────── + importance_sampling_level=cfg.importance_sampling_level, # GRAIL_IMPORTANCE_SAMPLING_LEVEL + # ───────────────────────────────────────────────────────────────────── + # Generation Parameters + # ───────────────────────────────────────────────────────────────────── + temperature=cfg.temperature, + top_p=cfg.top_p, + top_k=cfg.top_k, + repetition_penalty=1.1, + num_generations=cfg.rollouts_per_problem, # ROLLOUTS_PER_PROBLEM + # generation_batch_size must equal effective_batch to ensure: + # - One generation per optimizer step (no stale advantages) + # - 32 groups × 16 rollouts = 512 samples per optimizer update + generation_batch_size=cfg.batch_size * cfg.grad_accum_steps, # 4 × 128 = 512 + # ───────────────────────────────────────────────────────────────────── + # Logging & Checkpointing + # ───────────────────────────────────────────────────────────────────── + logging_steps=1, + log_completions=True, + num_completions_to_print=1, + wandb_log_unique_prompts=True, + save_strategy="steps", + save_steps=40, + bf16=True, + report_to=["wandb"], + eval_strategy="no", + run_name=f"trl_{adapter.name}_grpo_qwen15b_grail_matched_final", + # ───────────────────────────────────────────────────────────────────── + # vLLM Configuration + # ───────────────────────────────────────────────────────────────────── + use_vllm=True, + vllm_mode="server", + vllm_server_base_url="http://127.0.0.1:8000", + vllm_importance_sampling_correction=False, + vllm_importance_sampling_cap=cfg.is_ratio_max, # GRAIL_TRAINER_IS_RATIO_MAX + ) + + # Create reward tracker with pass@k logging + reward_tracker = TrainingPassAtKTracker( + adapter=adapter, + prompt_to_answer=prompt_to_answer, + report_ks=cfg.report_ks, + ) + print(f" ✓ TrainingPassAtKTracker initialized (report_ks={cfg.report_ks})") + + print(f"\n🏋️ Training with GRPO on {adapter.name.upper()}...") + + # Initialize evaluation callback + vllm_eval_callback = VLLMEvalCallback( + adapter=adapter, + eval_data=eval_data, + tokenizer=tokenizer, + vllm_base_url=grpo_config.vllm_server_base_url, + eval_every_n_steps=args.eval_every, + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_tracker, + args=grpo_config, + train_dataset=train_ds, + processing_class=tokenizer, + callbacks=[vllm_eval_callback], + ) + + # Initialize WandB explicitly before baseline eval (GRPOTrainer does it lazily in .train()) + import wandb + + if wandb.run is None and grpo_config.report_to and "wandb" in grpo_config.report_to: + wandb.init( + project=os.getenv("WANDB_PROJECT", "grail"), + name=grpo_config.run_name, + config=grpo_config.to_dict(), + ) + print(" ✓ WandB initialized explicitly for baseline eval") + + # Baseline evaluation + vllm_eval_callback.run_and_log(step=0, label="BASELINE EVAL") + + # Train + trainer.train() + + # Final evaluation + final_step = trainer.state.global_step if hasattr(trainer, "state") else 9999 + final_metrics = vllm_eval_callback.run_and_log(step=final_step, label="FINAL EVAL") + + # Print summary + print("\n" + "=" * 60) + print(f"FINAL RESULTS SUMMARY ({adapter.name.upper()})") + print("=" * 60) + for k in cfg.report_ks: + if k > cfg.eval_replicates: + continue + print(f"\nMetrics @ k={k}:") + print(f" pass@{k}: {final_metrics[f'pass@{k}']:.3f}") + print(f" pass_ordered@{k}: {final_metrics[f'pass_ordered@{k}']:.3f}") + print(f" mean@{k}: {final_metrics[f'mean@{k}']:.3f}") + print(f" best@{k}: {final_metrics[f'best@{k}']:.3f}") + print("\nGlobal metrics:") + print(f" reward_mean_all: {final_metrics['reward_mean_all']:.3f}") + print(f" success_rate_all: {final_metrics['success_rate_all']:.3f}") + + +if __name__ == "__main__": + main() diff --git a/research/trl/train_trl_grpo_README.md b/research/trl/train_trl_grpo_README.md new file mode 100644 index 0000000..12d4b53 --- /dev/null +++ b/research/trl/train_trl_grpo_README.md @@ -0,0 +1,136 @@ +# TRL GRPO Training Script + +Unified TRL GRPO training script supporting both GSM8K and MATH (Hendrycks) datasets with exact parity to GRAIL environment implementations. + +## Quickstart + +### 1. Launch the vLLM Server (Generation GPUs) + +The vLLM server handles rollout generation on separate GPUs while the trainer runs on its own GPU. + +```bash +# Activate vLLM environment +source tools/vllm-server/.venv/bin/activate + +# Launch vLLM server on GPUs 1-4 (4-way tensor parallel) +CUDA_VISIBLE_DEVICES=1,2,3,4 nohup trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --tensor-parallel-size 4 \ + --host 127.0.0.1 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + > vllm_server.log 2>&1 & + +# Wait for server to be ready (check logs) +tail -f vllm_server.log +``` + +### 2. Start GRPO Training (Training GPU) + +```bash +# Train on GSM8K (default) +CUDA_VISIBLE_DEVICES=0 nohup python research/trl/train_trl_grpo.py \ + --dataset gsm8k \ + > research/trl/train_gsm8k.log 2>&1 & + +# Train on MATH (Hendrycks) +CUDA_VISIBLE_DEVICES=0 nohup python research/trl/train_trl_grpo.py \ + --dataset math \ + > research/trl/train_math.log 2>&1 & + +# Custom eval frequency +CUDA_VISIBLE_DEVICES=0 python research/trl/train_trl_grpo.py \ + --dataset math \ + --eval-every 50 +``` + +Training logs stream to the respective log files. + +## Features + +- **Factory Pattern**: Easy switching between datasets via `--dataset` CLI flag +- **GRAIL Parity**: Uses exact same task sources, validation logic, and reward weights +- **Multi-Strategy Validation** (MATH): Exact match → Symbolic (sympy) → Numeric +- **Stratified Splits** (MATH): 7,000 train / 500 val (stratified by subject) +- **vLLM Evaluation**: Async batched evaluation with KMetrics aggregation + +## Dataset Comparison + +| Aspect | GSM8K | MATH | +|--------|-------|------| +| **Train Size** | 7,473 | 7,000 | +| **Eval Size** | 1,319 (test) | 500 (stratified val) | +| **Gold Format** | `#### answer` | `\boxed{answer}` | +| **Validation** | Numeric exact | Multi-strategy (exact/sympy/numeric) | +| **Correctness Weight** | 0.6 | 0.7 | +| **Success Threshold** | ≥0.6 | ≥0.7 | + +## Reward Components + +### GSM8K (Total: 1.0) +| Component | Weight | Description | +|-----------|--------|-------------| +| Correctness | 0.6 | Exact numeric match | +| Strict format | 0.15 | Numeric-only + no trailing | +| Thinking | 0.1 | Has reasoning block | +| Answer | 0.1 | Has solution tags | +| No trailing | 0.05 | No text after answer | + +### MATH (Total: 1.0) +| Component | Weight | Description | +|-----------|--------|-------------| +| Correctness | 0.7 | Multi-strategy validation | +| Answer format | 0.15 | Has answer + trailing < 50 chars | +| Thinking | 0.1 | Has reasoning block | +| No trailing | 0.05 | No text after answer | + +## Hyperparameters (from .env) + +| Parameter | Value | Source | +|-----------|-------|--------| +| Learning rate | 3e-6 | `GRAIL_TRAINER_LR` | +| Epochs | 1 | `GRAIL_TRAINER_EPOCHS` | +| Batch size | 4 | `GRAIL_TRAINER_BATCH_SIZE` | +| Grad accum | 128 | `GRAIL_TRAINER_GRAD_ACCUM_STEPS` | +| Max length | 2048 | `GRAIL_TRAINER_MAX_LENGTH` | +| Max completion | 1024 | `GRPO_MAX_COMPLETION_TOKENS` | +| Loss type | dapo | `GRAIL_GRPO_VARIANT` | + +## Architecture + +``` +train_trl_grpo.py +├── DatasetAdapter (ABC) +│ ├── GSM8KAdapter # Uses GSM8KTaskSource from GRAIL +│ └── MATHAdapter # Uses MATHTaskSource from GRAIL +├── get_dataset_adapter() # Factory function +├── VLLMEvalCallback # Dataset-agnostic evaluation +└── main() # CLI entry point +``` + +## GPU Layout (Example: 8x A100) + +``` +┌─────────────────────────────────────────────────────────────┐ +│ GPU 0: Training (GRPO backward pass) │ +├─────────────────────────────────────────────────────────────┤ +│ GPUs 1-4: vLLM Server (4-way tensor parallel generation) │ +├─────────────────────────────────────────────────────────────┤ +│ GPUs 5-7: Available for other tasks │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Requirements + +- TRL with vLLM support (`pip install trl[vllm]`) +- GRAIL codebase (for task sources and validation logic) +- vLLM server running on port 8000 +- Flash Attention 2 (optional, for faster training) + +## Files + +| File | Description | +|------|-------------| +| `train_trl_grpo.py` | Main training script (unified GSM8K + MATH) | +| `train_trl_grpo_README.md` | This documentation | +| `train_trl_gsm8k.py` | Legacy GSM8K-only script (deprecated) | diff --git a/research/trl/train_trl_gsm8k.py b/research/trl/train_trl_gsm8k.py index 7582671..5a8b401 100644 --- a/research/trl/train_trl_gsm8k.py +++ b/research/trl/train_trl_gsm8k.py @@ -32,35 +32,55 @@ sys.path.append("/root/grail") -# ──────────────── HYPERPARAMETERS (from GRAIL config) ──────────────── +# ──────────────── HYPERPARAMETERS (from .env GRAIL config) ──────────────── @dataclass class Config: + # Model (from GRAIL_TRAIN_MODEL_ID) model_id: str = "Qwen/Qwen2.5-1.5B-Instruct" - lr: float = 2e-6 - epochs: int = 2 - batch_size: int = 8 # 16 groups (prompts) per step - grad_accum_steps: int = 32 - max_length: int = 1536 + # Learning rate (from GRAIL_TRAINER_LR) + lr: float = 3e-6 + # Epochs per window (from GRAIL_TRAINER_EPOCHS) + epochs: int = 1 + # Batch size (from GRAIL_TRAINER_BATCH_SIZE) + batch_size: int = 4 + # Gradient accumulation (from GRAIL_TRAINER_GRAD_ACCUM_STEPS) + grad_accum_steps: int = 128 + # Max sequence length (from GRAIL_TRAINER_MAX_LENGTH) + max_length: int = 2048 + # Gradient clipping (from GRAIL_TRAINER_GRAD_CLIP) grad_clip: float = 1.0 + # Warmup steps (from GRAIL_TRAINER_WARMUP_STEPS) warmup_steps: int = 50 + # KL coefficient (from GRAIL_TRAINER_KL_COEF) kl_coef: float = 0.0 + # Entropy coefficient (from GRAIL_TRAINER_ENTROPY_COEF) entropy_coef: float = 0.0005 + # PPO clip epsilon (standard GRAIL values) ppo_clip_eps: float = 0.2 ppo_clip_eps_upper: float = 0.28 + # Importance sampling ratio max (from GRAIL_TRAINER_IS_RATIO_MAX) is_ratio_max: float = 2.5 + # Log-ratio clamp (from GRAIL_TRAINER_LOGRATIO_CLAMP) logratio_clamp: float = 0.92 + # Dataset sampling num_train_samples: int | None = None # None = use all training samples num_eval_samples: int | None = None # None = use all test samples + # Rollouts per problem (matches GRAIL default) rollouts_per_problem: int = 16 + # Generation parameters temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 - max_new_tokens: int = 512 + # Max completion tokens (from GRPO_MAX_COMPLETION_TOKENS) + max_new_tokens: int = 1024 + # Evaluation config eval_replicates: int = 5 report_ks: tuple = (1, 5, 10) # Evaluation optimization (for multi-GPU with 8 A100s) eval_batch_size: int = 128 # Large batch for parallel generation eval_num_workers: int = 4 # Dataloader workers + # Max groups for GRPO (from GRPO_MAX_GROUPS) + max_groups: int = 128 cfg = Config() @@ -511,29 +531,43 @@ def main() -> None: wandb.login(key=wandb_api_key) print(f" ✓ WandB logged in (project: {os.getenv('WANDB_PROJECT', 'grail')})") + # Calculate max_prompt_length: total max_length minus max_completion_tokens + max_prompt_length = cfg.max_length - cfg.max_new_tokens # 2048 - 1024 = 1024 + grpo_config = GRPOConfig( output_dir="./outputs/trl_gsm8k", + # Learning rate (GRAIL_TRAINER_LR=3e-6) learning_rate=cfg.lr, + # Epochs (GRAIL_TRAINER_EPOCHS=1) num_train_epochs=cfg.epochs, + # Batch size (GRAIL_TRAINER_BATCH_SIZE=4) per_device_train_batch_size=cfg.batch_size, + # Gradient accumulation (GRAIL_TRAINER_GRAD_ACCUM_STEPS=128) gradient_accumulation_steps=cfg.grad_accum_steps, + # Gradient clipping (GRAIL_TRAINER_GRAD_CLIP=1.0) max_grad_norm=cfg.grad_clip, + # Warmup steps (GRAIL_TRAINER_WARMUP_STEPS=50) warmup_steps=cfg.warmup_steps, - beta=cfg.kl_coef, # Beta is KL coefficient in GRPO - epsilon=cfg.ppo_clip_eps, # PPO epsilon - epsilon_high=cfg.ppo_clip_eps_upper, # Upper PPO epsilon - max_prompt_length=512, # Reasonable prompt limit - max_completion_length=cfg.max_new_tokens, # Max new tokens + # KL coefficient (GRAIL_TRAINER_KL_COEF=0.0) + beta=cfg.kl_coef, + # PPO clip epsilon + epsilon=cfg.ppo_clip_eps, + epsilon_high=cfg.ppo_clip_eps_upper, + # Max prompt length (derived from GRAIL_TRAINER_MAX_LENGTH - GRPO_MAX_COMPLETION_TOKENS) + max_prompt_length=max_prompt_length, + # Max completion tokens (GRPO_MAX_COMPLETION_TOKENS=1024) + max_completion_length=cfg.max_new_tokens, + # Generation parameters temperature=cfg.temperature, top_p=cfg.top_p, - top_k=cfg.top_k, # Match loop.py: 50 highest probability tokens - repetition_penalty=1.1, # Match loop.py: penalize repeating tokens - num_generations=16, # group size: 16 completions per prompt - generation_batch_size=16, # 64 prompts per generation batch + top_k=cfg.top_k, + repetition_penalty=1.1, + # Group size: 16 completions per prompt (rollouts_per_problem) + num_generations=cfg.rollouts_per_problem, + generation_batch_size=16, steps_per_generation=None, logging_steps=1, - # Enable logging a small sample of (prompt, completion) pairs each logging step. - # Prints to console if `rich` is installed and logs a WandB table named "completions". + # Enable logging a sample of (prompt, completion) pairs each logging step log_completions=True, num_completions_to_print=1, wandb_log_unique_prompts=True, @@ -541,14 +575,16 @@ def main() -> None: bf16=True, report_to=["wandb"], eval_strategy="no", # Disable TRL's internal eval (using VLLMEvalCallback instead) - run_name="trl_gsm8k_grpo_qwen15b_g16x16_vllm", - loss_type="dapo", # Match config.py GRPO_VARIANT + run_name="trl_gsm8k_grpo_qwen15b_env_matched", + # Loss type (GRAIL_GRPO_VARIANT=dapo) + loss_type="dapo", # vLLM configuration for offloading generation to separate GPUs use_vllm=True, vllm_mode="server", vllm_server_base_url="http://127.0.0.1:8000", - vllm_importance_sampling_correction=False, # Correct for vLLM/training distribution mismatch - vllm_importance_sampling_cap=2.0, # Cap importance sampling ratio for stability + # Importance sampling (GRAIL_TRAINER_IS_RATIO_MAX=2.5) + vllm_importance_sampling_correction=False, + vllm_importance_sampling_cap=cfg.is_ratio_max, ) # Reward function wrapper diff --git a/tools/vllm-server/pyproject.toml b/tools/vllm-server/pyproject.toml index 01dc784..d5ec666 100644 --- a/tools/vllm-server/pyproject.toml +++ b/tools/vllm-server/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "openai>=1.0.0", "trl>=0.22.0", "wandb>=0.22.3", + # Evaluation harness with vLLM support + "lm-eval>=0.4.0", ] [tool.ruff]