Skip to content

GRPO with Olmo-Core working#1543

Merged
mnoukhov merged 7 commits intomainfrom
grpo-workin
Mar 23, 2026
Merged

GRPO with Olmo-Core working#1543
mnoukhov merged 7 commits intomainfrom
grpo-workin

Conversation

@mnoukhov
Copy link
Copy Markdown
Contributor

Added a local debug script grpo.sh that avoids using ray_node_setup.sh for simplicity and fixed up all the issues so it runs.

Still some confusing issues to figure out

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables the Generative Reinforcement Learning with Policy Optimization (GRPO) framework to function correctly with Olmo-Core. It introduces a dedicated local debug script to streamline development and addresses several critical issues related to Ray initialization, data tensor dimensionality, and model configuration within the training modules. These changes ensure a more robust and debuggable GRPO training pipeline, particularly when working with varying batch sizes and model architectures.

Highlights

  • Local Debug Script for GRPO: A new grpo.sh script was added to simplify local debugging of GRPO with Olmo-Core, bypassing the ray_node_setup.sh script and configuring necessary environment variables and arguments.
  • Ray Initialization Flexibility: The ray.init call in grpo.py was updated to dynamically determine the Ray address, allowing for more flexible deployment scenarios.
  • Data Handling and Tensor Shape Consistency: Multiple fixes were implemented in grpo_utils.py and olmo_core_train_modules.py to correctly handle tensor shapes, particularly for single-sample batches (ndim=1), ensuring proper logprob computation and mask application.
  • Refactored GRPO Train Module: The GRPOTrainModule in olmo_core_train_modules.py was refactored to use sample_microbatch_size instead of rank_microbatch_size and includes a pre_train method to bypass default batch validation, aligning with GRPO's specific batching logic.
  • Robust Model Saving: A fallback mechanism was added to olmo_core_utils.py for saving Hugging Face models, allowing the raw state dictionary to be saved if the standard HF export is not supported for a given OLMo-core model.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several changes to improve the GRPO implementation with Olmo-Core, including a debug script, Ray initialization enhancements, and various fixes in data loading and log probability calculations. The changes aim to simplify debugging and ensure correct execution. The review focuses on correctness and potential improvements in code clarity and efficiency.

Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/olmo_core_train_modules.py
Comment thread open_instruct/data_loader.py Outdated
dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long)
dummy_response_mask = torch.zeros_like(dummy_qr)
dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float)
dummy_response_mask = torch.tensor([0, 1], dtype=torch.long)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dummy_response_mask is hardcoded. It would be better to dynamically generate this based on the dummy_qr tensor to avoid potential mismatches in length or content. This could lead to unexpected behavior if the dummy_qr tensor changes.

Suggested change
dummy_response_mask = torch.tensor([0, 1], dtype=torch.long)
dummy_response_mask = torch.zeros_like(dummy_qr, dtype=torch.long)
dummy_response_mask[-1] = 1 # Ensure the last token is part of the response

Comment thread open_instruct/grpo.py
Comment on lines +343 to +372
if len(set(shapes)) != 1:
for i in batch_indices:
single_query_responses = data_BT.query_responses[i]
single_attention_mask = data_BT.attention_masks[i]
single_position_ids = data_BT.position_ids[i]
if single_query_responses.ndim == 1:
single_query_responses = single_query_responses.unsqueeze(0)
single_attention_mask = single_attention_mask.unsqueeze(0)
single_position_ids = single_position_ids.unsqueeze(0)

single_logprobs, _ = forward_for_logprobs(
model,
single_query_responses,
single_attention_mask,
single_position_ids,
pad_token_id,
temperature,
False,
)

