Skip to content

Conversation

@erfanMhi
Copy link
Collaborator

@erfanMhi erfanMhi commented Dec 24, 2025

This PR introduces significant enhancements to GRAIL's training, evaluation, and data mining capabilities. Key changes include:

  • Advanced TRL GRPO Training: Comprehensive training script with detailed hyperparameter configuration, pass@k metrics tracking, and improved WandB integration for GSM8K and MATH datasets
  • Extended Evaluation Support: New eval tasks for AIME 2024, AMC 2023, and GSM8K with reasoning support, plus vLLM backend integration with tensor parallelism for improved evaluation throughput
  • Parallel Mining Infrastructure: Multi-GPU mining functionality with result aggregation and optimized batch sizing (default 16) for enhanced performance and collision prevention

Summary by CodeRabbit

  • New Features

    • Added comprehensive math benchmark evaluation workflows for MATH, AIME, AMC, and GSM8K datasets with extraction and comparison utilities.
    • Added parallel multi-GPU mining coordination and result aggregation infrastructure.
    • Added pass@k metrics computation for MATH problem evaluation.
    • Added multi-miner orchestration and configuration helpers.
  • Improvements

    • Enhanced model loading with flexible attention mechanism selection (Flash Attention 2 and SDPA support).
    • Expanded training scripts with environment-driven configuration for GRPO training workflows.

✏️ Tip: You can customize this high-level summary in your review settings.

…K, including learning rate, batch size, and max sequence length
…and add critical checks for time management in parallel mining
… collisions by using ROLLOUTS_PER_PROBLEM as multiplier
…rness that support reasoning model evaluation as well.
…arameter configuration, add TrainingPassAtKTracker for pass@k metrics logging, and update evaluation callback for improved WandB integration
…ks, including answer extraction and string normalization functions.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

Walkthrough

This PR introduces comprehensive parallel multi-GPU mining infrastructure, MATH benchmark evaluation frameworks, and training script modernization. New modules enable GPU worker coordination, multi-miner result aggregation, and extensible evaluation task definitions across multiple math datasets. Existing mining logic is enhanced with worker-mode support, and model loading is refactored to prioritize attention mechanisms.

Changes

Cohort / File(s) Summary
Parallel & Multi-Miner Mining
grail/cli/parallel_miner.py, grail/cli/multi_miner_aggregator.py, grail/cli/multi_miner_config.py
New modules for multi-GPU worker coordination, result aggregation across miners, and configuration builders. Includes GPUWorkerConfig, ParallelMiningCoordinator, WindowAggregator, and CLI integration via Typer.
Mining Core Changes
grail/cli/mine.py, grail/cli/__init__.py
Nonce calculation refactored to use ROLLOUTS_PER_PROBLEM multiplier; worker-mode parameters (problem_offset, max_problems) added with environment-driven configuration; parallel_miner subcommand registered.
Evaluation Harness
eval_math_harness.py, eval_pass_at_k.py
New standalone evaluation scripts for Hendrycks MATH using lm-evaluation-harness with vLLM/HF backends, and pass@k computation with multi-sample generation.
Evaluation Task Infrastructure
research/eval/tasks/_common.py, research/eval/README.md
Comprehensive answer extraction (boxed, SOLUTION tags, dollar format), string normalization utilities, equivalence checks (string/numeric), and evaluation guide documentation.
Task Definitions (AIME)
research/eval/tasks/aime24_thinking/aime24_thinking.yaml, research/eval/tasks/aime24_thinking/utils.py
YAML config and process_results utility for AIME 2024 thinking tasks with integer-aware answer comparison.
Task Definitions (AMC)
research/eval/tasks/amc2023/amc2023.yaml, research/eval/tasks/amc2023/utils.py, research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml, research/eval/tasks/amc2023_thinking/utils.py
YAML configs and utilities for AMC 2023 standard and thinking-model evaluations with float-to-integer normalization.
Task Definitions (GSM8K)
research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml, research/eval/tasks/gsm8k_thinking/utils.py
YAML config and doc_to_target/process_results for GSM8K thinking evaluation with final-answer extraction and numeric comparison.
Task Definitions (Hendrycks MATH)
research/eval/tasks/hendrycks_math_thinking/_default_template.yaml, research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml, research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_*.yaml (7 subject configs), research/eval/tasks/hendrycks_math_thinking/utils.py
Master template, group config, per-subject YAML configs (algebra, counting_and_prob, geometry, intermediate_algebra, num_theory, prealgebra, precalc), and utilities for boxed-answer extraction and string equivalence checking.
Pass@K Evaluation
research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml, research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py
Config for pass@5 metric with 10 samples per problem and utilities computing pass@1, pass@5, pass@10 via unbiased estimator.
Model & Environment Changes
grail/model/provider.py, grail/environments/providers.py
Added use_sdpa parameter with Flash Attention 2 priority-chain logic; MATHTaskSource reintroduced with subject/level filtering support and enhanced boxed-answer extraction.
Training Script Refactoring
research/trl/train_trl_grpo.py, research/trl/train_trl_grpo_README.md, research/trl/train_trl_gsm8k.py
New unified GRPO training script with DatasetAdapter pattern (GSM8K/MATH), reward tracking, and vLLM evaluation callback; gsm8k script Config expanded with env-driven hyperparameters, increased model capacity (max_new_tokens 512→1024, max_length 1536→2048).
Config & Dependencies
grail/trainer/config.py, tools/vllm-server/pyproject.toml
EvalConfig.split comment updated for clarity; lm-eval>=0.4.0 dependency added.

Sequence Diagram(s)

