-
Notifications
You must be signed in to change notification settings - Fork 8
feat(miner): vLLM backend for GRPO (EOS termination, logprobs fallback); + benchmark #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughIntroduces vLLM inference backend support alongside existing HF path: adds env vars and constants, a vLLM HTTP client, a fast-path in GRPO rollout generation, documentation updates, and a benchmark script. Also initializes an inference package namespace and updates miner docs to include optional vLLM workflow. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as Miner (GRPO)
participant C as VLLMClient
participant S as vLLM Server
participant H as HF Model (fallback)
Note over U: Generate GRPO rollouts for a group
alt INFERENCE_BACKEND == "vllm" and VLLM_BASE_URL set
U->>U: Build chat-templated prompt
U->>C: generate(prompt, n, params)
C->>S: POST /v1/completions
S-->>C: 200 JSON (choices, logprobs?)
C-->>U: texts [+ optional token logprobs]
rect rgba(200,255,200,0.2)
Note right of U: For each completion: parse action, compute reward, tokenize, derive s_vals, sign, assemble rollout
end
else Fallback
U->>H: _generate_single_rollout() per rollout
H-->>U: rollout artifacts
end
U->>U: Compute advantages across group
U-->>U: Return GRPO rollouts
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (19)
grail/inference/__init__.py (1)
1-4: Drop shebang from package init (not executable).Library modules don’t need a shebang; remove for cleanliness.
-#!/usr/bin/env python3 # Namespace for inference backends (vLLM, etc.).env.example (1)
33-43: Fix dotenv lint nits and add optional API key placeholder.
- Remove extra blank line (line 33).
- Order keys to satisfy dotenv-linter (MAX_RETRIES before MODEL).
- Provide VLLM_API_KEY stub for secured deployments.
If most users start without a vLLM server, consider defaulting INFERENCE_BACKEND=hf in .env.example to avoid accidental misconfig.
- # choose backend # vllm or hf INFERENCE_BACKEND=vllm # extra vllm settings VLLM_BASE_URL=http://127.0.0.1:8000 -VLLM_MODEL=Qwen/Qwen3-4B-Instruct-2507 VLLM_TIMEOUT_S=30 VLLM_MAX_RETRIES=2 +VLLM_MODEL=Qwen/Qwen3-4B-Instruct-2507 +VLLM_API_KEY=grail/shared/constants.py (1)
21-27: Gate backend to supported set and centralize API key.Avoid unexpected values; add VLLM_API_KEY here for consistency with other constants.
-# Inference backend selection -INFERENCE_BACKEND = os.getenv("INFERENCE_BACKEND", "hf").lower() -VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "") -VLLM_MODEL = os.getenv("VLLM_MODEL", "Qwen/Qwen3-4B-Instruct-2507") -VLLM_TIMEOUT_S = float(os.getenv("VLLM_TIMEOUT_S", "30")) -VLLM_MAX_RETRIES = int(os.getenv("VLLM_MAX_RETRIES", "2")) +# Inference backend selection +_SUPPORTED_BACKENDS = {"hf", "vllm"} +_backend = os.getenv("INFERENCE_BACKEND", "hf").lower() +INFERENCE_BACKEND = _backend if _backend in _SUPPORTED_BACKENDS else "hf" +VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "") +VLLM_MODEL = os.getenv("VLLM_MODEL", "Qwen/Qwen3-4B-Instruct-2507") +VLLM_TIMEOUT_S = float(os.getenv("VLLM_TIMEOUT_S", "30")) +VLLM_MAX_RETRIES = int(os.getenv("VLLM_MAX_RETRIES", "2")) +VLLM_API_KEY = os.getenv("VLLM_API_KEY", "")grail/inference/vllm_client.py (3)
109-114: Use requests’ json= payload to avoid double-encoding.Slightly cleaner and sets header automatically.
- resp = self._session.post( - self.completions_url, - data=json.dumps(payload), - headers=self._headers(), - timeout=self.timeout, - ) + resp = self._session.post( + self.completions_url, + json=payload, + headers=self._headers(), + timeout=self.timeout, + )
120-133: Defensive check: empty choices.Surface a clear error if the server returns 200 with no choices.
- choices = data.get("choices", []) + choices = data.get("choices", []) + if not choices: + raise RuntimeError("vLLM server returned 200 but no choices in response")
14-19: Add simple exponential backoff between retries.Reduces hammering on transient failures.
import json import logging +import time import oslast_error: Optional[Exception] = None - for attempt in range(self.max_retries + 1): + for attempt in range(self.max_retries + 1): try: resp = self._session.post( @@ return results except Exception as e: # noqa: BLE001 last_error = e logger.warning(f"vLLM request failed (attempt {attempt+1}): {e}") + if attempt < self.max_retries: + # capped exponential backoff + delay = min(1.5 ** attempt, 5.0) + time.sleep(delay)If your vLLM OpenAI shim doesn’t accept non-standard fields (top_k, repetition_penalty, ignore_eos), confirm compatibility and remove unsupported keys to avoid 400s.
Also applies to: 106-137
grail/mining/rollout_generator.py (5)
176-179: Simplify vLLM gating expression.Minor readability tweak.
- use_vllm = ( - INFERENCE_BACKEND == "vllm" and isinstance(VLLM_BASE_URL, str) and len(VLLM_BASE_URL) > 0 - ) + use_vllm = INFERENCE_BACKEND == "vllm" and bool(VLLM_BASE_URL)
218-246: Avoid recomputing r_vec inside the loop.Compute once per GRPO group; it’s invariant to completions.
- for comp in comps: + # Precompute once + from ..grail import dot_mod_q, r_vec_from_randomness, sign_s_vals + r_vec = r_vec_from_randomness(randomness_hex, resolve_hidden_size(self.model)) + for comp in comps: @@ - # GRAIL proof + local logits in a single HF forward pass - from ..grail import dot_mod_q, r_vec_from_randomness, sign_s_vals - - r_vec = r_vec_from_randomness(randomness_hex, resolve_hidden_size(self.model)) s_vals: list[int] = []
434-437: Use log_softmax directly for numerical stability and parity with vLLM path.Matches the vLLM branch’s computation.
- score_dist = torch.softmax(scores[i][0], dim=-1) - token_logprob = torch.log(score_dist[token_id]).item() + log_probs = torch.log_softmax(scores[i][0], dim=-1) + token_logprob = log_probs[token_id].item()
341-356: Safer pad/eos ids in HF generate.Guard against missing eos_token_id to avoid runtime errors on some tokenizers.
- outputs = self.model.generate( + pad_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else self.tokenizer.pad_token_id + outputs = self.model.generate( prompt_ids, attention_mask=attention_mask, max_new_tokens=self.get_max_tokens(), temperature=self.get_temperature(), do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.1, output_scores=True, return_dict_in_generate=True, - pad_token_id=self.tokenizer.eos_token_id, - eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=pad_id, + eos_token_id=self.tokenizer.eos_token_id, )Confirm that
self.tokenizer.eos_token_idis set for the default model; otherwise generation may not terminate as expected.
218-292: Optional: batch the HF proof/logprob pass for vLLM completions.You currently run one forward per completion. Padding and batching the
all_token_idscan cut proof/LP latency notably at larger n.I can provide a batched variant if you want it in this PR.
scripts/benchmark_vllm.py (8)
42-51: Don’t silently swallow template application errors; log at debug.Swallowing exceptions hides misconfigurations (e.g., incompatible tokenizer). Log once at debug so issues are discoverable without polluting INFO.
- try: - tpl = _build_qwen_chat_template(SYSTEM_PROMPT, REASONING_START) - if getattr(tokenizer, "chat_template", None) != tpl: - tokenizer.chat_template = tpl - except Exception: - pass + try: + tpl = _build_qwen_chat_template(SYSTEM_PROMPT, REASONING_START) + if getattr(tokenizer, "chat_template", None) != tpl: + tokenizer.chat_template = tpl + except Exception as e: + logger.debug(f"chat_template apply skipped: {e}")
53-63: Fix duplicate “SAT Problem” header and unused loop index.Minor output nit and a tiny readability tweak.
- lines = [f"SAT Problem (seed: {seed[:8]}...):", f"Variables: {num_vars}", "Clauses:"] - for i in range(num_clauses): + lines = [f"SAT Problem (seed: {seed[:8]}...):", f"Variables: {num_vars}", "Clauses:"] + for _ in range(num_clauses): lines.append(f" (x1 OR NOT x2 OR x3)") instr = ( "Provide your final assignment between <SOLUTION></SOLUTION> as " "space-separated 0/1 values for x1..xN (e.g., <SOLUTION>0 1 0 1</SOLUTION>).\n" ) - return "\n".join(["SAT Problem:", "\n".join(lines), instr]) + return "\n".join(["\n".join(lines), instr])
117-126: Derive d_model robustly; hidden_size may be absent on some configs.Falling back to 4096 can cause shape mismatch if the model’s hidden size differs. Prefer checking common aliases and, as a last resort, the embedding dim.
- randomness_hex = "deadbeef" * 8 - r_vec = r_vec_from_randomness(randomness_hex, getattr(model.config, "hidden_size", 4096)) + randomness_hex = "deadbeef" * 8 + d_model = ( + getattr(model.config, "hidden_size", None) + or getattr(model.config, "n_embd", None) + or getattr(model.config, "d_model", None) + or int(getattr(model.get_input_embeddings(), "embedding_dim", 4096)) + ) + r_vec = r_vec_from_randomness(randomness_hex, d_model)
127-133: Vectorize s-value computation to reduce Python-loop overhead.Same math, measurable speedup on long sequences; keeps timing closer to “forward + projection” rather than “forward + per-token Python loop.”
- with torch.inference_mode(): - token_tensor = torch.tensor([all_ids], dtype=torch.long, device=model.device) - outs = model(token_tensor, output_hidden_states=True) - h_layer = outs.hidden_states[-1][0] - _svals = [dot_mod_q(h_layer[pos], r_vec) for pos in range(min(len(all_ids), h_layer.size(0)))] + with torch.inference_mode(): + token_tensor = torch.tensor([all_ids], dtype=torch.long, device=model.device) + outs = model(token_tensor, output_hidden_states=True) + h_layer = outs.hidden_states[-1][0] # [seq, hidden] + seq_len = min(len(all_ids), h_layer.size(0)) + scaled = torch.round(h_layer[:seq_len] * 1024.0) + # match dot_mod_q semantics: move r_vec to device and cast to float + svals = torch.matmul(scaled, r_vec.to(scaled.device).float()) + _ = (svals.long()) # materialize; modulo is not required for timing
145-152: Add safe fallback when safetensors are unavailable; keep CUDA move and eval.Some model cards don’t ship safetensors. Fall back cleanly without breaking the benchmark.
- model = ( - AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_safetensors=True) - .to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) - .eval() - ) + try: + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_safetensors=True) + except Exception as e: # noqa: BLE001 + logger.warning(f"use_safetensors=True failed, retrying without: {e}") + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")).eval()
137-140: Avoid ignoring CLI; wire minimal flags instead of parse_args([]).Let users tweak problem count without code changes; maintains “no flags required” by giving sensible defaults.
- parser = argparse.ArgumentParser(description="Benchmark vLLM vs HF miner defaults") - # Keep defaults fixed to simplify usage per request - args = parser.parse_args([]) + parser = argparse.ArgumentParser(description="Benchmark vLLM vs HF miner defaults") + parser.add_argument("--problems", type=int, default=2, help="number of prompt groups") + args = parser.parse_args()And later:
- problems = 2 # fixed per request + problems = int(args.problems)
221-235: Guard against None in per-group vLLM timings to avoid format errors.If a key is missing, formatting with “:.3f” on None raises. Safer defaults help future edits.
- vg = results.get("vllm", {}).get("gen_times", [None] * problems)[i] if "vllm" in results else None - vp = results.get("vllm", {}).get("proof_times", [None] * problems)[i] if "vllm" in results else None + vg = (results.get("vllm", {}).get("gen_times") or [float("nan")] * problems)[i] if "vllm" in results else None + vp = (results.get("vllm", {}).get("proof_times") or [float("nan")] * problems)[i] if "vllm" in results else None
88-115: Optional: allow reproducible HF runs with an explicit generator.Helpful for apples-to-apples backend comparisons.
- for _ in range(n): + gen = torch.Generator(device=model.device) # optionally seed via env + if os.getenv("BENCH_SEED"): + gen.manual_seed(int(os.getenv("BENCH_SEED"))) + for _ in range(n): with torch.inference_mode(): outs = model.generate( input_ids, attention_mask=attn, max_new_tokens=int(MAX_NEW_TOKENS), temperature=0.7, do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, + generator=gen if os.getenv("BENCH_SEED") else None, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
.env.example(1 hunks)docs/miner.md(1 hunks)grail/inference/__init__.py(1 hunks)grail/inference/vllm_client.py(1 hunks)grail/mining/rollout_generator.py(2 hunks)grail/shared/constants.py(1 hunks)scripts/benchmark_vllm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
grail/mining/rollout_generator.py (4)
grail/environments/sat.py (6)
create_environment(366-371)reset_environment(373-379)create_prompt(381-395)get_max_tokens(521-523)parse_action(397-416)step_environment(418-485)grail/inference/vllm_client.py (2)
VLLMClient(31-138)generate(57-138)grail/grail.py (3)
dot_mod_q(293-303)r_vec_from_randomness(122-214)sign_s_vals(306-327)grail/shared/hf_compat.py (1)
resolve_hidden_size(6-49)
scripts/benchmark_vllm.py (3)
grail/grail.py (2)
dot_mod_q(293-303)r_vec_from_randomness(122-214)grail/mining/rollout_generator.py (1)
_build_qwen_chat_template(45-80)grail/inference/vllm_client.py (2)
VLLMClient(31-138)generate(57-138)
🪛 dotenv-linter (3.3.0)
.env.example
[warning] 33-33: [ExtraBlankLine] Extra blank line detected
(ExtraBlankLine)
[warning] 42-42: [UnorderedKey] The VLLM_MAX_RETRIES key should go before the VLLM_MODEL key
(UnorderedKey)
🔇 Additional comments (1)
docs/miner.md (1)
58-70: Document vLLM auth and per-request parity
- Update docs/miner.md (lines 58–70, 73): if the vLLM server requires auth, start it with --api-key "$VLLM_API_KEY" and add VLLM_API_KEY to .env.
- Ensure server defaults do not override per-request knobs (temperature / top_p / top_k / ignore_eos); document server-side defaults and how to preserve per-request settings so validator parity is maintained.
- Verification: a quick curl to http://127.0.0.1:8000/v1/completions failed (curl: (7) Couldn't connect). Run this against your running vLLM to confirm the OpenAI-compatible endpoint honors top_k and ignore_eos:
curl -sS -X POST "${VLLM_BASE_URL:-http://127.0.0.1:8000}/v1/completions" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer ${VLLM_API_KEY:-}" \ -d '{"prompt":"test","n":1,"max_tokens":1,"temperature":0.7,"top_p":0.95,"top_k":50,"ignore_eos":false,"logprobs":1}' | jq '{choices, error}'
feat(miner): vLLM backend for GRPO (EOS termination, logprobs fallback); + benchmark
Description
Swaps HF model.generate in the miner for a vLLM backend with batched n‑completions per prompt, preserving parity and validator compatibility.
Keeps the existing Qwen chat template; we render with the tokenizer, send a plain prompt to vLLM, and re‑tokenize completion text for proofs/packing.
Sampling matches miner defaults (temperature=0.7, top_p=0.95, top_k=50, repetition_penalty=1.1, max_new_tokens=GRAIL_MAX_NEW_TOKENS). Requests use EOS termination (ignore_eos=false) to mirror HF behavior.
GRAIL proof path unchanged: one HF forward computes s_vals and commit‑binding.
Adds a small OpenAI‑style vLLM client and an opt‑in benchmark to compare vLLM vs HF under miner defaults.
Why vLLM serve?
Simple ops, stable batching via n, isolates CUDA mem from miner, and avoids tight coupling to vLLM internals (vs vllm.LLM/AsyncLLMEngine).
If GPU memory is tight, reduce vLLM --gpu-memory-utilization so the proof pass still fits.
Performance
~4× E2E throughput over HF on A100; successes/mean rewards consistent with HF when EOS termination used and server config doesn’t override per-request knobs.
Backwards compatibility
HF path unchanged when INFERENCE_BACKEND=hf
Summary by CodeRabbit
New Features
Documentation