-
Notifications
You must be signed in to change notification settings - Fork 8
feat: add extended evaluation support using elutherai eval-harness and improved TRL training #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…rycks MATH dataset
…TH datasets with detailed README
…K, including learning rate, batch size, and max sequence length
…ics on MATH dataset on the test set
…nd result aggregation
… and tensor parallelism features
…and add critical checks for time management in parallel mining
… collisions by using ROLLOUTS_PER_PROBLEM as multiplier
…ity handling for Flash Attention 2
…rness that support reasoning model evaluation as well.
…arameter configuration, add TrainingPassAtKTracker for pass@k metrics logging, and update evaluation callback for improved WandB integration
…ks, including answer extraction and string normalization functions.
WalkthroughThis PR introduces comprehensive parallel multi-GPU mining infrastructure, MATH benchmark evaluation frameworks, and training script modernization. New modules enable GPU worker coordination, multi-miner result aggregation, and extensible evaluation task definitions across multiple math datasets. Existing mining logic is enhanced with worker-mode support, and model loading is refactored to prioritize attention mechanisms. Changes
Sequence Diagram(s)sequenceDiagram
participant Main as Coordinator
participant GPU1 as GPU Worker 1
participant GPU2 as GPU Worker 2
participant Queue as Result Queue
participant S3 as R2 Storage
Main->>GPU1: spawn worker (problem_offset=0, max=N)
Main->>GPU2: spawn worker (problem_offset=N, max=N)
rect rgb(230, 240, 250)
note over GPU1,GPU2: Parallel Rollout Generation
GPU1->>GPU1: load model, run env loop<br/>generate rollouts [0, N)
GPU2->>GPU2: load model, run env loop<br/>generate rollouts [N, 2N)
end
GPU1->>GPU1: package results<br/>write to JSON
GPU2->>GPU2: package results<br/>write to JSON
GPU1->>Queue: signal completion
GPU2->>Queue: signal completion
Main->>Queue: wait_for_workers()
Main->>Main: collect & sort results<br/>by problem_index
rect rgb(240, 250, 240)
note over Main: Result Validation & Upload
Main->>Main: validate aggregated<br/>inferences
Main->>S3: upload_file_chunked<br/>(window payload)
end
S3-->>Main: success
Main->>GPU1: cleanup temp files
Main->>GPU2: cleanup temp files
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
research/trl/train_trl_gsm8k.py (1)
502-513: Fix incorrect parameter name:dtypeshould betorch_dtype.
AutoModelForCausalLM.from_pretrained()expectstorch_dtype, notdtype. The current code at lines 505 and 512 will silently ignore the parameter, causing the model to load with default precision instead of bfloat16.try: model = AutoModelForCausalLM.from_pretrained( cfg.model_id, - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) except (ImportError, RuntimeError) as e: print(f"⚠️ Flash Attention 2 unavailable ({type(e).__name__}), using default attention") model = AutoModelForCausalLM.from_pretrained( cfg.model_id, - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, )
♻️ Duplicate comments (2)
research/trl/train_trl_grpo.py (2)
36-48: Same hardcoded path issue as in train_trl_gsm8k.py.Consider parameterizing the paths or using relative imports for better portability across environments.
431-463: Same private member access issue as GSM8KAdapter.The MATHAdapter also accesses
_ensure_datasetand_dataprivate members. The same coupling concern applies.
🧹 Nitpick comments (25)
research/eval/README.md (1)
14-16: Consider using relative paths or environment variables for portability.Hardcoded paths like
/root/grail/appear throughout this documentation (lines 15, 45, 88, 92-100, 137, 151-159, 234, 273, 330, 371). This reduces portability for users with different installation locations.Consider using:
- Relative paths from the repo root
- Environment variable placeholders like
$GRAIL_ROOT- A note at the top instructing users to substitute their installation path
eval_pass_at_k.py (4)
56-63: Type hint inconsistency with default value.The type hint
list[dict]doesn't match the defaultNone. Consider usinglist[dict] | None = Nonefor accuracy.-def build_prompt(problem: str, few_shot_examples: list[dict] = None) -> str: +def build_prompt(problem: str, few_shot_examples: list[dict] | None = None) -> str:
28-34: Regex may miss deeply nested boxed content.The regex
r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"handles one level of nesting but may fail on deeper nesting like\boxed{\frac{1}{\sqrt{2}}}.There's a more robust
_extract_boxed_answerimplementation ingrail/environments/providers.py(lines 180-198) andresearch/eval/tasks/_common.pythat uses brace-counting for arbitrary nesting depth. Consider using the shared utility or aligning the implementations.
37-53: Duplication with shared utilities.The
normalize_answerandis_correctfunctions duplicate functionality available inresearch/eval/tasks/_common.py(strip_string_basic,is_equiv_string,is_equiv_combined). Consider importing from the shared module to maintain DRY principles.
82-91: High GPU memory utilization may cause OOM on smaller GPUs.
gpu_memory_utilization=0.95leaves minimal headroom. Consider making this configurable via CLI argument or using a more conservative default (e.g., 0.90) for broader hardware compatibility.grail/cli/parallel_miner.py (2)
363-378: Daemon processes may leave resources unclean on unexpected exit.Using
daemon=True(line 368) means workers are killed immediately when the main process exits, potentially leaving temp files. The cleanup infinallyblock (line 771) should handle this, but consider adding signal handlers for graceful shutdown.
840-842: Batch size validation is overly restrictive.The validation
batch_size > 16rejects valid batch sizes. If 16 is optimal for A100, larger values might work on H100 or future hardware. Consider warning instead of erroring, or document why 16 is the hard limit.research/eval/tasks/_common.py (1)
311-311: Invalid escape sequence warning.Line 311 has
"\%"which triggers aW605warning. The noqa comment suppresses it, but the string should be"\\%"for correctness (though functionally it may work due to Python's lenient handling).- string = string.replace("\%", "") # noqa: W605 + string = string.replace("\\%", "")research/eval/tasks/amc2023_thinking/utils.py (3)
10-11: Consider using relative imports instead of sys.path manipulation.The
sys.path.insert()approach is fragile and can cause issues if the file is moved or imported from different contexts. Consider restructuring as a proper package with relative imports:from .._common import extract_answer_cascade, is_equiv_combined, strip_string_basicIf the module structure doesn't support this yet, adding an
__init__.pyfile in the parent directory would enable cleaner imports.
20-44: Consider adding validation for empty results list.Line 22 directly accesses
results[0]without checking if the list is non-empty. While this may not occur in practice with the evaluation harness, adding a guard would make the code more robust:def process_results(doc: dict, results: list[str]) -> dict[str, int]: """Process model output and compare with target answer.""" + if not results: + return {"exact_match": 0} + response = results[0]
54-60: Consider usingis_integer()for more robust float-to-int detection.Line 57's comparison
num == int(num)can have floating-point precision issues. Python'sis_integer()method is more reliable:try: num = float(string) - if num == int(num): + if num.is_integer(): string = str(int(num)) except ValueError: passThis avoids edge cases where floating-point representation might cause unexpected behavior.
research/eval/tasks/aime24_thinking/utils.py (2)
10-17: Consider using relative imports instead ofsys.pathmanipulation.The
sys.path.insert(0, ...)pattern is fragile and can cause import conflicts if multiple modules modifysys.path. Since this is part of a package structure, consider using relative imports or restructuring to avoid path manipulation.# Alternative: relative import if package structure allows from . import _common # or ensure the package is properly installedThis pattern appears across multiple evaluation utility files in this PR.
20-38: Consider defensive check for emptyresultslist.Accessing
results[0]on line 25 will raise anIndexErrorif the list is empty. While the lm-eval framework likely guarantees at least one result, a defensive check would improve robustness:def process_results(doc: dict, results: list[str]) -> dict[str, int]: """Process model output and compare with target answer. AIME answers are always integers from 000-999. """ + if not results: + return {"exact_match": 0} response = results[0]research/eval/tasks/amc2023/utils.py (1)
32-33: Inconsistent answer key handling compared to AIME utilities.This uses
doc.get("answer", "")directly, whileaime24_thinking/utils.pyuses a case-insensitive lookup. If the dataset might have varying key casing (e.g., "Answer" vs "answer"), consider aligning the approach:- target = str(doc.get("answer", "")) + answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None) + target = str(doc[answer_key]) if answer_key else ""Alternatively, if the AMC dataset consistently uses lowercase "answer", the current approach is acceptable.
research/eval/tasks/hendrycks_math_thinking/utils.py (1)
23-35:process_docsdrops additional document fields that may be needed.The mapping function returns only
problem,solution, andanswer, which discards other fields liketype(subject) andlevel(difficulty) that might be valuable for per-subject or per-difficulty analysis mentioned in the PR objectives.Consider preserving original fields:
def _process_doc(doc: dict) -> dict: boxed = last_boxed_only_string(doc["solution"]) answer = remove_boxed(boxed) if boxed else "" return { + **doc, # Preserve original fields "problem": doc["problem"], "solution": doc["solution"], "answer": answer, }research/eval/tasks/gsm8k_thinking/utils.py (1)
20-34:doc_to_targetmay raiseKeyErrorifanswerfield is missing.Unlike
process_resultswhich usesdoc["answer"]after checking the #### format,doc_to_targetaccessesdoc["answer"]directly on line 22. Consider using.get()with a default for defensive coding:def doc_to_target(doc: dict) -> str: """Convert document to target format for thinking.""" - answer = doc["answer"] + answer = doc.get("answer", "")However, if the GSM8K dataset schema guarantees this field, the current approach is acceptable.
eval_math_harness.py (2)
45-61: Subprocess fallback should handle errors more explicitly.The
nvidia-smifallback silently returns 1 on any failure (including command not found). Consider logging a warning:result = subprocess.run( ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], capture_output=True, text=True, ) - return len(result.stdout.strip().split("\n")) if result.returncode == 0 else 1 + if result.returncode != 0: + logger.warning("nvidia-smi failed, defaulting to 1 GPU") + return 1 + return len(result.stdout.strip().split("\n"))
312-318:save_resultsmutates the inputresultsdictionary.Adding
"metadata"to the input dict is a side effect that may surprise callers who don't expect their data to be modified. Consider creating a copy:+ # Create copy to avoid mutating input + output = {**results} # Add metadata - results["metadata"] = { + output["metadata"] = { "model": args.model, "num_fewshot": args.num_fewshot, "batch_size": args.batch_size, "timestamp": timestamp, } with open(output_file, "w") as f: - json.dump(results, f, indent=2, default=str) + json.dump(output, f, indent=2, default=str)grail/cli/multi_miner_config.py (2)
70-80: Potential issue: GPU assignment whengpusis an empty list.Line 73
gpus[i % len(gpus)]would raiseZeroDivisionErrorifgpusis an empty list (as opposed toNone). The current checkif gpusis falsy for empty list, but for safety:- gpu = gpus[i % len(gpus)] if gpus else None + gpu = gpus[i % len(gpus)] if gpus and len(gpus) > 0 else NoneOr ensure the parsing logic in
from_environmentnever produces an empty list (it currently setsgpus = Nonewhen empty, which is correct).
37-43: Minor: Consider movingtimeimport to module level.The
import timeinside__post_init__works but is unconventional. Moving it to the module-level imports improves clarity and avoids repeated import overhead if the method is called multiple times.grail/cli/multi_miner_aggregator.py (3)
170-181: Synchronous file I/O blocks the event loop in async method.
_load_and_parseuses synchronousopen()andjson.load()inside an async context. While the file sizes are likely small, this blocks the event loop during file operations.🔎 Proposed fix using asyncio.to_thread
async def _load_and_parse(self, result_file: Path, hotkey: str) -> list[dict]: """Load and parse inferences from result file.""" + def _read_file() -> list[dict]: + with open(result_file) as f: + data = json.load(f) + inferences = data.get("inferences", []) + if not isinstance(inferences, list): + raise ValueError(f"Expected list of inferences, got {type(inferences)}") + return inferences + try: - with open(result_file) as f: - data = json.load(f) - inferences = data.get("inferences", []) - if not isinstance(inferences, list): - raise ValueError(f"Expected list of inferences, got {type(inferences)}") - return inferences + return await asyncio.to_thread(_read_file) except Exception as e: logger.debug(f"Failed to parse {result_file}: {e}") raise
214-231: Consider caching the subtensor instance outside the loop.A new
bt.subtensor()is instantiated on every poll iteration (default 5s). This may create unnecessary connection overhead depending on the implementation.🔎 Proposed fix
try: + subtensor = bt.subtensor() while not self.stop_event.is_set(): # Get current window - subtensor = bt.subtensor() current_block = await asyncio.to_thread(subtensor.get_current_block) current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH
340-344: Consider preserving exception context for debugging.Using
from Nonesuppresses the original exception chain. For credential loading failures, the original exception often contains useful context (e.g., missing file path, permission errors).🔎 Proposed fix
try: credentials = load_r2_credentials() except Exception as e: console.print(f"[red]Failed to load R2 credentials: {e}[/red]") - raise typer.Exit(code=1) from None + raise typer.Exit(code=1) from eresearch/trl/train_trl_gsm8k.py (1)
26-32: Hardcoded paths reduce portability.The hardcoded paths
/root/grail/.envand/root/grailassume a specific deployment environment. For a research script this may be acceptable, but consider using environment variables or relative paths for broader compatibility.🔎 Proposed fix using relative paths
-# Load environment from .env for WandB -load_dotenv("/root/grail/.env") # Load WandB API key and project +# Load environment from .env for WandB (try common locations) +from pathlib import Path +_script_dir = Path(__file__).resolve().parent +_dotenv_candidates = [ + _script_dir / ".env", + _script_dir.parent.parent / ".env", # repo root + Path.home() / "grail" / ".env", +] +for _env_path in _dotenv_candidates: + if _env_path.exists(): + load_dotenv(_env_path) + break -sys.path.append("/root/grail") +sys.path.insert(0, str(_script_dir.parent.parent)) # Add repo root to pathresearch/trl/train_trl_grpo.py (1)
681-691: Silent exception swallowing may hide logging issues.The bare
passafter catching any exception hides all WandB errors. Consider at least logging at debug level to aid troubleshooting when metrics unexpectedly don't appear.🔎 Proposed fix
def _log_to_wandb(self, metrics: dict[str, float]) -> None: """Log metrics to WandB.""" try: import wandb if wandb.run is not None and metrics: wandb_data = {f"train/{k}": v for k, v in metrics.items()} wandb.log(wandb_data) - except Exception: - pass # Silently ignore WandB errors + except Exception as e: + # Log at debug level to aid troubleshooting without flooding output + import logging + logging.getLogger(__name__).debug(f"WandB logging failed: {e}")
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (36)
eval_math_harness.pyeval_pass_at_k.pygrail/cli/__init__.pygrail/cli/mine.pygrail/cli/multi_miner_aggregator.pygrail/cli/multi_miner_config.pygrail/cli/parallel_miner.pygrail/environments/providers.pygrail/model/provider.pygrail/trainer/config.pyresearch/eval/README.mdresearch/eval/tasks/_common.pyresearch/eval/tasks/aime24_thinking/aime24_thinking.yamlresearch/eval/tasks/aime24_thinking/utils.pyresearch/eval/tasks/amc2023/amc2023.yamlresearch/eval/tasks/amc2023/utils.pyresearch/eval/tasks/amc2023_thinking/amc2023_thinking.yamlresearch/eval/tasks/amc2023_thinking/utils.pyresearch/eval/tasks/gsm8k_thinking/gsm8k_thinking.yamlresearch/eval/tasks/gsm8k_thinking/utils.pyresearch/eval/tasks/hendrycks_math_thinking/_default_template.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yamlresearch/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yamlresearch/eval/tasks/hendrycks_math_thinking/utils.pyresearch/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yamlresearch/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.pyresearch/trl/train_trl_grpo.pyresearch/trl/train_trl_grpo_README.mdresearch/trl/train_trl_gsm8k.pytools/vllm-server/pyproject.toml
🧰 Additional context used
🧬 Code graph analysis (5)
research/eval/tasks/gsm8k_thinking/utils.py (1)
research/eval/tasks/_common.py (2)
extract_solution_tag(17-29)is_equiv_combined(484-523)
research/eval/tasks/aime24_thinking/utils.py (1)
research/eval/tasks/_common.py (3)
extract_answer_cascade(119-165)is_equiv_combined(484-523)strip_string_basic(173-199)
research/eval/tasks/amc2023_thinking/utils.py (1)
research/eval/tasks/_common.py (3)
extract_answer_cascade(119-165)is_equiv_combined(484-523)strip_string_basic(173-199)
research/eval/tasks/hendrycks_math_thinking/utils.py (1)
research/eval/tasks/_common.py (5)
extract_solution_tag(17-29)is_equiv_string(422-447)last_boxed_only_string(73-116)remove_boxed(47-70)strip_string_math(271-343)
research/eval/tasks/amc2023/utils.py (1)
research/eval/tasks/_common.py (3)
extract_answer_cascade(119-165)is_equiv_combined(484-523)strip_string_basic(173-199)
🪛 LanguageTool
research/trl/train_trl_grpo_README.md
[style] ~52-~52: ‘exact same’ might be wordy. Consider a shorter alternative.
Context: ...aset` CLI flag - GRAIL Parity: Uses exact same task sources, validation logic, and rew...
(EN_WORDINESS_PREMIUM_EXACT_SAME)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
🔇 Additional comments (50)
grail/trainer/config.py (1)
99-99: LGTM!The simplified comment aligns with the PR's standardization of evaluation splits across the new harness scripts and task utilities. The default value remains unchanged.
research/eval/README.md (1)
1-379: Documentation is comprehensive and well-structured.The README provides excellent coverage of evaluation workflows including base vs. reasoning model configurations, multiple benchmarks (MATH, AIME, AMC, GSM8K), and pass@k evaluation. The task structure diagram and argument tables are helpful.
grail/model/provider.py (2)
117-137: Well-structured attention backend selection with appropriate fallbacks.The priority chain (Flash Attention 2 > SDPA > default) is correctly implemented with:
- Proper import check for flash_attn availability
- Clear fallback logging when Flash Attention is unavailable
- SDPA as a sensible default for PyTorch 2.0+ on CUDA
The logging messages help users understand which attention implementation is active.
62-75: Docstring accurately reflects the new behavior.The updated docstring clearly documents the
use_sdpaparameter and the priority relationship between Flash Attention 2 and SDPA.eval_pass_at_k.py (1)
15-25: Pass@k estimator implementation is correct.The unbiased pass@k estimator from the Codex paper is correctly implemented with proper handling of the edge case where
n - c < k.grail/environments/providers.py (2)
256-282: Enhanced MATHTaskSource with filtering support looks good.The class properly supports filtering by
levelandsubjectwith:
- Cached filtered indices for performance
- Appropriate validation error when no samples match filters
- Stratified train/val split maintaining proportions
The
_filtered_indicescache using string keys is a reasonable approach.
218-253: Remove duplicate definitions - dead code.Lines 218-253 duplicate the definitions already present at lines 180-215:
_extract_boxed_answer(lines 218-236 duplicates 180-198)_MATH_SUBSETS(lines 240-248 duplicates 202-210)_MATH_VAL_SIZE(line 251 duplicates 213)_MATH_VAL_SEED(line 253 duplicates 215)The second definitions overwrite the first at runtime, making the earlier code dead. Remove the duplicate block.
🔎 Proposed fix
-def _extract_boxed_answer(solution: str) -> str: - """Extract answer from \\boxed{...} in solution, handling nested braces.""" - import re - - match = re.search(r"\\boxed\{", solution) - if not match: - return "" - - start = match.end() - depth = 1 - i = start - while i < len(solution) and depth > 0: - if solution[i] == "{": - depth += 1 - elif solution[i] == "}": - depth -= 1 - i += 1 - - return solution[start : i - 1] if depth == 0 else "" - - -# Subsets in EleutherAI/hendrycks_math dataset -_MATH_SUBSETS = ( - "algebra", - "counting_and_probability", - "geometry", - "intermediate_algebra", - "number_theory", - "prealgebra", - "precalculus", -) - -# Fixed validation set size (stratified across problem types) -_MATH_VAL_SIZE = 500 -# Seed for deterministic stratified sampling of validation set -_MATH_VAL_SEED = 42 - - class MATHTaskSource(TaskSource):Likely an incorrect or invalid review comment.
grail/cli/parallel_miner.py (3)
534-548: Critical: Proper sorting ensures validator compatibility.The sorting by
(rollout_group, rollout_index)is essential for correct seed derivation by validators. Good documentation explaining why this is critical.
682-704: Good time budget estimation before parallel mining.The proactive check for remaining blocks before starting prevents wasted GPU cycles on windows that can't be completed in time.
202-218: No changes needed. Thebase_nonce = problem_indexapproach correctly ensures nonce uniqueness across parallel workers. Thepackage_rollout_datafunction already implements the nonce collision fix usingbase_nonce * ROLLOUTS_PER_PROBLEM + rollout_idx(line 437 in grail/cli/mine.py), which prevents duplicates. Since each GPU processes non-overlapping problem ranges, differentproblem_indexvalues guarantee uniquebase_noncevalues, and the multiplier formula handles per-rollout uniqueness within each worker.research/eval/tasks/_common.py (3)
119-165: Well-designed cascade extraction with configurable fallbacks.The
extract_answer_cascadefunction provides flexible answer extraction with clear priority ordering (SOLUTION tag → boxed → dollar sign → raw text). The boolean flags allow callers to customize the cascade for different evaluation contexts.
484-523: Robust combined equivalence check.
is_equiv_combinedappropriately tries string comparison first (faster) before falling back to numeric comparison. The exception handling ensures graceful degradation to raw string comparison.
1-10: Excellent adherence to DRY principles.Centralizing these utilities in
_common.pyfor reuse across AIME, AMC, GSM8K, and MATH evaluation tasks is the right design choice.grail/cli/__init__.py (1)
328-337: Clean integration of parallel_miner subcommand.The new module is registered consistently with existing subcommands (mine, validate, train). The dynamic import pattern with
callable(register)check handles modules gracefully.tools/vllm-server/pyproject.toml (1)
15-16: LGTM - dependency aligns with evaluation harness integration.The
lm-eval>=0.4.0dependency supports the new evaluation tooling introduced in this PR. The version constraint is appropriate; the latest release (0.4.9.2) is well above the minimum specified.research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml (1)
1-5: LGTM!The configuration correctly references the default template and uses consistent naming for the number theory subtask. The task name matches the group file entry.
research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml (1)
1-5: LGTM!Consistent configuration for the prealgebra subtask, following the established pattern.
research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml (1)
1-5: LGTM!Standard algebra subtask configuration, correctly integrated with the task group.
research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml (1)
1-5: LGTM!Geometry subtask configuration follows the consistent pattern.
research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml (1)
1-5: LGTM!Intermediate algebra subtask configuration is correct and consistent.
research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml (1)
1-15: LGTM!The group configuration correctly defines all seven Hendrycks MATH subtasks with appropriate aggregate metrics. The
weight_by_size: truesetting ensures fair aggregation across subtasks with varying sample counts.research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml (1)
1-5: LGTM!Counting and probability subtask configuration is correct. The abbreviated task name (
counting_and_prob) paired with the full dataset name (counting_and_probability) follows the established pattern.research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml (1)
1-5: LGTM!Precalculus subtask configuration completes the seven-task suite with consistent naming conventions.
research/trl/train_trl_grpo_README.md (1)
1-136: Documentation looks comprehensive and well-structured!The README provides clear quickstart instructions, detailed feature descriptions, and helpful reference tables. The GPU layout example and hyperparameter documentation will be valuable for users.
research/eval/tasks/hendrycks_math_thinking/_default_template.yaml (1)
1-21: LGTM! Evaluation configuration is well-structured.The YAML configuration follows standard patterns with appropriate generation parameters (greedy decoding, temperature 0) and structured output format for reasoning models.
research/eval/tasks/amc2023/amc2023.yaml (1)
1-28: LGTM! AMC2023 configuration is appropriate.The configuration correctly uses integer casting for numeric answers and sets reasonable token limits (512) for short AMC responses.
research/eval/tasks/aime24_thinking/aime24_thinking.yaml (2)
9-26: Configuration parameters look appropriate for AIME evaluation.The structured output format and higher token limit (4096) are suitable for complex AIME problems that require extensive reasoning.
6-8: No action needed. The AIME_2024 dataset from Hugging Face only provides a single "train" split (30 rows), making this configuration the only valid option. Using "train" for all three evaluation splits is necessary and correct for this dataset.Likely an incorrect or invalid review comment.
research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml (1)
15-31: Pass@k metric configuration is correct.The sampling parameters (temperature 0.7, top_p 0.95, repeats 10) and dual metrics (pass@1 and pass@5) are properly configured for pass@k evaluation.
research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml (1)
1-26: LGTM! GSM8K thinking configuration is appropriate.The configuration uses proper splits (test for evaluation), reasonable token limits (1024), and greedy decoding (temperature 0.0) for deterministic evaluation.
research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml (1)
1-25: LGTM! AMC2023 thinking configuration is well-designed.The structured output format and high token limit (4096) are appropriate for reasoning models that require extended working out space. The configuration is consistent with other thinking task patterns in the repository.
research/eval/tasks/aime24_thinking/utils.py (1)
54-61: LGTM!The
_strip_string_aimefunction correctly handles the leading zero removal while preserving the edge case of "0" itself. Theor "0"fallback is a clean pattern for this use case.research/eval/tasks/amc2023/utils.py (1)
47-61: LGTM!The float-to-int conversion logic is well-implemented with proper exception handling for non-numeric strings.
research/eval/tasks/hendrycks_math_thinking/utils.py (1)
38-54: LGTM!The fallback logic for both model answer extraction and ground truth handling is robust. Using
is_equiv_stringwithstrip_string_mathis appropriate for the MATH benchmark's symbolic answers.research/eval/tasks/gsm8k_thinking/utils.py (1)
80-95: LGTM!The
_extract_last_numberfunction has good coverage of numeric formats including negative numbers, comma-separated thousands, and decimals. The fallback to the stripped string is a reasonable default.eval_math_harness.py (1)
74-86: LGTM!The tensor parallel size calculation correctly finds the largest valid divisor that satisfies the attention head divisibility requirement.
grail/cli/mine.py (3)
434-437: Excellent fix for nonce collision bug.The change from
base_nonce * 10tobase_nonce * ROLLOUTS_PER_PROBLEMcorrectly prevents nonce collisions. The comment clearly documents the bug and fix rationale, which is valuable for future maintainers.
545-591: LGTM!The worker mode implementation cleanly supports parallel mining with environment-driven configuration. The precedence of environment variables over function arguments enables proper subprocess isolation.
627-665: LGTM!The early termination logic and problem index calculation correctly support parallel mining coordination across GPU workers.
research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py (2)
52-65: LGTM!The
_pass_at_kfunction correctly implements the pass@k formula with proper handling of the edge case wheren - c < k(guaranteed at least one correct in any k-sample).
68-88: LGTM!The aggregation logic handles edge cases well, including the
n < kscenario with a reasonable fallback heuristic.grail/cli/multi_miner_config.py (1)
143-175: LGTM!The environment setup correctly copies
os.environto avoid mutating global state, and the CUDA device handling appropriately removes constraints when no GPU is specified.grail/cli/multi_miner_aggregator.py (1)
41-51: LGTM on AggregatorConfig dataclass.The configuration structure is well-designed with sensible defaults and clear parameter naming.
research/trl/train_trl_gsm8k.py (4)
36-84: LGTM on Config dataclass.The hyperparameters are well-documented with comments linking to their GRAIL environment variable equivalents, making it easy to verify alignment.
163-197: LGTM on reward computation logic.The reward weights sum to 1.0 and the component breakdown (correctness, format, thinking, answer, trailing) is clear and matches the documented GSM8KEnv weights.
415-467: Retry logic with exponential backoff is well implemented.The batch generation handles failures gracefully with retries, backoff, and falls back to empty completions rather than crashing. This ensures evaluation continues even with intermittent vLLM server issues.
537-588: LGTM on GRPOConfig initialization.The configuration is comprehensive with clear inline comments mapping each parameter to its GRAIL environment variable equivalent.
research/trl/train_trl_grpo.py (3)
180-244: Well-designed dataset adapter abstraction.The
DatasetAdapterABC provides a clean interface with clear responsibilities: data loading, answer parsing, validation, and reward computation. This enables the factory pattern to work elegantly.
1053-1069: LGTM on model loading.This file correctly uses
torch_dtype=torch.bfloat16(unlike train_trl_gsm8k.py which usesdtype). The fallback pattern for Flash Attention availability is also good.
1199-1208: Good practice: Explicit WandB initialization before baseline eval.Explicitly initializing WandB before the baseline evaluation ensures metrics are captured from step 0, rather than relying on the trainer's lazy initialization.
| # Print aggregate (use weighted average based on number of samples per subject) | ||
| accs = [d["accuracy"] for d in subject_results.values()] | ||
| overall_acc = sum(accs) / len(accs) * 100 if accs else 0 | ||
|
|
||
| print("-" * 50) | ||
| print(f" {'OVERALL':35s} {overall_acc:5.2f}%") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment-code discrepancy: "weighted average" comment but simple average implementation.
The comment on line 295 states "weighted average based on number of samples per subject," but lines 296-297 compute a simple mean. Either update the comment or implement weighted averaging:
- # Print aggregate (use weighted average based on number of samples per subject)
- accs = [d["accuracy"] for d in subject_results.values()]
- overall_acc = sum(accs) / len(accs) * 100 if accs else 0
+ # Simple average across subjects (unweighted)
+ accs = [d["accuracy"] for d in subject_results.values()]
+ overall_acc = sum(accs) / len(accs) * 100 if accs else 0🤖 Prompt for AI Agents
In eval_math_harness.py around lines 295 to 300 the comment says "weighted
average based on number of samples per subject" but the code computes a simple
mean; change the computation to a weighted average using each subject's sample
count: extract both accuracy and sample count from subject_results (e.g.,
accuracy and n_samples for each subject), compute overall_acc = (sum(accuracy *
n_samples) / sum(n_samples)) * 100 if total samples > 0 else 0, and keep the
existing printing logic; alternatively, if sample counts are not available in
subject_results, update the comment to say "simple average" or add the sample
counts where results are aggregated so weighted averaging is possible.
| async def batch_process_window(self, window_start: int) -> bool: | ||
| """Process a single window in batch mode.""" | ||
| wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey) | ||
| await self._process_window(wallet, window_start) | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return value doesn't reflect actual processing outcome.
batch_process_window always returns True regardless of whether _process_window succeeded or found any results. This could mask failures when used programmatically.
🔎 Proposed fix to propagate success status
async def _process_window(self, wallet: bt.wallet, window_start: int) -> bool:
"""Process and upload a specific window."""
aggregator = WindowAggregator(window_start, self.config)
results = await aggregator.collect_results()
# Check if we have any results
total_inferences = sum(len(inf) for inf in results.values())
if total_inferences == 0:
logger.info(f"⊘ No results for window {window_start}, skipping")
- return
+ return False
# Upload aggregated results
- await aggregator.aggregate_and_upload(wallet)
+ return await aggregator.aggregate_and_upload(wallet)
async def batch_process_window(self, window_start: int) -> bool:
"""Process a single window in batch mode."""
wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey)
- await self._process_window(wallet, window_start)
- return True
+ return await self._process_window(wallet, window_start)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In grail/cli/multi_miner_aggregator.py around lines 253 to 257,
batch_process_window currently always returns True which hides failures; change
it to await the call to self._process_window, capture its return value and
return that boolean result (or convert it to bool if it can be None/other), and
wrap the await in a try/except to log/handle exceptions and return False on
error so the caller sees real success/failure.
| try: | ||
| mp.set_start_method("spawn", force=True) | ||
| except RuntimeError: | ||
| pass # Already set | ||
| self._ctx = mp.get_context("spawn") | ||
| self._result_queue: mp.Queue = self._ctx.Queue() | ||
| self._shutdown_requested = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential issue with multiprocessing start method initialization.
Calling mp.set_start_method("spawn", force=True) in __init__ can cause issues:
- If multiple
ParallelMiningCoordinatorinstances are created - If any CUDA operations occurred before this call
Consider moving this to module level or the CLI entry point before any CUDA operations.
🔎 Suggested refactor
Move to module level:
# At module level, before any class definitions
import multiprocessing as mp
# Set spawn method early, before any CUDA context
try:
mp.set_start_method("spawn", force=True)
except RuntimeError:
pass # Already set
_mp_context = mp.get_context("spawn")Then in __init__:
- try:
- mp.set_start_method("spawn", force=True)
- except RuntimeError:
- pass # Already set
- self._ctx = mp.get_context("spawn")
+ self._ctx = _mp_contextCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In grail/cli/parallel_miner.py around lines 300 to 306, the call to
mp.set_start_method("spawn", force=True) inside __init__ can raise race/ordering
problems if multiple ParallelMiningCoordinator instances are created or if CUDA
contexts were initialized earlier; move the start_method setup to module import
time or the CLI entrypoint so it runs once before any CUDA or multiprocessing
use: perform the mp.set_start_method(...) (with the same try/except RuntimeError
guard) at module level (or in main) and obtain a single shared
mp.get_context("spawn") to store in a module-level variable, then have __init__
use that pre-created context and Queue rather than calling
set_start_method/get_context itself.
| group: hendrycks_math_thinking_pass_at_k | ||
| task: hendrycks_math_thinking_pass_at_5 | ||
| dataset_path: EleutherAI/hendrycks_math | ||
| dataset_name: algebra |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
fd -e yaml . research/eval/tasks/hendrycks_math_thinking_pass_at_k/Repository: one-covenant/grail
Length of output: 154
🏁 Script executed:
cat -n research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yamlRepository: one-covenant/grail
Length of output: 1124
🏁 Script executed:
fd -e yaml hendrycks_math . research/eval/tasks/Repository: one-covenant/grail
Length of output: 1605
🏁 Script executed:
cat -n research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml | head -15Repository: one-covenant/grail
Length of output: 604
🏁 Script executed:
cat -n research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml | head -15Repository: one-covenant/grail
Length of output: 216
🏁 Script executed:
# Check if there are any related pass@k configs or documentation
fd -e yaml . research/eval/tasks/hendrycks_math_thinking_pass_at_k/
fd -e md . research/eval/tasks/hendrycks_math_thinking_pass_at_k/ 2>/dev/null || echo "No markdown files"Repository: one-covenant/grail
Length of output: 154
🏁 Script executed:
# Check if there might be parametrization or if this is the only pass@k config
ls -la research/eval/tasks/hendrycks_math_thinking_pass_at_k/
# Also check parent directory structure
ls -la research/eval/tasks/ | grep passRepository: one-covenant/grail
Length of output: 402
🏁 Script executed:
# Check if there are any other pass@ configurations
fd -e yaml . research/eval/tasks/ | grep -i "pass_at"
# Also look at git history or comments that might explain the choice
git log --oneline research/eval/tasks/hendrycks_math_thinking_pass_at_k/ 2>/dev/null | head -5 || echo "No git history available"Repository: one-covenant/grail
Length of output: 201
The pass@5 configuration should evaluate the full MATH dataset, not just algebra.
The single pass@5 configuration file is limited to dataset_name: algebra, which is inconsistent with the base hendrycks_math_thinking.yaml that aggregates all seven subjects (algebra, counting_and_prob, geometry, intermediate_algebra, num_theory, prealgebra, precalc). Either change dataset_name to evaluate the complete dataset, or create subject-specific pass@5 configurations following the existing pattern in hendrycks_math_thinking/ directory.
| def load_train_data(self) -> list[dict[str, Any]]: | ||
| """Load GSM8K training data via task source.""" | ||
| self._train_source._ensure_dataset() | ||
| assert self._train_source._ds is not None | ||
| data = [] | ||
| for i in range(len(self._train_source._ds)): | ||
| sample = self._train_source._ds[i] | ||
| data.append( | ||
| { | ||
| "question": sample["question"], | ||
| "answer": sample["answer"], | ||
| } | ||
| ) | ||
| return data | ||
|
|
||
| def load_eval_data(self) -> list[dict[str, Any]]: | ||
| """Load GSM8K test data via task source.""" | ||
| self._eval_source._ensure_dataset() | ||
| assert self._eval_source._ds is not None | ||
| data = [] | ||
| for i in range(len(self._eval_source._ds)): | ||
| sample = self._eval_source._ds[i] | ||
| data.append( | ||
| { | ||
| "question": sample["question"], | ||
| "answer": sample["answer"], | ||
| } | ||
| ) | ||
| return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find GSM8KTaskSource class definition
fd -e py -x rg -l "class GSM8KTaskSource" {} \;Repository: one-covenant/grail
Length of output: 95
🏁 Script executed:
#!/bin/bash
# Once we find the file, examine the GSM8KTaskSource class
rg -n "class GSM8KTaskSource" -A 50 --type=pyRepository: one-covenant/grail
Length of output: 3611
🏁 Script executed:
#!/bin/bash
# Look for public methods in GSM8KTaskSource that could provide data access
rg -n "def " research/trl/ --type=py | grep -i "gsm8k\|task.*source" | head -30Repository: one-covenant/grail
Length of output: 1421
🏁 Script executed:
#!/bin/bash
# Read the actual lines from train_trl_grpo.py to see what attribute is being accessed
sed -n '281,309p' research/trl/train_trl_grpo.pyRepository: one-covenant/grail
Length of output: 1091
🏁 Script executed:
#!/bin/bash
# Also check the TaskSource parent class to understand the interface
rg -n "class TaskSource" -A 30 grail/environments/providers.pyRepository: one-covenant/grail
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check if GSM8KTaskSource has any public data access methods (not starting with _)
rg -n "def [^_]" grail/environments/providers.py | grep -A 5 -B 5 "GSM8K"Repository: one-covenant/grail
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Get full GSM8KTaskSource class implementation
rg -n "class GSM8KTaskSource" -A 150 grail/environments/providers.py | head -180Repository: one-covenant/grail
Length of output: 5583
🏁 Script executed:
#!/bin/bash
# Search for _ds attribute definition anywhere in providers.py
rg -n "_ds" grail/environments/providers.pyRepository: one-covenant/grail
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check the full context of GSM8KAdapter in train_trl_grpo.py
sed -n '240,320p' research/trl/train_trl_grpo.pyRepository: one-covenant/grail
Length of output: 2776
The code accesses non-existent _ds attribute; use _data or the public API instead.
load_train_data and load_eval_data access _train_source._ds and _eval_source._ds, but GSM8KTaskSource has a _data attribute, not _ds. This will cause an AttributeError at runtime. Additionally, accessing private members creates fragile coupling to internal implementation details.
Either correct the attribute name to _data or, better, use the public API: call self._train_source.size() to get the dataset length and self._train_source.iter_ids() to get task IDs for iteration.
🤖 Prompt for AI Agents
In research/trl/train_trl_grpo.py around lines 281 to 309, the functions
load_train_data and load_eval_data reference private attribute _ds which does
not exist and creates fragile coupling; replace uses of self._train_source._ds
and self._eval_source._ds with the public API: use self._train_source.size()
(and self._eval_source.size()) to get lengths and iterate with
self._train_source.iter_ids() / self._eval_source.iter_ids() (or any public
accessor that returns task entries) to retrieve each sample, extracting
"question" and "answer" from the returned task dicts; do not access private
attributes like _data/_ds directly.
| parser.add_argument( | ||
| "--eval-every", | ||
| type=int, | ||
| default=40, | ||
| help="Run evaluation every N steps (default: 30)", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Help text default doesn't match actual default value.
The help text says "(default: 30)" but the actual default is 40.
🔎 Proposed fix
parser.add_argument(
"--eval-every",
type=int,
default=40,
- help="Run evaluation every N steps (default: 30)",
+ help="Run evaluation every N steps (default: 40)",
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument( | |
| "--eval-every", | |
| type=int, | |
| default=40, | |
| help="Run evaluation every N steps (default: 30)", | |
| ) | |
| return parser.parse_args() | |
| parser.add_argument( | |
| "--eval-every", | |
| type=int, | |
| default=40, | |
| help="Run evaluation every N steps (default: 40)", | |
| ) | |
| return parser.parse_args() |
🤖 Prompt for AI Agents
In research/trl/train_trl_grpo.py around lines 998 to 1004, the --eval-every
argument’s help string incorrectly states "(default: 30)" while the argument
default is 40; update the help text to "(default: 40)" so it matches the actual
default value (or alternatively set default=30 if you intended 30) and ensure
the help string and the default parameter remain consistent.
This PR introduces significant enhancements to GRAIL's training, evaluation, and data mining capabilities. Key changes include:
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.