sequenceDiagram
    participant Main as Coordinator
    participant GPU1 as GPU Worker 1
    participant GPU2 as GPU Worker 2
    participant Queue as Result Queue
    participant S3 as R2 Storage

    Main->>GPU1: spawn worker (problem_offset=0, max=N)
    Main->>GPU2: spawn worker (problem_offset=N, max=N)
    
    rect rgb(230, 240, 250)
    note over GPU1,GPU2: Parallel Rollout Generation
    GPU1->>GPU1: load model, run env loop<br/>generate rollouts [0, N)
    GPU2->>GPU2: load model, run env loop<br/>generate rollouts [N, 2N)
    end
    
    GPU1->>GPU1: package results<br/>write to JSON
    GPU2->>GPU2: package results<br/>write to JSON
    
    GPU1->>Queue: signal completion
    GPU2->>Queue: signal completion
    
    Main->>Queue: wait_for_workers()
    Main->>Main: collect & sort results<br/>by problem_index
    
    rect rgb(240, 250, 240)
    note over Main: Result Validation & Upload
    Main->>Main: validate aggregated<br/>inferences
    Main->>S3: upload_file_chunked<br/>(window payload)
    end
    
    S3-->>Main: success
    Main->>GPU1: cleanup temp files
    Main->>GPU2: cleanup temp files
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • distributedstatemachine

Poem

🐰 Whiskers twitch with parallel glee,
Ten-thousand problems split by GPU spree!
Miners gather, aggregate their cheer,
Math benchmarks polished, results crystal clear. 🎯

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly and concisely describes the main changes: extended evaluation support using EleutherAI eval-harness and improved TRL training, which directly align with the PR objectives.
Docstring Coverage ✅ Passed Docstring coverage is 82.61% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch research/grail-v0-blog

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
research/trl/train_trl_gsm8k.py (1)

502-513: Fix incorrect parameter name: dtype should be torch_dtype.

AutoModelForCausalLM.from_pretrained() expects torch_dtype, not dtype. The current code at lines 505 and 512 will silently ignore the parameter, causing the model to load with default precision instead of bfloat16.

     try:
         model = AutoModelForCausalLM.from_pretrained(
             cfg.model_id,
-            dtype=torch.bfloat16,
+            torch_dtype=torch.bfloat16,
             attn_implementation="flash_attention_2",
         )
     except (ImportError, RuntimeError) as e:
         print(f"⚠️  Flash Attention 2 unavailable ({type(e).__name__}), using default attention")
         model = AutoModelForCausalLM.from_pretrained(
             cfg.model_id,
-            dtype=torch.bfloat16,
+            torch_dtype=torch.bfloat16,
         )
♻️ Duplicate comments (2)
research/trl/train_trl_grpo.py (2)

36-48: Same hardcoded path issue as in train_trl_gsm8k.py.

Consider parameterizing the paths or using relative imports for better portability across environments.


431-463: Same private member access issue as GSM8KAdapter.

The MATHAdapter also accesses _ensure_dataset and _data private members. The same coupling concern applies.

🧹 Nitpick comments (25)
research/eval/README.md (1)

14-16: Consider using relative paths or environment variables for portability.

Hardcoded paths like /root/grail/ appear throughout this documentation (lines 15, 45, 88, 92-100, 137, 151-159, 234, 273, 330, 371). This reduces portability for users with different installation locations.

Consider using:

  • Relative paths from the repo root
  • Environment variable placeholders like $GRAIL_ROOT
  • A note at the top instructing users to substitute their installation path
eval_pass_at_k.py (4)

56-63: Type hint inconsistency with default value.

The type hint list[dict] doesn't match the default None. Consider using list[dict] | None = None for accuracy.

-def build_prompt(problem: str, few_shot_examples: list[dict] = None) -> str:
+def build_prompt(problem: str, few_shot_examples: list[dict] | None = None) -> str:

28-34: Regex may miss deeply nested boxed content.

The regex r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}" handles one level of nesting but may fail on deeper nesting like \boxed{\frac{1}{\sqrt{2}}}.

There's a more robust _extract_boxed_answer implementation in grail/environments/providers.py (lines 180-198) and research/eval/tasks/_common.py that uses brace-counting for arbitrary nesting depth. Consider using the shared utility or aligning the implementations.


37-53: Duplication with shared utilities.

The normalize_answer and is_correct functions duplicate functionality available in research/eval/tasks/_common.py (strip_string_basic, is_equiv_string, is_equiv_combined). Consider importing from the shared module to maintain DRY principles.


82-91: High GPU memory utilization may cause OOM on smaller GPUs.

gpu_memory_utilization=0.95 leaves minimal headroom. Consider making this configurable via CLI argument or using a more conservative default (e.g., 0.90) for broader hardware compatibility.

grail/cli/parallel_miner.py (2)

363-378: Daemon processes may leave resources unclean on unexpected exit.

Using daemon=True (line 368) means workers are killed immediately when the main process exits, potentially leaving temp files. The cleanup in finally block (line 771) should handle this, but consider adding signal handlers for graceful shutdown.


840-842: Batch size validation is overly restrictive.

The validation batch_size > 16 rejects valid batch sizes. If 16 is optimal for A100, larger values might work on H100 or future hardware. Consider warning instead of erroring, or document why 16 is the hard limit.

research/eval/tasks/_common.py (1)

311-311: Invalid escape sequence warning.

