diff --git a/AGENTS.md b/AGENTS.md index 48249eaf4d..69f501618a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -15,15 +15,16 @@ - `scripts/train/debug/single_gpu_on_beaker.sh`: single GPU, no tools (~8 minutes). - `scripts/train/debug/tools/olmo_3_parser_multigpu.sh`: multi GPU, with tools. - `scripts/train/debug/large_test_script.sh`: two 8x GPU nodes, no tools (~32 minutes). -- For DPO, we have two test scripts: - - `scripts/train/debug/dpo.sh`: single GPU. - - `scripts/train/debug/large_dpo.sh`: four 8x GPU nodes. +- For DPO, we have three test scripts: + - `scripts/train/debug/dpo/local.sh`: local single GPU (no Beaker). + - `scripts/train/debug/dpo/single_gpu.sh`: single GPU on Beaker. + - `scripts/train/debug/dpo/multi_node.sh`: two 8x GPU nodes on Beaker. - To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes. - Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tools/olmo_3_parser_multigpu.sh`. - Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`. -- Launch DPO experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo.sh`. -- Launch multi-node DPO experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/medium_dpo.sh`. -- Launch the GPU tests with `./scripts/train/build_image_and_launch.sh scripts/test/run_gpu_tests.sh`. +- Launch DPO experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/single_gpu.sh`. +- Launch multi-node DPO experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/multi_node.sh`. +- Launch the GPU tests with `./scripts/train/build_image_and_launch.sh scripts/train/debug/run_gpu_tests.sh`. - If you are given a Beaker URL (beaker\.allen\.ai.*) use the Beaker CLI tool to interact with it. # Coding conventions diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c995b11a..8c50969f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Added +- Added OLMo-core based DPO training script (https://github.com/allenai/open-instruct/pull/1391). - Added SLURM scripts for OLMo SFT training with checkpoint resume support and configurable shuffle seed. https://github.com/allenai/open-instruct/pull/1368 - Added retry logic with exponential backoff to `make_api_request` for tool API calls (retries on timeouts, connection errors, 429, and 5xx). Also added configurable `max_concurrency` parameter to tool configs for controlling Ray actor concurrency per-tool. https://github.com/allenai/open-instruct/pull/1388 - Added support for generic MCP tools during training, with some limitations (no changing tools, no tool discovery during training). For details: https://github.com/allenai/open-instruct/pull/1384 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 943af346b3..f26ed9d5a5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,6 +2,14 @@ Thank you for your interest in contributing to Open Instruct! +## Adding Olmo-core models + +For our new infrastructure, which is based on [Olmo-core](https://github.com/allenai/OLMo-core), we need to add models in manually to convert them from Huggingface. You don't need to merge the PR to `olmo-core` (although we encourage it!) as you can modify `pyproject.toml` to use a specific commit of `olmo-core` (or a fork). + +Here are some example PRs adding models: [Qwen3](https://github.com/allenai/OLMo-core/pull/533), [Gemma 3](https://github.com/allenai/OLMo-core/pull/534). + +Once you have modified `pyproject.toml` to point to the specific commit, run `uv sync`, and then you should be able to run your experiment with the new model type. + ## External contributors ### CI (Fork PRs) diff --git a/mason.py b/mason.py index a456bba0a6..b34b86c4ab 100644 --- a/mason.py +++ b/mason.py @@ -26,6 +26,7 @@ # Open Instruct logic OPEN_INSTRUCT_COMMANDS = [ "open_instruct/finetune.py", + "open_instruct/dpo.py", "open_instruct/dpo_tune_cache.py", "open_instruct/grpo_fast.py", "open_instruct/reward_modeling.py", diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 09cde8c56f..65f8116c89 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -82,6 +82,7 @@ def __init__( automatic_reshuffle: bool = False, collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None, device: torch.device | None = None, + drop_last: bool = True, ) -> None: """Initialize the HFDataLoader. @@ -96,6 +97,8 @@ def __init__( collator: Optional collation function for batching examples. If None, batches will be dictionaries of the form `{'examples': [example_1, example_2, ...]}`. device: Device to move tensors to. + drop_last: If True, drop the last incomplete batch. If False, pad the last batch + with repeated indices to fill a complete batch. Note: The dataset must have an 'index' column for tracking samples across epochs. @@ -131,6 +134,7 @@ def __init__( self._per_rank_batch_size = batch_size // dp_world_size self._collator = collator if collator is not None else (lambda x: {"examples": x}) self._automatic_reshuffle = automatic_reshuffle + self._drop_last = drop_last self._excluded_indices: set[int] = set() self._epoch: int = 0 self._current_iter: Iterator[dict[str, Any]] | None = None @@ -224,6 +228,13 @@ def _reshard(self, epoch: int) -> None: total_batches = global_size // self._batch_size usable_size = total_batches * self._batch_size + if not self._drop_last and usable_size < global_size: + remainder = global_size - usable_size + pad_indices = all_indices[: self._batch_size - remainder] + all_indices = np.concatenate([all_indices, pad_indices]) + total_batches += 1 + usable_size = total_batches * self._batch_size + # Distribute examples from global batches to ranks. This is a form of strided sampling where each # rank gets a subset of examples from each global batch, ensuring a diverse set of examples. rank_indices = all_indices[:usable_size].reshape(total_batches, self._batch_size) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py new file mode 100644 index 0000000000..c851de19f0 --- /dev/null +++ b/open_instruct/dpo.py @@ -0,0 +1,474 @@ +""" +DPO training with OLMo-core's Trainer. + +This module provides DPO (Direct Preference Optimization) training using +OLMo-core's native training infrastructure. +""" + +import os +import pathlib +import shutil +from functools import partial + +import bitsandbytes.optim +import torch +import torch.distributed as dist +import transformers +from olmo_core import train +from olmo_core.config import DType +from olmo_core.distributed import utils as distributed_utils +from olmo_core.distributed.parallel import DataParallelType, build_world_mesh, get_dp_model_mesh +from olmo_core.nn.attention.backend import has_flash_attn_3 +from olmo_core.nn.hf.checkpoint import load_hf_model +from olmo_core.nn.transformer.config import TransformerActivationCheckpointingMode +from olmo_core.optim import ConstantWithWarmup, CosWithWarmup, LinearWithWarmup +from olmo_core.train import callbacks +from olmo_core.train.callbacks import CheckpointerCallback +from olmo_core.train.train_module.transformer import ( + TransformerDataParallelConfig, + TransformerDataParallelWrappingStrategy, +) + +from open_instruct import data_loader as data_loader_lib +from open_instruct import dataset_transformation, dpo_utils, logger_utils, model_utils, olmo_core_utils, utils +from open_instruct.beaker_callback import BeakerCallbackV2 +from open_instruct.olmo_core_train_modules import DPOTrainModule +from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO + +logger = logger_utils.setup_logger(__name__) + + +def export_to_hf( + model, model_config, tokenizer, save_dir: str, original_model_name_or_path: str, is_main_process: bool +): + """Export an FSDP-wrapped model to HuggingFace format. + + All ranks must call this function as state_dict() and full_tensor() are collective operations. + Only the main process saves to disk. + """ + logger.info("Gathering FSDP state dict...") + state_dict = model.state_dict() + state_dict = {k: v.full_tensor().cpu() if hasattr(v, "full_tensor") else v.cpu() for k, v in state_dict.items()} + + if is_main_process: + logger.info(f"Exporting model to HuggingFace format at {save_dir}") + olmo_core_utils.save_state_dict_as_hf( + model_config, state_dict, save_dir, original_model_name_or_path, tokenizer + ) + + +def _load_dataset_distributed( + args: dpo_utils.ExperimentConfig, + tc: dataset_transformation.TokenizerConfig, + transform_fn_args: list[dict], + is_main_process: bool, +): + """Load dataset with distributed coordination.""" + + def _load(): + return dataset_transformation.get_cached_dataset_tulu( + dataset_mixer_list=args.mixer_list, + dataset_mixer_list_splits=args.mixer_list_splits, + tc=tc, + dataset_transform_fn=args.transform_fn, + transform_fn_args=transform_fn_args, + target_columns=args.target_columns, + dataset_cache_mode=args.cache_mode, + dataset_config_hash=args.config_hash, + hf_entity=args.hf_entity, + dataset_local_cache_dir=args.local_cache_dir, + dataset_skip_cache=args.skip_cache, + ) + + if is_main_process: + dataset = _load() + if distributed_utils.is_distributed(): + dist.barrier() + if not is_main_process: + dataset = _load() + return dataset + + +def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device): + """Load and configure OLMo-core model.""" + hf_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) + vocab_size = hf_config.vocab_size + logger.info(f"Building OLMo-core model with vocab_size={vocab_size}") + config_name_for_lookup = args.config_name if args.config_name else args.model_name_or_path + + attn_backend = args.attn_backend + if attn_backend == "auto": + device_name = torch.cuda.get_device_name(0).lower() if torch.cuda.is_available() else "" + is_h100 = "h100" in device_name or "h800" in device_name + attn_backend = "flash_3" if (is_h100 and has_flash_attn_3()) else "flash_2" + logger.info(f"Auto-detected attn_backend={attn_backend} for device: {device_name}") + + model_config = olmo_core_utils.get_transformer_config( + config_name_for_lookup, vocab_size, attn_backend=attn_backend + ) + model = model_config.build(init_device="cpu") + + logger.info(f"Loading HuggingFace weights from {args.model_name_or_path}") + load_hf_model(args.model_name_or_path, model.state_dict(), work_dir=args.output_dir) + model = model.to(device=device, dtype=torch.bfloat16) + + if args.gradient_checkpointing: + logger.info("Enabling activation checkpointing...") + model.apply_activation_checkpointing(TransformerActivationCheckpointingMode.full) + + return model, model_config + + +def _apply_parallelism( + model, + device: torch.device, + tensor_parallel_degree: int = 1, + context_parallel_degree: int = 1, + pipeline_parallel_degree: int = 1, +): + """Apply parallelism strategies to model (HSDP, TP, CP, PP). + + Args: + model: The model to apply parallelism to. + device: The device to use. + tensor_parallel_degree: Tensor parallelism degree (default 1, disabled). + context_parallel_degree: Context parallelism degree (default 1, disabled). + pipeline_parallel_degree: Pipeline parallelism degree (default 1, disabled). + + Returns: + The model with parallelism applied. + """ + if tensor_parallel_degree > 1 and context_parallel_degree > 1: + raise ValueError("Cannot use both tensor parallelism and context parallelism simultaneously.") + + dp_config = TransformerDataParallelConfig( + name=DataParallelType.hsdp, + num_replicas=None, + shard_degree=None, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + wrapping_strategy=TransformerDataParallelWrappingStrategy.blocks, + ) + + tp_config = tensor_parallel_degree if tensor_parallel_degree > 1 else None + cp_config = context_parallel_degree if context_parallel_degree > 1 else None + pp_config = pipeline_parallel_degree if pipeline_parallel_degree > 1 else None + + world_mesh = build_world_mesh(dp=dp_config, tp=tp_config, cp=cp_config, pp=pp_config, device_type=device.type) + dp_mesh = get_dp_model_mesh(world_mesh) + + if tensor_parallel_degree > 1: + logger.info(f"Applying tensor parallelism with degree={tensor_parallel_degree}") + tp_mesh = world_mesh["tp"] + model.apply_tp(tp_mesh) + + if context_parallel_degree > 1: + logger.info(f"Applying context parallelism with degree={context_parallel_degree}") + + if pipeline_parallel_degree > 1: + logger.info(f"Applying pipeline parallelism with degree={pipeline_parallel_degree}") + pp_mesh = world_mesh["pp"] + model.apply_pp(pp_mesh) + + logger.info(f"Applying HSDP with dp_mesh: {dp_mesh}") + model.apply_fsdp( + dp_mesh=dp_mesh, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + wrapping_strategy=dp_config.wrapping_strategy, + ) + return model + + +def _setup_optimizer_and_scheduler(args: dpo_utils.ExperimentConfig, model, num_training_steps: int): + """Return (optimizer, scheduler).""" + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + if args.dpo_use_paged_optimizer: + optim = bitsandbytes.optim.AdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + optim_bits=8 if args.use_8bit_optimizer else 32, + is_paged=True, + ) + else: + optim = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer) + + warmup_steps = int(num_training_steps * args.warmup_ratio) + if args.lr_scheduler_type == "cosine": + scheduler = CosWithWarmup(warmup_steps=warmup_steps) + elif args.lr_scheduler_type == "linear": + scheduler = LinearWithWarmup(warmup_steps=warmup_steps, alpha_f=0.0) + else: + scheduler = ConstantWithWarmup(warmup_steps=warmup_steps) + + return optim, scheduler + + +def _setup_callbacks(args: dpo_utils.ExperimentConfig, model): + """Return callbacks dict.""" + json_config = dpo_utils.config_to_json_serializable(vars(args)) + trainer_callbacks: dict[str, callbacks.Callback] = {"beaker": BeakerCallbackV2(config=json_config)} + device_name = utils.get_device_name(torch.cuda.get_device_name(0)) + device_peak_flops = int(utils.GPU_SPECS[device_name]["flops"]) + trainer_callbacks["speed_monitor"] = callbacks.SpeedMonitorCallback( + num_flops_per_token=model.num_flops_per_token(args.max_seq_length), device_peak_flops=device_peak_flops + ) + trainer_callbacks["gpu_memory"] = callbacks.GPUMemoryMonitorCallback() + slack_webhook_url = os.environ.get("SLACK_WEBHOOK_URL") + if slack_webhook_url: + trainer_callbacks["slack"] = callbacks.SlackNotifierCallback( + name=args.run_name or args.exp_name, webhook_url=slack_webhook_url + ) + if args.with_tracking: + trainer_callbacks["wandb"] = callbacks.WandBCallback( + name=args.run_name or args.exp_name, + project=args.wandb_project, + entity=args.wandb_entity, + config=json_config, + ) + checkpointing_steps = int(args.checkpointing_steps) + trainer_callbacks["checkpointer"] = CheckpointerCallback(save_interval=checkpointing_steps, save_async=False) + return trainer_callbacks + + +def _handle_post_training( + args: dpo_utils.ExperimentConfig, + model, + model_config, + tokenizer, + trainer_callbacks, + beaker_config, + is_main_process: bool, +): + """Save HF model, copy to beaker, launch evals, push to hub.""" + hf_model_path = os.path.join(args.output_dir, "hf_model") + export_to_hf(model, model_config, tokenizer, hf_model_path, args.model_name_or_path, is_main_process) + + if distributed_utils.is_distributed(): + dist.barrier() + + output_path = pathlib.Path(args.output_dir).resolve() + beaker_output_path = pathlib.Path("/output").resolve() + if ( + args.try_auto_save_to_beaker + and is_main_process + and utils.is_beaker_job() + and beaker_config is not None + and len(beaker_config.beaker_dataset_id_urls) > 0 + and output_path != beaker_output_path + ): + shutil.copytree(hf_model_path, "/output", dirs_exist_ok=True) + + if utils.is_beaker_job() and is_main_process and args.try_launch_beaker_eval_jobs: + wandb_url = None + if args.with_tracking: + wandb_tracker = trainer_callbacks.get("wandb") + if wandb_tracker is not None and hasattr(wandb_tracker, "run") and wandb_tracker.run is not None: + wandb_url = wandb_tracker.run.get_url() + if args.hf_repo_revision is not None: + eval_path = hf_model_path + if beaker_config is not None and beaker_config.beaker_dataset_ids: + eval_path = beaker_config.beaker_dataset_ids[-1] + utils.launch_ai2_evals_on_weka( + path=eval_path, + leaderboard_name=args.hf_repo_revision, + oe_eval_max_length=args.oe_eval_max_length, + wandb_url=wandb_url, + oe_eval_tasks=args.oe_eval_tasks, + gs_bucket_path=args.gs_bucket_path, + eval_workspace=args.eval_workspace, + eval_priority=args.eval_priority, + oe_eval_gpu_multiplier=args.oe_eval_gpu_multiplier, + ) + + if args.push_to_hub and is_main_process: + model_utils.push_folder_to_hub(hf_model_path, args.hf_repo_id, args.hf_repo_revision) + + +def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerConfig) -> None: + """Main entry point for DPO training with OLMo-core.""" + if args.model_name_or_path is None: + raise ValueError("--model_name_or_path is required. Specify a HuggingFace model name or path.") + + if args.use_lora: + raise ValueError("LoRA is not supported with OLMo-core DPO training. Use dpo_tune_cache.py instead.") + + tc.tokenizer_name_or_path = ( + args.model_name_or_path if tc.tokenizer_name_or_path is None else tc.tokenizer_name_or_path + ) + tokenizer = tc.tokenizer + + args.local_cache_dir = os.path.abspath(args.local_cache_dir) + if utils.is_beaker_job(): + args.local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache" + + transform_fn_args = [{"max_seq_length": args.max_seq_length}, {}] + ref_cache_hash = dpo_utils.compute_reference_cache_hash(args, tc) + reference_cache_path = pathlib.Path(dpo_utils.REFERENCE_LOGPROBS_CACHE_PATH) / f"{ref_cache_hash}.pt" + logger.info(f"Reference logprobs cache path: {reference_cache_path}") + + if args.cache_dataset_only: + dataset_transformation.get_cached_dataset_tulu( + dataset_mixer_list=args.mixer_list, + dataset_mixer_list_splits=args.mixer_list_splits, + tc=tc, + dataset_transform_fn=args.transform_fn, + transform_fn_args=transform_fn_args, + target_columns=args.target_columns, + dataset_cache_mode=args.cache_mode, + dataset_config_hash=args.config_hash, + hf_entity=args.hf_entity, + dataset_local_cache_dir=args.local_cache_dir, + dataset_skip_cache=args.skip_cache, + ) + logger.info("Dataset cached successfully. Exiting because --cache_dataset_only was set.") + return + + train.prepare_training_environment(seed=args.seed) + + dp_rank = distributed_utils.get_rank() if distributed_utils.is_distributed() else 0 + is_main_process = dp_rank == 0 + + dataset = _load_dataset_distributed(args, tc, transform_fn_args, is_main_process) + dataset = dataset.shuffle(seed=args.seed) + dataset.set_format(type="pt") # Must be after shuffle (shuffle resets format) + + world_size = distributed_utils.get_world_size() if distributed_utils.is_distributed() else 1 + parallelism_factor = args.tensor_parallel_degree * args.context_parallel_degree * args.pipeline_parallel_degree + dp_world_size = world_size // parallelism_factor + + logger_utils.setup_logger(rank=dp_rank) + + beaker_config = utils.setup_experiment_paths(args, is_main_process) + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + if distributed_utils.is_distributed(): + dist.barrier() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model, model_config = _setup_model(args, device) + + if args.packing: + collator = TensorDataCollatorWithFlatteningDPO(return_position_ids=True, return_flash_attn_kwargs=True) + else: + collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest") + + global_batch_size = args.per_device_train_batch_size * dp_world_size + data_loader = data_loader_lib.HFDataLoader( + dataset=dataset, + batch_size=global_batch_size, + seed=args.seed, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + work_dir=args.output_dir, + collator=collator, + device=device, + ) + # 4x batch size: forward-only (no backward), so no activation storage needed. + cache_batch_size = args.per_device_train_batch_size * 4 * dp_world_size + cache_data_loader = data_loader_lib.HFDataLoader( + dataset=dataset, + batch_size=cache_batch_size, + seed=args.seed, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + work_dir=args.output_dir, + collator=collator, + device=device, + drop_last=False, + ) + + forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo + if args.packing: + forward_fn = partial(dpo_utils.concatenated_forward_olmo, packing=True) + average_log_prob = args.loss_type.is_average_loss + + cache_kwargs = dict( + dataloader=cache_data_loader, + average_log_prob=average_log_prob, + forward_fn=forward_fn, + full_dataset_size=len(dataset), + device=device, + cache_path=reference_cache_path, + is_main_process=is_main_process, + model_dims=utils.ModelDims.from_hf_config(args.model_name_or_path), + use_lora=False, + disable_adapter_context=None, + ) + + model_is_sharded = False + logger.info("Caching reference logprobs (trying unsharded first)...") + try: + reference_cache = dpo_utils.build_reference_logprobs_cache(model=model, **cache_kwargs) + logger.info("Reference logprobs cached (unsharded).") + except torch.cuda.OutOfMemoryError: + logger.warning("OOM with unsharded model, falling back to FSDP-sharded.") + torch.cuda.empty_cache() + model_is_sharded = True + model = _apply_parallelism( + model, device, args.tensor_parallel_degree, args.context_parallel_degree, args.pipeline_parallel_degree + ) + reference_cache = dpo_utils.build_reference_logprobs_cache(model=model, **cache_kwargs) + logger.info("Reference logprobs cached (sharded).") + + if args.cache_logprobs_only: + logger.info("--cache_logprobs_only set, exiting after cache build.") + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + return + + if not model_is_sharded: + model = _apply_parallelism( + model, device, args.tensor_parallel_degree, args.context_parallel_degree, args.pipeline_parallel_degree + ) + data_loader.reshuffle(epoch=0) + + num_training_steps = len(data_loader) * args.num_epochs + optim, scheduler = _setup_optimizer_and_scheduler(args, model, num_training_steps) + + max_grad_norm = args.max_grad_norm if args.max_grad_norm > 0 else None + train_module = DPOTrainModule( + model=model, + optim=optim, + args=args, + reference_cache=reference_cache, + scheduler=scheduler, + device=device, + max_grad_norm=max_grad_norm, + ) + + trainer_callbacks = _setup_callbacks(args, model) + + trainer = train.TrainerConfig( + save_folder=args.output_dir, + max_duration=train.Duration.epochs(args.num_epochs), + metrics_collect_interval=args.logging_steps, + callbacks=trainer_callbacks, + save_overwrite=True, + ).build(train_module, data_loader) + + logger.info("Starting training...") + trainer.fit() + logger.info("Training complete.") + + _handle_post_training(args, model, model_config, tokenizer, trainer_callbacks, beaker_config, is_main_process) + + train.teardown_training_environment() + + +if __name__ == "__main__": + from open_instruct.utils import ArgumentParserPlus + + parser = ArgumentParserPlus((dpo_utils.ExperimentConfig, dataset_transformation.TokenizerConfig)) + args, tc = parser.parse() + main(args, tc) diff --git a/open_instruct/dpo_utils.py b/open_instruct/dpo_utils.py index 5869a858be..4170b0958e 100644 --- a/open_instruct/dpo_utils.py +++ b/open_instruct/dpo_utils.py @@ -17,26 +17,28 @@ Adapted from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py """ +import contextlib import enum import functools import hashlib import json import os import pathlib +import time from collections.abc import Callable from dataclasses import dataclass, field from typing import Literal +import numpy as np import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from accelerate import Accelerator from tqdm.auto import tqdm from transformers import DataCollatorForSeq2Seq from transformers.training_args import _convert_str_dict -from open_instruct import logger_utils, model_utils +from open_instruct import logger_utils, model_utils, utils from open_instruct.dataset_transformation import ( TOKENIZED_PREFERENCE_DATASET_KEYS, TokenizerConfig, @@ -50,6 +52,17 @@ logger = logger_utils.setup_logger(__name__) +def config_to_json_serializable(obj: object) -> object: + """Convert config object to JSON-serializable format.""" + if isinstance(obj, dict): + return {k: config_to_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [config_to_json_serializable(v) for v in obj] + if isinstance(obj, enum.Enum): + return obj.value + return obj + + class DPOLossType(enum.StrEnum): dpo = "dpo" dpo_norm = "dpo_norm" @@ -137,6 +150,14 @@ class TrainingConfig: """Use paged optimizer from bitsandbytes.""" fused_optimizer: bool = True """Whether to use fused AdamW or not.""" + tensor_parallel_degree: int = 1 + """Tensor parallelism degree. Default 1 (disabled).""" + context_parallel_degree: int = 1 + """Context parallelism degree. Default 1 (disabled).""" + pipeline_parallel_degree: int = 1 + """Pipeline parallelism degree. Default 1 (disabled).""" + cache_logprobs_only: bool = False + """Exit after building the reference logprobs cache (for benchmarking).""" @dataclass @@ -217,7 +238,7 @@ class CheckpointConfig: output_dir: str = "output/" """The output directory where the model predictions and checkpoints will be written.""" - checkpointing_steps: int | str | None = None + checkpointing_steps: int | str = 500 """Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.""" keep_last_n_checkpoints: int = 3 """How many checkpoints to keep in the output directory. -1 for all.""" @@ -255,6 +276,8 @@ class ModelConfig: """The model checkpoint for weights initialization.""" use_flash_attn: bool = True """Whether to use flash attention in the model training""" + attn_backend: str = "auto" + """Attention backend for OLMo-core models. Options: flash_2, flash_3, auto.""" model_revision: str | None = None """The specific model version to use (can be a branch name, tag name or commit id).""" low_cpu_mem_usage: bool = False @@ -352,12 +375,6 @@ class ExperimentConfig( default=None, metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"} ) use_liger_kernel: bool = field(default=False, metadata={"help": "Whether to use LigerKernel for training."}) - checkpointing_steps: str | None = field( - default=None, - metadata={ - "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa - }, - ) hf_metadata_dataset: str | None = "allenai/tulu-3-evals" """What dataset to upload the metadata to. If unset, don't upload metadata""" @@ -461,18 +478,40 @@ def compute_reference_cache_hash(args: ExperimentConfig, tc: TokenizerConfig) -> def build_reference_logprobs_cache( model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, average_log_prob: bool, forward_fn: Callable, full_dataset_size: int, - reference_cache_hash: str, + device: torch.device, + cache_path: pathlib.Path, + is_main_process: bool, + model_dims: utils.ModelDims, use_lora: bool = False, + disable_adapter_context: Callable[[], contextlib.AbstractContextManager] | None = None, ) -> model_utils.TensorCache: - """Build a TensorCache with reference logprobs by computing logprobs once for all samples.""" - cache_path = pathlib.Path(REFERENCE_LOGPROBS_CACHE_PATH) / f"{reference_cache_hash}.pt" - if not cache_path.exists(): + """Build a TensorCache with reference logprobs by computing logprobs once for all samples. + + Args: + model: The model to compute logprobs with. + dataloader: DataLoader providing batches with 'index' key. + average_log_prob: Whether to average log probs over sequence length. + forward_fn: Forward function to compute logprobs. + full_dataset_size: Total number of samples in the dataset. + device: Device to place tensors on. + cache_path: Path to save/load cache from. + is_main_process: Whether this is the main process. + use_lora: Whether LoRA is enabled (requires disable_adapter_context). + disable_adapter_context: Callable returning context manager to disable LoRA adapter. + + Returns: + TensorCache containing 'chosen_logps' and 'rejected_logps' tensors. + """ + if cache_path.exists(): + logger.info(f"Loading reference logprobs cache from {cache_path}") + return model_utils.TensorCache.from_disk(cache_path, device=device) + + if is_main_process: cache_path.parent.mkdir(parents=True, exist_ok=True) - test_file = cache_path.parent / f".write_test_{reference_cache_hash}" + test_file = cache_path.parent / f".write_test_{cache_path.stem}" try: test_file.touch() test_file.unlink() @@ -481,21 +520,22 @@ def build_reference_logprobs_cache( f"Cannot write to cache directory {cache_path.parent}: {e}. " f"Set REFERENCE_LOGPROBS_CACHE_PATH to a writable location." ) from e - if cache_path.exists(): - logger.info(f"Loading reference logprobs cache from {cache_path}") - return model_utils.TensorCache.from_disk(cache_path, device=accelerator.device) + if dist.is_initialized(): + dist.barrier() model.eval() - device = accelerator.device chosen_tensor = torch.full((full_dataset_size,), float("-inf"), dtype=torch.float32, device=device) rejected_tensor = torch.full((full_dataset_size,), float("-inf"), dtype=torch.float32, device=device) + total_tokens = 0 + total_examples = 0 + with torch.no_grad(): - for batch in tqdm( - dataloader, disable=not accelerator.is_local_main_process, desc="Caching reference logprobs" - ): - if use_lora: - with accelerator.unwrap_model(model).disable_adapter(): + pbar = tqdm(dataloader, disable=not is_main_process, desc="Caching reference logprobs") + for batch in pbar: + batch_start = time.perf_counter() + if use_lora and disable_adapter_context is not None: + with disable_adapter_context(): chosen_logps, rejected_logps, _ = forward_fn(model, batch, average_log_prob=average_log_prob) else: chosen_logps, rejected_logps, _ = forward_fn(model, batch, average_log_prob=average_log_prob) @@ -503,6 +543,22 @@ def build_reference_logprobs_cache( chosen_tensor[batch["index"]] = chosen_logps rejected_tensor[batch["index"]] = rejected_logps + batch_tokens = batch["chosen_input_ids"].numel() + batch["rejected_input_ids"].numel() + total_tokens += batch_tokens + total_examples += len(batch["index"]) + + bs = len(batch["index"]) + chosen_lengths = [batch["chosen_input_ids"].shape[1]] * bs + rejected_lengths = [batch["rejected_input_ids"].shape[1]] * bs + pbar.set_postfix( + { + "avg_tok/ex": f"{total_tokens / total_examples:.0f}", + "MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - batch_start):.1f}", + "mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}", + "mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}", + } + ) + if dist.is_initialized(): dist.all_reduce(chosen_tensor, op=dist.ReduceOp.MAX) dist.all_reduce(rejected_tensor, op=dist.ReduceOp.MAX) @@ -519,7 +575,7 @@ def build_reference_logprobs_cache( model.train() cache = model_utils.TensorCache(tensors={"chosen_logps": chosen_tensor, "rejected_logps": rejected_tensor}) - if accelerator.is_main_process: + if is_main_process: logger.info(f"Saving reference logprobs cache to {cache_path}") cache.to_disk(cache_path) @@ -1021,7 +1077,20 @@ class DataCollatorForSeq2SeqDPO(DataCollatorForSeq2Seq): def __call__(self, features, return_tensors=None): # call the original collator on chosen and rejected separately, then combine def filter_batch(match_string, features): - return [{k.replace(match_string, ""): v for k, v in f.items() if match_string in k} for f in features] + filtered = [] + for f in features: + item = {} + for k, v in f.items(): + if match_string in k: + key = k.replace(match_string, "") + if isinstance(v, np.ndarray): + item[key] = torch.as_tensor(v) + elif isinstance(v, list): + item[key] = torch.tensor(v) + else: + item[key] = v + filtered.append(item) + return filtered chosen_features = super().__call__(filter_batch("chosen_", features), return_tensors=return_tensors) rejected_features = super().__call__(filter_batch("rejected_", features), return_tensors=return_tensors) diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py new file mode 100644 index 0000000000..a574f7b274 --- /dev/null +++ b/open_instruct/olmo_core_train_modules.py @@ -0,0 +1,127 @@ +"""OLMo-core TrainModule classes for various training objectives.""" + +from functools import partial +from typing import Any + +import torch +import torch.nn as nn +from olmo_core.optim.scheduler import Scheduler +from olmo_core.train.common import ReduceType +from olmo_core.train.train_module import EvalBatchSpec, TrainModule + +from open_instruct import dpo_utils, model_utils + + +class DPOTrainModule(TrainModule): + """Training module for DPO with OLMo-core's Trainer. + + Uses OLMo-core's scheduler.set_lr() pattern for learning rate scheduling. + """ + + def __init__( + self, + model: nn.Module, + optim: torch.optim.Optimizer, + args: dpo_utils.ExperimentConfig, + reference_cache: model_utils.TensorCache, + scheduler: Scheduler, + device: torch.device | None = None, + max_grad_norm: float | None = None, + ) -> None: + super().__init__() + self.model = model + self.optim = optim + self.args = args + self.reference_cache = reference_cache + self.scheduler = scheduler + self.device = device + self.max_grad_norm = max_grad_norm + + if args.packing: + self._forward_fn = partial(dpo_utils.concatenated_forward_olmo, packing=True) + elif args.concatenated_forward: + self._forward_fn = dpo_utils.concatenated_forward_olmo + else: + self._forward_fn = dpo_utils.separate_forward_olmo + + def state_dict(self, *, optim: bool | None = None) -> dict[str, Any]: + state_dict: dict[str, Any] = {"model": self.model.state_dict()} + if optim is not False: + state_dict["optim"] = self.optim.state_dict() + return state_dict + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.model.load_state_dict(state_dict["model"]) + if "optim" in state_dict: + self.optim.load_state_dict(state_dict["optim"]) + + def zero_grads(self) -> None: + self.optim.zero_grad() + + def optim_step(self) -> None: + if self.max_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.trainer.record_metric("total grad norm", grad_norm, reduce_type=None, namespace="optim") + for group_idx, group in enumerate(self.optim.param_groups): + new_lr = self.scheduler.set_lr(group, self.trainer) + self.trainer.record_metric(f"LR (group {group_idx})", new_lr, namespace="optim") + self.optim.step() + + def num_flops_per_token(self, seq_len: int) -> int: + return self.model.num_flops_per_token(seq_len) + + def global_num_flops_in_batch(self, batch: dict[str, Any]) -> int: + seq_len = batch["input_ids"].shape[1] + flops_per_token = self.num_flops_per_token(seq_len) + global_num_tokens = self.trainer.data_loader.global_num_tokens_in_batch(batch) + return flops_per_token * global_num_tokens + + @property + def eval_batch_spec(self) -> EvalBatchSpec: + return EvalBatchSpec(rank_batch_size=1) + + def eval_batch(self, batch: dict[str, Any]) -> torch.Tensor: + self.model.eval() + with torch.no_grad(): + return self.model(**batch) + + def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: + self.model.train() + + policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn( + self.model, + batch, + average_log_prob=self.args.loss_type.is_average_loss, + output_router_logits=self.args.load_balancing_loss, + ) + + losses, chosen_rewards, rejected_rewards = dpo_utils.compute_loss( + self.args, + batch, + policy_chosen_logps, + policy_rejected_logps, + self.reference_cache if self.args.loss_type.needs_reference_model else None, + ) + + loss = losses.mean() + + if self.args.load_balancing_loss and aux_loss is not None: + loss = loss + self.args.load_balancing_weight * aux_loss + + if not dry_run: + self.record_metric("train/loss", loss.detach(), ReduceType.mean) + self.record_metric("train/logps_chosen", policy_chosen_logps.mean().detach(), ReduceType.mean) + self.record_metric("train/logps_rejected", policy_rejected_logps.mean().detach(), ReduceType.mean) + + if self.args.loss_type.computes_reward_metrics: + accuracy = (chosen_rewards > rejected_rewards).float().mean() + margin = (chosen_rewards - rejected_rewards).mean() + self.record_metric("train/rewards_chosen", chosen_rewards.mean().detach(), ReduceType.mean) + self.record_metric("train/rewards_rejected", rejected_rewards.mean().detach(), ReduceType.mean) + self.record_metric("train/rewards_accuracy", accuracy.detach(), ReduceType.mean) + self.record_metric("train/rewards_margin", margin.detach(), ReduceType.mean) + + if self.args.load_balancing_loss and aux_loss is not None: + self.record_metric("train/aux_loss", aux_loss.detach(), ReduceType.mean) + + loss.backward() diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py new file mode 100644 index 0000000000..eebc5940d7 --- /dev/null +++ b/open_instruct/olmo_core_utils.py @@ -0,0 +1,75 @@ +""" +OLMo-core utility functions and configuration mappings. + +This module provides common utilities for working with OLMo-core models, +including model configuration mappings and helper functions. +""" + +import transformers +from olmo_core.nn.attention import AttentionBackendName +from olmo_core.nn.hf.checkpoint import save_hf_model +from olmo_core.nn.transformer import TransformerConfig + +from open_instruct import logger_utils + +logger = logger_utils.setup_logger(__name__) + +OLMO_MODEL_CONFIG_MAP: dict[str, str] = { + "allenai/OLMo-2-0425-1B": "olmo2_1B_v2", + "allenai/OLMo-2-1124-7B": "olmo2_7B", + "allenai/OLMo-2-1124-13B": "olmo2_13B", + "allenai/OLMo-2-0325-32B": "olmo2_32B", + "allenai/Olmo-3-1025-7B": "olmo3_7B", + "allenai/OLMoE-1B-7B-0924": "olmoe_1B_7B", + "Qwen/Qwen3-0.6B": "qwen3_0_6B", + "Qwen/Qwen3-1.7B": "qwen3_1_7B", + "Qwen/Qwen3-4B": "qwen3_4B", + "Qwen/Qwen3-8B": "qwen3_8B", + "Qwen/Qwen3-14B": "qwen3_14B", + "Qwen/Qwen3-32B": "qwen3_32B", +} + + +def get_transformer_config( + model_name_or_config: str, vocab_size: int, attn_backend: str | None = None +) -> TransformerConfig: + """Get the appropriate TransformerConfig for a given model name or config name. + + Args: + model_name_or_config: HuggingFace model name, path, or direct config name (e.g., 'olmo3_7B'). + vocab_size: Vocabulary size for the model. + attn_backend: Attention backend name (e.g., 'flash_2', 'flash_3'). If None, uses default. + + Returns: + TransformerConfig for the specified model. + + Raises: + ValueError: If model/config not found. + """ + config_name = OLMO_MODEL_CONFIG_MAP.get(model_name_or_config) + if config_name is None: + config_name = model_name_or_config + + if not hasattr(TransformerConfig, config_name): + available_models = ", ".join(OLMO_MODEL_CONFIG_MAP.keys()) + available_configs = [ + name for name in dir(TransformerConfig) if name.startswith(("olmo", "qwen")) and not name.startswith("_") + ] + raise ValueError( + f"Model/config '{model_name_or_config}' not found. " + f"Available models: {available_models}. " + f"Available config names: {', '.join(available_configs)}" + ) + kwargs: dict = {"vocab_size": vocab_size} + if attn_backend is not None: + kwargs["attn_backend"] = AttentionBackendName(attn_backend) + return getattr(TransformerConfig, config_name)(**kwargs) + + +def save_state_dict_as_hf(model_config, state_dict, save_dir, original_model_name_or_path, tokenizer): + unwrapped_model = model_config.build(init_device="cpu") + unwrapped_model.load_state_dict(state_dict) + save_hf_model(save_dir=save_dir, model_state_dict=state_dict, model=unwrapped_model, save_overwrite=True) + tokenizer.save_pretrained(save_dir) + original_config = transformers.AutoConfig.from_pretrained(original_model_name_or_path) + original_config.save_pretrained(save_dir) diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index d85459f72a..2c11f58b03 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -15,6 +15,12 @@ def single_example_collator(examples: list[dict]) -> dict: return examples[0] +def batch_collator(examples: list[dict]) -> dict: + """Collator that stacks example values into lists.""" + keys = examples[0].keys() + return {key: [ex[key] for ex in examples] for key in keys} + + def make_test_dataset(num_examples: int) -> datasets.Dataset: """Create a test dataset with the required 'index' column.""" data = {"text": [f"example_{i}" for i in range(num_examples)], "label": list(range(num_examples))} @@ -348,6 +354,73 @@ def test_global_num_tokens_in_batch(self): with self.assertRaises(ValueError): loader.global_num_tokens_in_batch(batch_without_tokens) + @parameterized.parameterized.expand( + [ + ("size_17_batch_4", 17, 4), + ("size_23_batch_8", 23, 8), + ("size_10_batch_3", 10, 3), + ("size_33_batch_16", 33, 16), + ] + ) + def test_drop_last_true_drops_remainder(self, name, num_examples, batch_size): + dataset = make_test_dataset(num_examples) + loader = open_instruct.data_loader.HFDataLoader( + dataset=dataset, + batch_size=batch_size, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=tempfile.gettempdir(), + collator=batch_collator, + drop_last=True, + ) + indices = [idx for batch in loader for idx in batch["index"]] + expected_count = (num_examples // batch_size) * batch_size + self.assertEqual(len(indices), expected_count) + + @parameterized.parameterized.expand( + [ + ("size_17_batch_4", 17, 4), + ("size_23_batch_8", 23, 8), + ("size_10_batch_3", 10, 3), + ("size_33_batch_16", 33, 16), + ] + ) + def test_drop_last_false_covers_all_indices(self, name, num_examples, batch_size): + dataset = make_test_dataset(num_examples) + loader = open_instruct.data_loader.HFDataLoader( + dataset=dataset, + batch_size=batch_size, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=tempfile.gettempdir(), + collator=batch_collator, + drop_last=False, + ) + indices = [idx for batch in loader for idx in batch["index"]] + self.assertEqual(set(indices), set(range(num_examples))) + + @parameterized.parameterized.expand( + [("dp2_size_17_batch_4", 17, 4, 2), ("dp4_size_23_batch_8", 23, 8, 4), ("dp2_size_33_batch_16", 33, 16, 2)] + ) + def test_drop_last_false_multi_rank_covers_all_indices(self, name, num_examples, batch_size, dp_world_size): + dataset = make_test_dataset(num_examples) + all_indices = [] + for dp_rank in range(dp_world_size): + loader = open_instruct.data_loader.HFDataLoader( + dataset=dataset, + batch_size=batch_size, + seed=42, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + work_dir=tempfile.gettempdir(), + collator=batch_collator, + drop_last=False, + ) + all_indices.extend(idx for batch in loader for idx in batch["index"]) + self.assertEqual(set(all_indices), set(range(num_examples))) + class TestStreamingDataLoaderConfigSaveTraces(unittest.TestCase): def test_save_traces_requires_rollouts_save_path(self): diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 2da97b3a5e..1d3690e8cb 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -58,6 +58,7 @@ import ray import requests import torch +import torch.distributed as dist import torch.nn.functional as F import wandb from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk @@ -1333,6 +1334,42 @@ def maybe_use_ai2_hf_entity() -> str | None: return None +def setup_experiment_paths(args, is_main_process: bool) -> BeakerRuntimeConfig | None: + """Set up exp_name, output_dir, HF Hub config, wandb_entity. + + Modifies args in-place. Returns BeakerRuntimeConfig if on Beaker. + """ + if getattr(args, "add_seed_and_date_to_exp_name", False): + args.exp_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.output_dir = os.path.join(args.output_dir, args.exp_name) + + if dist.is_initialized(): + path_list = [args.output_dir] + dist.broadcast_object_list(path_list, src=0) + args.output_dir = path_list[0] + + beaker_config = None + if is_beaker_job() and is_main_process: + beaker_config = maybe_get_beaker_config() + + if getattr(args, "push_to_hub", False) and is_main_process: + if args.hf_repo_id is None: + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: + args.hf_repo_revision = args.exp_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if getattr(args, "wandb_entity", None) is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + return beaker_config + + @retry_on_exception() def upload_metadata_to_hf(metadata_dict, filename, hf_dataset_name, hf_dataset_save_dir): # upload a random dict to HF. Originally for uploading metadata to HF diff --git a/scripts/benchmarking/launch_dpo_cache_benchmark.sh b/scripts/benchmarking/launch_dpo_cache_benchmark.sh new file mode 100755 index 0000000000..df4acadab8 --- /dev/null +++ b/scripts/benchmarking/launch_dpo_cache_benchmark.sh @@ -0,0 +1,53 @@ +#!/bin/bash +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" +MODEL_NAME=/weka/oe-adapt-default/scottg/olmo/merging/ckpts/olmo3-7b-instruct-sft-1115 + +uv run python mason.py \ + --cluster ai2/saturn \ + --cluster ai2/jupiter \ + --description "DPO cache forward-pass benchmark: OLMo3-7B, 2 nodes, 8k seq" \ + --workspace ai2/open-instruct-dev \ + --priority urgent \ + --image "$BEAKER_IMAGE" \ + --pure_docker_mode \ + --preemptible \ + --num_nodes 2 \ + --budget ai2/oe-adapt \ + --no_auto_dataset_cache \ + --env OLMO_SHARED_FS=1 \ + --env "REFERENCE_LOGPROBS_CACHE_PATH=/tmp/benchmark_cache_\$(date +%s)" \ + --gpus 8 -- torchrun \ + --nnodes=2 \ + --node_rank=\$BEAKER_REPLICA_RANK \ + --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ + --master_port=29400 \ + --nproc_per_node=8 \ + open_instruct/dpo.py \ + --model_name_or_path "$MODEL_NAME" \ + --config_name olmo3_7B \ + --chat_template_name olmo123 \ + --max_seq_length 8192 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate 1e-6 \ + --lr_scheduler_type linear \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_epochs 1 \ + --output_dir output/benchmark_dpo_cache/ \ + --mixer_list allenai/olmo-3-pref-mix-deltas-complement2-DECON-tpc-kwd-ch-dedup5-lbc100-grafmix-unbal 125000 \ + allenai/dpo-yolo1-200k-gpt4.1-2w2s-maxdelta_reje-426124-rm-gemma3-kwd-ftd-ch-ftd-topic-ftd-dedup5-lbc100 125000 \ + allenai/related-query_qwen_pairs_filtered_lbc100 1250 \ + allenai/paraphrase_qwen_pairs_filtered_lbc100 938 \ + allenai/repeat_qwen_pairs_filtered_lbc100 312 \ + allenai/self-talk_qwen_pairs_filtered_lbc100 2500 \ + allenai/related-query_gpt_pairs_filtered_lbc100 1250 \ + allenai/paraphrase_gpt_pairs_filtered_lbc100 938 \ + allenai/repeat_gpt_pairs_filtered_lbc100 312 \ + allenai/self-talk_gpt_pairs_filtered_lbc100 2500 \ + --seed 123 \ + --use_flash_attn \ + --loss_type dpo_norm \ + --beta 5 \ + --gradient_checkpointing \ + --cache_logprobs_only diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index c55320ed65..6a2c063cd3 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -7,7 +7,7 @@ import yaml -from open_instruct import utils +from open_instruct import launch_utils, utils ######################################## @@ -181,7 +181,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): # remove nfs if asked or jupiter in cluster list. weka_available = False -if all(c in utils.WEKA_CLUSTERS for c in cluster): +if all(c in launch_utils.WEKA_CLUSTERS for c in cluster): d1["tasks"][0]["datasets"].append({"mountPath": "/weka/oe-adapt-default", "source": {"weka": "oe-adapt-default"}}) d1["tasks"][0]["datasets"].append( {"mountPath": "/weka/oe-training-default", "source": {"weka": "oe-training-default"}} diff --git a/scripts/test/run_gpu_pytest.sh b/scripts/test/run_gpu_pytest.sh index 14a11126a9..486b4811bb 100755 --- a/scripts/test/run_gpu_pytest.sh +++ b/scripts/test/run_gpu_pytest.sh @@ -14,7 +14,7 @@ uv run python mason.py \ --description "GPU tests for test_*_gpu.py" \ --pure_docker_mode \ --workspace ai2/open-instruct-dev \ - --priority high \ + --priority urgent \ --preemptible \ --num_nodes 1 \ --max_retries 0 \ diff --git a/scripts/train/convert_olmo_core_to_hf.py b/scripts/train/convert_olmo_core_to_hf.py new file mode 100644 index 0000000000..f6c2964191 --- /dev/null +++ b/scripts/train/convert_olmo_core_to_hf.py @@ -0,0 +1,48 @@ +"""Convert an olmo-core distributed checkpoint to HuggingFace format. + +Example usage: + uv run python scripts/train/convert_olmo_core_to_hf.py \ + --checkpoint-dir /path/to/checkpoint/step1000 \ + --model-name allenai/OLMo-2-1124-7B \ + --output-dir /path/to/output/hf_checkpoint +""" + +import argparse + +import torch +import torch.distributed.checkpoint.state_dict as dcp_state_dict +import transformers + +from open_instruct import logger_utils, olmo_core_utils + +logger = logger_utils.setup_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="Convert olmo-core checkpoint to HuggingFace format") + parser.add_argument("--checkpoint-dir", required=True, help="Path to olmo-core checkpoint directory") + parser.add_argument("--model-name", required=True, help="HF model name or olmo-core config name") + parser.add_argument("--output-dir", required=True, help="Where to save the HF checkpoint") + parser.add_argument("--tokenizer-name", default=None, help="HF tokenizer name (defaults to --model-name)") + args = parser.parse_args() + + tokenizer_name = args.tokenizer_name or args.model_name + + hf_config = transformers.AutoConfig.from_pretrained(args.model_name) + vocab_size = hf_config.vocab_size + + model_config = olmo_core_utils.get_transformer_config(args.model_name, vocab_size) + model = model_config.build(init_device="cpu") + + state_dict = {"model": model.state_dict()} + dcp_state_dict.load_state_dict(state_dict, checkpoint_id=args.checkpoint_dir) + + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + olmo_core_utils.save_state_dict_as_hf( + model_config, state_dict["model"], args.output_dir, args.model_name, tokenizer + ) + logger.info(f"Saved HuggingFace checkpoint to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train/debug/dpo/7b_instruct_dpo_olmo_core.sh b/scripts/train/debug/dpo/7b_instruct_dpo_olmo_core.sh new file mode 100644 index 0000000000..e5c5cf7062 --- /dev/null +++ b/scripts/train/debug/dpo/7b_instruct_dpo_olmo_core.sh @@ -0,0 +1,54 @@ +#!/bin/bash +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" +MODEL_NAME=/weka/oe-adapt-default/scottg/olmo/merging/ckpts/olmo3-7b-instruct-sft-1115 +LR=1e-6 +EXP_NAME=olmo3-7b-DPO-olmo-core-8k-${LR} + +uv run python mason.py \ + --cluster ai2/saturn \ + --cluster ai2/jupiter \ + --description "OLMo3-7B DPO with OLMo-core, 2 nodes, 8k seq len" \ + --workspace ai2/open-instruct-dev \ + --priority urgent \ + --image "$BEAKER_IMAGE" \ + --pure_docker_mode \ + --preemptible \ + --num_nodes 2 \ + --budget ai2/oe-adapt \ + --no_auto_dataset_cache \ + --env OLMO_SHARED_FS=1 \ + --gpus 8 -- torchrun \ + --nnodes=2 \ + --node_rank=\$BEAKER_REPLICA_RANK \ + --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ + --master_port=29400 \ + --nproc_per_node=8 \ + open_instruct/dpo.py \ + --exp_name "$EXP_NAME" \ + --model_name_or_path "$MODEL_NAME" \ + --config_name olmo3_7B \ + --chat_template_name olmo123 \ + --max_seq_length 8192 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate "$LR" \ + --lr_scheduler_type linear \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_epochs 1 \ + --mixer_list allenai/olmo-3-pref-mix-deltas-complement2-DECON-tpc-kwd-ch-dedup5-lbc100-grafmix-unbal 125000 \ + allenai/dpo-yolo1-200k-gpt4.1-2w2s-maxdelta_reje-426124-rm-gemma3-kwd-ftd-ch-ftd-topic-ftd-dedup5-lbc100 125000 \ + allenai/related-query_qwen_pairs_filtered_lbc100 1250 \ + allenai/paraphrase_qwen_pairs_filtered_lbc100 938 \ + allenai/repeat_qwen_pairs_filtered_lbc100 312 \ + allenai/self-talk_qwen_pairs_filtered_lbc100 2500 \ + allenai/related-query_gpt_pairs_filtered_lbc100 1250 \ + allenai/paraphrase_gpt_pairs_filtered_lbc100 938 \ + allenai/repeat_gpt_pairs_filtered_lbc100 312 \ + allenai/self-talk_gpt_pairs_filtered_lbc100 2500 \ + --seed 123 \ + --logging_steps 1 \ + --loss_type dpo_norm \ + --beta 5 \ + --gradient_checkpointing \ + --with_tracking diff --git a/scripts/train/debug/dpo/local.sh b/scripts/train/debug/dpo/local.sh new file mode 100755 index 0000000000..bb3fa6be8b --- /dev/null +++ b/scripts/train/debug/dpo/local.sh @@ -0,0 +1,18 @@ +#!/bin/bash +uv run torchrun --nproc_per_node=1 open_instruct/dpo.py \ + --model_name_or_path allenai/OLMo-2-0425-1B \ + --tokenizer_name allenai/OLMo-2-0425-1B \ + --attn_backend flash_2 \ + --max_seq_length 1024 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate 5e-07 \ + --lr_scheduler_type linear \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_epochs 1 \ + --output_dir output/dpo_local_test/ \ + --logging_steps 1 \ + --mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 50 \ + --chat_template_name olmo \ + --seed 123 diff --git a/scripts/train/debug/medium_dpo.sh b/scripts/train/debug/dpo/multi_node.sh similarity index 64% rename from scripts/train/debug/medium_dpo.sh rename to scripts/train/debug/dpo/multi_node.sh index 2dc0ab8984..7d55d3b93e 100755 --- a/scripts/train/debug/medium_dpo.sh +++ b/scripts/train/debug/dpo/multi_node.sh @@ -1,12 +1,13 @@ #!/bin/bash BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" -MODEL_NAME=allenai/Olmo-3-1025-7B +MODEL_NAME=allenai/OLMo-2-1124-7B LR=1e-6 -EXP_NAME=olmo3-7b-DPO-debug-32k-${LR} +EXP_NAME=olmo2-7b-DPO-debug-16k-${LR} uv run python mason.py \ + --cluster ai2/saturn \ --cluster ai2/jupiter \ - --description "2 node DPO run with OLMo3-7B, 16k sequence length." \ + --description "2 node DPO run with OLMo2-7B, 16k sequence length (OLMo-core)." \ --workspace ai2/open-instruct-dev \ --priority urgent \ --image "$BEAKER_IMAGE" \ @@ -15,13 +16,14 @@ uv run python mason.py \ --num_nodes 2 \ --budget ai2/oe-adapt \ --no_auto_dataset_cache \ - --gpus 8 -- accelerate launch \ - --mixed_precision bf16 \ - --num_processes 8 \ - --use_deepspeed \ - --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ - --deepspeed_multinode_launcher standard \ - open_instruct/dpo_tune_cache.py \ + --env OLMO_SHARED_FS=1 \ + --gpus 8 -- torchrun \ + --nnodes=2 \ + --node_rank=\$BEAKER_REPLICA_RANK \ + --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ + --master_port=29400 \ + --nproc_per_node=8 \ + open_instruct/dpo.py \ --exp_name "$EXP_NAME" \ --model_name_or_path "$MODEL_NAME" \ --chat_template_name olmo \ @@ -33,10 +35,9 @@ uv run python mason.py \ --warmup_ratio 0.1 \ --weight_decay 0.0 \ --num_epochs 1 \ - --output_dir output/dpo_olmo3_debug/ \ + --output_dir output/dpo_olmo2_debug/ \ --mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 1000 \ --seed 123 \ - --use_flash_attn \ --logging_steps 1 \ --loss_type dpo_norm \ --beta 5 \ diff --git a/scripts/train/debug/dpo.sh b/scripts/train/debug/dpo/single_gpu.sh similarity index 65% rename from scripts/train/debug/dpo.sh rename to scripts/train/debug/dpo/single_gpu.sh index 18d9fd6729..3cb4c02d3a 100755 --- a/scripts/train/debug/dpo.sh +++ b/scripts/train/debug/dpo/single_gpu.sh @@ -4,7 +4,7 @@ BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" uv run python mason.py \ --cluster ai2/saturn \ --cluster ai2/jupiter \ - --description "Single GPU DPO run, for debugging purposes." \ + --description "Single GPU DPO run with OLMo-core, for debugging purposes." \ --workspace ai2/open-instruct-dev \ --priority urgent \ --image "$BEAKER_IMAGE" \ @@ -13,13 +13,9 @@ uv run python mason.py \ --num_nodes 1 \ --budget ai2/oe-adapt \ --no_auto_dataset_cache \ - --gpus 1 -- accelerate launch \ - --mixed_precision bf16 \ - --num_processes 1 \ - open_instruct/dpo_tune_cache.py \ - --model_name_or_path Qwen/Qwen3-0.6B \ - --tokenizer_name Qwen/Qwen3-0.6B \ - --use_flash_attn false \ + --gpus 1 -- torchrun --nproc_per_node=1 open_instruct/dpo.py \ + --model_name_or_path allenai/OLMo-2-0425-1B \ + --tokenizer_name allenai/OLMo-2-0425-1B \ --max_seq_length 1024 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 4 \ @@ -28,11 +24,11 @@ uv run python mason.py \ --warmup_ratio 0.1 \ --weight_decay 0.0 \ --num_epochs 3 \ - --output_dir output/dpo_pythia_14m/ \ + --output_dir output/dpo_olmo_core_debug/ \ --logging_steps 1 \ --mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 100 \ - --add_bos \ --chat_template_name olmo \ --seed 123 \ - --try_launch_beaker_eval_jobs false \ + --try_launch_beaker_eval_jobs true \ + --hf_repo_revision dpo_olmo_core_debug_test \ --with_tracking diff --git a/scripts/train/debug/large_dpo.sh b/scripts/train/debug/large_dpo.sh deleted file mode 100755 index 4d14e564d9..0000000000 --- a/scripts/train/debug/large_dpo.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" -MODEL_NAME=allenai/Olmo-3-1025-7B -LR=1e-6 -EXP_NAME=olmo3-7b-DPO-debug-32k-${LR} - -uv run python mason.py \ - --cluster ai2/jupiter \ - --description "Multi-node DPO run with OLMo3-7B, 32k sequence length." \ - --workspace ai2/open-instruct-dev \ - --priority urgent \ - --image "$BEAKER_IMAGE" \ - --pure_docker_mode \ - --preemptible \ - --num_nodes 4 \ - --budget ai2/oe-adapt \ - --no_auto_dataset_cache \ - --gpus 8 -- accelerate launch \ - --mixed_precision bf16 \ - --num_processes 8 \ - --use_deepspeed \ - --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ - --deepspeed_multinode_launcher standard \ - open_instruct/dpo_tune_cache.py \ - --exp_name "$EXP_NAME" \ - --model_name_or_path "$MODEL_NAME" \ - --chat_template_name olmo \ - --max_seq_length 32768 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 4 \ - --learning_rate "$LR" \ - --lr_scheduler_type linear \ - --warmup_ratio 0.1 \ - --weight_decay 0.0 \ - --num_epochs 1 \ - --output_dir output/dpo_olmo3_debug/ \ - --mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 1000 \ - --seed 123 \ - --use_flash_attn \ - --logging_steps 1 \ - --loss_type dpo_norm \ - --beta 5 \ - --gradient_checkpointing \ - --with_tracking diff --git a/scripts/train/debug/single_gpu_on_beaker.sh b/scripts/train/debug/single_gpu_on_beaker.sh index 780901650e..d78bc2f6fa 100755 --- a/scripts/train/debug/single_gpu_on_beaker.sh +++ b/scripts/train/debug/single_gpu_on_beaker.sh @@ -34,7 +34,8 @@ uv run python mason.py \ --per_device_train_batch_size 1 \ --num_unique_prompts_rollout 8 \ --num_samples_per_prompt_rollout 4 \ - --model_name_or_path Qwen/Qwen3-1.7B \ + --model_name_or_path /weka/oe-adapt-default/allennlp/deletable_checkpoint/finbarrt/dpo_utils__123__1769051928/hf_model \ + --add_bos \ --stop_strings "" \ --apply_r1_style_format_reward \ --apply_verifiable_reward true \