-
Notifications
You must be signed in to change notification settings - Fork 539
GRPO with Olmo-Core working #1543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e501678
1dff29d
32c3399
464c6c7
9b2506b
555eaf6
a1cd28c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -300,7 +300,7 @@ def forward_for_logprobs( | |
| logits = logits / temperature | ||
| # The logits at position i predict token i+1, so we align them with labels shifted by 1 | ||
| logits = logits[:, :-1] | ||
| labels = query_responses[:, 1:].clone() | ||
| labels = query_responses[:, 1:].clone().to(logits.device) | ||
|
finbarrtimbers marked this conversation as resolved.
|
||
| # Replace pad tokens with 0 to avoid index out of bounds errors in gather | ||
| labels[labels == pad_token_id] = 0 | ||
| logprob_BT = model_utils.log_softmax_and_gather(logits, labels) | ||
|
|
@@ -335,9 +335,33 @@ def compute_logprobs( | |
| end_idx = min(start_idx + batch_size, num_samples) | ||
| batch_indices = list(range(start_idx, end_idx)) | ||
|
|
||
| batch_query_responses = torch.cat([data_BT.query_responses[i] for i in batch_indices], dim=0) | ||
| batch_attention_masks = torch.cat([data_BT.attention_masks[i] for i in batch_indices], dim=0) | ||
| batch_position_ids = torch.cat([data_BT.position_ids[i] for i in batch_indices], dim=0) | ||
| query_responses = [data_BT.query_responses[i] for i in batch_indices] | ||
| attention_masks = [data_BT.attention_masks[i] for i in batch_indices] | ||
| position_ids = [data_BT.position_ids[i] for i in batch_indices] | ||
| shapes = [tuple(t.shape) for t in query_responses] | ||
|
|
||
| if len(set(shapes)) != 1: | ||
| for i in batch_indices: | ||
| single_logprobs, _ = forward_for_logprobs( | ||
| model, | ||
| data_BT.query_responses[i], | ||
| data_BT.attention_masks[i], | ||
| data_BT.position_ids[i], | ||
| pad_token_id, | ||
| temperature, | ||
| False, | ||
| ) | ||
|
|
||
| response_mask_BT = data_BT.response_masks[i].to(single_logprobs.device) | ||
| single_logprobs = torch.masked_fill( | ||
| single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB | ||
| ) | ||
| logprobs_BT.append(single_logprobs) | ||
| continue | ||
|
Comment on lines
+343
to
+360
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code handles cases where the shapes of if len(set(shapes)) != 1:
for i in batch_indices:
single_query_responses = data_BT.query_responses[i]
single_attention_mask = data_BT.attention_masks[i]
single_position_ids = data_BT.position_ids[i]
if single_query_responses.ndim == 1:
single_query_responses = single_query_responses.unsqueeze(0)
single_attention_mask = single_attention_mask.unsqueeze(0)
single_position_ids = single_position_ids.unsqueeze(0)
single_logprobs, _ = forward_for_logprobs(
model,
single_query_responses,
single_attention_mask,
single_position_ids,
pad_token_id,
temperature,
False,
)
response_mask_BT = data_BT.response_masks[i]
if response_mask_BT.ndim == 1:
response_mask_BT = response_mask_BT.unsqueeze(0)
response_mask_BT = response_mask_BT.to(single_logprobs.device)
single_logprobs = torch.masked_fill(
single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
)
logprobs_BT.append(single_logprobs)
# torch.cuda.empty_cache() # Move outside the loop
continue |
||
|
|
||
| batch_query_responses = torch.cat(query_responses, dim=0) | ||
| batch_attention_masks = torch.cat(attention_masks, dim=0) | ||
| batch_position_ids = torch.cat(position_ids, dim=0) | ||
|
|
||
| batch_logprobs, _ = forward_for_logprobs( | ||
| model, | ||
|
|
@@ -357,8 +381,6 @@ def compute_logprobs( | |
| logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB) | ||
| logprobs_BT.append(logprob_BT) | ||
|
|
||
| torch.cuda.empty_cache() | ||
|
|
||
| return logprobs_BT | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #!/bin/bash | ||
| set -euo pipefail | ||
|
|
||
| export TORCH_COMPILE_DISABLE=1 | ||
| export VLLM_ALLOW_INSECURE_SERIALIZATION=1 | ||
| export VLLM_DISABLE_COMPILE_CACHE=1 | ||
| export VLLM_USE_V1=1 | ||
|
|
||
| uv run --active open_instruct/grpo.py \ | ||
| --dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \ | ||
| --dataset_mixer_list_splits train \ | ||
| --dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \ | ||
| --dataset_mixer_eval_list_splits train \ | ||
| --max_prompt_token_length 512 \ | ||
| --response_length 512 \ | ||
| --pack_length 1024 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --num_unique_prompts_rollout 8 \ | ||
| --num_samples_per_prompt_rollout 4 \ | ||
| --model_name_or_path Qwen/Qwen3-0.6B \ | ||
| --system_prompt_override_file scripts/train/qwen/math_system_prompt.txt \ | ||
| --apply_verifiable_reward true \ | ||
| --learning_rate 1e-6 \ | ||
| --total_episodes 128 \ | ||
| --deepspeed_stage 2 \ | ||
| --num_epochs 1 \ | ||
| --num_learners_per_node 1 \ | ||
| --vllm_tensor_parallel_size 1 \ | ||
| --beta 0.01 \ | ||
| --seed 3 \ | ||
| --local_eval_every 4 \ | ||
| --vllm_sync_backend gloo \ | ||
| --vllm_gpu_memory_utilization 0.4 \ | ||
| --vllm_enforce_eager \ | ||
| --single_gpu_mode \ | ||
| --push_to_hub false $@ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Please reason step by step, and put your final answer within \boxed{}. |
Uh oh!
There was an error while loading. Please reload this page.