Line 311 has "\%" which triggers a W605 warning. The noqa comment suppresses it, but the string should be "\\%" for correctness (though functionally it may work due to Python's lenient handling).

-    string = string.replace("\%", "")  # noqa: W605
+    string = string.replace("\\%", "")
research/eval/tasks/amc2023_thinking/utils.py (3)

10-11: Consider using relative imports instead of sys.path manipulation.

The sys.path.insert() approach is fragile and can cause issues if the file is moved or imported from different contexts. Consider restructuring as a proper package with relative imports:

from .._common import extract_answer_cascade, is_equiv_combined, strip_string_basic

If the module structure doesn't support this yet, adding an __init__.py file in the parent directory would enable cleaner imports.


20-44: Consider adding validation for empty results list.

Line 22 directly accesses results[0] without checking if the list is non-empty. While this may not occur in practice with the evaluation harness, adding a guard would make the code more robust:

 def process_results(doc: dict, results: list[str]) -> dict[str, int]:
     """Process model output and compare with target answer."""
+    if not results:
+        return {"exact_match": 0}
+    
     response = results[0]

54-60: Consider using is_integer() for more robust float-to-int detection.

Line 57's comparison num == int(num) can have floating-point precision issues. Python's is_integer() method is more reliable:

     try:
         num = float(string)
-        if num == int(num):
+        if num.is_integer():
             string = str(int(num))
     except ValueError:
         pass

This avoids edge cases where floating-point representation might cause unexpected behavior.

research/eval/tasks/aime24_thinking/utils.py (2)

10-17: Consider using relative imports instead of sys.path manipulation.

The sys.path.insert(0, ...) pattern is fragile and can cause import conflicts if multiple modules modify sys.path. Since this is part of a package structure, consider using relative imports or restructuring to avoid path manipulation.

# Alternative: relative import if package structure allows
from . import _common
# or ensure the package is properly installed

This pattern appears across multiple evaluation utility files in this PR.


20-38: Consider defensive check for empty results list.

Accessing results[0] on line 25 will raise an IndexError if the list is empty. While the lm-eval framework likely guarantees at least one result, a defensive check would improve robustness:

 def process_results(doc: dict, results: list[str]) -> dict[str, int]:
     """Process model output and compare with target answer.

     AIME answers are always integers from 000-999.
     """
+    if not results:
+        return {"exact_match": 0}
     response = results[0]
research/eval/tasks/amc2023/utils.py (1)

32-33: Inconsistent answer key handling compared to AIME utilities.

This uses doc.get("answer", "") directly, while aime24_thinking/utils.py uses a case-insensitive lookup. If the dataset might have varying key casing (e.g., "Answer" vs "answer"), consider aligning the approach:

-    target = str(doc.get("answer", ""))
+    answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None)
+    target = str(doc[answer_key]) if answer_key else ""

Alternatively, if the AMC dataset consistently uses lowercase "answer", the current approach is acceptable.

research/eval/tasks/hendrycks_math_thinking/utils.py (1)

23-35: process_docs drops additional document fields that may be needed.

The mapping function returns only problem, solution, and answer, which discards other fields like type (subject) and level (difficulty) that might be valuable for per-subject or per-difficulty analysis mentioned in the PR objectives.

Consider preserving original fields:

     def _process_doc(doc: dict) -> dict:
         boxed = last_boxed_only_string(doc["solution"])
         answer = remove_boxed(boxed) if boxed else ""
         return {
+            **doc,  # Preserve original fields
             "problem": doc["problem"],
             "solution": doc["solution"],
             "answer": answer,
         }
research/eval/tasks/gsm8k_thinking/utils.py (1)

20-34: doc_to_target may raise KeyError if answer field is missing.

Unlike process_results which uses doc["answer"] after checking the #### format, doc_to_target accesses doc["answer"] directly on line 22. Consider using .get() with a default for defensive coding:

 def doc_to_target(doc: dict) -> str:
     """Convert document to target format for thinking."""
-    answer = doc["answer"]
+    answer = doc.get("answer", "")

However, if the GSM8K dataset schema guarantees this field, the current approach is acceptable.

eval_math_harness.py (2)

45-61: Subprocess fallback should handle errors more explicitly.

The nvidia-smi fallback silently returns 1 on any failure (including command not found). Consider logging a warning:

         result = subprocess.run(
             ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"],
             capture_output=True,
             text=True,
         )
-        return len(result.stdout.strip().split("\n")) if result.returncode == 0 else 1
+        if result.returncode != 0:
+            logger.warning("nvidia-smi failed, defaulting to 1 GPU")
+            return 1
+        return len(result.stdout.strip().split("\n"))

312-318: save_results mutates the input results dictionary.

Adding "metadata" to the input dict is a side effect that may surprise callers who don't expect their data to be modified. Consider creating a copy:

+    # Create copy to avoid mutating input
+    output = {**results}
     # Add metadata
-    results["metadata"] = {
+    output["metadata"] = {
         "model": args.model,
         "num_fewshot": args.num_fewshot,
         "batch_size": args.batch_size,
         "timestamp": timestamp,
     }

     with open(output_file, "w") as f:
-        json.dump(results, f, indent=2, default=str)
+        json.dump(output, f, indent=2, default=str)
grail/cli/multi_miner_config.py (2)

70-80: Potential issue: GPU assignment when gpus is an empty list.

Line 73 gpus[i % len(gpus)] would raise ZeroDivisionError if gpus is an empty list (as opposed to None). The current check if gpus is falsy for empty list, but for safety:

-            gpu = gpus[i % len(gpus)] if gpus else None
+            gpu = gpus[i % len(gpus)] if gpus and len(gpus) > 0 else None

Or ensure the parsing logic in from_environment never produces an empty list (it currently sets gpus = None when empty, which is correct).


37-43: Minor: Consider moving time import to module level.

The import time inside __post_init__ works but is unconventional. Moving it to the module-level imports improves clarity and avoids repeated import overhead if the method is called multiple times.

grail/cli/multi_miner_aggregator.py (3)

170-181: Synchronous file I/O blocks the event loop in async method.

_load_and_parse uses synchronous open() and json.load() inside an async context. While the file sizes are likely small, this blocks the event loop during file operations.

🔎 Proposed fix using asyncio.to_thread
     async def _load_and_parse(self, result_file: Path, hotkey: str) -> list[dict]:
         """Load and parse inferences from result file."""
+        def _read_file() -> list[dict]:
+            with open(result_file) as f:
+                data = json.load(f)
+                inferences = data.get("inferences", [])
+                if not isinstance(inferences, list):
+                    raise ValueError(f"Expected list of inferences, got {type(inferences)}")
+                return inferences
+
         try:
-            with open(result_file) as f:
-                data = json.load(f)
-                inferences = data.get("inferences", [])
-                if not isinstance(inferences, list):
-                    raise ValueError(f"Expected list of inferences, got {type(inferences)}")
-                return inferences
+            return await asyncio.to_thread(_read_file)
         except Exception as e:
             logger.debug(f"Failed to parse {result_file}: {e}")
             raise

214-231: Consider caching the subtensor instance outside the loop.

A new bt.subtensor() is instantiated on every poll iteration (default 5s). This may create unnecessary connection overhead depending on the implementation.

🔎 Proposed fix
         try:
+            subtensor = bt.subtensor()
             while not self.stop_event.is_set():
                 # Get current window
-                subtensor = bt.subtensor()
                 current_block = await asyncio.to_thread(subtensor.get_current_block)
                 current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH

340-344: Consider preserving exception context for debugging.

Using from None suppresses the original exception chain. For credential loading failures, the original exception often contains useful context (e.g., missing file path, permission errors).

🔎 Proposed fix
         try:
             credentials = load_r2_credentials()
         except Exception as e:
             console.print(f"[red]Failed to load R2 credentials: {e}[/red]")
-            raise typer.Exit(code=1) from None
+            raise typer.Exit(code=1) from e
research/trl/train_trl_gsm8k.py (1)

26-32: Hardcoded paths reduce portability.

The hardcoded paths /root/grail/.env and /root/grail assume a specific deployment environment. For a research script this may be acceptable, but consider using environment variables or relative paths for broader compatibility.

🔎 Proposed fix using relative paths
-# Load environment from .env for WandB
-load_dotenv("/root/grail/.env")  # Load WandB API key and project
+# Load environment from .env for WandB (try common locations)
+from pathlib import Path
+_script_dir = Path(__file__).resolve().parent
+_dotenv_candidates = [
+    _script_dir / ".env",
+    _script_dir.parent.parent / ".env",  # repo root
+    Path.home() / "grail" / ".env",
+]
+for _env_path in _dotenv_candidates:
+    if _env_path.exists():
+        load_dotenv(_env_path)
+        break

-sys.path.append("/root/grail")
+sys.path.insert(0, str(_script_dir.parent.parent))  # Add repo root to path
research/trl/train_trl_grpo.py (1)

681-691: Silent exception swallowing may hide logging issues.

The bare pass after catching any exception hides all WandB errors. Consider at least logging at debug level to aid troubleshooting when metrics unexpectedly don't appear.

🔎 Proposed fix
     def _log_to_wandb(self, metrics: dict[str, float]) -> None:
         """Log metrics to WandB."""
         try:
             import wandb

             if wandb.run is not None and metrics:
                 wandb_data = {f"train/{k}": v for k, v in metrics.items()}
                 wandb.log(wandb_data)
-        except Exception:
-            pass  # Silently ignore WandB errors
+        except Exception as e:
+            # Log at debug level to aid troubleshooting without flooding output
+            import logging
+            logging.getLogger(__name__).debug(f"WandB logging failed: {e}")
📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 73e4de5 and 7f700f2.

📒 Files selected for processing (36)
  • eval_math_harness.py
  • eval_pass_at_k.py
  • grail/cli/__init__.py
  • grail/cli/mine.py
  • grail/cli/multi_miner_aggregator.py
  • grail/cli/multi_miner_config.py
  • grail/cli/parallel_miner.py
  • grail/environments/providers.py
  • grail/model/provider.py
  • grail/trainer/config.py
  • research/eval/README.md
  • research/eval/tasks/_common.py
  • research/eval/tasks/aime24_thinking/aime24_thinking.yaml
  • research/eval/tasks/aime24_thinking/utils.py
  • research/eval/tasks/amc2023/amc2023.yaml
  • research/eval/tasks/amc2023/utils.py
  • research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml
  • research/eval/tasks/amc2023_thinking/utils.py
  • research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml
  • research/eval/tasks/gsm8k_thinking/utils.py
  • research/eval/tasks/hendrycks_math_thinking/_default_template.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml
  • research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml
  • research/eval/tasks/hendrycks_math_thinking/utils.py
  • research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml
  • research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py
  • research/trl/train_trl_grpo.py
  • research/trl/train_trl_grpo_README.md
  • research/trl/train_trl_gsm8k.py
  • tools/vllm-server/pyproject.toml
🧰 Additional context used
🧬 Code graph analysis (5)
research/eval/tasks/gsm8k_thinking/utils.py (1)
research/eval/tasks/_common.py (2)
  • extract_solution_tag (17-29)
  • is_equiv_combined (484-523)
research/eval/tasks/aime24_thinking/utils.py (1)
research/eval/tasks/_common.py (3)
  • extract_answer_cascade (119-165)
  • is_equiv_combined (484-523)
  • strip_string_basic (173-199)
research/eval/tasks/amc2023_thinking/utils.py (1)
research/eval/tasks/_common.py (3)
  • extract_answer_cascade (119-165)
  • is_equiv_combined (484-523)
  • strip_string_basic (173-199)
research/eval/tasks/hendrycks_math_thinking/utils.py (1)
research/eval/tasks/_common.py (5)
  • extract_solution_tag (17-29)
  • is_equiv_string (422-447)
  • last_boxed_only_string (73-116)
  • remove_boxed (47-70)
  • strip_string_math (271-343)
research/eval/tasks/amc2023/utils.py (1)
research/eval/tasks/_common.py (3)
  • extract_answer_cascade (119-165)
  • is_equiv_combined (484-523)
  • strip_string_basic (173-199)
🪛 LanguageTool
research/trl/train_trl_grpo_README.md

[style] ~52-~52: ‘exact same’ might be wordy. Consider a shorter alternative.
Context: ...aset` CLI flag - GRAIL Parity: Uses exact same task sources, validation logic, and rew...

(EN_WORDINESS_PREMIUM_EXACT_SAME)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
🔇 Additional comments (50)
grail/trainer/config.py (1)

99-99: LGTM!

The simplified comment aligns with the PR's standardization of evaluation splits across the new harness scripts and task utilities. The default value remains unchanged.

research/eval/README.md (1)

1-379: Documentation is comprehensive and well-structured.

The README provides excellent coverage of evaluation workflows including base vs. reasoning model configurations, multiple benchmarks (MATH, AIME, AMC, GSM8K), and pass@k evaluation. The task structure diagram and argument tables are helpful.

grail/model/provider.py (2)

117-137: Well-structured attention backend selection with appropriate fallbacks.

The priority chain (Flash Attention 2 > SDPA > default) is correctly implemented with:

  • Proper import check for flash_attn availability
  • Clear fallback logging when Flash Attention is unavailable
  • SDPA as a sensible default for PyTorch 2.0+ on CUDA

The logging messages help users understand which attention implementation is active.


62-75: Docstring accurately reflects the new behavior.

The updated docstring clearly documents the use_sdpa parameter and the priority relationship between Flash Attention 2 and SDPA.

eval_pass_at_k.py (1)

15-25: Pass@k estimator implementation is correct.

The unbiased pass@k estimator from the Codex paper is correctly implemented with proper handling of the edge case where n - c < k.

grail/environments/providers.py (2)

256-282: Enhanced MATHTaskSource with filtering support looks good.

The class properly supports filtering by level and subject with:

  • Cached filtered indices for performance
  • Appropriate validation error when no samples match filters
  • Stratified train/val split maintaining proportions

The _filtered_indices cache using string keys is a reasonable approach.


218-253: Remove duplicate definitions - dead code.

Lines 218-253 duplicate the definitions already present at lines 180-215:

  • _extract_boxed_answer (lines 218-236 duplicates 180-198)
  • _MATH_SUBSETS (lines 240-248 duplicates 202-210)
  • _MATH_VAL_SIZE (line 251 duplicates 213)
  • _MATH_VAL_SEED (line 253 duplicates 215)

The second definitions overwrite the first at runtime, making the earlier code dead. Remove the duplicate block.

🔎 Proposed fix
-def _extract_boxed_answer(solution: str) -> str:
-    """Extract answer from \\boxed{...} in solution, handling nested braces."""
-    import re
-
-    match = re.search(r"\\boxed\{", solution)
-    if not match:
-        return ""
-
-    start = match.end()
-    depth = 1
-    i = start
-    while i < len(solution) and depth > 0:
-        if solution[i] == "{":
-            depth += 1
-        elif solution[i] == "}":
-            depth -= 1
-        i += 1
-
-    return solution[start : i - 1] if depth == 0 else ""
-
-
-# Subsets in EleutherAI/hendrycks_math dataset
-_MATH_SUBSETS = (
-    "algebra",
-    "counting_and_probability",
-    "geometry",
-    "intermediate_algebra",
-    "number_theory",
-    "prealgebra",
-    "precalculus",
-)
-
-# Fixed validation set size (stratified across problem types)
-_MATH_VAL_SIZE = 500
-# Seed for deterministic stratified sampling of validation set
-_MATH_VAL_SEED = 42
-
-
 class MATHTaskSource(TaskSource):

Likely an incorrect or invalid review comment.

grail/cli/parallel_miner.py (3)

534-548: Critical: Proper sorting ensures validator compatibility.

The sorting by (rollout_group, rollout_index) is essential for correct seed derivation by validators. Good documentation explaining why this is critical.


682-704: Good time budget estimation before parallel mining.

The proactive check for remaining blocks before starting prevents wasted GPU cycles on windows that can't be completed in time.


202-218: No changes needed. The base_nonce = problem_index approach correctly ensures nonce uniqueness across parallel workers. The package_rollout_data function already implements the nonce collision fix using base_nonce * ROLLOUTS_PER_PROBLEM + rollout_idx (line 437 in grail/cli/mine.py), which prevents duplicates. Since each GPU processes non-overlapping problem ranges, different problem_index values guarantee unique base_nonce values, and the multiplier formula handles per-rollout uniqueness within each worker.

research/eval/tasks/_common.py (3)

119-165: Well-designed cascade extraction with configurable fallbacks.

The extract_answer_cascade function provides flexible answer extraction with clear priority ordering (SOLUTION tag → boxed → dollar sign → raw text). The boolean flags allow callers to customize the cascade for different evaluation contexts.


484-523: Robust combined equivalence check.

is_equiv_combined appropriately tries string comparison first (faster) before falling back to numeric comparison. The exception handling ensures graceful degradation to raw string comparison.


1-10: Excellent adherence to DRY principles.

Centralizing these utilities in _common.py for reuse across AIME, AMC, GSM8K, and MATH evaluation tasks is the right design choice.

grail/cli/__init__.py (1)

328-337: Clean integration of parallel_miner subcommand.

The new module is registered consistently with existing subcommands (mine, validate, train). The dynamic import pattern with callable(register) check handles modules gracefully.

tools/vllm-server/pyproject.toml (1)

15-16: LGTM - dependency aligns with evaluation harness integration.

The lm-eval>=0.4.0 dependency supports the new evaluation tooling introduced in this PR. The version constraint is appropriate; the latest release (0.4.9.2) is well above the minimum specified.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml (1)

1-5: LGTM!

The configuration correctly references the default template and uses consistent naming for the number theory subtask. The task name matches the group file entry.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml (1)

1-5: LGTM!

Consistent configuration for the prealgebra subtask, following the established pattern.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml (1)

1-5: LGTM!

Standard algebra subtask configuration, correctly integrated with the task group.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml (1)

1-5: LGTM!

Geometry subtask configuration follows the consistent pattern.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml (1)

1-5: LGTM!

Intermediate algebra subtask configuration is correct and consistent.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml (1)

1-15: LGTM!

The group configuration correctly defines all seven Hendrycks MATH subtasks with appropriate aggregate metrics. The weight_by_size: true setting ensures fair aggregation across subtasks with varying sample counts.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml (1)

1-5: LGTM!

Counting and probability subtask configuration is correct. The abbreviated task name (counting_and_prob) paired with the full dataset name (counting_and_probability) follows the established pattern.

research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml (1)

1-5: LGTM!

Precalculus subtask configuration completes the seven-task suite with consistent naming conventions.

research/trl/train_trl_grpo_README.md (1)

1-136: Documentation looks comprehensive and well-structured!

The README provides clear quickstart instructions, detailed feature descriptions, and helpful reference tables. The GPU layout example and hyperparameter documentation will be valuable for users.

research/eval/tasks/hendrycks_math_thinking/_default_template.yaml (1)

1-21: LGTM! Evaluation configuration is well-structured.

The YAML configuration follows standard patterns with appropriate generation parameters (greedy decoding, temperature 0) and structured output format for reasoning models.

research/eval/tasks/amc2023/amc2023.yaml (1)

1-28: LGTM! AMC2023 configuration is appropriate.

The configuration correctly uses integer casting for numeric answers and sets reasonable token limits (512) for short AMC responses.

research/eval/tasks/aime24_thinking/aime24_thinking.yaml (2)

9-26: Configuration parameters look appropriate for AIME evaluation.

The structured output format and higher token limit (4096) are suitable for complex AIME problems that require extensive reasoning.


6-8: No action needed. The AIME_2024 dataset from Hugging Face only provides a single "train" split (30 rows), making this configuration the only valid option. Using "train" for all three evaluation splits is necessary and correct for this dataset.

Likely an incorrect or invalid review comment.

research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml (1)

15-31: Pass@k metric configuration is correct.

The sampling parameters (temperature 0.7, top_p 0.95, repeats 10) and dual metrics (pass@1 and pass@5) are properly configured for pass@k evaluation.

research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml (1)

1-26: LGTM! GSM8K thinking configuration is appropriate.

The configuration uses proper splits (test for evaluation), reasonable token limits (1024), and greedy decoding (temperature 0.0) for deterministic evaluation.

research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml (1)

1-25: LGTM! AMC2023 thinking configuration is well-designed.

The structured output format and high token limit (4096) are appropriate for reasoning models that require extended working out space. The configuration is consistent with other thinking task patterns in the repository.

research/eval/tasks/aime24_thinking/utils.py (1)

54-61: LGTM!

The _strip_string_aime function correctly handles the leading zero removal while preserving the edge case of "0" itself. The or "0" fallback is a clean pattern for this use case.

research/eval/tasks/amc2023/utils.py (1)

47-61: LGTM!

The float-to-int conversion logic is well-implemented with proper exception handling for non-numeric strings.

research/eval/tasks/hendrycks_math_thinking/utils.py (1)

38-54: LGTM!

The fallback logic for both model answer extraction and ground truth handling is robust. Using is_equiv_string with strip_string_math is appropriate for the MATH benchmark's symbolic answers.

research/eval/tasks/gsm8k_thinking/utils.py (1)

80-95: LGTM!

The _extract_last_number function has good coverage of numeric formats including negative numbers, comma-separated thousands, and decimals. The fallback to the stripped string is a reasonable default.

eval_math_harness.py (1)

74-86: LGTM!

The tensor parallel size calculation correctly finds the largest valid divisor that satisfies the attention head divisibility requirement.

grail/cli/mine.py (3)

434-437: Excellent fix for nonce collision bug.

The change from base_nonce * 10 to base_nonce * ROLLOUTS_PER_PROBLEM correctly prevents nonce collisions. The comment clearly documents the bug and fix rationale, which is valuable for future maintainers.


545-591: LGTM!

The worker mode implementation cleanly supports parallel mining with environment-driven configuration. The precedence of environment variables over function arguments enables proper subprocess isolation.


627-665: LGTM!

The early termination logic and problem index calculation correctly support parallel mining coordination across GPU workers.

research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py (2)

52-65: LGTM!

The _pass_at_k function correctly implements the pass@k formula with proper handling of the edge case where n - c < k (guaranteed at least one correct in any k-sample).


68-88: LGTM!

The aggregation logic handles edge cases well, including the n < k scenario with a reasonable fallback heuristic.

grail/cli/multi_miner_config.py (1)

143-175: LGTM!

The environment setup correctly copies os.environ to avoid mutating global state, and the CUDA device handling appropriately removes constraints when no GPU is specified.

grail/cli/multi_miner_aggregator.py (1)

41-51: LGTM on AggregatorConfig dataclass.

The configuration structure is well-designed with sensible defaults and clear parameter naming.

research/trl/train_trl_gsm8k.py (4)

36-84: LGTM on Config dataclass.

The hyperparameters are well-documented with comments linking to their GRAIL environment variable equivalents, making it easy to verify alignment.


163-197: LGTM on reward computation logic.

The reward weights sum to 1.0 and the component breakdown (correctness, format, thinking, answer, trailing) is clear and matches the documented GSM8KEnv weights.


415-467: Retry logic with exponential backoff is well implemented.

The batch generation handles failures gracefully with retries, backoff, and falls back to empty completions rather than crashing. This ensures evaluation continues even with intermittent vLLM server issues.


537-588: LGTM on GRPOConfig initialization.

The configuration is comprehensive with clear inline comments mapping each parameter to its GRAIL environment variable equivalent.

research/trl/train_trl_grpo.py (3)

180-244: Well-designed dataset adapter abstraction.

The DatasetAdapter ABC provides a clean interface with clear responsibilities: data loading, answer parsing, validation, and reward computation. This enables the factory pattern to work elegantly.


1053-1069: LGTM on model loading.

This file correctly uses torch_dtype=torch.bfloat16 (unlike train_trl_gsm8k.py which uses dtype). The fallback pattern for Flash Attention availability is also good.


1199-1208: Good practice: Explicit WandB initialization before baseline eval.

Explicitly initializing WandB before the baseline evaluation ensures metrics are captured from step 0, rather than relying on the trainer's lazy initialization.

Comment on lines +295 to +300
# Print aggregate (use weighted average based on number of samples per subject)
accs = [d["accuracy"] for d in subject_results.values()]
overall_acc = sum(accs) / len(accs) * 100 if accs else 0

print("-" * 50)
print(f" {'OVERALL':35s} {overall_acc:5.2f}%")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Comment-code discrepancy: "weighted average" comment but simple average implementation.

The comment on line 295 states "weighted average based on number of samples per subject," but lines 296-297 compute a simple mean. Either update the comment or implement weighted averaging:

-    # Print aggregate (use weighted average based on number of samples per subject)
-    accs = [d["accuracy"] for d in subject_results.values()]
-    overall_acc = sum(accs) / len(accs) * 100 if accs else 0
+    # Simple average across subjects (unweighted)
+    accs = [d["accuracy"] for d in subject_results.values()]
+    overall_acc = sum(accs) / len(accs) * 100 if accs else 0
🤖 Prompt for AI Agents
In eval_math_harness.py around lines 295 to 300 the comment says "weighted
average based on number of samples per subject" but the code computes a simple
mean; change the computation to a weighted average using each subject's sample
count: extract both accuracy and sample count from subject_results (e.g.,
accuracy and n_samples for each subject), compute overall_acc = (sum(accuracy *
n_samples) / sum(n_samples)) * 100 if total samples > 0 else 0, and keep the
existing printing logic; alternatively, if sample counts are not available in
subject_results, update the comment to say "simple average" or add the sample
counts where results are aggregated so weighted averaging is possible.

Comment on lines +253 to +257
async def batch_process_window(self, window_start: int) -> bool:
"""Process a single window in batch mode."""
wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey)
await self._process_window(wallet, window_start)
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return value doesn't reflect actual processing outcome.

batch_process_window always returns True regardless of whether _process_window succeeded or found any results. This could mask failures when used programmatically.

🔎 Proposed fix to propagate success status
     async def _process_window(self, wallet: bt.wallet, window_start: int) -> bool:
         """Process and upload a specific window."""
         aggregator = WindowAggregator(window_start, self.config)
         results = await aggregator.collect_results()

         # Check if we have any results
         total_inferences = sum(len(inf) for inf in results.values())
         if total_inferences == 0:
             logger.info(f"⊘ No results for window {window_start}, skipping")
-            return
+            return False

         # Upload aggregated results
-        await aggregator.aggregate_and_upload(wallet)
+        return await aggregator.aggregate_and_upload(wallet)

     async def batch_process_window(self, window_start: int) -> bool:
         """Process a single window in batch mode."""
         wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey)
-        await self._process_window(wallet, window_start)
-        return True
+        return await self._process_window(wallet, window_start)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In grail/cli/multi_miner_aggregator.py around lines 253 to 257,
batch_process_window currently always returns True which hides failures; change
it to await the call to self._process_window, capture its return value and
return that boolean result (or convert it to bool if it can be None/other), and
wrap the await in a try/except to log/handle exceptions and return False on
error so the caller sees real success/failure.

Comment on lines +300 to +306
try:
mp.set_start_method("spawn", force=True)
except RuntimeError:
pass # Already set
self._ctx = mp.get_context("spawn")
self._result_queue: mp.Queue = self._ctx.Queue()
self._shutdown_requested = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential issue with multiprocessing start method initialization.

Calling mp.set_start_method("spawn", force=True) in __init__ can cause issues:

  1. If multiple ParallelMiningCoordinator instances are created
  2. If any CUDA operations occurred before this call

Consider moving this to module level or the CLI entry point before any CUDA operations.

🔎 Suggested refactor

Move to module level:

# At module level, before any class definitions
import multiprocessing as mp

# Set spawn method early, before any CUDA context
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass  # Already set

_mp_context = mp.get_context("spawn")

Then in __init__:

-        try:
-            mp.set_start_method("spawn", force=True)
-        except RuntimeError:
-            pass  # Already set
-        self._ctx = mp.get_context("spawn")
+        self._ctx = _mp_context

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In grail/cli/parallel_miner.py around lines 300 to 306, the call to
mp.set_start_method("spawn", force=True) inside __init__ can raise race/ordering
problems if multiple ParallelMiningCoordinator instances are created or if CUDA
contexts were initialized earlier; move the start_method setup to module import
time or the CLI entrypoint so it runs once before any CUDA or multiprocessing
use: perform the mp.set_start_method(...) (with the same try/except RuntimeError
guard) at module level (or in main) and obtain a single shared
mp.get_context("spawn") to store in a module-level variable, then have __init__
use that pre-created context and Queue rather than calling
set_start_method/get_context itself.

group: hendrycks_math_thinking_pass_at_k
task: hendrycks_math_thinking_pass_at_5
dataset_path: EleutherAI/hendrycks_math
dataset_name: algebra
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

fd -e yaml . research/eval/tasks/hendrycks_math_thinking_pass_at_k/

Repository: one-covenant/grail

Length of output: 154


🏁 Script executed:

cat -n research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml

Repository: one-covenant/grail

Length of output: 1124


🏁 Script executed:

fd -e yaml hendrycks_math . research/eval/tasks/

Repository: one-covenant/grail

Length of output: 1605


🏁 Script executed:

cat -n research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml | head -15

Repository: one-covenant/grail

Length of output: 604


🏁 Script executed:

cat -n research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml | head -15

Repository: one-covenant/grail

Length of output: 216


🏁 Script executed:

# Check if there are any related pass@k configs or documentation
fd -e yaml . research/eval/tasks/hendrycks_math_thinking_pass_at_k/
fd -e md . research/eval/tasks/hendrycks_math_thinking_pass_at_k/ 2>/dev/null || echo "No markdown files"

Repository: one-covenant/grail

Length of output: 154


🏁 Script executed:

# Check if there might be parametrization or if this is the only pass@k config
ls -la research/eval/tasks/hendrycks_math_thinking_pass_at_k/
# Also check parent directory structure
ls -la research/eval/tasks/ | grep pass

Repository: one-covenant/grail

Length of output: 402


🏁 Script executed:

# Check if there are any other pass@ configurations
fd -e yaml . research/eval/tasks/ | grep -i "pass_at"
# Also look at git history or comments that might explain the choice
git log --oneline research/eval/tasks/hendrycks_math_thinking_pass_at_k/ 2>/dev/null | head -5 || echo "No git history available"

Repository: one-covenant/grail

Length of output: 201


The pass@5 configuration should evaluate the full MATH dataset, not just algebra.

The single pass@5 configuration file is limited to dataset_name: algebra, which is inconsistent with the base hendrycks_math_thinking.yaml that aggregates all seven subjects (algebra, counting_and_prob, geometry, intermediate_algebra, num_theory, prealgebra, precalc). Either change dataset_name to evaluate the complete dataset, or create subject-specific pass@5 configurations following the existing pattern in hendrycks_math_thinking/ directory.

Comment on lines +281 to +309
def load_train_data(self) -> list[dict[str, Any]]:
"""Load GSM8K training data via task source."""
self._train_source._ensure_dataset()
assert self._train_source._ds is not None
data = []
for i in range(len(self._train_source._ds)):
sample = self._train_source._ds[i]
data.append(
{
"question": sample["question"],
"answer": sample["answer"],
}
)
return data

def load_eval_data(self) -> list[dict[str, Any]]:
"""Load GSM8K test data via task source."""
self._eval_source._ensure_dataset()
assert self._eval_source._ds is not None
data = []
for i in range(len(self._eval_source._ds)):
sample = self._eval_source._ds[i]
data.append(
{
"question": sample["question"],
"answer": sample["answer"],
}
)
return data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find GSM8KTaskSource class definition
fd -e py -x rg -l "class GSM8KTaskSource" {} \;

Repository: one-covenant/grail

Length of output: 95


🏁 Script executed:

#!/bin/bash
# Once we find the file, examine the GSM8KTaskSource class
rg -n "class GSM8KTaskSource" -A 50 --type=py

Repository: one-covenant/grail

Length of output: 3611


🏁 Script executed:

#!/bin/bash
# Look for public methods in GSM8KTaskSource that could provide data access
rg -n "def " research/trl/ --type=py | grep -i "gsm8k\|task.*source" | head -30

Repository: one-covenant/grail

Length of output: 1421


🏁 Script executed:

#!/bin/bash
# Read the actual lines from train_trl_grpo.py to see what attribute is being accessed
sed -n '281,309p' research/trl/train_trl_grpo.py

Repository: one-covenant/grail

Length of output: 1091


🏁 Script executed:

#!/bin/bash
# Also check the TaskSource parent class to understand the interface
rg -n "class TaskSource" -A 30 grail/environments/providers.py

Repository: one-covenant/grail

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check if GSM8KTaskSource has any public data access methods (not starting with _)
rg -n "def [^_]" grail/environments/providers.py | grep -A 5 -B 5 "GSM8K"

Repository: one-covenant/grail

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Get full GSM8KTaskSource class implementation
rg -n "class GSM8KTaskSource" -A 150 grail/environments/providers.py | head -180

Repository: one-covenant/grail

Length of output: 5583


🏁 Script executed:

#!/bin/bash
# Search for _ds attribute definition anywhere in providers.py
rg -n "_ds" grail/environments/providers.py

Repository: one-covenant/grail

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check the full context of GSM8KAdapter in train_trl_grpo.py
sed -n '240,320p' research/trl/train_trl_grpo.py

Repository: one-covenant/grail

Length of output: 2776


The code accesses non-existent _ds attribute; use _data or the public API instead.

load_train_data and load_eval_data access _train_source._ds and _eval_source._ds, but GSM8KTaskSource has a _data attribute, not _ds. This will cause an AttributeError at runtime. Additionally, accessing private members creates fragile coupling to internal implementation details.

Either correct the attribute name to _data or, better, use the public API: call self._train_source.size() to get the dataset length and self._train_source.iter_ids() to get task IDs for iteration.

🤖 Prompt for AI Agents
In research/trl/train_trl_grpo.py around lines 281 to 309, the functions
load_train_data and load_eval_data reference private attribute _ds which does
not exist and creates fragile coupling; replace uses of self._train_source._ds
and self._eval_source._ds with the public API: use self._train_source.size()
(and self._eval_source.size()) to get lengths and iterate with
self._train_source.iter_ids() / self._eval_source.iter_ids() (or any public
accessor that returns task entries) to retrieve each sample, extracting
"question" and "answer" from the returned task dicts; do not access private
attributes like _data/_ds directly.

Comment on lines +998 to +1004
parser.add_argument(
"--eval-every",
type=int,
default=40,
help="Run evaluation every N steps (default: 30)",
)
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Help text default doesn't match actual default value.

The help text says "(default: 30)" but the actual default is 40.

🔎 Proposed fix
     parser.add_argument(
         "--eval-every",
         type=int,
         default=40,
-        help="Run evaluation every N steps (default: 30)",
+        help="Run evaluation every N steps (default: 40)",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument(
"--eval-every",
type=int,
default=40,
help="Run evaluation every N steps (default: 30)",
)
return parser.parse_args()
parser.add_argument(
"--eval-every",
type=int,
default=40,
help="Run evaluation every N steps (default: 40)",
)
return parser.parse_args()
🤖 Prompt for AI Agents
In research/trl/train_trl_grpo.py around lines 998 to 1004, the --eval-every
argument’s help string incorrectly states "(default: 30)" while the argument
default is 40; update the help text to "(default: 40)" so it matches the actual
default value (or alternatively set default=30 if you intended 30) and ensure
the help string and the default parameter remain consistent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants