diff --git a/recipes/configs/dev/3B_full_grpo.yaml b/recipes/configs/dev/3B_full_grpo.yaml new file mode 100644 index 0000000000..3fd1d7e240 --- /dev/null +++ b/recipes/configs/dev/3B_full_grpo.yaml @@ -0,0 +1,140 @@ +# Config for multi-node GRPO in dev/grpo_full_finetune_distributed.py +# using a Llama3.2 3B Base model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-3.2-3B --output-dir /tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth" +# +# It can be beneficial to first train the base model with SFT using the 3B_sft recipe. +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/3B_full_grpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# +# Furthermore, you can launch it on multiple nodes by going to recipes/dev/ and using +# sbatch multinode_grpo.sbatch + +name: grpo_llama3b + +output_dir: /tmp/checkpoints/${name} +base_model_path: /tmp/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-3B/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.dev.grpo.gsm8k.gsm8k_dataset + partition: 1-9/10 +seed: null +shuffle: False + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.llama3_2_3b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: ${base_model_path} + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + + +ref_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: ${base_model_path} + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir}/ref # shouldn't be used? + model_type: LLAMA3 + + +resume_from_checkpoint: False +save_every_n_epochs: 1 + +# Fine-tuning arguments +batch_size: 1 +grpo_samples: 16 +forward_batch_size: 1 +max_generated_tokens: 512 +top_k: null +temperature: 1.0 + +ppo_epochs: 1 + +num_steps: 200 + +clip_grad_norm: 1.0 + +epochs: 10 +optimizer: + _component_: torch.optim.AdamW + lr: 1e-5 + fused: True +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 50 +loss: + _component_: torchtune.dev.grpo.loss.GRPOSimpleLoss + kl_coeff: 0.01 + epsilon: 0.2 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +compile: False # pytorch compile, set to true for better perf/memory + +# Reduced precision +dtype: bf16 + + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: True + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: True + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/dev/3B_sft_for_grpo.yaml b/recipes/configs/dev/3B_sft_for_grpo.yaml new file mode 100644 index 0000000000..1713de2cc1 --- /dev/null +++ b/recipes/configs/dev/3B_sft_for_grpo.yaml @@ -0,0 +1,109 @@ +# Config for multi-device SFT for reasoning in full_finetune_distributed.py +# using a Llama3.2 3B Base model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 full_finetune_distributed --config dev/3B_grpo_sft +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. + + +name: llama3B_gsm8k_sft_part0 + +output_dir: /tmp/${name} + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-3B/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.dev.grpo.gsm8k.gsm8k_sft + partition: 0-0/10 +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.llama3_2_3b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-3B/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 + +optimizer: + _component_: torch.optim.AdamW + lr: 1e-5 + fused: True +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 +gradient_accumulation_steps: 1 # Use to increase effective batch size + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py new file mode 100644 index 0000000000..8d1511c362 --- /dev/null +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -0,0 +1,1077 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.distributed import destroy_process_group, init_process_group +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, generation, modules, rlhf, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.datasets import ConcatDataset +from torchtune.dev.grpo.generation import generate +from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward +from torchtune.dev.grpo.types import GRPOStats, GRPOTrajectory +from torchtune.modules import local_kv_cache +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY +from torchtune.training.lr_schedulers import get_lr +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class FullGRPOFinetuneRecipeDistributed(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2, trained with GRPO. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + disabled for faster generation (corresponding to FULL_SHARD sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + device_type = cfg.device + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False) + + self.distributed_backend = training.get_distributed_backend( + device_type, offload_ops_to_cpu=self.fsdp_cpu_offload + ) + init_process_group(self.distributed_backend) + + world_size, rank = utils.get_world_size_and_rank() + self.rank = rank + self.world_size = world_size + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # activation checkpointing + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.total_epochs = cfg.epochs + self.global_step = 0 + self._steps_run = 0 + self._total_steps = 0 + self._epochs_run = 0 + self._rng = torch.Generator(self._device).manual_seed(self.seed) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def load_ref_checkpoint(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the reference checkpoint state from file and validate. + """ + self._ref_checkpointer = config.instantiate( + cfg_ref_checkpointer, resume_from_checkpoint=False + ) + + ref_checkpoint_dict = self._ref_checkpointer.load_checkpoint() + + return ref_checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self._epochs_run = ckpt_dict[training.EPOCHS_KEY] + self._rng.set_state(ckpt_dict[training.RNG_KEY]) + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. + """ + if self.fsdp_cpu_offload: + training.set_torch_num_threads() + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + ref_checkpoint_dict = self.load_ref_checkpoint( + cfg_ref_checkpointer=cfg.ref_checkpointer + ) + + self._compile = cfg.get("compile", False) + self._model, self._ref_model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=self.fsdp_cpu_offload, + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ref_model_state_dict=ref_checkpoint_dict[training.MODEL_KEY], + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + utils.log_rank_zero(log, "Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get( + "collate_fn", "torchtune.dev.grpo.data.padded_collate_rl" + ) + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader. + # This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = len(self._dataloader) + self.global_step = self._epochs_run * self._steps_per_epoch + + # Setup lr scheduler + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.get("lr_scheduler", None), + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # RL params + self.grpo_samples = cfg.grpo_samples + self._temperature = cfg.temperature + self._top_k = cfg.top_k + self._max_generated_tokens = cfg.max_generated_tokens + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + + self._ppo_epochs = cfg.ppo_epochs + + self._save_every_n_epochs = cfg.save_every_n_epochs + + self._total_steps = cfg.num_steps + + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer, "stop_tokens"): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: Optional[DictConfig], + num_training_steps: int, + last_epoch: int, + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + It supports both standard optimization and optimizer-in-backward cases. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + if self._is_rank_zero: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + optimizer = self._optimizer + + # Instantiate the learning rate scheduler + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + + return lr_scheduler + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) + if self._is_rank_zero: + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + fsdp_cpu_offload: bool, + model_state_dict: Dict[str, Any], + ref_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ) -> tuple[nn.Module, nn.Module]: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + ref_model = config.instantiate(cfg_model) + + ref_model.eval() + for p in ref_model.parameters(): + p.requires_grad = False + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + training.compile_model(ref_model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + + # Policy doesn't reshard after forward for faster generation. + # Reference net reshards after forward because it never calls .backward() + # See: https://github.com/pytorch/torchtune/pull/2326/#issuecomment-2654684159 + + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=False, + ) + + training.shard_model( + model=ref_model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=True, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + for m in ref_model.modules(): + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + training.load_from_full_model_state_dict( + ref_model, + ref_model_state_dict, + self._device, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + training.validate_no_params_on_meta_device(ref_model) + + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) + if self._is_rank_zero: + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + disable_dropout(model) + disable_dropout(ref_model) + + # synchronize before training begins + torch.distributed.barrier() + + return model, ref_model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + self._model, + optimizer, + opt_state_dict, + self._device, + ) + + utils.log_rank_zero(log, "Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + + # Instantiate collate_fn + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, + num_replicas=self.world_size, + rank=self.rank, + shuffle=shuffle, + seed=self.seed, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ) + ), + ) + + utils.log_rank_zero(log, "Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.gather_cpu_state_dict( + self._model, + self._is_rank_zero, + device=self._device, + ) + + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) + + if intermediate_checkpoint: + start = time.perf_counter() + utils.log_rank_zero(log, "Getting optimizer state dict...") + opt_state_dict = training.get_full_optimizer_state_dict( + self._model, + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self._epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.RNG_KEY: self._rng.get_state(), + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def generate_trajectory( + self, input_ids: torch.Tensor, answers: List[str] + ) -> GRPOTrajectory: + """ + Generates a trajectory given the current policy model, the reference policy model, the reward function, + and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Compute rewards and successes for the generated responses. + 4. Estimate advantages using GRPO. + 5. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producing truncated responses. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers (List[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.GRPOTrajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + grpo_size = self.grpo_samples + + batch_input_ids = input_ids[:, None, :].expand(-1, grpo_size, -1) # [B, G, L] + batch_input_ids = batch_input_ids.reshape(batch_size * grpo_size, -1) + + # step 1: generate responses, and logits corresponding to the responses using the current policy + + with local_kv_cache( + model=self._model, + batch_size=batch_size * grpo_size, + device=self._device, + dtype=self._dtype, + decoder_max_seq_len=context_length + self._max_generated_tokens, + ): + query_responses, _ = generate( # [B x G, L], [B x G, L, V] + model=self._model, + prompt=batch_input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + stop_tokens=self._tokenizer.stop_tokens, + return_logits=False, + ) + + torch.distributed.barrier() + training._distributed.recursive_reshard(self._model) + torch.cuda.empty_cache() + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses != self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = generation.get_causal_mask_from_padding_mask( + query_response_padding_masks + ) + position_ids = generation.get_position_ids_from_padding_mask( + query_response_padding_masks + ) + + del query_response_padding_masks + + logits = self._model(query_responses, input_pos=position_ids, mask=masks) + + # step 2. estimate logprobs of the responses using the current policy + logits = logits[:, context_length - 1 :] + logprobs = rlhf.batched_logits_to_logprobs(logits, responses, self._temperature) + + del logits + torch.cuda.empty_cache() + + # step 2.1 estimate logprobs of the responses using the reference policy + ref_logits = self._ref_model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.batched_logits_to_logprobs( + ref_logits, responses, self._temperature + ) + + del ref_logits + torch.cuda.empty_cache() + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + ( + response_padding_masks, + responses, + ) = rlhf.truncate_sequence_at_first_stop_token( # [B x G, L] + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # responses :: [B x G, L] + responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] + + rewards, successes = batch_shaped_correctness_reward( + self._tokenizer, responses, answers + ) # [B, G] + rewards = rewards.to(self._device) + successes = successes.to(self._device) + + advantages = (rewards - rewards.mean(1, keepdim=True)) / ( + rewards.std(1, keepdim=True) + 1e-4 + ) + advantages = advantages.reshape(batch_size * grpo_size) # flatten + + del responses + torch.cuda.empty_cache() + + seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + return GRPOTrajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + rewards=rewards.reshape(batch_size * grpo_size), + successes=successes.reshape(batch_size * grpo_size), + advantages=advantages, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + seq_lens=seq_lens, + ) + + def generate_trajectory_batched( + self, input_ids: torch.Tensor, answers: List[str] + ) -> GRPOTrajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers: (List[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: List[GRPOTrajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + batch_answers = answers[ + batch_start : batch_start + self._forward_batch_size + ] + torch.cuda.empty_cache() + trajectories.append( + self.generate_trajectory(batch_input_ids, batch_answers) + ) + torch.cuda.empty_cache() + return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + def grpo_step( + self, + trajectory: GRPOTrajectory, + context_length: int, + ) -> GRPOStats: + """ + Perform a single GRPO optimization step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + context_length (int): input ids sequence length + + Returns: + GRPOStats: An instance of :class:`~torchtune.rlhf.PPOStats`, a NamedTuple containing: + - loss (torch.Tensor): The total PPO loss. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + # estimate logprobs from the policy at the current optimisation step + + torch.cuda.empty_cache() + + pi_logits = self._model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.batched_logits_to_logprobs( + pi_logits, + trajectory.query_responses[:, context_length:], + self._temperature, + chunk_size=1, + ) + + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + torch.cuda.empty_cache() + + # calculate grpo loss + loss, policy_loss, kl_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + trajectory.ref_logprobs, + trajectory.advantages, + padding_masks=~trajectory.response_padding_masks, + ) + + torch.cuda.empty_cache() + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return GRPOStats( + loss, + policy_loss, + kl_loss, + ratios, + clipfrac, + approx_policy_kls, + ) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + grad_norm = None + + training_completed = False + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self._epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) + for idx, batch in enumerate(self._dataloader): + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + tokens = batch["tokens"] # type: ignore + answers = batch["answers"] # type: ignore + tokens = tokens.to(self._device) # [B, P] + + _, context_length = tokens.shape + + trajectory = self.generate_trajectory_batched(tokens, answers) + torch.distributed.barrier() + + grpo_stats: list[GRPOStats] = [] + for _ in range(self._ppo_epochs): + + step_stats = self.grpo_step(trajectory, context_length) + + grpo_stats.append(step_stats) + + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + torch.distributed.barrier() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + torch.distributed.barrier() + + self.global_step += 1 + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + extra_metrics = {} + extra_metrics["lr"] = get_lr(self._optimizer) + if grad_norm is not None: + extra_metrics["grad_norm"] = grad_norm + + self.log_metrics( + trajectory, + GRPOStats(*map(torch.stack, zip(*grpo_stats))), + **extra_metrics, + ) + + self.cleanup_after_step(trajectory, grpo_stats) + pbar.update(1) + + if self._steps_run == self._total_steps: + training_completed = True + break + + self._epochs_run += 1 + if self._epochs_run % self._save_every_n_epochs == 0: + self.save_checkpoint(curr_epoch) + if training_completed: + return + + self._profiler.stop() + + def log_metrics( + self, trajectory: GRPOTrajectory, grpo_stats: GRPOStats, **extras + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + rewards = trajectory.rewards.mean() + torch.distributed.reduce(rewards, dst=0, op=torch.distributed.ReduceOp.AVG) + + successes = trajectory.successes.mean() + torch.distributed.reduce(successes, dst=0, op=torch.distributed.ReduceOp.AVG) + + log_dict = { + "rewards": rewards, + "successes": successes, + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "loss": grpo_stats.loss.mean(), + "policy_loss": grpo_stats.policy_loss.mean(), + "kl_loss": grpo_stats.kl_loss.mean(), + "clipfrac": grpo_stats.clipfrac.mean(), + "ratios": grpo_stats.ratios.mean(), + "approx_policy_kl": grpo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + **extras, + } + + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._is_rank_zero: + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + def cleanup_after_step( + self, + trajectory: GRPOTrajectory, + l_grpo_stats: list[GRPOStats], + ) -> None: + for v in trajectory: + del v + del trajectory + for g in l_grpo_stats: + for v in g: + del v + del g + del l_grpo_stats + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + + recipe = FullGRPOFinetuneRecipeDistributed(cfg=cfg) + config.log_config(recipe_name="FullGRPOFinetuneRecipeDistributed", cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/recipes/dev/gsm8k_sft.sbatch b/recipes/dev/gsm8k_sft.sbatch new file mode 100644 index 0000000000..9e1ca8e6c5 --- /dev/null +++ b/recipes/dev/gsm8k_sft.sbatch @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --time=01:00:00 +#SBATCH --constraint=volta32gb +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=8 +#SBATCH --no-requeue +#SBATCH --exclusive + +#SBATCH --job-name=torchtune +#SBATCH --output=slurm_logs/%j.out +#SBATCH --error=slurm_logs/%j.err + +# /\ Customize SBATCH directives to custommize your hardware + +# \/ Customize the virtual env/module load - this assumes a virtual env in root of torchtune +source ../../.venv/bin/activate + +srun tune run \ +--nnodes 1 \ +--nproc_per_node 8 \ +full_finetune_distributed --config dev/3B_sft_for_grpo "$@" diff --git a/recipes/dev/multinode_grpo.sbatch b/recipes/dev/multinode_grpo.sbatch new file mode 100644 index 0000000000..39f105d065 --- /dev/null +++ b/recipes/dev/multinode_grpo.sbatch @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# ---------- SBATCH commands ---------- # +#SBATCH --time=72:00:00 +#SBATCH --job-name=torchtune-multi-node +#SBATCH --constraint=volta32gb +#SBATCH --ntasks-per-node=1 +#SBATCH --nodes=2 +#SBATCH --exclusive +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=80 +#SBATCH --output=slurm_logs/%j/%N.out +#SBATCH --error=slurm_logs/%j/%N.err + + +# ---------- Set env variables ---------- # +# Grab the IP for head node: +# You may need to set this to the fully qualified domain name of your head node +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +echo Node IP: $head_node_ip + +# You might need to explicitly set the network interface for distributed backends: +# export NCCL_SOCKET_IFNAME=... +# export GLOO_SOCKET_IFNAME=... + +export TORCH_DIST_INIT_BARRIER=1 +export LOGLEVEL=INFO + +# ---------- Launch training ---------- # +# You probably want to load in a virtual env w/ conda... +# module load conda +# conda activate torchtune +# ...or venv +# source torchtune/bin/activate + +source ../../.venv/bin/activate + +# Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count +srun --export=ALL,OMP_NUM_THREADS=8 tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \ + dev/grpo_full_finetune_distributed --config dev/3B_full_grpo "$@" diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index a675b64dfa..3de8f1a7bc 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -23,6 +23,14 @@ class Recipe: _ALL_RECIPES = [ + Recipe( + name="dev/grpo_full_finetune_distributed", + file_path="dev/grpo_full_finetune_distributed.py", + configs=[ + Config(name="dev/3B_full_grpo", file_path="dev/3B_full_grpo.yaml"), + ], + supports_distributed=True, + ), Recipe( name="full_finetune_single_device", file_path="full_finetune_single_device.py", @@ -102,6 +110,7 @@ class Recipe: name="full_finetune_distributed", file_path="full_finetune_distributed.py", configs=[ + Config(name="dev/3B_grpo_sft", file_path="dev/3B_grpo_sft.yaml"), Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"), Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"), diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 0d1461dd0d..d252072c81 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -92,6 +92,7 @@ class SFTDataset(Dataset): filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more details. + filter_kwargs (Optional[Dict[str, Any]]): additional keyword arguments to pass to ``filter_fn``. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging Face's `API ref `_ for more details. @@ -104,6 +105,7 @@ def __init__( message_transform: Transform, model_transform: Transform, filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, **load_dataset_kwargs: Dict[str, Any], ) -> None: self._message_transform = message_transform @@ -111,7 +113,9 @@ def __init__( self._data = load_dataset(source, **load_dataset_kwargs) if filter_fn is not None: - self._data = self._data.filter(filter_fn) + if filter_kwargs is None: + filter_kwargs = {} + self._data = self._data.filter(filter_fn, **filter_kwargs) self._prepare_sample = SFTTransform( message_transform=self._message_transform, diff --git a/torchtune/dev/grpo/__init__.py b/torchtune/dev/grpo/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtune/dev/grpo/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/dev/grpo/data.py b/torchtune/dev/grpo/data.py new file mode 100644 index 0000000000..ed44b9b352 --- /dev/null +++ b/torchtune/dev/grpo/data.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Mapping, Optional, TypedDict, Union + +import torch +from datasets import load_dataset +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset + +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Transform + +BASE_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. " + "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " + "The reasoning process and answer are enclosed within and tags, respectively, " + "i.e., reasoning process here answer here. User: %s. Assistant: " +) + + +class ReasoningProblem(TypedDict): + question: str + cot: str + answer: str + + +class RLDataset(Dataset): + """ + Base class for datasets used in reinforcement learning, + which provide a reference answer that can be verified to compute rewards. + """ + + def __init__( + self, + *, + source: str, + problem_transform: Transform, + tokenizer: ModelTokenizer, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + self._problem_transform = problem_transform + self._tokenizer = tokenizer + + self._data = load_dataset(source, **load_dataset_kwargs) + if filter_fn is not None: + if filter_kwargs is None: + filter_kwargs = {} + self._data = self._data.filter(filter_fn, **filter_kwargs) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, Any]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + transformed_sample = self._problem_transform( + sample + ) # keys "question" and "answer" + + question = BASE_PROMPT % transformed_sample["question"] + + q_tokens = self._tokenizer.encode(question, add_eos=False) + mask = [1 for _ in q_tokens] + answer = transformed_sample["answer"] + + return {"tokens": q_tokens, "mask": mask, "answer": answer} + + +def padded_collate_rl( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, +) -> Dict[str, Union[torch.Tensor, List[str]]]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. Answers are simply concatenated into a list. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing tokens. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, Union[torch.Tensor, List[str]]]: Collated input tensors and string answers. + + Example: + >>> token_pairs = [ + >>> {"tokens": [1, 2, 3], "answer": "15"}, + >>> {"tokens": [7,], "answer": "bromance"}, + >>> ] + >>> collated = padded_collate_rl( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["tokens"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["answers"] + >>> ["15", "bromance"] + """ + input_ids = pad_sequence( + [torch.tensor(x["tokens"]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + + answers = [x["answer"] for x in batch] + + return {"tokens": input_ids.long(), "answers": answers} diff --git a/torchtune/dev/grpo/generation.py b/torchtune/dev/grpo/generation.py new file mode 100644 index 0000000000..98032b81d6 --- /dev/null +++ b/torchtune/dev/grpo/generation.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +import torch + +from torchtune import utils +from torchtune.generation import generate_next_token, get_causal_mask_from_padding_mask +from torchtune.generation._generation import ( + get_position_ids_from_padding_mask, + update_stop_tokens_tracker, +) +from torchtune.modules import TransformerDecoder + +from tqdm.auto import trange + + +# NOTE: This is almost the same as torchtune.generation.generate, with a few changes necessary for GRPO. +# Namely: +# 1. The `return_logits` argument - we can optionally omit keeping track of logits during generation, which +# drastically improves generation speed. +# 2. Stop token-based breaking now communicates across multiple devices in a distributed setting. +# TODO: Figure out the right abstractions to be used in the main repository, and remove this function. +@torch.no_grad() +def generate( + model: TransformerDecoder, + prompt: torch.Tensor, + *, + max_generated_tokens: int, + pad_id: int = 0, + temperature: float = 1.0, + top_k: Optional[int] = None, + stop_tokens: Optional[List[int]] = None, + rng: Optional[torch.Generator] = None, + custom_generate_next_token: Optional[Callable] = None, + return_logits: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Generates tokens from a model conditioned on a prompt, and also returns logits for the generations. + + Args: + model (TransformerDecoder): model used for generation + prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape either [seq_length] or [bsz x seq_length]. + max_generated_tokens (int): number of tokens to be generated + pad_id (int): token ID to use for padding, default 0. + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, + default None. + stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated, + default None. + rng (Optional[torch.Generator]): random number generator, default None. + custom_generate_next_token (Optional[Callable]): If specified, we'll use the + ``custom_generate_next_token function``. This is generally only useful if + you want to specify a ``torch.compile`` version of the generate next token for + performance reasons. If None, we use the default :func:`generate_next_token`. + Default is None. + return_logits (bool): whether to return logits associated with the generated tokens, default True. + + Note: + This function has only been tested with decoder-only models. + + Examples: + >>> model = torchtune.models.llama3.llama3_8b() + >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() + >>> prompt = tokenizer.encode("Hi my name is") + >>> rng.manual_seed(42) + >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0) + >>> print(tokenizer.decode(output[0].tolist())) + Hi my name is Jeremy and I'm a friendly language model assistant! + + Returns: + Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: + - tokens (torch.Tensor): tensor with the generated tokens, + with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens`` + may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided. + - logits (torch.Tensor): tensor with the logits associated with the generated tokens, + with shape ``[bsz x num_generated_tokens x vocab_size]``. + """ + prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt + + if custom_generate_next_token is None: + custom_generate_next_token = generate_next_token + + bsz, prompt_length = prompt.size() + total_response_length = prompt_length + max_generated_tokens + + generated_tokens = prompt.clone() + incremental_decoding = model.caches_are_enabled() + + # grab the correct max_seq_len to generate full causal masks/position ids + # this is the model's max cache len if incremental decoding, or the sequence + # length otherwise + max_seq_len = ( + total_response_length + if not incremental_decoding + else model.decoder_max_cache_seq_len + ) + + padding_masks = generated_tokens != pad_id + + if not padding_masks.all(): + # we have padding in the prompt due to varying-length sequences in a batch + # extend padding masks out to the correct seq len + padding_masks = torch.nn.functional.pad( + padding_masks, (0, max_generated_tokens), value=True + ) + + # generate the full causal mask for the whole padding mask with padding ignored + masks = get_causal_mask_from_padding_mask( + padding_masks, target_seq_len=max_seq_len + ) + + # right-shift position IDs to account for padding + input_pos = get_position_ids_from_padding_mask(padding_masks) + else: + # just use a regular causal mask if there is no padding + masks = torch.tril( + torch.ones( + total_response_length, + max_seq_len, + dtype=torch.bool, + device=prompt.device, + ) + ).unsqueeze(0) + input_pos = torch.arange( + 0, total_response_length, device=generated_tokens.device + ).unsqueeze(0) + + if incremental_decoding: + # if KV-caches are enabled, we need a causal mask of shape [bsz, prompt_length, max_cache_len] + # to match the key/value cache tensor shapes + curr_masks = masks[:, :prompt_length] + else: + # otherwise the causal mask is shape [bsz, prompt_length, prompt_length] because key/value + # tensors are of identical shape to the prompt + curr_masks = masks[:, :prompt_length, :prompt_length] + + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) + tokens, generated_logits = generate_next_token( + model, + input_pos=input_pos[:, :prompt_length].squeeze(), + mask=curr_masks, + x=prompt, + temperature=temperature, + top_k=top_k, + q=q, + ) + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + + curr_pos = prompt_length + + # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop + stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) + stop_tokens = ( + torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) + if stop_tokens + else None + ) + + # everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences + # that already hit a stop token + stop_token_mask = torch.ones( + (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device + ) + + # stop early if we reach a stop token in every seq + if stop_tokens is not None: + stop_token_reached = update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + if stop_token_reached.all().item(): + return generated_tokens, generated_logits if return_logits else None + + world_size, rank = utils.get_world_size_and_rank() + for _ in (pbar := trange(max_generated_tokens - 1, leave=False, disable=rank > 0)): + # update stop_token_mask if we reached a stop token in a previous step + # by appending the logical not of stop_token_reached to the end of the mask + # reshaped to be bsz first + if stop_tokens is not None: + stop_token_mask = torch.cat( + [stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1 + ) + + # if incremental decoding is enabled, we can use the current position + # otherwise, we take the whole sequence up to the current position + if incremental_decoding: + curr_input_pos = input_pos[:, curr_pos].contiguous() + curr_masks = masks[:, curr_pos, None, :].contiguous() + else: + tokens = generated_tokens.clone() + curr_input_pos = input_pos[:, : curr_pos + 1] + curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] + + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) + tokens, logits = custom_generate_next_token( + model, + input_pos=curr_input_pos, + x=tokens.clone(), + mask=curr_masks, + temperature=temperature, + top_k=top_k, + q=q, + ) + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + if return_logits: + generated_logits = torch.cat([generated_logits, logits], dim=1) + curr_pos += 1 + + if stop_tokens is not None: + stop_token_reached = update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + if world_size == 1: + # Single device + if stop_token_reached.all(): + break + else: + all_done = stop_token_reached.all().int() + torch.distributed.all_reduce(all_done) + if all_done == world_size: + # Multiple devices + break + + # mask out generated tokens in seqs that already hit a stop token + if stop_tokens is not None: + generated_tokens *= stop_token_mask + if return_logits: + generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None] + + return generated_tokens, generated_logits diff --git a/torchtune/dev/grpo/gsm8k.py b/torchtune/dev/grpo/gsm8k.py new file mode 100644 index 0000000000..afdb6a23ec --- /dev/null +++ b/torchtune/dev/grpo/gsm8k.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import re +from typing import Any, Callable, Dict, Optional + +from torchtune.datasets import SFTDataset +from torchtune.modules.tokenizers import ModelTokenizer + +from .data import ReasoningProblem, RLDataset + +# TODO: dedup this between here and _rl +PREAMBLE_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. " + "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " + "The reasoning process and answer are enclosed within and tags, respectively, " + "i.e., reasoning process here answer here. User: {question} Assistant: " +) + +TRAINABLE_PROMPT = "{cot} {answer}" + + +def normalize_gsm(problem: dict[str, str]) -> ReasoningProblem: + """ + Parses an item from the GSM8K dataset into a ReasoningProblem by splitting it up into the question, cot, and answer. + """ + question = problem["question"] + solution = problem["answer"] + + cot, answer = solution.split("#### ") + + return {"question": question, "cot": cot, "answer": answer} + + +def sft_gsm_transform(problem: dict[str, str]) -> dict[str, str]: + """ + Prepares an item from the GSM8k into a format that can be used for SFT. + """ + question = problem["question"] + solution = problem["answer"] + + cot, answer = solution.split("#### ") + + preamble = PREAMBLE_PROMPT.format(question=question) + trainable = TRAINABLE_PROMPT.format(cot=cot, answer=answer) + + return {"preamble": preamble, "trainable": trainable} + + +def gsm8k_dataset( + tokenizer: ModelTokenizer, + *, + source: str = "openai/gsm8k", + filter_fn: Optional[Callable] = None, + split: str = "train", + name: str = "main", + partition: Optional[str] = None, + **load_dataset_kwargs: Dict[str, Any], +) -> RLDataset: + """ + GSM8k dataset from OpenAI, prepared for RL-based training with verifiable rewards. + """ + + def default_filter_fn(example: dict, idx: int): + if partition is None: + return True + + match = re.match(r"^(\d+)-(\d+)/(\d+)$", partition) + if not match: + raise ValueError( + f"Invalid partition format: {partition}. Expected format: start-end/total" + ) + + start, end, total = map(int, match.groups()) + + current = idx % total + return start <= current <= end + + filter_fn = filter_fn if filter_fn is not None else default_filter_fn + + ds = RLDataset( + source=source, + name=name, + tokenizer=tokenizer, + problem_transform=normalize_gsm, + filter_fn=filter_fn, + filter_kwargs=dict(with_indices=True), + split=split, + **load_dataset_kwargs, + ) + + return ds + + +def gsm8k_sft( + tokenizer: ModelTokenizer, + *, + source: str = "openai/gsm8k", + filter_fn: Optional[Callable] = None, + split: str = "train", + name: str = "main", + partition: Optional[str] = None, + **load_dataset_kwargs: Dict[str, Any], +) -> SFTDataset: + """ + GSM8k dataset from OpenAI, prepared for SFT-based training with CoT. + """ + + def model_transform(problem: dict[str, str]) -> dict[str, list[int]]: + pre_tokens = tokenizer.encode(problem["preamble"], add_eos=False) + trainable_tokens = tokenizer.encode(problem["trainable"], add_bos=False) + + # 1 == discard the token, 0 == include the token in training + mask = [1 for t in pre_tokens] + [0 for t in trainable_tokens] + + return {"tokens": pre_tokens + trainable_tokens, "mask": mask} + + def default_filter_fn(example: dict, idx: int): + if partition is None: + return True + + match = re.match(r"^(\d+)-(\d+)/(\d+)$", partition) + if not match: + raise ValueError( + f"Invalid partition format: {partition}. Expected format: start-end/total" + ) + + start, end, total = map(int, match.groups()) + + current = idx % total + return start <= current <= end + + filter_fn = filter_fn if filter_fn is not None else default_filter_fn + + ds = SFTDataset( + source=source, + message_transform=sft_gsm_transform, + model_transform=model_transform, + filter_fn=filter_fn, + filter_kwargs=dict(with_indices=True), + split=split, + name=name, + **load_dataset_kwargs, + ) + + return ds diff --git a/torchtune/dev/grpo/loss.py b/torchtune/dev/grpo/loss.py new file mode 100644 index 0000000000..3f2b2dd260 --- /dev/null +++ b/torchtune/dev/grpo/loss.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torchtune import rlhf +from torchtune.rlhf import masked_sum + + +class GRPOLoss(nn.Module): + """ + Group Relative Policy Optimization (GRPO) Loss module. + Introduced by https://arxiv.org/abs/2402.03300, popularized by https://arxiv.org/abs/2501.12948. + + This loss implementation follows the usual formulation of GRPO with clipped ratios of token-wise logprobs. + Currently not validated to perform well. + + Args: + epsilon (float): clipping range for GRPO update. + kl_coeff (float): KL divergence coefficient (also known as beta). + """ + + def __init__( + self, + epsilon: float = 0.1, + kl_coeff: float = 0.1, + ): + super().__init__() + self.epsilon = epsilon + self.kl_coeff = kl_coeff + + def forward( + self, + pi_old_logprobs: torch.Tensor, # [B x G, L] + pi_logprobs: torch.Tensor, # [B x G, L] + ref_logprobs: torch.Tensor, # [B x G, L] + advantages: torch.Tensor, # [B x G] + padding_masks: Optional[torch.Tensor] = None, # [B x G, L] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass of the GRPO loss module. + + Args: + pi_old_logprobs (torch.Tensor): Log probabilities of the old policy. Shape: [batch_size * num_groups, seq_len] + pi_logprobs (torch.Tensor): Log probabilities of the current policy. Shape: [batch_size * num_groups, seq_len] + ref_logprobs (torch.Tensor): Log probabilities of the reference model. Shape: [batch_size * num_groups, seq_len] + advantages (torch.Tensor): Advantage values. Shape: [batch_size * num_groups] + padding_masks (Optional[torch.Tensor]): Padding token masks where True indicates tokens to include in loss calculation. + Shape: [batch_size * num_groups, seq_len] + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - loss: Total GRPO loss (policy loss + KL penalty) + - policy_loss: Clipped policy loss + - kl_loss: KL divergence loss between policy and reference model + - ratios: Mean ratio between current and old policy probabilities + - clipfrac: Fraction of clipped policy ratios + """ + + ratios = torch.exp(pi_logprobs - pi_old_logprobs) # [B x G, L] + clipped_ratios = torch.clamp( + ratios, 1.0 - self.epsilon, 1.0 + self.epsilon + ) # [B x G, L] + + advantages = advantages[:, None] # [B x G, 1] + + policy_losses_clipped = advantages * clipped_ratios # [B x G, L] + policy_losses_unclipped = advantages * ratios # [B x G, L] + + clipfrac = ( + policy_losses_clipped < policy_losses_unclipped + ).float() # [B x G, L] + clipfrac = rlhf.masked_mean(clipfrac, padding_masks) # scalar + + policy_loss = torch.minimum( + policy_losses_clipped, policy_losses_unclipped + ) # [B x G, L] + policy_loss = rlhf.masked_mean(policy_loss, padding_masks) + + kl_loss = ( + torch.exp(ref_logprobs - pi_logprobs) - (ref_logprobs - pi_logprobs) - 1 + ) # [B x G] + kl_loss = rlhf.masked_mean(kl_loss, padding_masks) + + loss = -(policy_loss - self.kl_coeff * kl_loss) + + return ( + loss, + policy_loss.detach(), + kl_loss.detach(), + ratios.mean().detach(), + clipfrac.detach(), + ) + + +class GRPOCompletionLoss(nn.Module): + """ + Group Relative Policy Optimization (GRPO) Loss module. + Introduced by https://arxiv.org/abs/2402.03300, popularized by https://arxiv.org/abs/2501.12948. + + This loss implementation follows the usual formulation of GRPO with clipped ratios of full completion logprobs. + Currently not validated to perform well. + + Args: + epsilon (float): clipping range for GRPO update. + kl_coeff (float): KL divergence coefficient (also known as beta). + """ + + def __init__( + self, + epsilon: float = 0.1, + kl_coeff: float = 0.1, + ): + super().__init__() + self.epsilon = epsilon + self.kl_coeff = kl_coeff + + def forward( + self, + pi_old_logprobs: torch.Tensor, # [B x G, L] + pi_logprobs: torch.Tensor, # [B x G, L] + ref_logprobs: torch.Tensor, # [B x G, L] + advantages: torch.Tensor, # [B x G] + padding_masks: Optional[torch.Tensor] = None, # [B x G, L] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass of the GRPO loss module. + + Args: + pi_old_logprobs (torch.Tensor): Log probabilities of the old policy. Shape: [batch_size * num_groups, seq_len] + pi_logprobs (torch.Tensor): Log probabilities of the current policy. Shape: [batch_size * num_groups, seq_len] + ref_logprobs (torch.Tensor): Log probabilities of the reference model. Shape: [batch_size * num_groups, seq_len] + advantages (torch.Tensor): Advantage values. Shape: [batch_size * num_groups] + padding_masks (Optional[torch.Tensor]): Padding token masks where True indicates tokens to include in loss calculation. + Shape: [batch_size * num_groups, seq_len] + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - loss: Total GRPO loss (policy loss + KL penalty) + - policy_loss: Clipped policy loss + - kl_loss: KL divergence loss between policy and reference model + - ratios: Mean ratio between current and old policy probabilities + - clipfrac: Fraction of clipped policy ratios + """ + + pi_old_logprobs = masked_sum(pi_old_logprobs, padding_masks) # [B x G] + pi_logprobs = masked_sum(pi_logprobs, padding_masks) # [B x G] + ref_logprobs = masked_sum(ref_logprobs, padding_masks) # [B x G] + + ratios = torch.exp(pi_logprobs - pi_old_logprobs) # [B x G] + clipped_ratios = torch.clamp( + ratios, 1.0 - self.epsilon, 1.0 + self.epsilon + ) # [B x G] + + policy_losses_clipped = advantages * clipped_ratios # [B x G] + policy_losses_unclipped = advantages * ratios # [B x G] + + clipfrac = (policy_losses_clipped < policy_losses_unclipped).float() # [B x G] + clipfrac = clipfrac.mean() # scalar, only for logging + + policy_loss = torch.minimum( + policy_losses_clipped, policy_losses_unclipped + ) # [B x G] + policy_loss = policy_loss.mean() # scalar + + kl_loss = ( + torch.exp(ref_logprobs - pi_logprobs) - (ref_logprobs - pi_logprobs) - 1 + ) # [B x G] + kl_loss = rlhf.masked_mean(kl_loss, padding_masks) + + loss = -(policy_loss - self.kl_coeff * kl_loss) + + return ( + loss, + policy_loss.detach(), + kl_loss.detach(), + ratios.mean().detach(), + clipfrac.detach(), + ) + + +class GRPOSimpleLoss(nn.Module): + """ + Group Relative Policy Optimization (GRPO) Loss module. + Introduced by https://arxiv.org/abs/2402.03300, popularized by https://arxiv.org/abs/2501.12948. + + This loss implementation is based on TRL's implementation of GRPO, + which only takes a single gradient step per batch, trivializing some parts of the computation. + This empirically seems to perform well. + + Args: + epsilon (float): clipping range for GRPO update. + kl_coeff (float): KL divergence coefficient (also known as beta). + """ + + def __init__( + self, + epsilon: float = 0.1, + kl_coeff: float = 0.1, + ): + super().__init__() + self.epsilon = epsilon + self.kl_coeff = kl_coeff + + def forward( + self, + pi_old_logprobs: torch.Tensor, # [B x G, L] + pi_logprobs: torch.Tensor, # [B x G, L] + ref_logprobs: torch.Tensor, # [B x G, L] + advantages: torch.Tensor, # [B x G] + padding_masks: Optional[torch.Tensor] = None, # [B x G, L] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass of the GRPO loss module. + + Args: + pi_old_logprobs (torch.Tensor): *UNUSED* Log probabilities of the old policy. + Shape: [batch_size * num_groups, seq_len] + pi_logprobs (torch.Tensor): Log probabilities of the current policy. + Shape: [batch_size * num_groups, seq_len] + ref_logprobs (torch.Tensor): *UNUSED* Log probabilities of the reference model. + Shape: [batch_size * num_groups, seq_len] + advantages (torch.Tensor): Advantage values. + Shape: [batch_size * num_groups] + padding_masks (Optional[torch.Tensor]): Padding token masks where True indicates tokens to include in loss calculation. + Shape: [batch_size * num_groups, seq_len] + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - loss: Total GRPO loss (policy loss + KL penalty) + - policy_loss: Clipped policy loss + - kl_loss: KL divergence loss between policy and reference model + - ratios: Mean ratio between current and old policy probabilities + - clipfrac: Fraction of clipped policy ratios + """ + + # [B x G, L] + per_token_kl = ( + torch.exp(ref_logprobs.detach() - pi_logprobs) + - (ref_logprobs.detach() - pi_logprobs) + - 1 + ) + + advantages = advantages[:, None] # [B x G, 1] + + per_token_policy_loss = ( + torch.exp(pi_logprobs - pi_logprobs.detach()) * advantages + ) + + per_token_loss = -(per_token_policy_loss - self.kl_coeff * per_token_kl) + + loss = rlhf.masked_mean(per_token_loss, padding_masks, dim=1).mean() + + policy_loss = ( + rlhf.masked_mean(per_token_policy_loss, padding_masks, dim=1) + .mean() + .detach() + ) + kl_loss = rlhf.masked_mean(per_token_kl, padding_masks, dim=1).mean().detach() + + return ( # This loss doesn't track clipfrac and ratios + loss, + policy_loss, + kl_loss, + torch.tensor(1.0), + torch.tensor(0.0), + ) diff --git a/torchtune/dev/grpo/rewards.py b/torchtune/dev/grpo/rewards.py new file mode 100644 index 0000000000..2cba5ee4a4 --- /dev/null +++ b/torchtune/dev/grpo/rewards.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from xml.etree import ElementTree as ET + +import torch + +from torchtune.modules.transforms.tokenizers import ModelTokenizer + + +def extract_tags(text: str) -> dict[str, list[str]]: + """ + Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'. + The values are lists of strings, with each string being the content of a tag. + """ + xml_string = f"{text}" + root = ET.fromstring(xml_string) + + return { + "think": [ + elem.text if elem.text is not None else "" for elem in root.findall("think") + ], + "answer": [ + elem.text if elem.text is not None else "" + for elem in root.findall("answer") + ], + } + + +def shaped_correctness_reward(answer: str, completion: str) -> tuple[float, float]: + """ + Reward function for verifiable rewards with some mild shaping. + + Args: + answer (str): ground-truth answer to the current problem + completion (str): model's completion, starting immediately after "Assistant: " + Returns: + reward: (float) a shaped reward indicating the correct answer and the correct format + success: (float) a binary measure of success (1 if the answer is correct and correctly formatted, 0 otherwise) + """ + reward = 0.0 + success = 0.0 + + try: + tags = extract_tags("" + completion.replace("<<", "").replace(">>", "")) + except ET.ParseError: + tags = {"think": [], "answer": []} + + if len(tags["answer"]) == 1: + reward += 5.0 + + if len(tags["think"]) == 1: + reward += 5.0 + + if any(attempt == answer for attempt in tags["answer"]): + # One of the answer tags has the right answer + reward += 20.0 + + if any((answer in attempt) for attempt in tags["answer"]): + # One of the answer tags contains the right answer (might be e.g. $20 instead of 20) + reward += 10.0 + + if len(tags["answer"]) > 0 and tags["answer"][-1] == answer: + reward = 100.0 + success = 1 + + return reward, success + + +def batch_shaped_correctness_reward( + tokenizer: ModelTokenizer, completions: torch.Tensor, answers: list[str] +) -> [torch.Tensor, torch.Tensor]: + """Utility function to apply the shaped reward function to a GRPO-style batch of completions.""" + + batch_size, grpo_size, *_ = completions.shape + rewards = torch.zeros(batch_size, grpo_size, dtype=torch.float32) + successes = torch.zeros(batch_size, grpo_size, dtype=torch.float32) + # completions :: [B, G, L] + for b in range(batch_size): + for g in range(grpo_size): + text_completion = tokenizer.decode( + completions[b, g].tolist() + ) # skips special tokens, stops at eos + reward, success = shaped_correctness_reward( + answer=answers[b], completion=text_completion + ) + rewards[b, g] = reward + successes[b, g] = success + + return rewards, successes diff --git a/torchtune/dev/grpo/types.py b/torchtune/dev/grpo/types.py new file mode 100644 index 0000000000..6f510a386c --- /dev/null +++ b/torchtune/dev/grpo/types.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import NamedTuple + +import torch + + +class GRPOTrajectory(NamedTuple): + """ + Contains a collection of tensors describing a generated trajectory during GRPO training. + + Attributes: + query_responses (torch.Tensor): (query, response) pairs with shape [B x G, P+L]. + logprobs (torch.Tensor): Log probabilities of the generated responses with shape [B x G, L]. + ref_logprobs (torch.Tensor): Log probabilities of the generated responses using the reference policy with shape [B x G, L]. + rewards (torch.Tensor): Rewards obtained from the environment or reward model with shape [B x G]. + successes (torch.Tensor): Success indicators for each trajectory. + advantages (torch.Tensor): Advantage estimates for the generated responses with shape [B x G]. + masks (torch.Tensor): Attention masks for input ids-generated responses pairs with shape [B x G, P+L, P+L]. + position_ids (torch.Tensor): Position IDs for input ids-generated responses pairs with shape [B x G, P+L]. + response_padding_masks (torch.Tensor): Padding masks for the truncated and padded generated responses with shape [B x G, L]. + seq_lens (torch.Tensor): Sequence lengths of truncated generated responses. + """ + + query_responses: torch.Tensor # [B x G, P+L] + logprobs: torch.Tensor # [B x G, L] + ref_logprobs: torch.Tensor # [B x G, L] + rewards: torch.Tensor # [B x G] + successes: torch.Tensor + advantages: torch.Tensor # [B x G] + masks: torch.Tensor # [B x G, P+L, P+L] + position_ids: torch.Tensor # [B x G, P+L] + response_padding_masks: torch.Tensor # [B x G, L] + seq_lens: torch.Tensor + + +class GRPOStats(NamedTuple): + """ + Contains GRPO loss statistics (metrics). + + Attributes: + loss (torch.Tensor): The total GRPO loss. + policy_loss (torch.Tensor): The policy function loss. + kl_loss (torch.Tensor): The KL divergence loss. + ratios (torch.Tensor): The ratio between the current and old policy probabilities. + clipfrac (torch.Tensor): The fraction of ratios that were clipped. + approx_policy_kls (torch.Tensor): Average estimated KL divergence between the policy before and after the optimization step. + """ + + loss: torch.Tensor + policy_loss: torch.Tensor + kl_loss: torch.Tensor + ratios: torch.Tensor + clipfrac: torch.Tensor + approx_policy_kls: torch.Tensor diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index 76d4acb743..af0a3b2b84 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple import torch + from torchtune.modules.transformer import TransformerDecoder diff --git a/torchtune/rlhf/__init__.py b/torchtune/rlhf/__init__.py index 506ff2192b..3589d096f3 100644 --- a/torchtune/rlhf/__init__.py +++ b/torchtune/rlhf/__init__.py @@ -6,15 +6,18 @@ from ._types import PPOStats, Trajectory + from .rewards import ( estimate_advantages, get_reward_penalty_mask, get_rewards_ppo, masked_mean, + masked_sum, masked_var, whiten, ) from .sequence_processing import ( + batched_logits_to_logprobs, get_batch_log_probs, logits_to_logprobs, truncate_sequence_at_first_stop_token, @@ -24,12 +27,14 @@ __all__ = [ "truncate_sequence_at_first_stop_token", "logits_to_logprobs", + "batched_logits_to_logprobs", "truncate_sequence_for_logprobs", "get_reward_penalty_mask", "estimate_advantages", "get_rewards_ppo", "whiten", "masked_mean", + "masked_sum", "masked_var", "PPOStats", "get_batch_log_probs", diff --git a/torchtune/rlhf/loss/__init__.py b/torchtune/rlhf/loss/__init__.py index 4058979f4a..1ab98d34c9 100644 --- a/torchtune/rlhf/loss/__init__.py +++ b/torchtune/rlhf/loss/__init__.py @@ -8,4 +8,8 @@ from .dpo import DPOLoss, RSOLoss from .ppo import PPOLoss -__all__ = ["DPOLoss", "RSOLoss", "PPOLoss"] +__all__ = [ + "DPOLoss", + "RSOLoss", + "PPOLoss", +] diff --git a/torchtune/rlhf/rewards.py b/torchtune/rlhf/rewards.py index f5882908bc..83744bb5d7 100644 --- a/torchtune/rlhf/rewards.py +++ b/torchtune/rlhf/rewards.py @@ -112,6 +112,24 @@ def masked_mean( return (x * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8) +def masked_sum( + x: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + """ + Compute sum of tensor with masked values. + + Args: + x (torch.Tensor): The input tensor. + mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the sum calculation. + dim (Optional[int]): The axis to calculate the sum over. Default None. + + Returns: + torch.Tensor: The sum tensor. + """ + return (x * mask).sum(dim=dim) + + def masked_var( centered_values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True ) -> torch.Tensor: diff --git a/torchtune/rlhf/sequence_processing.py b/torchtune/rlhf/sequence_processing.py index 9844dd001c..272571cc46 100644 --- a/torchtune/rlhf/sequence_processing.py +++ b/torchtune/rlhf/sequence_processing.py @@ -101,6 +101,42 @@ def logits_to_logprobs( ).squeeze(-1) +def batched_logits_to_logprobs( + logits: torch.Tensor, + sequences: torch.Tensor, + temperature: float = 1.0, + chunk_size: int = 4, +): + """ + Converts to logits to logprobs in a batched manner, to minimize the memory spike. + + Args: + logits (torch.Tensor): The logits tensor of shape [b, response_length, vocab_size]. + sequences (torch.Tensor): The corresponding tokens of shape [b, response_length]. + temperature (float): The temperature to scale the logits. Default 1.0 + chunk_size (int): The size of the chunks to process at a time. Default 4. + Returns: + torch.Tensor: The log probabilities corresponding to each token in ``sequences``. Shape [b, response_length]. + """ + batch_size = logits.shape[0] + result = torch.empty_like(sequences, dtype=torch.float32, device=logits.device) + + for chunk_start in range(0, batch_size, chunk_size): + chunk_end = min(chunk_start + chunk_size, batch_size) + + # Process log_softmax for this batch chunk + chunk_log_probs = F.log_softmax( + logits[chunk_start:chunk_end] / temperature, dim=-1 + ) + + # Gather for this chunk + result[chunk_start:chunk_end] = torch.gather( + chunk_log_probs, 2, sequences[chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + + return result + + def get_batch_log_probs( logits: torch.FloatTensor, labels: torch.LongTensor, diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 3dcc77f859..287342fac4 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -136,5 +136,6 @@ "OffloadActivations", "FormattedCheckpointFiles", "scale_grads", + "get_distributed_backend", "disable_dropout", ] diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 75e7544c2d..04431d328d 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -568,6 +568,17 @@ def shard_model( fully_shard(model, **fsdp_kwargs) +def recursive_reshard(module: nn.Module): + """ + Manually reshard all modules in the model. + This might be useful for memory management when a model isn't automatically resharded after forward. + """ + for n, m in reversed(list(module.named_modules())): + module.reshard() + + module.reshard() + + def prepare_mha_for_tp( model: nn.Module, tp_mesh: DeviceMesh,