Skip to content

Conversation

@eigenad
Copy link

@eigenad eigenad commented Sep 18, 2025

feat(miner): vLLM backend for GRPO (EOS termination, logprobs fallback); + benchmark

Description
Swaps HF model.generate in the miner for a vLLM backend with batched n‑completions per prompt, preserving parity and validator compatibility.
Keeps the existing Qwen chat template; we render with the tokenizer, send a plain prompt to vLLM, and re‑tokenize completion text for proofs/packing.
Sampling matches miner defaults (temperature=0.7, top_p=0.95, top_k=50, repetition_penalty=1.1, max_new_tokens=GRAIL_MAX_NEW_TOKENS). Requests use EOS termination (ignore_eos=false) to mirror HF behavior.
GRAIL proof path unchanged: one HF forward computes s_vals and commit‑binding.
Adds a small OpenAI‑style vLLM client and an opt‑in benchmark to compare vLLM vs HF under miner defaults.

Why vLLM serve?
Simple ops, stable batching via n, isolates CUDA mem from miner, and avoids tight coupling to vLLM internals (vs vllm.LLM/AsyncLLMEngine).

If GPU memory is tight, reduce vLLM --gpu-memory-utilization so the proof pass still fits.

Performance
~4× E2E throughput over HF on A100; successes/mean rewards consistent with HF when EOS termination used and server config doesn’t override per-request knobs.

Backwards compatibility
HF path unchanged when INFERENCE_BACKEND=hf

Summary by CodeRabbit

  • New Features

    • Optional vLLM inference backend with environment-based configuration (backend selector, base URL, model, timeout, retries).
    • Faster rollout generation via batch completions when using vLLM, with automatic fallback to the existing path if unavailable.
    • Benchmarking script to compare vLLM and Hugging Face backends, reporting generation and end-to-end performance.
  • Documentation

    • Updated miner setup guide: streamlined environment steps and added an optional vLLM workflow with server startup instructions.
    • Quick Start now references selecting vLLM or Hugging Face backends alongside wallet/R2 configuration.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 18, 2025

Walkthrough

Introduces vLLM inference backend support alongside existing HF path: adds env vars and constants, a vLLM HTTP client, a fast-path in GRPO rollout generation, documentation updates, and a benchmark script. Also initializes an inference package namespace and updates miner docs to include optional vLLM workflow.

Changes

Cohort / File(s) Summary
Environment & constants
.env.example, grail/shared/constants.py
Adds INFERENCE_BACKEND and vLLM-related env vars; exposes constants (VLLM_BASE_URL, VLLM_MODEL, VLLM_TIMEOUT_S, VLLM_MAX_RETRIES).
Inference package & client
grail/inference/__init__.py, grail/inference/vllm_client.py
Creates inference package; implements vLLMClient targeting OpenAI-style /v1/completions with retries, timeout, API key, and optional token logprobs extraction.
Mining rollout fast-path
grail/mining/rollout_generator.py
Adds vLLM fast-path for batched completions per prompt when INFERENCE_BACKEND=="vllm"; computes rollouts, logprobs handling, signing, and falls back to HF on error; adjusts env init per GRPO group; computes advantages post-collection.
Docs
docs/miner.md
Updates setup flow and adds optional vLLM backend instructions and server run example; adjusts quick start guidance.
Benchmark script
scripts/benchmark_vllm.py
Adds benchmarking for vLLM vs HF, including prompt building, fake SAT problem generation, per-backend runs, optional proof pass timing, and summarized metrics.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as Miner (GRPO)
  participant C as VLLMClient
  participant S as vLLM Server
  participant H as HF Model (fallback)

  Note over U: Generate GRPO rollouts for a group

  alt INFERENCE_BACKEND == "vllm" and VLLM_BASE_URL set
    U->>U: Build chat-templated prompt
    U->>C: generate(prompt, n, params)
    C->>S: POST /v1/completions
    S-->>C: 200 JSON (choices, logprobs?)
    C-->>U: texts [+ optional token logprobs]
    rect rgba(200,255,200,0.2)
      Note right of U: For each completion: parse action, compute reward, tokenize, derive s_vals, sign, assemble rollout
    end
  else Fallback
    U->>H: _generate_single_rollout() per rollout
    H-->>U: rollout artifacts
  end

  U->>U: Compute advantages across group
  U-->>U: Return GRPO rollouts
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Poem

I twitch my ears at lightning speed,
New paths to think, as prompts stampede—
vLLM’s hum, a gentle breeze,
HF stands by, in case it sneezes.
Benchmarks hop, the numbers gleam—
Carrots for throughput, dream by dream. 🥕🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely and accurately summarizes the primary change: adding a vLLM inference backend for the miner's GRPO flow (including EOS termination and logprobs fallback) and adding a benchmark. It specifically names the affected subsystem (miner) and the key behaviors, matching the changes in the diff. This phrasing will be clear to teammates scanning PR history.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (19)
grail/inference/__init__.py (1)

1-4: Drop shebang from package init (not executable).

Library modules don’t need a shebang; remove for cleanliness.

-#!/usr/bin/env python3
 # Namespace for inference backends (vLLM, etc.)
.env.example (1)

33-43: Fix dotenv lint nits and add optional API key placeholder.

  • Remove extra blank line (line 33).
  • Order keys to satisfy dotenv-linter (MAX_RETRIES before MODEL).
  • Provide VLLM_API_KEY stub for secured deployments.

If most users start without a vLLM server, consider defaulting INFERENCE_BACKEND=hf in .env.example to avoid accidental misconfig.

-
 # choose backend
 # vllm or hf
 INFERENCE_BACKEND=vllm

 # extra vllm settings
 VLLM_BASE_URL=http://127.0.0.1:8000
-VLLM_MODEL=Qwen/Qwen3-4B-Instruct-2507
 VLLM_TIMEOUT_S=30
 VLLM_MAX_RETRIES=2
+VLLM_MODEL=Qwen/Qwen3-4B-Instruct-2507
+VLLM_API_KEY=
grail/shared/constants.py (1)

21-27: Gate backend to supported set and centralize API key.

Avoid unexpected values; add VLLM_API_KEY here for consistency with other constants.