response_mask_BT = data_BT.response_masks[i]
if response_mask_BT.ndim == 1:
response_mask_BT = response_mask_BT.unsqueeze(0)
response_mask_BT = response_mask_BT.to(single_logprobs.device)
single_logprobs = torch.masked_fill(
single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
)
logprobs_BT.append(single_logprobs)
torch.cuda.empty_cache()
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code handles cases where the shapes of query_responses are not uniform within a batch. However, the torch.cuda.empty_cache() call inside the loop might be inefficient. It would be better to move this call outside the loop to reduce overhead, or even better, rely on PyTorch's memory management to handle the caching.

            if len(set(shapes)) != 1:
                for i in batch_indices:
                    single_query_responses = data_BT.query_responses[i]
                    single_attention_mask = data_BT.attention_masks[i]
                    single_position_ids = data_BT.position_ids[i]
                    if single_query_responses.ndim == 1:
                        single_query_responses = single_query_responses.unsqueeze(0)
                        single_attention_mask = single_attention_mask.unsqueeze(0)
                        single_position_ids = single_position_ids.unsqueeze(0)

                    single_logprobs, _ = forward_for_logprobs(
                        model,
                        single_query_responses,
                        single_attention_mask,
                        single_position_ids,
                        pad_token_id,
                        temperature,
                        False,
                    )

                    response_mask_BT = data_BT.response_masks[i]
                    if response_mask_BT.ndim == 1:
                        response_mask_BT = response_mask_BT.unsqueeze(0)
                    response_mask_BT = response_mask_BT.to(single_logprobs.device)
                    single_logprobs = torch.masked_fill(
                        single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
                    )
                    logprobs_BT.append(single_logprobs)
                # torch.cuda.empty_cache() # Move outside the loop
                continue

Comment thread open_instruct/grpo_utils.py Outdated
Comment on lines 393 to 402
sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices]
split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0)

for i, logprob_BT in zip(batch_indices, split_logprobs):
response_mask_BT = data_BT.response_masks[i].to(logprob_BT.device)
response_mask_BT = data_BT.response_masks[i]
if response_mask_BT.ndim == 1:
response_mask_BT = response_mask_BT.unsqueeze(0)
response_mask_BT = response_mask_BT.to(logprob_BT.device)
logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB)
logprobs_BT.append(logprob_BT)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code adds a dimension to mask if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in compute_logprobs). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.

Suggested change
sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices]
split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0)
for i, logprob_BT in zip(batch_indices, split_logprobs):
response_mask_BT = data_BT.response_masks[i].to(logprob_BT.device)
response_mask_BT = data_BT.response_masks[i]
if response_mask_BT.ndim == 1:
response_mask_BT = response_mask_BT.unsqueeze(0)
response_mask_BT = response_mask_BT.to(logprob_BT.device)
logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB)
logprobs_BT.append(logprob_BT)
sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices]
split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0)
for i, logprob_BT in zip(batch_indices, split_logprobs):
response_mask_BT = data_BT.response_masks[i]
response_mask_BT = response_mask_BT.unsqueeze(0) if response_mask_BT.ndim == 1 else response_mask_BT # Added line
response_mask_BT = response_mask_BT.to(logprob_BT.device)

Comment on lines +394 to +399
(
data_BT.response_masks[i].unsqueeze(0)
if data_BT.response_masks[i].ndim == 1
else data_BT.response_masks[i]
)[:, 1:].sum().float()
for i in range(num_samples)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code adds a dimension to data_BT.response_masks[i] if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in grpo_utils.py). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.

Suggested change
(
data_BT.response_masks[i].unsqueeze(0)
if data_BT.response_masks[i].ndim == 1
else data_BT.response_masks[i]
)[:, 1:].sum().float()
for i in range(num_samples)
(
data_BT.response_masks[i].unsqueeze(0)
if data_BT.response_masks[i].ndim == 1
else data_BT.response_masks[i]
)[:, 1:].sum().float()

Comment on lines +435 to +437
advantages = data_BT.advantages[sample_idx]
if advantages.ndim == 1:
advantages = advantages.unsqueeze(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code adds a dimension to advantages if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in grpo_utils.py). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.

Suggested change
advantages = data_BT.advantages[sample_idx]
if advantages.ndim == 1:
advantages = advantages.unsqueeze(0)
advantages = data_BT.advantages[sample_idx]
if advantages.ndim == 1:
advantages = advantages.unsqueeze(0) # Added line
advantages = advantages.to(new_logprobs.device)

Comment thread open_instruct/olmo_core_utils.py
Comment thread scripts/train/debug/grpo.sh Outdated
Comment on lines +9 to +36
# source configs/beaker_configs/ray_node_setup.sh && \
python open_instruct/grpo.py \
--dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \
--dataset_mixer_eval_list_splits train \
--max_prompt_token_length 512 \
--response_length 512 \
--pack_length 1024 \
--per_device_train_batch_size 1 \
--num_unique_prompts_rollout 8 \
--num_samples_per_prompt_rollout 4 \
--model_name_or_path Qwen/Qwen3-0.6B \
--apply_verifiable_reward true \
--learning_rate 1e-6 \
--total_episodes 200 \
--deepspeed_stage 2 \
--num_epochs 1 \
--num_learners_per_node 1 \
--vllm_tensor_parallel_size 1 \
--local_eval_every 1 \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.3 \
--vllm_enforce_eager \
--gradient_checkpointing \
--single_gpu_mode \
--push_to_hub false $@
# --system_prompt_override_file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script uses hardcoded paths and parameters. Consider using environment variables or command-line arguments for greater flexibility and portability. This would allow users to easily modify the script without directly editing the file.

