-
Notifications
You must be signed in to change notification settings - Fork 166
DRAFT: feat: Enable simulated user for multi-turn GRPO #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Jialei Chen <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
Signed-off-by: Ahmad Kiswani <[email protected]>
📝 WalkthroughWalkthroughThis PR introduces a new simulated user environment called "unique numbers" that uses Google ADK agents to simulate user-agent interactions within a GRPO training framework. It includes configuration files, environment implementation, utility functions, an example training script, infrastructure updates, and comprehensive tests. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Training User
participant GRPO as GRPO Training
participant UniqueEnv as Unique Numbers Env
participant SimUser as Simulated User Runner
participant Grader as Grader Runner
User->>GRPO: Start training with config
GRPO->>UniqueEnv: Initialize environment
UniqueEnv->>SimUser: Create ADK agent
UniqueEnv->>Grader: Create ADK agent
loop Per training step
GRPO->>UniqueEnv: step(message_log, metadata)
UniqueEnv->>SimUser: extract last assistant message
UniqueEnv->>SimUser: run_prompt_async(query or statement)
SimUser-->>UniqueEnv: simulated user response
UniqueEnv->>UniqueEnv: check if guess pattern matched
alt Guess detected
UniqueEnv->>UniqueEnv: compute reward (correct/incorrect)
UniqueEnv->>Grader: run_prompt_async(grade conversation)
Grader-->>UniqueEnv: optional score adjustment
UniqueEnv-->>GRPO: EnvironmentReturn (reward, terminated=True)
else Query or other
UniqueEnv-->>GRPO: EnvironmentReturn (response, turn increment)
end
end
GRPO->>UniqueEnv: shutdown()
UniqueEnv->>SimUser: cleanup
UniqueEnv->>Grader: cleanup
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes This PR introduces substantial new functionality spanning multiple files with heterogeneous changes: new environment class with orchestration logic, async utilities with retry mechanisms, integration with external ADK library, rollout processing modifications, and comprehensive test coverage. The logic density is moderate-to-high, with coordination between simulated user/grader runners and reward computation requiring careful review. Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
19-26: Honor NEMO_RL_PY_EXECUTABLES_SYSTEM for ADK for parity.Other entries (VLLM/MCORE) respect the system override; do the same for ADK.
Apply:
USE_SYSTEM_EXECUTABLE = os.environ.get("NEMO_RL_PY_EXECUTABLES_SYSTEM", "0") == "1" VLLM_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM ) MCORE_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE ) +ADK_EXECUTABLE = ( + PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.ADK +) @@ - "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK, + "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": ADK_EXECUTABLE,Also applies to: 40-41
nemo_rl/environments/interfaces.py (1)
80-82: Update step() return docs to include answers.Docstring enumerates 5 fields; add the optional
answersto avoid confusion.Apply:
- - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. + - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, terminateds, and answers (optional).nemo_rl/experience/rollouts.py (2)
475-491: Bug: Tensor used inifcondition; will raise truth-value error.
active_input_lengths[i]is a Tensor;if (... + active_input_lengths[i] >= max_seq_len)yields a Tensor boolean, which is invalid inif.Apply:
- if ( - len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i] - >= max_seq_len - ): + input_len = int(active_input_lengths[i].item()) + if len(tokenized_obs) + len(generated_ids[i]) + input_len >= max_seq_len: tokens_left_for_obs = max_seq_len - ( - len(generated_ids[i]) + active_input_lengths[i] + len(generated_ids[i]) + input_len )
758-766: Bug: Tensor truth value inifcondition (single-sample path).
input_lengthsis a Tensor; comparing directly inifis invalid.Apply:
- if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len: + input_len = int(input_lengths.item()) + if input_len + gen_token_count + len(tokenized_obs) >= max_seq_len: # Truncate environment observation - max_env_tokens = max_seq_len - input_lengths - gen_token_count + max_env_tokens = max_seq_len - input_len - gen_token_count
🧹 Nitpick comments (13)
examples/configs/grpo_adk_llama8b.yaml (1)
37-43: Confirm intentional batching differences vs Gemma config.
dynamic_batching.enabledis False here but True in the Gemma config. If this is model‑specific tuning, consider a short YAML comment noting why.nemo_rl/experience/rollouts.py (1)
376-380: Prefer logger over print for per-turn progress.Switch to
logging.getLogger(__name__).info/debugand allow callers to control verbosity.To support this outside the hunk, add:
# at top-level import logging logger = logging.getLogger(__name__)Then:
- if max_rollout_turns > 1: - print( - f"▶ ▶ ▶ Running rollout turn {turn + 1} / {max_rollout_turns} with {len(active_indices)} active samples..." - ) + if max_rollout_turns > 1: + logger.info( + "▶ ▶ ▶ Running rollout turn %d / %d with %d active samples...", + turn + 1, max_rollout_turns, len(active_indices) + )nemo_rl/environments/simulated_user/prompt.py (3)
1-7: Promote constants to UPPER_SNAKE_CASE and keep aliases.Module-level prompt strings are constants. Rename to UPPER_SNAKE_CASE and keep lowercase aliases for compatibility.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from typing import Final + - starting_user_prompt = ( + STARTING_USER_PROMPT: Final[str] = ( "I will play a game with you. I have a list of integers in mind and can NOT tell you. " "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: " "You can either ask me 'what is number k?' to get the number at position k in my list, " "or answer 'there are m unique numbers' whenever you feel you want to make a guess. " "Please do not say anything else. You cannot ask me to provide the list of integers." ) +simulated_user_instruction = SIMULATED_USER_INSTRUCTION = SIMULATED_USER_INSTRUCTION # type: ignore # back-compat alias +starting_user_prompt = STARTING_USER_PROMPT # back-compat alias +grader_instruction = GRADER_INSTRUCTION # back-compat aliasFollow-up diff below adjusts definitions and typos.
As per coding guidelines.
2-7: Fix grammar in the starting prompt.Minor clarity/grammar tweaks.
Apply:
- "I will play a game with you. I have a list of integers in mind and can NOT tell you. " - "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: " + "I will play a game with you. I have a list of integers that I will not reveal. " + "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do are: "
10-19: Use consistent naming and strip once.Define as constant and avoid trailing strip duplication.
Apply:
-simulated_user_instruction = """ +SIMULATED_USER_INSTRUCTION: Final[str] = """ ... -""".strip() +""".strip()tests/unit/environments/test_simulated_user.py (2)
90-107: Fixture stubs: keep scope local and silence ruff ARG warnings.Good isolation overall. Consider prefixing unused fixture/function args with “_” or add “# noqa: ARG001/ARG005” where appropriate to keep linters quiet without affecting readability.
220-249: Retry test is solid; minor style nit.
monkeypatcharg is unused. Prefix with_monkeypatchfor clarity.examples/run_grpo_unique_numbers_w_adk.py (2)
195-198: Timezone: avoid manual UTC offsets.Use zoneinfo to format local time reliably (handles DST).
Apply:
-from datetime import datetime, timedelta +from datetime import datetime +from zoneinfo import ZoneInfo ... - now_pst = datetime.utcnow() + timedelta(hours=-7) + now_pst = datetime.now(ZoneInfo("America/Los_Angeles"))
229-240: Unused variable ‘cluster’.Prefix with underscore to silence linters.
Apply:
- cluster, + _cluster,nemo_rl/environments/simulated_user/adk_utils.py (3)
22-36: Minor grammar in default instruction.“help people” → “helps people”.
Apply:
- instruction=instruction - or "You are a helpful assistant that help people answer questions.", + instruction=instruction + or "You are a helpful assistant that helps people answer questions.",
100-116: Use logger.exception and tighten generic except.Prefer
logger.exceptionto capture traceback. Keep a narrowServerErrorexcept, and retain a broad fallback but clearly log traceback.Apply:
- except ServerError as e: + except ServerError as e: retries += 1 delay_with_jitter = delay + (random.random() * 2 - 1) * (delay * 0.5) - logger.error( + logger.exception( f"Gemini API call (with message {new_message}) failed with ServerError {e} (attempt {retries}/{max_retries}). Retrying in {delay_with_jitter} seconds..." ) await asyncio.sleep(delay_with_jitter) delay *= 2 # Exponential backoff - except Exception as e: - logger.error( + except Exception as e: # keep as last-resort + logger.exception( f"Gemini API call (with message {new_message}) failed with an unexpected error: {e}." ) return f"<No response due to unexpected error: {e}>" - logger.error( + logger.error( f"Gemini API call (with message {new_message}) reached maximum retries ({max_retries}) without success." ) - return f"<No response due after {max_retries} retries>" + return f"<No response after {max_retries} retries>"
39-46: Session access assertions are brittle.Assuming exactly one app/user/session can break multi-user scenarios. Return helpful errors or search by keys defensively.
Apply (illustrative):
- assert len(app_session_map) == 1, "Expected exactly one app in session_service" - user_sessions_map = next(iter(app_session_map.values())) - sessions = user_sessions_map[user_id] - assert len(sessions) == 1, "Expected exactly one user in app session" - return next(iter(sessions.values())) + if not app_session_map: + raise RuntimeError("No sessions available in session_service") + # Prefer the first app containing the user_id + for user_sessions_map in app_session_map.values(): + if user_id in user_sessions_map: + sessions = user_sessions_map[user_id] + if not sessions: + raise RuntimeError(f"No sessions for user_id={user_id}") + return next(iter(sessions.values())) + raise KeyError(f"user_id={user_id} not found in session_service")nemo_rl/environments/simulated_user/unique_numbers.py (1)
246-261: zip(strict=...) for clarity and static analysis.Be explicit with
strict=Falseto document intent and silence linters.Apply:
- for log, meta in zip(message_log_batch, metadata): + for log, meta in zip(message_log_batch, metadata, strict=False):
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (13)
examples/configs/grpo_adk_gemma.yaml(1 hunks)examples/configs/grpo_adk_llama8b.yaml(1 hunks)examples/run_grpo_unique_numbers_w_adk.py(1 hunks)nemo_rl/distributed/ray_actor_environment_registry.py(1 hunks)nemo_rl/distributed/virtual_cluster.py(1 hunks)nemo_rl/environments/interfaces.py(1 hunks)nemo_rl/environments/simulated_user/adk_utils.py(1 hunks)nemo_rl/environments/simulated_user/prompt.py(1 hunks)nemo_rl/environments/simulated_user/unique_numbers.py(1 hunks)nemo_rl/experience/rollouts.py(6 hunks)pyproject.toml(1 hunks)pyrefly.toml(2 hunks)tests/unit/environments/test_simulated_user.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/distributed/ray_actor_environment_registry.pynemo_rl/environments/simulated_user/prompt.pynemo_rl/environments/interfaces.pynemo_rl/environments/simulated_user/adk_utils.pynemo_rl/distributed/virtual_cluster.pyexamples/run_grpo_unique_numbers_w_adk.pynemo_rl/environments/simulated_user/unique_numbers.pynemo_rl/experience/rollouts.pytests/unit/environments/test_simulated_user.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/distributed/ray_actor_environment_registry.pynemo_rl/environments/simulated_user/prompt.pynemo_rl/environments/interfaces.pynemo_rl/environments/simulated_user/adk_utils.pynemo_rl/distributed/virtual_cluster.pynemo_rl/environments/simulated_user/unique_numbers.pynemo_rl/experience/rollouts.py
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_adk_gemma.yamlexamples/configs/grpo_adk_llama8b.yaml
🧬 Code graph analysis (6)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
PY_EXECUTABLES(42-60)
nemo_rl/environments/simulated_user/adk_utils.py (1)
tests/unit/environments/test_simulated_user.py (2)
from_text(26-27)create_session(125-130)
examples/run_grpo_unique_numbers_w_adk.py (8)
nemo_rl/algorithms/utils.py (1)
get_tokenizer(157-288)nemo_rl/data/interfaces.py (1)
DatumSpec(32-40)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(50-65)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(86-172)nemo_rl/environments/simulated_user/unique_numbers.py (2)
UniqueNumbersEnv(229-296)UniqueNumbersMetadata(44-52)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(24-45)nemo_rl/utils/config.py (1)
parse_hydra_overrides(146-166)nemo_rl/utils/logger.py (1)
get_next_experiment_dir(1311-1345)
nemo_rl/environments/simulated_user/unique_numbers.py (2)
nemo_rl/environments/interfaces.py (2)
EnvironmentInterface(52-88)EnvironmentReturn(26-49)nemo_rl/environments/simulated_user/adk_utils.py (3)
extract_conversation_history(52-61)create_agent(14-36)run_prompt_async(64-116)
nemo_rl/experience/rollouts.py (1)
tests/unit/data/test_data_processor.py (1)
apply_chat_template(45-57)
tests/unit/environments/test_simulated_user.py (2)
nemo_rl/environments/simulated_user/unique_numbers.py (1)
_UniqueNumbersRunner(55-225)nemo_rl/environments/simulated_user/adk_utils.py (2)
run_prompt_async(64-116)extract_conversation_history(52-61)
🪛 Ruff (0.14.1)
nemo_rl/environments/simulated_user/adk_utils.py
101-101: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
102-104: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
107-107: Do not catch blind exception: Exception
(BLE001)
108-110: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
175-175: Unpacked variable convo2 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/run_grpo_unique_numbers_w_adk.py
82-82: Unused function argument: add_system_prompt
(ARG001)
100-100: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
101-101: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
232-232: Unpacked variable cluster is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
nemo_rl/environments/simulated_user/unique_numbers.py
197-197: Do not catch blind exception: Exception
(BLE001)
246-246: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tests/unit/environments/test_simulated_user.py
92-92: Unused function argument: user_id
(ARG001)
92-92: Unused function argument: silence
(ARG001)
105-105: Unused lambda argument: a
(ARG005)
105-105: Unused lambda argument: k
(ARG005)
166-166: Unused function argument: patch_unique_numbers
(ARG001)
180-180: Unused function argument: patch_unique_numbers
(ARG001)
194-194: Unused function argument: patch_unique_numbers
(ARG001)
207-207: Unused function argument: patch_unique_numbers
(ARG001)
220-220: Unused function argument: patch_adk_utils
(ARG001)
228-228: Unused function argument: monkeypatch
(ARG001)
228-228: Unused function argument: patch_adk_utils
(ARG001)
252-252: Unused function argument: patch_adk_utils
(ARG001)
🔇 Additional comments (5)
pyrefly.toml (1)
19-20: Stub additions look good.Import stubs for
google.adk.*andgoogle.genai.*plus includes for simulated_user modules are appropriate for static checks.Also applies to: 89-92
nemo_rl/distributed/virtual_cluster.py (1)
60-61: ADK executable entry is consistent.Matches the existing pattern for extras-based executables.
pyproject.toml (1)
104-107: Verify ADK extra resolves and imports using pipThe
uvcommand isn’t available; please run:python3 -m venv venv source venv/bin/activate pip install --upgrade pip pip install .[adk] python - <<'PY' import google.adk, google.genai print("OK", google.adk.__version__, getattr(google.genai, "__version__", "n/a")) PYnemo_rl/environments/simulated_user/unique_numbers.py (1)
145-156: Async orchestration inside thread pool: verify no running event loop.
asyncio.runinside threads works, but if an event loop is active it raisesRuntimeError. Consider extracting an asyncprocess_turn_asyncand running it at the caller, or add a small helper that usesasyncio.get_running_loop()and falls back appropriately.Would you like a patch to make
process_turnasync and wire it throughThreadPoolExecutorwithasyncio.runat the top-level call site?examples/run_grpo_unique_numbers_w_adk.py (1)
186-193: Default config and registry verified:examples/configs/grpo_adk_llama8b.yamlis present andUniqueNumbersEnvis correctly registered in the environment registry.
| env: | ||
| unique_numbers: | ||
| cfg: | ||
| max_turns: 15 | ||
| min_length: 5 | ||
| max_length: 10 | ||
| max_integer: 15 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Document defaults for new env cfg keys in exemplar YAML.
Per repo guideline, exemplar configs must include documented defaults. Please add brief comments describing each key’s purpose and recommended default. Also confirm that defaults: "grpo_math_8B.yaml" resolves correctly at runtime.
Apply this inline comment diff in-place:
env:
unique_numbers:
cfg:
- max_turns: 15
- min_length: 5
- max_length: 10
- max_integer: 15
+ # Maximum dialogue turns per episode before forced termination (recommended default: 15)
+ max_turns: 15
+ # Minimum required unique numbers the agent must produce (recommended default: 5)
+ min_length: 5
+ # Maximum allowed unique numbers the agent may produce (recommended default: 10)
+ max_length: 10
+ # Upper bound on integer values the agent may choose (recommended default: 15)
+ max_integer: 15If grpo_math_8B.yaml lives outside the same directory, please adjust the path accordingly or add a short comment explaining search semantics.
🏁 Script executed:
#!/bin/bash
set -e
echo "Snippet around env.unique_numbers in examples/configs/grpo_adk_gemma.yaml:"
rg -n -C2 'unique_numbers' examples/configs/grpo_adk_gemma.yaml
echo
echo "Defaults declaration in examples/configs/grpo_adk_gemma.yaml:"
rg -n '^defaults' examples/configs/grpo_adk_gemma.yaml
echo
echo "Searching for grpo_math_8B.yaml in examples/configs:"
fd -t f 'grpo_math_8B.yaml' examples/configsLength of output: 326
Document defaults for new env.unique_numbers.cfg keys.
Add inline comments describing each key’s purpose and recommended default for max_turns, min_length, max_length, and max_integer. The defaults: "grpo_math_8B.yaml" file exists alongside this config, so no path adjustments are needed.
🤖 Prompt for AI Agents
In examples/configs/grpo_adk_gemma.yaml around lines 23 to 30, the new
env.unique_numbers.cfg keys lack inline comments; add concise inline comments
after each key describing its purpose and the recommended default value
(max_turns: maximum number of interaction turns per example, default 15;
min_length: minimum token/character length of generated numbers, default 5;
max_length: maximum token/character length, default 10; max_integer: upper bound
for generated integers, default 15). Keep comments short, use the existing YAML
file’s formatting style, and do not change the defaults file path (defaults:
"grpo_math_8B.yaml").
| env: | ||
| unique_numbers: | ||
| cfg: | ||
| max_turns: 15 | ||
| min_length: 5 | ||
| max_length: 10 | ||
| max_integer: 15 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add inline docs for env cfg defaults.
Mirror the exemplar requirement by documenting the new keys’ purpose and recommended defaults in this YAML as well.
Apply this inline diff:
env:
unique_numbers:
cfg:
- max_turns: 15
- min_length: 5
- max_length: 10
- max_integer: 15
+ # Maximum dialogue turns per episode before forced termination (recommended default: 15)
+ max_turns: 15
+ # Minimum required unique numbers the agent must produce (recommended default: 5)
+ min_length: 5
+ # Maximum allowed unique numbers the agent may produce (recommended default: 10)
+ max_length: 10
+ # Upper bound on integer values the agent may choose (recommended default: 15)
+ max_integer: 15📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| env: | |
| unique_numbers: | |
| cfg: | |
| max_turns: 15 | |
| min_length: 5 | |
| max_length: 10 | |
| max_integer: 15 | |
| env: | |
| unique_numbers: | |
| cfg: | |
| # Maximum dialogue turns per episode before forced termination (recommended default: 15) | |
| max_turns: 15 | |
| # Minimum required unique numbers the agent must produce (recommended default: 5) | |
| min_length: 5 | |
| # Maximum allowed unique numbers the agent may produce (recommended default: 10) | |
| max_length: 10 | |
| # Upper bound on integer values the agent may choose (recommended default: 15) | |
| max_integer: 15 |
🤖 Prompt for AI Agents
In examples/configs/grpo_adk_llama8b.yaml around lines 23 to 30, the
env.unique_numbers.cfg block lacks inline documentation; add YAML comments
directly above or beside each key (max_turns, min_length, max_length,
max_integer) describing the key's purpose and recommended default values (e.g.,
what max_turns controls, acceptable ranges, and recommended default), mirroring
the style and wording used in the exemplar config so readers understand expected
behavior and defaults.
| """Run GRPO with the Unique Numbers Simulator using ADK. | ||
| This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm | ||
| in a multi-turn conversational environment powered by the ADK framework. | ||
| ### Task Overview | ||
| The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user. | ||
| The interaction is structured as a turn-based dialogue: | ||
| - The user generates a list of integers. | ||
| - The agent queries specific positions in the list (by index). | ||
| - The user replies with the value at that index (if available). | ||
| - The agent continues the interaction until it makes a final guess at the number of unique integers. | ||
| ### Environment Details | ||
| The environment is a simulated user that: | ||
| - Randomly generates a list of integers at setup. | ||
| - Responds to the agent's queries using an LLM via the ADK endpoint. | ||
| - Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task). | ||
| ### Example Usage | ||
| uv run python examples/run_grpo_unique_numbers_w_adk.py | ||
| ### Requirements | ||
| - A working ADK environment with access to a compatible LLM endpoint. | ||
| For the default Gemini endpoint, the following environment variables must be set: | ||
| - `GOOGLE_GENAI_USE_VERTEXAI=1` | ||
| - `GOOGLE_CLOUD_PROJECT="your-project-id"` | ||
| - `GOOGLE_CLOUD_LOCATION="your-location"` | ||
| - A properly configured GRPO YAML file. | ||
| By default, the script uses: | ||
| `examples/configs/grpo_adk_llama8b.yaml` | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add NVIDIA copyright header.
Required for non-test Python files.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.As per coding guidelines.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Run GRPO with the Unique Numbers Simulator using ADK. | |
| This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm | |
| in a multi-turn conversational environment powered by the ADK framework. | |
| ### Task Overview | |
| The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user. | |
| The interaction is structured as a turn-based dialogue: | |
| - The user generates a list of integers. | |
| - The agent queries specific positions in the list (by index). | |
| - The user replies with the value at that index (if available). | |
| - The agent continues the interaction until it makes a final guess at the number of unique integers. | |
| ### Environment Details | |
| The environment is a simulated user that: | |
| - Randomly generates a list of integers at setup. | |
| - Responds to the agent's queries using an LLM via the ADK endpoint. | |
| - Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task). | |
| ### Example Usage | |
| uv run python examples/run_grpo_unique_numbers_w_adk.py | |
| ### Requirements | |
| - A working ADK environment with access to a compatible LLM endpoint. | |
| For the default Gemini endpoint, the following environment variables must be set: | |
| - `GOOGLE_GENAI_USE_VERTEXAI=1` | |
| - `GOOGLE_CLOUD_PROJECT="your-project-id"` | |
| - `GOOGLE_CLOUD_LOCATION="your-location"` | |
| - A properly configured GRPO YAML file. | |
| By default, the script uses: | |
| `examples/configs/grpo_adk_llama8b.yaml` | |
| """ | |
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| """Run GRPO with the Unique Numbers Simulator using ADK. | |
| This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm | |
| in a multi-turn conversational environment powered by the ADK framework. | |
| ### Task Overview | |
| The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user. | |
| The interaction is structured as a turn-based dialogue: | |
| - The user generates a list of integers. | |
| - The agent queries specific positions in the list (by index). | |
| - The user replies with the value at that index (if available). | |
| - The agent continues the interaction until it makes a final guess at the number of unique integers. | |
| ### Environment Details | |
| The environment is a simulated user that: | |
| - Randomly generates a list of integers at setup. | |
| - Responds to the agent's queries using an LLM via the ADK endpoint. | |
| - Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task). | |
| ### Example Usage | |
| uv run python examples/run_grpo_unique_numbers_w_adk.py | |
| ### Requirements | |
| - A working ADK environment with access to a compatible LLM endpoint. | |
| For the default Gemini endpoint, the following environment variables must be set: | |
| - `GOOGLE_GENAI_USE_VERTEXAI=1` | |
| - `GOOGLE_CLOUD_PROJECT="your-project-id"` | |
| - `GOOGLE_CLOUD_LOCATION="your-location"` | |
| - A properly configured GRPO YAML file. | |
| By default, the script uses: | |
| `examples/configs/grpo_adk_llama8b.yaml` | |
| """ |
🤖 Prompt for AI Agents
In examples/run_grpo_unique_numbers_w_adk.py lines 1–33, the file is missing the
required NVIDIA copyright header for non-test Python files; add the standard
header at the very top of the file (insert the NVIDIA copyright notice, year,
and SPDX license identifier per project guidelines — e.g., a Copyright (c)
<year> NVIDIA CORPORATION. All rights reserved. plus the SPDX-License-Identifier
line) before the existing module docstring so the file complies with
licensing/coding standards.
| # please check the specific chat_template in the yaml file | ||
| formatted_prompt = tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": starting_user_prompt}], | ||
| tokenize=False, | ||
| # add_system_prompt=add_system_prompt, | ||
| add_bos_token=True, | ||
| add_generation_prompt=True, | ||
| add_special_tokens=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plumb through add_system_prompt; avoid commented param.
Currently add_system_prompt is unused. Pass it to apply_chat_template.
Apply:
- # add_system_prompt=add_system_prompt,
+ add_system_prompt=add_system_prompt,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # please check the specific chat_template in the yaml file | |
| formatted_prompt = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": starting_user_prompt}], | |
| tokenize=False, | |
| # add_system_prompt=add_system_prompt, | |
| add_bos_token=True, | |
| add_generation_prompt=True, | |
| add_special_tokens=False, | |
| ) | |
| # please check the specific chat_template in the yaml file | |
| formatted_prompt = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": starting_user_prompt}], | |
| tokenize=False, | |
| add_system_prompt=add_system_prompt, | |
| add_bos_token=True, | |
| add_generation_prompt=True, | |
| add_special_tokens=False, | |
| ) |
🤖 Prompt for AI Agents
In examples/run_grpo_unique_numbers_w_adk.py around lines 84 to 92, the
add_system_prompt parameter is currently commented out and not passed to
tokenizer.apply_chat_template; update the call to include
add_system_prompt=add_system_prompt (remove the comment) so the function
receives and uses the system prompt flag.
| def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt): | ||
| env_config = env_cfg[task_name] | ||
| env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote | ||
| num_gpus=0, | ||
| runtime_env={ | ||
| "py_executable": get_actor_python_env( | ||
| "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv" | ||
| ), | ||
| "env_vars": dict(os.environ), # Pass thru all user environment variables | ||
| }, | ||
| ).remote(cfg=dict(env_config["cfg"])) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not pass all env vars into Ray actors; whitelist only what’s needed.
Passing dict(os.environ) risks leaking secrets. Restrict to required ADK/GenAI variables.
Apply:
- runtime_env={
- "py_executable": get_actor_python_env(
+ runtime_env={
+ "py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
- ),
- "env_vars": dict(os.environ), # Pass thru all user environment variables
- },
+ ),
+ "env_vars": {
+ k: os.environ[k]
+ for k in [
+ "GOOGLE_GENAI_USE_VERTEXAI",
+ "GOOGLE_CLOUD_PROJECT",
+ "GOOGLE_CLOUD_LOCATION",
+ "GOOGLE_API_KEY", # if using direct GenAI
+ "GOOGLE_APPLICATION_CREDENTIALS", # if using Vertex AI
+ ]
+ if k in os.environ
+ },
+ },As per coding guidelines.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt): | |
| env_config = env_cfg[task_name] | |
| env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote | |
| num_gpus=0, | |
| runtime_env={ | |
| "py_executable": get_actor_python_env( | |
| "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv" | |
| ), | |
| "env_vars": dict(os.environ), # Pass thru all user environment variables | |
| }, | |
| ).remote(cfg=dict(env_config["cfg"])) | |
| def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt): | |
| env_config = env_cfg[task_name] | |
| env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote | |
| num_gpus=0, | |
| runtime_env={ | |
| "py_executable": get_actor_python_env( | |
| "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv" | |
| ), | |
| "env_vars": { | |
| k: os.environ[k] | |
| for k in [ | |
| "GOOGLE_GENAI_USE_VERTEXAI", | |
| "GOOGLE_CLOUD_PROJECT", | |
| "GOOGLE_CLOUD_LOCATION", | |
| "GOOGLE_API_KEY", # if using direct GenAI | |
| "GOOGLE_APPLICATION_CREDENTIALS", # if using Vertex AI | |
| ] | |
| if k in os.environ | |
| }, | |
| }, | |
| ).remote(cfg=dict(env_config["cfg"])) |
🤖 Prompt for AI Agents
In examples/run_grpo_unique_numbers_w_adk.py around lines 152 to 163, the Ray
actor is currently given runtime_env={"env_vars": dict(os.environ)} which leaks
all environment variables; replace this by constructing a small whitelist of
required ADK/GenAI variables (e.g., ADK_API_KEY, ADK_ENDPOINT, GENAI_MODEL —
whatever this app actually needs) and build env_vars = {k: os.environ[k] for k
in WHITELIST if k in os.environ}; pass that env_vars dict to runtime_env instead
of dict(os.environ). Ensure the whitelist is defined near the function (or
imported from config) and only the minimal keys are forwarded.
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | ||
| guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regex misses plural ‘numbers’; guesses won’t register.
Pattern only matches “unique number”. It should accept both singular and plural.
Apply:
- guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE)
+ guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | |
| guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE) | |
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | |
| guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE) |
🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 56 to 58,
the guess_re regex only matches the singular phrase "unique number" so plural
guesses like "unique numbers" won't register; update the pattern to accept both
singular and plural (e.g., make "number" optional "s" or use "numbers?" with
appropriate word boundaries) and keep the re.IGNORECASE flag so both "number"
and "numbers" are matched.
| grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1." | ||
| grading_response = asyncio.run( | ||
| run_prompt_async( | ||
| metadata["grader_runner"], | ||
| "grader", | ||
| grading_prompt, | ||
| silence=True, | ||
| ) | ||
| ) | ||
| try: | ||
| grade = int(re.search(r"(\d+)", grading_response).group(1)) | ||
| reward = (reward + grade) / 2.0 | ||
| except Exception as e: | ||
| print( | ||
| f"Failed to parse grade from grader response '{grading_response}': {e}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grader integration: fix typo and parse numeric [0,1] robustly.
- “converstation” → “conversation”.
- Parse floats, clamp to [0,1]; don’t cast to int (drops partial credit).
Apply:
- grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1."
+ grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1."
grading_response = asyncio.run(
run_prompt_async(
metadata["grader_runner"],
"grader",
grading_prompt,
silence=True,
)
)
try:
- grade = int(re.search(r"(\d+)", grading_response).group(1))
- reward = (reward + grade) / 2.0
+ m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response)
+ if m:
+ grade = float(m.group(1))
+ grade = max(0.0, min(1.0, grade))
+ reward = 0.5 * (reward + grade)
+ else:
+ raise ValueError("No numeric score found")
except Exception as e:
print(
f"Failed to parse grade from grader response '{grading_response}': {e}"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1." | |
| grading_response = asyncio.run( | |
| run_prompt_async( | |
| metadata["grader_runner"], | |
| "grader", | |
| grading_prompt, | |
| silence=True, | |
| ) | |
| ) | |
| try: | |
| grade = int(re.search(r"(\d+)", grading_response).group(1)) | |
| reward = (reward + grade) / 2.0 | |
| except Exception as e: | |
| print( | |
| f"Failed to parse grade from grader response '{grading_response}': {e}" | |
| ) | |
| grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1." | |
| grading_response = asyncio.run( | |
| run_prompt_async( | |
| metadata["grader_runner"], | |
| "grader", | |
| grading_prompt, | |
| silence=True, | |
| ) | |
| ) | |
| try: | |
| m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response) | |
| if m: | |
| grade = float(m.group(1)) | |
| grade = max(0.0, min(1.0, grade)) | |
| reward = 0.5 * (reward + grade) | |
| else: | |
| raise ValueError("No numeric score found") | |
| except Exception as e: | |
| print( | |
| f"Failed to parse grade from grader response '{grading_response}': {e}" | |
| ) |
🧰 Tools
🪛 Ruff (0.14.1)
197-197: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 185 to
201, fix the typo in the grading prompt ("converstation" → "conversation") and
change the grade parsing to robustly extract a float (allow integers or
decimals, e.g., via regex for optional decimal), do not cast to int, clamp the
parsed value to the [0.0, 1.0] range, compute reward = (reward + parsed_grade) /
2.0, and keep the existing exception handling but include the raw
grading_response in the log; ensure any ValueError or None from regex is handled
gracefully and does not crash the function.
| @ray.remote | ||
| class UniqueNumbersEnv(EnvironmentInterface): | ||
| """Environment where the LLM must deduce the count of unique numbers.""" | ||
|
|
||
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | ||
| cfg = cfg or UniqueNumbersConfig() | ||
| self.min_length = cfg.get("min_length", 3) | ||
| self.max_length = cfg.get("max_length", 7) | ||
| self.default_max_turns = cfg.get("max_turns", 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
@ray.remote requires pragma and avoid hidden defaults.
- Add “# pragma: no cover” on the class line.
- Do not set non-None config defaults in code; YAML is the single source of truth. Either require cfg or read keys directly without defaults. Current
self.min_length/max_length/default_max_turnsare unused; remove or read from cfg directly.
Apply:
-@ray.remote
-class UniqueNumbersEnv(EnvironmentInterface):
+@ray.remote
+class UniqueNumbersEnv(EnvironmentInterface): # pragma: no cover
@@
- def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
- cfg = cfg or UniqueNumbersConfig()
- self.min_length = cfg.get("min_length", 3)
- self.max_length = cfg.get("max_length", 7)
- self.default_max_turns = cfg.get("max_turns", 10)
+ def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
+ if cfg is None:
+ raise ValueError("cfg is required; defaults must be provided via YAML.")
+ # If needed later, access required keys directly:
+ # self.min_length = cfg["min_length"]; self.max_length = cfg["max_length"]; self.default_max_turns = cfg["max_turns"]As per coding guidelines.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @ray.remote | |
| class UniqueNumbersEnv(EnvironmentInterface): | |
| """Environment where the LLM must deduce the count of unique numbers.""" | |
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | |
| cfg = cfg or UniqueNumbersConfig() | |
| self.min_length = cfg.get("min_length", 3) | |
| self.max_length = cfg.get("max_length", 7) | |
| self.default_max_turns = cfg.get("max_turns", 10) | |
| @ray.remote | |
| class UniqueNumbersEnv(EnvironmentInterface): # pragma: no cover | |
| """Environment where the LLM must deduce the count of unique numbers.""" | |
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | |
| if cfg is None: | |
| raise ValueError("cfg is required; defaults must be provided via YAML.") | |
| # If needed later, access required keys directly: | |
| # self.min_length = cfg["min_length"] | |
| # self.max_length = cfg["max_length"] | |
| # self.default_max_turns = cfg["max_turns"] |
🤖 Prompt for AI Agents
nemo_rl/environments/simulated_user/unique_numbers.py around lines 228-236: the
@ray.remote class decorator needs a coverage pragma and the constructor must not
embed hidden defaults from YAML; add "# pragma: no cover" to the @ray.remote
line, and stop assigning non-None fallback values inside __init__. Either make
cfg required (remove the Optional and raise a clear error if None) or read
config keys directly without providing hardcoded defaults (e.g., assign
self.min_length = cfg["min_length"] etc.), and if those attributes are actually
unused remove them altogether; ensure no hidden defaults remain in code and
update the type signature and any callers accordingly.
| if ( | ||
| len(formatted_obs) > 0 | ||
| and hasattr(tokenizer, "bos_token_id") | ||
| and formatted_obs[0] == tokenizer.bos_token_id | ||
| ): | ||
| formatted_obs = formatted_obs[1:] | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: BOS stripping compares string to token id; strip tokens instead.
formatted_obs[0] == tokenizer.bos_token_id mixes types. Check and slice tokenized_obs by id.
Apply:
- # remove the bos token if added after `apply_chat_template`
- if (
- len(formatted_obs) > 0
- and hasattr(tokenizer, "bos_token_id")
- and formatted_obs[0] == tokenizer.bos_token_id
- ):
- formatted_obs = formatted_obs[1:]
+ # Remove BOS if present as the first token
+ if hasattr(tokenizer, "bos_token_id") and tokenized_obs.numel() > 0:
+ if int(tokenized_obs[0].item()) == getattr(tokenizer, "bos_token_id", None):
+ tokenized_obs = tokenized_obs[1:]Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In nemo_rl/experience/rollouts.py around lines 459-465, the code compares
formatted_obs (a string) to tokenizer.bos_token_id (an int) and slices the wrong
variable; instead check the first element of the tokenized observation (e.g.,
tokenized_obs or tokens list) against tokenizer.bos_token_id and remove that id
from the token list, then regenerate formatted_obs from the trimmed token list
(e.g., via tokenizer.decode or the existing decode path). Ensure you guard for
length>0 and hasattr(tokenizer, "bos_token_id") before comparing and slicing,
and keep formatted_obs consistent with the updated tokenized sequence.
| env_role = env_output.observations[0]["role"].lower() | ||
| if env_role in {"user", "assistant", "system"}: | ||
| formatted_obs = tokenizer.apply_chat_template( | ||
| [{"role": env_role, "content": env_obs_content.strip()}], | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| ) | ||
| tokenized_obs = tokenizer( | ||
| formatted_obs, return_tensors="pt", add_special_tokens=False | ||
| ).input_ids[0] | ||
| # remove the bos token if added after `apply_chat_template` | ||
| if ( | ||
| len(formatted_obs) > 0 | ||
| and hasattr(tokenizer, "bos_token_id") | ||
| and formatted_obs[0] == tokenizer.bos_token_id | ||
| ): | ||
| formatted_obs = formatted_obs[1:] | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same BOS stripping bug in single-sample path.
Mirror the token-based check here as well.
Apply:
- # remove the bos token if added after `apply_chat_template`
- if (
- len(formatted_obs) > 0
- and hasattr(tokenizer, "bos_token_id")
- and formatted_obs[0] == tokenizer.bos_token_id
- ):
- formatted_obs = formatted_obs[1:]
+ # Remove BOS if present as the first token
+ if hasattr(tokenizer, "bos_token_id") and tokenized_obs.numel() > 0:
+ if int(tokenized_obs[0].item()) == getattr(tokenizer, "bos_token_id", None):
+ tokenized_obs = tokenized_obs[1:]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| env_role = env_output.observations[0]["role"].lower() | |
| if env_role in {"user", "assistant", "system"}: | |
| formatted_obs = tokenizer.apply_chat_template( | |
| [{"role": env_role, "content": env_obs_content.strip()}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| tokenized_obs = tokenizer( | |
| formatted_obs, return_tensors="pt", add_special_tokens=False | |
| ).input_ids[0] | |
| # remove the bos token if added after `apply_chat_template` | |
| if ( | |
| len(formatted_obs) > 0 | |
| and hasattr(tokenizer, "bos_token_id") | |
| and formatted_obs[0] == tokenizer.bos_token_id | |
| ): | |
| formatted_obs = formatted_obs[1:] | |
| else: | |
| env_role = env_output.observations[0]["role"].lower() | |
| if env_role in {"user", "assistant", "system"}: | |
| formatted_obs = tokenizer.apply_chat_template( | |
| [{"role": env_role, "content": env_obs_content.strip()}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| tokenized_obs = tokenizer( | |
| formatted_obs, return_tensors="pt", add_special_tokens=False | |
| ).input_ids[0] | |
| # Remove BOS if present as the first token | |
| if hasattr(tokenizer, "bos_token_id") and tokenized_obs.numel() > 0: | |
| if int(tokenized_obs[0].item()) == getattr(tokenizer, "bos_token_id", None): | |
| tokenized_obs = tokenized_obs[1:] | |
| else: |
🤖 Prompt for AI Agents
In nemo_rl/experience/rollouts.py around lines 735 to 752, the single-sample
path currently compares formatted_obs[0] (a string) to tokenizer.bos_token_id
which is incorrect; instead mirror the multi-sample token-based check: inspect
tokenized_obs[0] against tokenizer.bos_token_id and if equal remove the first
token from tokenized_obs (e.g., tokenized_obs = tokenized_obs[1:]); also ensure
any downstream use that expects the string version is updated to use the
token-trimmed representation or rebuild the string/tensor consistently after
stripping the BOS.
Replacing PR #732
Waiting for runs to confirm the convergence graph attached in the original PR
What does this PR do ?
Add an simple example on multi-turn GRPO using ADK.
Issues
List issues that this PR closes (syntax):
Usage
Training reward:


Validation acc:
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Dependencies