-# Inference backend selection
-INFERENCE_BACKEND = os.getenv("INFERENCE_BACKEND", "hf").lower()
-VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "")
-VLLM_MODEL = os.getenv("VLLM_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
-VLLM_TIMEOUT_S = float(os.getenv("VLLM_TIMEOUT_S", "30"))
-VLLM_MAX_RETRIES = int(os.getenv("VLLM_MAX_RETRIES", "2"))
+# Inference backend selection
+_SUPPORTED_BACKENDS = {"hf", "vllm"}
+_backend = os.getenv("INFERENCE_BACKEND", "hf").lower()
+INFERENCE_BACKEND = _backend if _backend in _SUPPORTED_BACKENDS else "hf"
+VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "")
+VLLM_MODEL = os.getenv("VLLM_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
+VLLM_TIMEOUT_S = float(os.getenv("VLLM_TIMEOUT_S", "30"))
+VLLM_MAX_RETRIES = int(os.getenv("VLLM_MAX_RETRIES", "2"))
+VLLM_API_KEY = os.getenv("VLLM_API_KEY", "")
grail/inference/vllm_client.py (3)

109-114: Use requests’ json= payload to avoid double-encoding.

Slightly cleaner and sets header automatically.

-                resp = self._session.post(
-                    self.completions_url,
-                    data=json.dumps(payload),
-                    headers=self._headers(),
-                    timeout=self.timeout,
-                )
+                resp = self._session.post(
+                    self.completions_url,
+                    json=payload,
+                    headers=self._headers(),
+                    timeout=self.timeout,
+                )

120-133: Defensive check: empty choices.

Surface a clear error if the server returns 200 with no choices.

-                choices = data.get("choices", [])
+                choices = data.get("choices", [])
+                if not choices:
+                    raise RuntimeError("vLLM server returned 200 but no choices in response")

14-19: Add simple exponential backoff between retries.

Reduces hammering on transient failures.

 import json
 import logging
+import time
 import os
         last_error: Optional[Exception] = None
-        for attempt in range(self.max_retries + 1):
+        for attempt in range(self.max_retries + 1):
             try:
                 resp = self._session.post(
@@
                 return results
             except Exception as e:  # noqa: BLE001
                 last_error = e
                 logger.warning(f"vLLM request failed (attempt {attempt+1}): {e}")
+                if attempt < self.max_retries:
+                    # capped exponential backoff
+                    delay = min(1.5 ** attempt, 5.0)
+                    time.sleep(delay)

If your vLLM OpenAI shim doesn’t accept non-standard fields (top_k, repetition_penalty, ignore_eos), confirm compatibility and remove unsupported keys to avoid 400s.

Also applies to: 106-137

grail/mining/rollout_generator.py (5)

176-179: Simplify vLLM gating expression.

Minor readability tweak.

-        use_vllm = (
-            INFERENCE_BACKEND == "vllm" and isinstance(VLLM_BASE_URL, str) and len(VLLM_BASE_URL) > 0
-        )
+        use_vllm = INFERENCE_BACKEND == "vllm" and bool(VLLM_BASE_URL)

218-246: Avoid recomputing r_vec inside the loop.

Compute once per GRPO group; it’s invariant to completions.

-                for comp in comps:
+                # Precompute once
+                from ..grail import dot_mod_q, r_vec_from_randomness, sign_s_vals
+                r_vec = r_vec_from_randomness(randomness_hex, resolve_hidden_size(self.model))
+                for comp in comps:
@@
-                    # GRAIL proof + local logits in a single HF forward pass
-                    from ..grail import dot_mod_q, r_vec_from_randomness, sign_s_vals
-
-                    r_vec = r_vec_from_randomness(randomness_hex, resolve_hidden_size(self.model))
                     s_vals: list[int] = []

434-437: Use log_softmax directly for numerical stability and parity with vLLM path.

Matches the vLLM branch’s computation.

-                score_dist = torch.softmax(scores[i][0], dim=-1)
-                token_logprob = torch.log(score_dist[token_id]).item()
+                log_probs = torch.log_softmax(scores[i][0], dim=-1)
+                token_logprob = log_probs[token_id].item()

341-356: Safer pad/eos ids in HF generate.

Guard against missing eos_token_id to avoid runtime errors on some tokenizers.

-            outputs = self.model.generate(
+            pad_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else self.tokenizer.pad_token_id
+            outputs = self.model.generate(
                 prompt_ids,
                 attention_mask=attention_mask,
                 max_new_tokens=self.get_max_tokens(),
                 temperature=self.get_temperature(),
                 do_sample=True,
                 top_p=0.95,
                 top_k=50,
                 repetition_penalty=1.1,
                 output_scores=True,
                 return_dict_in_generate=True,
-                pad_token_id=self.tokenizer.eos_token_id,
-                eos_token_id=self.tokenizer.eos_token_id,
+                pad_token_id=pad_id,
+                eos_token_id=self.tokenizer.eos_token_id,
             )

Confirm that self.tokenizer.eos_token_id is set for the default model; otherwise generation may not terminate as expected.


218-292: Optional: batch the HF proof/logprob pass for vLLM completions.

You currently run one forward per completion. Padding and batching the all_token_ids can cut proof/LP latency notably at larger n.

I can provide a batched variant if you want it in this PR.

scripts/benchmark_vllm.py (8)

42-51: Don’t silently swallow template application errors; log at debug.

Swallowing exceptions hides misconfigurations (e.g., incompatible tokenizer). Log once at debug so issues are discoverable without polluting INFO.

-    try:
-        tpl = _build_qwen_chat_template(SYSTEM_PROMPT, REASONING_START)
-        if getattr(tokenizer, "chat_template", None) != tpl:
-            tokenizer.chat_template = tpl
-    except Exception:
-        pass
+    try:
+        tpl = _build_qwen_chat_template(SYSTEM_PROMPT, REASONING_START)
+        if getattr(tokenizer, "chat_template", None) != tpl:
+            tokenizer.chat_template = tpl
+    except Exception as e:
+        logger.debug(f"chat_template apply skipped: {e}")

53-63: Fix duplicate “SAT Problem” header and unused loop index.

Minor output nit and a tiny readability tweak.

-    lines = [f"SAT Problem (seed: {seed[:8]}...):", f"Variables: {num_vars}", "Clauses:"]
-    for i in range(num_clauses):
+    lines = [f"SAT Problem (seed: {seed[:8]}...):", f"Variables: {num_vars}", "Clauses:"]
+    for _ in range(num_clauses):
         lines.append(f"  (x1 OR NOT x2 OR x3)")
     instr = (
         "Provide your final assignment between <SOLUTION></SOLUTION> as "
         "space-separated 0/1 values for x1..xN (e.g., <SOLUTION>0 1 0 1</SOLUTION>).\n"
     )
-    return "\n".join(["SAT Problem:", "\n".join(lines), instr])
+    return "\n".join(["\n".join(lines), instr])

117-126: Derive d_model robustly; hidden_size may be absent on some configs.

Falling back to 4096 can cause shape mismatch if the model’s hidden size differs. Prefer checking common aliases and, as a last resort, the embedding dim.

-    randomness_hex = "deadbeef" * 8
-    r_vec = r_vec_from_randomness(randomness_hex, getattr(model.config, "hidden_size", 4096))
+    randomness_hex = "deadbeef" * 8
+    d_model = (
+        getattr(model.config, "hidden_size", None)
+        or getattr(model.config, "n_embd", None)
+        or getattr(model.config, "d_model", None)
+        or int(getattr(model.get_input_embeddings(), "embedding_dim", 4096))
+    )
+    r_vec = r_vec_from_randomness(randomness_hex, d_model)

127-133: Vectorize s-value computation to reduce Python-loop overhead.

Same math, measurable speedup on long sequences; keeps timing closer to “forward + projection” rather than “forward + per-token Python loop.”

-    with torch.inference_mode():
-        token_tensor = torch.tensor([all_ids], dtype=torch.long, device=model.device)
-        outs = model(token_tensor, output_hidden_states=True)
-        h_layer = outs.hidden_states[-1][0]
-        _svals = [dot_mod_q(h_layer[pos], r_vec) for pos in range(min(len(all_ids), h_layer.size(0)))]
+    with torch.inference_mode():
+        token_tensor = torch.tensor([all_ids], dtype=torch.long, device=model.device)
+        outs = model(token_tensor, output_hidden_states=True)
+        h_layer = outs.hidden_states[-1][0]  # [seq, hidden]
+        seq_len = min(len(all_ids), h_layer.size(0))
+        scaled = torch.round(h_layer[:seq_len] * 1024.0)
+        # match dot_mod_q semantics: move r_vec to device and cast to float
+        svals = torch.matmul(scaled, r_vec.to(scaled.device).float())
+        _ = (svals.long())  # materialize; modulo is not required for timing

145-152: Add safe fallback when safetensors are unavailable; keep CUDA move and eval.

Some model cards don’t ship safetensors. Fall back cleanly without breaking the benchmark.

-    model = (
-        AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_safetensors=True)
-        .to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-        .eval()
-    )
+    try:
+        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_safetensors=True)
+    except Exception as e:  # noqa: BLE001
+        logger.warning(f"use_safetensors=True failed, retrying without: {e}")
+        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
+    model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")).eval()

137-140: Avoid ignoring CLI; wire minimal flags instead of parse_args([]).

Let users tweak problem count without code changes; maintains “no flags required” by giving sensible defaults.

-    parser = argparse.ArgumentParser(description="Benchmark vLLM vs HF miner defaults")
-    # Keep defaults fixed to simplify usage per request
-    args = parser.parse_args([])
+    parser = argparse.ArgumentParser(description="Benchmark vLLM vs HF miner defaults")
+    parser.add_argument("--problems", type=int, default=2, help="number of prompt groups")
+    args = parser.parse_args()

And later:

-    problems = 2  # fixed per request
+    problems = int(args.problems)

221-235: Guard against None in per-group vLLM timings to avoid format errors.

If a key is missing, formatting with “:.3f” on None raises. Safer defaults help future edits.

-        vg = results.get("vllm", {}).get("gen_times", [None] * problems)[i] if "vllm" in results else None
-        vp = results.get("vllm", {}).get("proof_times", [None] * problems)[i] if "vllm" in results else None
+        vg = (results.get("vllm", {}).get("gen_times") or [float("nan")] * problems)[i] if "vllm" in results else None
+        vp = (results.get("vllm", {}).get("proof_times") or [float("nan")] * problems)[i] if "vllm" in results else None

88-115: Optional: allow reproducible HF runs with an explicit generator.

Helpful for apples-to-apples backend comparisons.

-    for _ in range(n):
+    gen = torch.Generator(device=model.device)  # optionally seed via env
+    if os.getenv("BENCH_SEED"):
+        gen.manual_seed(int(os.getenv("BENCH_SEED")))
+    for _ in range(n):
         with torch.inference_mode():
             outs = model.generate(
                 input_ids,
                 attention_mask=attn,
                 max_new_tokens=int(MAX_NEW_TOKENS),
                 temperature=0.7,
                 do_sample=True,
                 top_p=0.95,
                 top_k=50,
                 repetition_penalty=1.1,
                 pad_token_id=tokenizer.eos_token_id,
                 eos_token_id=tokenizer.eos_token_id,
                 return_dict_in_generate=True,
+                generator=gen if os.getenv("BENCH_SEED") else None,
             )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d9f6d5a and c005aca.

📒 Files selected for processing (7)
  • .env.example (1 hunks)
  • docs/miner.md (1 hunks)
  • grail/inference/__init__.py (1 hunks)
  • grail/inference/vllm_client.py (1 hunks)
  • grail/mining/rollout_generator.py (2 hunks)
  • grail/shared/constants.py (1 hunks)
  • scripts/benchmark_vllm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
grail/mining/rollout_generator.py (4)
grail/environments/sat.py (6)
  • create_environment (366-371)
  • reset_environment (373-379)
  • create_prompt (381-395)
  • get_max_tokens (521-523)
  • parse_action (397-416)
  • step_environment (418-485)
grail/inference/vllm_client.py (2)
  • VLLMClient (31-138)
  • generate (57-138)
grail/grail.py (3)
  • dot_mod_q (293-303)
  • r_vec_from_randomness (122-214)
  • sign_s_vals (306-327)
grail/shared/hf_compat.py (1)
  • resolve_hidden_size (6-49)
scripts/benchmark_vllm.py (3)
grail/grail.py (2)
  • dot_mod_q (293-303)
  • r_vec_from_randomness (122-214)
grail/mining/rollout_generator.py (1)
  • _build_qwen_chat_template (45-80)
grail/inference/vllm_client.py (2)
  • VLLMClient (31-138)
  • generate (57-138)
🪛 dotenv-linter (3.3.0)
.env.example

[warning] 33-33: [ExtraBlankLine] Extra blank line detected

(ExtraBlankLine)


[warning] 42-42: [UnorderedKey] The VLLM_MAX_RETRIES key should go before the VLLM_MODEL key

(UnorderedKey)

🔇 Additional comments (1)
docs/miner.md (1)

58-70: Document vLLM auth and per-request parity

  • Update docs/miner.md (lines 58–70, 73): if the vLLM server requires auth, start it with --api-key "$VLLM_API_KEY" and add VLLM_API_KEY to .env.
  • Ensure server defaults do not override per-request knobs (temperature / top_p / top_k / ignore_eos); document server-side defaults and how to preserve per-request settings so validator parity is maintained.
  • Verification: a quick curl to http://127.0.0.1:8000/v1/completions failed (curl: (7) Couldn't connect). Run this against your running vLLM to confirm the OpenAI-compatible endpoint honors top_k and ignore_eos:
curl -sS -X POST "${VLLM_BASE_URL:-http://127.0.0.1:8000}/v1/completions" \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer ${VLLM_API_KEY:-}" \
  -d '{"prompt":"test","n":1,"max_tokens":1,"temperature":0.7,"top_p":0.95,"top_k":50,"ignore_eos":false,"logprobs":1}' | jq '{choices, error}'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant