Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enables the Generative Reinforcement Learning with Policy Optimization (GRPO) framework to function correctly with Olmo-Core. It introduces a dedicated local debug script to streamline development and addresses several critical issues related to Ray initialization, data tensor dimensionality, and model configuration within the training modules. These changes ensure a more robust and debuggable GRPO training pipeline, particularly when working with varying batch sizes and model architectures. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces several changes to improve the GRPO implementation with Olmo-Core, including a debug script, Ray initialization enhancements, and various fixes in data loading and log probability calculations. The changes aim to simplify debugging and ensure correct execution. The review focuses on correctness and potential improvements in code clarity and efficiency.
| dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) | ||
| dummy_response_mask = torch.zeros_like(dummy_qr) | ||
| dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) | ||
| dummy_response_mask = torch.tensor([0, 1], dtype=torch.long) |
There was a problem hiding this comment.
The dummy_response_mask is hardcoded. It would be better to dynamically generate this based on the dummy_qr tensor to avoid potential mismatches in length or content. This could lead to unexpected behavior if the dummy_qr tensor changes.
| dummy_response_mask = torch.tensor([0, 1], dtype=torch.long) | |
| dummy_response_mask = torch.zeros_like(dummy_qr, dtype=torch.long) | |
| dummy_response_mask[-1] = 1 # Ensure the last token is part of the response |
| if len(set(shapes)) != 1: | ||
| for i in batch_indices: | ||
| single_query_responses = data_BT.query_responses[i] | ||
| single_attention_mask = data_BT.attention_masks[i] | ||
| single_position_ids = data_BT.position_ids[i] | ||
| if single_query_responses.ndim == 1: | ||
| single_query_responses = single_query_responses.unsqueeze(0) | ||
| single_attention_mask = single_attention_mask.unsqueeze(0) | ||
| single_position_ids = single_position_ids.unsqueeze(0) | ||
|
|
||
| single_logprobs, _ = forward_for_logprobs( | ||
| model, | ||
| single_query_responses, | ||
| single_attention_mask, | ||
| single_position_ids, | ||
| pad_token_id, | ||
| temperature, | ||
| False, | ||
| ) | ||
|
|
||
| response_mask_BT = data_BT.response_masks[i] | ||
| if response_mask_BT.ndim == 1: | ||
| response_mask_BT = response_mask_BT.unsqueeze(0) | ||
| response_mask_BT = response_mask_BT.to(single_logprobs.device) | ||
| single_logprobs = torch.masked_fill( | ||
| single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB | ||
| ) | ||
| logprobs_BT.append(single_logprobs) | ||
| torch.cuda.empty_cache() | ||
| continue |
There was a problem hiding this comment.
The code handles cases where the shapes of query_responses are not uniform within a batch. However, the torch.cuda.empty_cache() call inside the loop might be inefficient. It would be better to move this call outside the loop to reduce overhead, or even better, rely on PyTorch's memory management to handle the caching.
if len(set(shapes)) != 1:
for i in batch_indices:
single_query_responses = data_BT.query_responses[i]
single_attention_mask = data_BT.attention_masks[i]
single_position_ids = data_BT.position_ids[i]
if single_query_responses.ndim == 1:
single_query_responses = single_query_responses.unsqueeze(0)
single_attention_mask = single_attention_mask.unsqueeze(0)
single_position_ids = single_position_ids.unsqueeze(0)
single_logprobs, _ = forward_for_logprobs(
model,
single_query_responses,
single_attention_mask,
single_position_ids,
pad_token_id,
temperature,
False,
)
response_mask_BT = data_BT.response_masks[i]
if response_mask_BT.ndim == 1:
response_mask_BT = response_mask_BT.unsqueeze(0)
response_mask_BT = response_mask_BT.to(single_logprobs.device)
single_logprobs = torch.masked_fill(
single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
)
logprobs_BT.append(single_logprobs)
# torch.cuda.empty_cache() # Move outside the loop
continue| sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices] | ||
| split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0) | ||
|
|
||
| for i, logprob_BT in zip(batch_indices, split_logprobs): | ||
| response_mask_BT = data_BT.response_masks[i].to(logprob_BT.device) | ||
| response_mask_BT = data_BT.response_masks[i] | ||
| if response_mask_BT.ndim == 1: | ||
| response_mask_BT = response_mask_BT.unsqueeze(0) | ||
| response_mask_BT = response_mask_BT.to(logprob_BT.device) | ||
| logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB) | ||
| logprobs_BT.append(logprob_BT) |
There was a problem hiding this comment.
The code adds a dimension to mask if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in compute_logprobs). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.
| sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices] | |
| split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0) | |
| for i, logprob_BT in zip(batch_indices, split_logprobs): | |
| response_mask_BT = data_BT.response_masks[i].to(logprob_BT.device) | |
| response_mask_BT = data_BT.response_masks[i] | |
| if response_mask_BT.ndim == 1: | |
| response_mask_BT = response_mask_BT.unsqueeze(0) | |
| response_mask_BT = response_mask_BT.to(logprob_BT.device) | |
| logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB) | |
| logprobs_BT.append(logprob_BT) | |
| sample_sizes = [1 if data_BT.query_responses[i].ndim == 1 else data_BT.query_responses[i].shape[0] for i in batch_indices] | |
| split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0) | |
| for i, logprob_BT in zip(batch_indices, split_logprobs): | |
| response_mask_BT = data_BT.response_masks[i] | |
| response_mask_BT = response_mask_BT.unsqueeze(0) if response_mask_BT.ndim == 1 else response_mask_BT # Added line | |
| response_mask_BT = response_mask_BT.to(logprob_BT.device) |
| ( | ||
| data_BT.response_masks[i].unsqueeze(0) | ||
| if data_BT.response_masks[i].ndim == 1 | ||
| else data_BT.response_masks[i] | ||
| )[:, 1:].sum().float() | ||
| for i in range(num_samples) |
There was a problem hiding this comment.
The code adds a dimension to data_BT.response_masks[i] if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in grpo_utils.py). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.
| ( | |
| data_BT.response_masks[i].unsqueeze(0) | |
| if data_BT.response_masks[i].ndim == 1 | |
| else data_BT.response_masks[i] | |
| )[:, 1:].sum().float() | |
| for i in range(num_samples) | |
| ( | |
| data_BT.response_masks[i].unsqueeze(0) | |
| if data_BT.response_masks[i].ndim == 1 | |
| else data_BT.response_masks[i] | |
| )[:, 1:].sum().float() |
| advantages = data_BT.advantages[sample_idx] | ||
| if advantages.ndim == 1: | ||
| advantages = advantages.unsqueeze(0) |
There was a problem hiding this comment.
The code adds a dimension to advantages if it's 1D. This is good for handling different input shapes. However, it's repeated in multiple places (here and in grpo_utils.py). Consider creating a utility function to encapsulate this logic to avoid code duplication and improve maintainability.
| advantages = data_BT.advantages[sample_idx] | |
| if advantages.ndim == 1: | |
| advantages = advantages.unsqueeze(0) | |
| advantages = data_BT.advantages[sample_idx] | |
| if advantages.ndim == 1: | |
| advantages = advantages.unsqueeze(0) # Added line | |
| advantages = advantages.to(new_logprobs.device) |
| # source configs/beaker_configs/ray_node_setup.sh && \ | ||
| python open_instruct/grpo.py \ | ||
| --dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \ | ||
| --dataset_mixer_list_splits train \ | ||
| --dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \ | ||
| --dataset_mixer_eval_list_splits train \ | ||
| --max_prompt_token_length 512 \ | ||
| --response_length 512 \ | ||
| --pack_length 1024 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --num_unique_prompts_rollout 8 \ | ||
| --num_samples_per_prompt_rollout 4 \ | ||
| --model_name_or_path Qwen/Qwen3-0.6B \ | ||
| --apply_verifiable_reward true \ | ||
| --learning_rate 1e-6 \ | ||
| --total_episodes 200 \ | ||
| --deepspeed_stage 2 \ | ||
| --num_epochs 1 \ | ||
| --num_learners_per_node 1 \ | ||
| --vllm_tensor_parallel_size 1 \ | ||
| --local_eval_every 1 \ | ||
| --vllm_sync_backend gloo \ | ||
| --vllm_gpu_memory_utilization 0.3 \ | ||
| --vllm_enforce_eager \ | ||
| --gradient_checkpointing \ | ||
| --single_gpu_mode \ | ||
| --push_to_hub false $@ | ||
| # --system_prompt_override_file |
There was a problem hiding this comment.
The script uses hardcoded paths and parameters. Consider using environment variables or command-line arguments for greater flexibility and portability. This would allow users to easily modify the script without directly editing the file.
python open_instruct/grpo.py \
--dataset_mixer_list "${DATASET_MIXER_LIST:-ai2-adapt-dev/rlvr_gsm8k_zs 64}" \
--dataset_mixer_list_splits "${DATASET_MIXER_LIST_SPLITS:-train}" \
--dataset_mixer_eval_list "${DATASET_MIXER_EVAL_LIST:-ai2-adapt-dev/rlvr_gsm8k_zs 16}" \
--dataset_mixer_eval_list_splits "${DATASET_MIXER_EVAL_LIST_SPLITS:-train}" \
--max_prompt_token_length "${MAX_PROMPT_TOKEN_LENGTH:-512}" \
--response_length "${RESPONSE_LENGTH:-512}" \
--pack_length "${PACK_LENGTH:-1024}" \
--per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-1}" \
--num_unique_prompts_rollout "${NUM_UNIQUE_PROMPTS_ROLLOUT:-8}" \
--num_samples_per_prompt_rollout "${NUM_SAMPLES_PER_PROMPT_ROLLOUT:-4}" \
--model_name_or_path "${MODEL_NAME_OR_PATH:-Qwen/Qwen3-0.6B}" \
--apply_verifiable_reward "${APPLY_VERIFIABLE_REWARD:-true}" \
--learning_rate "${LEARNING_RATE:-1e-6}" \
--total_episodes "${TOTAL_EPISODES:-200}" \
--deepspeed_stage "${DEEPSPEED_STAGE:-2}" \
--num_epochs "${NUM_EPOCHS:-1}" \
--num_learners_per_node "${NUM_LEARNERS_PER_NODE:-1}" \
--vllm_tensor_parallel_size "${VLLM_TENSOR_PARALLEL_SIZE:-1}" \
--local_eval_every "${LOCAL_EVAL_EVERY:-1}" \
--vllm_sync_backend "${VLLM_SYNC_BACKEND:-gloo}" \
--vllm_gpu_memory_utilization "${VLLM_GPU_MEMORY_UTILIZATION:-0.3}" \
--vllm_enforce_eager "${VLLM_ENFORCE_EAGER:-true}" \
--gradient_checkpointing \
--single_gpu_mode \
--push_to_hub false "$@"* Fix collate_fn to always return 2D tensors, removing ~12 downstream ndim guards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Use torch.atleast_2d instead of manual ndim check in collate_fn Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| """ | ||
| accumulation_counts: dict[int, float] = {} | ||
| local_counts = [mask[:, 1:].sum().float() for mask in data_BT.response_masks] | ||
| local_counts = [] |
There was a problem hiding this comment.
Can we keep this as a comprehension?
|
|
||
| for _epoch_idx in range(self.grpo_config.num_epochs): | ||
| for sample_idx in range(num_samples): | ||
| query_responses = data_BT.query_responses[sample_idx] |
There was a problem hiding this comment.
Let's keep these inlined!
* grpo debug working * Fix collate_fn to always return 2D tensors, remove ndim guards (#1545) * Fix collate_fn to always return 2D tensors, removing ~12 downstream ndim guards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Use torch.atleast_2d instead of manual ndim check in collate_fn Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * initial script for gsm8k 3 gpu * inline and 0.6B base debug scripts * inline and make grpo and grpo_fast debugs the same * deleted extra script and added changelog --------- Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
Added a local debug script
grpo.shthat avoids usingray_node_setup.shfor simplicity and fixed up all the issues so it runs.Still some confusing issues to figure out