Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ All notable changes to this project will be documented in this file.
- OLMo-core GRPO actor with Ray-distributed FSDP2 training (https://github.com/allenai/open-instruct/pull/1398).

### Fixed
- Got Olmo-core GRPO running in single-gpu mode and added a grpo.py debug script (https://github.com/allenai/open-instruct/pull/1543).
- Batch vLLM weight sync broadcasts to reduce Ray RPCs from ~200+ to 1, fixing timeouts with 32k response lengths (https://github.com/allenai/open-instruct/pull/1535).
- Fix `wandb_tracker.run.url` `AttributeError` on non-main processes in multi-node SFT training by guarding accesses with `accelerator.is_main_process` checks (https://github.com/allenai/open-instruct/pull/1539).
- Fix `UnboundLocalError` for `beaker_config` in SFT tracking setup when `push_to_hub` is disabled (https://github.com/allenai/open-instruct/pull/1539).
Expand Down
11 changes: 6 additions & 5 deletions open_instruct/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,11 @@ def reshuffle(self, epoch: int | None = None, **kwargs):
self.current_epoch = epoch

def get_mock_batch(self) -> dict[str, Any]:
dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long)
dummy_attention = torch.tensor([1, 1], dtype=torch.long)
dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long)
dummy_response_mask = torch.zeros_like(dummy_qr)
dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float)
dummy_qr = torch.tensor([[self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]], dtype=torch.long)
dummy_attention = torch.tensor([[1, 1]], dtype=torch.long)
dummy_position_ids = torch.arange(dummy_qr.shape[-1], dtype=torch.long).unsqueeze(0)
dummy_response_mask = torch.tensor([[0, 1]], dtype=torch.long)
dummy_advantage = torch.tensor([[0.0, 1.0]], dtype=torch.float)

batch = data_types.CollatedBatchData(
query_responses=[dummy_qr],
Expand All @@ -631,6 +631,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]:

def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor:
padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id)
padded_tensor = torch.atleast_2d(padded_tensor)
if pin_memory and torch.cuda.is_available():
padded_tensor = padded_tensor.pin_memory()
return padded_tensor
Expand Down
19 changes: 13 additions & 6 deletions open_instruct/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ def main(
os.makedirs(args.output_dir, exist_ok=True)
pprint([args, model_config])

ray.init(
address="auto",
dashboard_host="0.0.0.0",
runtime_env={
ray_init_kwargs = {
"dashboard_host": "0.0.0.0",
"runtime_env": {
"excludes": [".git/"],
"env_vars": {k: v for k, v in os.environ.items() if k not in grpo_fast.EXCLUDED_ENV_VARS},
},
)
}
if ray_address := utils.get_ray_address():
ray_init_kwargs["address"] = ray_address
Comment thread
finbarrtimbers marked this conversation as resolved.
ray.init(**ray_init_kwargs)

pool_size = tools_config.pool_size
if pool_size is None:
Expand Down Expand Up @@ -196,6 +198,8 @@ def main(
model_dims = utils.ModelDims.from_hf_config(model_config.model_name_or_path)

data_prep_actor_name = "data_prep_singleton"
base_env_config = grpo_fast.build_base_env_config(tools_config, pools)

_data_prep_actor = DataPreparationActor.options(name=data_prep_actor_name, num_cpus=2).remote( # type: ignore[attr-defined]
dataset=train_dataset,
inference_results_Q=inference_results_Q,
Expand All @@ -213,8 +217,11 @@ def main(
model_dims=model_dims,
verbose=args.verbose,
work_dir=args.output_dir,
tool_names=tools_config.tool_call_names if tools_config else [],
run_name=args.run_name,
model_name=model_config.model_name_or_path,
base_env_config=base_env_config,
initial_state=None,
allow_world_padding=False,
)

wait_for_gpus(sum(args.num_learners_per_node))
Expand Down
13 changes: 2 additions & 11 deletions open_instruct/grpo_olmo_core_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from open_instruct import data_loader as data_loader_lib
from open_instruct import grpo_utils, logger_utils, olmo_core_utils, vllm_utils
from open_instruct.beaker_callback import BeakerCallbackV2
from open_instruct.grpo_callbacks import RefPolicyUpdateCallback, VLLMWeightSyncCallback, olmo_core_to_hf_name
from open_instruct.olmo_core_callbacks import BeakerCallbackV2
from open_instruct.olmo_core_train_modules import GRPOTrainModule
from open_instruct.utils import RayProcess, is_beaker_job, ray_get_with_progress

Expand Down Expand Up @@ -95,15 +95,6 @@ def setup_model(self) -> int:
f"[Rank {self.rank}] Set CUDA device to 0, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}"
)

if not torch.distributed.is_initialized():
logger.info(f"[Rank {self.rank}] Calling init_process_group with NCCL backend...")
torch.distributed.init_process_group(
backend="nccl", timeout=timedelta(minutes=self.grpo_config.backend_timeout)
)
logger.info(f"[Rank {self.rank}] init_process_group completed successfully")
else:
logger.info(f"[Rank {self.rank}] Process group already initialized")

backend = "cpu:gloo,cuda:nccl"
logger.info(f"[Rank {self.rank}] Calling train.prepare_training_environment...")
train.prepare_training_environment(seed=self.grpo_config.seed, backend=backend)
Expand Down Expand Up @@ -166,7 +157,7 @@ def setup_model(self) -> int:
self.train_module = GRPOTrainModule(
model=self.model,
optim=optim_config,
rank_microbatch_size=self.grpo_config.per_device_train_batch_size,
sample_microbatch_size=self.grpo_config.per_device_train_batch_size,
max_sequence_length=self.max_sequence_length,
grpo_config=self.grpo_config,
tokenizer=self.tokenizer,
Expand Down
34 changes: 28 additions & 6 deletions open_instruct/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def forward_for_logprobs(
logits = logits / temperature
# The logits at position i predict token i+1, so we align them with labels shifted by 1
logits = logits[:, :-1]
labels = query_responses[:, 1:].clone()
labels = query_responses[:, 1:].clone().to(logits.device)
Comment thread
finbarrtimbers marked this conversation as resolved.
# Replace pad tokens with 0 to avoid index out of bounds errors in gather
labels[labels == pad_token_id] = 0
logprob_BT = model_utils.log_softmax_and_gather(logits, labels)
Expand Down Expand Up @@ -335,9 +335,33 @@ def compute_logprobs(
end_idx = min(start_idx + batch_size, num_samples)
batch_indices = list(range(start_idx, end_idx))

batch_query_responses = torch.cat([data_BT.query_responses[i] for i in batch_indices], dim=0)
batch_attention_masks = torch.cat([data_BT.attention_masks[i] for i in batch_indices], dim=0)
batch_position_ids = torch.cat([data_BT.position_ids[i] for i in batch_indices], dim=0)
query_responses = [data_BT.query_responses[i] for i in batch_indices]
attention_masks = [data_BT.attention_masks[i] for i in batch_indices]
position_ids = [data_BT.position_ids[i] for i in batch_indices]
shapes = [tuple(t.shape) for t in query_responses]

if len(set(shapes)) != 1:
for i in batch_indices:
single_logprobs, _ = forward_for_logprobs(
model,
data_BT.query_responses[i],
data_BT.attention_masks[i],
data_BT.position_ids[i],
pad_token_id,
temperature,
False,
)

response_mask_BT = data_BT.response_masks[i].to(single_logprobs.device)
single_logprobs = torch.masked_fill(
single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
)
logprobs_BT.append(single_logprobs)
continue
Comment on lines +343 to +360
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code handles cases where the shapes of query_responses are not uniform within a batch. However, the torch.cuda.empty_cache() call inside the loop might be inefficient. It would be better to move this call outside the loop to reduce overhead, or even better, rely on PyTorch's memory management to handle the caching.

            if len(set(shapes)) != 1:
                for i in batch_indices:
                    single_query_responses = data_BT.query_responses[i]
                    single_attention_mask = data_BT.attention_masks[i]
                    single_position_ids = data_BT.position_ids[i]
                    if single_query_responses.ndim == 1:
                        single_query_responses = single_query_responses.unsqueeze(0)
                        single_attention_mask = single_attention_mask.unsqueeze(0)
                        single_position_ids = single_position_ids.unsqueeze(0)

                    single_logprobs, _ = forward_for_logprobs(
                        model,
                        single_query_responses,
                        single_attention_mask,
                        single_position_ids,
                        pad_token_id,
                        temperature,
                        False,
                    )

                    response_mask_BT = data_BT.response_masks[i]
                    if response_mask_BT.ndim == 1:
                        response_mask_BT = response_mask_BT.unsqueeze(0)
                    response_mask_BT = response_mask_BT.to(single_logprobs.device)
                    single_logprobs = torch.masked_fill(
                        single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
                    )
                    logprobs_BT.append(single_logprobs)
                # torch.cuda.empty_cache() # Move outside the loop
                continue


batch_query_responses = torch.cat(query_responses, dim=0)
batch_attention_masks = torch.cat(attention_masks, dim=0)
batch_position_ids = torch.cat(position_ids, dim=0)

batch_logprobs, _ = forward_for_logprobs(
model,
Expand All @@ -357,8 +381,6 @@ def compute_logprobs(
logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB)
logprobs_BT.append(logprob_BT)

torch.cuda.empty_cache()

return logprobs_BT


Expand Down
12 changes: 10 additions & 2 deletions open_instruct/olmo_core_train_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(
self,
model: Transformer,
optim: OptimConfig,
rank_microbatch_size: int,
sample_microbatch_size: int,
max_sequence_length: int,
grpo_config: grpo_utils.ExperimentConfig,
tokenizer: PreTrainedTokenizer,
Expand All @@ -294,10 +294,11 @@ def __init__(
state_dict_save_opts: dist_cp_sd.StateDictOptions | None = None,
state_dict_load_opts: dist_cp_sd.StateDictOptions | None = None,
):
rank_microbatch_size_tokens = sample_microbatch_size * max_sequence_length
Comment thread
finbarrtimbers marked this conversation as resolved.
super().__init__(
model=model,
optim=optim,
rank_microbatch_size=rank_microbatch_size,
rank_microbatch_size=rank_microbatch_size_tokens,
max_sequence_length=max_sequence_length,
dp_config=dp_config,
max_grad_norm=max_grad_norm,
Expand All @@ -307,6 +308,7 @@ def __init__(
state_dict_load_opts=state_dict_load_opts,
)

self.sample_microbatch_size = sample_microbatch_size
self.grpo_config = grpo_config
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
Expand All @@ -315,6 +317,12 @@ def __init__(
if ref_policy is not None:
self.ref_policy = ref_policy.to(device=self.device).eval().requires_grad_(False)

def pre_train(self):
# GRPO batches are prompt-grouped and do their own accumulation/token normalization
# inside train_batch(), so the base TransformerTrainModule global-batch validation
# does not apply here.
pass

def state_dict(self, *, optim: bool | None = None) -> dict[str, Any]:
state = super().state_dict(optim=optim)
if self.ref_policy is not None:
Expand Down
17 changes: 14 additions & 3 deletions open_instruct/olmo_core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
including model configuration mappings and helper functions.
"""

import os

import torch
import transformers
from olmo_core.nn.attention import AttentionBackendName
from olmo_core.nn.hf.checkpoint import save_hf_model
Expand All @@ -22,6 +25,7 @@
"allenai/Olmo-3-1025-7B": "olmo3_7B",
"allenai/OLMoE-1B-7B-0924": "olmoe_1B_7B",
"Qwen/Qwen3-0.6B": "qwen3_0_6B",
"Qwen/Qwen3-0.6B-Base": "qwen3_0_6B",
"Qwen/Qwen3-1.7B": "qwen3_1_7B",
"Qwen/Qwen3-4B": "qwen3_4B",
"Qwen/Qwen3-8B": "qwen3_8B",
Expand Down Expand Up @@ -67,9 +71,16 @@ def get_transformer_config(


def save_state_dict_as_hf(model_config, state_dict, save_dir, original_model_name_or_path, tokenizer):
unwrapped_model = model_config.build(init_device="cpu")
unwrapped_model.load_state_dict(state_dict)
save_hf_model(save_dir=save_dir, model_state_dict=state_dict, model=unwrapped_model, save_overwrite=True)
try:
unwrapped_model = model_config.build(init_device="cpu")
unwrapped_model.load_state_dict(state_dict)
save_hf_model(save_dir=save_dir, model_state_dict=state_dict, model=unwrapped_model, save_overwrite=True)
except NotImplementedError as exc:
logger.warning(
"Falling back to raw state_dict save because HF export is unsupported for this OLMo-core model: %s", exc
)
os.makedirs(save_dir, exist_ok=True)
torch.save(state_dict, os.path.join(save_dir, "model_state_dict.pt"))
Comment thread
finbarrtimbers marked this conversation as resolved.
tokenizer.save_pretrained(save_dir)
original_config = transformers.AutoConfig.from_pretrained(original_model_name_or_path)
original_config.save_pretrained(save_dir)
36 changes: 36 additions & 0 deletions scripts/train/debug/grpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
set -euo pipefail

export TORCH_COMPILE_DISABLE=1
export VLLM_ALLOW_INSECURE_SERIALIZATION=1
export VLLM_DISABLE_COMPILE_CACHE=1
export VLLM_USE_V1=1

uv run --active open_instruct/grpo.py \
--dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \
--dataset_mixer_eval_list_splits train \
--max_prompt_token_length 512 \
--response_length 512 \
--pack_length 1024 \
--per_device_train_batch_size 1 \
--num_unique_prompts_rollout 8 \
--num_samples_per_prompt_rollout 4 \
--model_name_or_path Qwen/Qwen3-0.6B \
--system_prompt_override_file scripts/train/qwen/math_system_prompt.txt \
--apply_verifiable_reward true \
--learning_rate 1e-6 \
--total_episodes 128 \
--deepspeed_stage 2 \
--num_epochs 1 \
--num_learners_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.01 \
--seed 3 \
--local_eval_every 4 \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.4 \
--vllm_enforce_eager \
--single_gpu_mode \
--push_to_hub false $@
26 changes: 11 additions & 15 deletions scripts/train/debug/grpo_fast.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#!/bin/bash
set -euo pipefail

export TORCH_COMPILE_DISABLE=1
export VLLM_ALLOW_INSECURE_SERIALIZATION=1
export VLLM_DISABLE_COMPILE_CACHE=1
export VLLM_USE_V1=1
uv run python open_instruct/grpo_fast.py \

uv run --active open_instruct/grpo_fast.py \
--dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \
Expand All @@ -14,27 +18,19 @@ uv run python open_instruct/grpo_fast.py \
--num_unique_prompts_rollout 8 \
--num_samples_per_prompt_rollout 4 \
--model_name_or_path Qwen/Qwen3-0.6B \
--stop_strings "</answer>" \
--system_prompt_override_file scripts/train/qwen/math_system_prompt.txt \
--apply_verifiable_reward true \
--temperature 0.7 \
--ground_truths_key ground_truth \
--chat_template_name r1_simple_chat_postpend_think \
--learning_rate 3e-7 \
--total_episodes 200 \
--learning_rate 1e-6 \
--total_episodes 128 \
--deepspeed_stage 2 \
--num_epochs 1 \
--num_learners_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.01 \
--seed 3 \
--local_eval_every 1 \
--local_eval_every 4 \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.3 \
--save_traces \
--vllm_gpu_memory_utilization 0.4 \
--vllm_enforce_eager \
--gradient_checkpointing \
--single_gpu_mode \
--push_to_hub false \
--system_prompt_override_file scripts/train/debug/cute_debug_system_prompt.txt \
--active_sampling --async_steps 8
# --with_tracking
--push_to_hub false $@
1 change: 1 addition & 0 deletions scripts/train/qwen/math_system_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Please reason step by step, and put your final answer within \boxed{}.
Loading