python open_instruct/grpo.py \
    --dataset_mixer_list "${DATASET_MIXER_LIST:-ai2-adapt-dev/rlvr_gsm8k_zs 64}" \
    --dataset_mixer_list_splits "${DATASET_MIXER_LIST_SPLITS:-train}" \
    --dataset_mixer_eval_list "${DATASET_MIXER_EVAL_LIST:-ai2-adapt-dev/rlvr_gsm8k_zs 16}" \
    --dataset_mixer_eval_list_splits "${DATASET_MIXER_EVAL_LIST_SPLITS:-train}" \
    --max_prompt_token_length "${MAX_PROMPT_TOKEN_LENGTH:-512}" \
    --response_length "${RESPONSE_LENGTH:-512}" \
    --pack_length "${PACK_LENGTH:-1024}" \
    --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-1}" \
    --num_unique_prompts_rollout "${NUM_UNIQUE_PROMPTS_ROLLOUT:-8}" \
    --num_samples_per_prompt_rollout "${NUM_SAMPLES_PER_PROMPT_ROLLOUT:-4}" \
    --model_name_or_path "${MODEL_NAME_OR_PATH:-Qwen/Qwen3-0.6B}" \
    --apply_verifiable_reward "${APPLY_VERIFIABLE_REWARD:-true}" \
    --learning_rate "${LEARNING_RATE:-1e-6}" \
    --total_episodes "${TOTAL_EPISODES:-200}" \
    --deepspeed_stage "${DEEPSPEED_STAGE:-2}" \
    --num_epochs "${NUM_EPOCHS:-1}" \
    --num_learners_per_node "${NUM_LEARNERS_PER_NODE:-1}" \
    --vllm_tensor_parallel_size "${VLLM_TENSOR_PARALLEL_SIZE:-1}" \
    --local_eval_every "${LOCAL_EVAL_EVERY:-1}" \
    --vllm_sync_backend "${VLLM_SYNC_BACKEND:-gloo}" \
    --vllm_gpu_memory_utilization "${VLLM_GPU_MEMORY_UTILIZATION:-0.3}" \
    --vllm_enforce_eager "${VLLM_ENFORCE_EAGER:-true}" \
    --gradient_checkpointing \
    --single_gpu_mode \
    --push_to_hub false "$@"

* Fix collate_fn to always return 2D tensors, removing ~12 downstream ndim guards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Use torch.atleast_2d instead of manual ndim check in collate_fn Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread open_instruct/grpo_utils.py Outdated
"""
accumulation_counts: dict[int, float] = {}
local_counts = [mask[:, 1:].sum().float() for mask in data_BT.response_masks]
local_counts = []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep this as a comprehension?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


for _epoch_idx in range(self.grpo_config.num_epochs):
for sample_idx in range(num_samples):
query_responses = data_BT.query_responses[sample_idx]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep these inlined!

@mnoukhov mnoukhov marked this pull request as ready for review March 23, 2026 02:46
@mnoukhov mnoukhov enabled auto-merge March 23, 2026 03:04
@mnoukhov mnoukhov added this pull request to the merge queue Mar 23, 2026
Merged via the queue into main with commit e5c9d22 Mar 23, 2026
6 of 7 checks passed
@mnoukhov mnoukhov deleted the grpo-workin branch March 23, 2026 03:19
hamishivi pushed a commit that referenced this pull request Mar 30, 2026
* grpo debug working

* Fix collate_fn to always return 2D tensors, remove ndim guards (#1545)

* Fix collate_fn to always return 2D tensors, removing ~12 downstream ndim guards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Use torch.atleast_2d instead of manual ndim check in collate_fn Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* initial script for gsm8k 3 gpu

* inline and 0.6B base debug scripts

* inline and make grpo and grpo_fast debugs the same

* deleted extra script and added changelog

---------

Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants