diff --git a/align_anything/configs/train/text_to_text/grpo_remote_rm.yaml b/align_anything/configs/train/text_to_text/grpo_remote_rm.yaml new file mode 100644 index 00000000..17f50208 --- /dev/null +++ b/align_anything/configs/train/text_to_text/grpo_remote_rm.yaml @@ -0,0 +1,189 @@ +# Copyright 2025 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# The training configurations +train_cfgs: + # Whether to save the model checkpoint + # if `False`, only save the 16-bits model + save_checkpoint: True + # Whether to load the model from checkpoint + load_checkpoint: False + # The deepspeed configuration + ds_cfgs: ds_z3_config.json + # Number of training epochs + epochs: 3 + # Seed for random number generator + seed: 42 + # Batch size per device for training + per_device_prompt_batch_size: 1 + # Batch size per device for training + per_device_train_batch_size: 1 + # Batch size per device for evauation + per_device_eval_batch_size: 1 + # The number of gradient accumulation steps + gradient_accumulation_steps: 1 + # Whether to use gradient checkpointing for the actor model + actor_gradient_checkpointing: True + # Initial learning rate for the actor model + actor_lr: 1.e-5 + # Type of learning rate scheduler for the actor model + actor_lr_scheduler_type: cosine + # Ratio of warmup steps for learning rate for the actor model + actor_lr_warmup_ratio: 0.03 + # Weight decay coefficient for the actor model + actor_weight_decay: 0.01 + # Initial learning rate for the critic model + critic_lr: 5.e-6 + # Type of learning rate scheduler for the critic model + critic_lr_scheduler_type: constant + # Ratio of warmup steps for learning rate for the critic model + critic_lr_warmup_ratio: 0.03 + # Weight decay coefficient for the critic model + critic_weight_decay: 0.0 + # Hyper-parameters for adam optimizer + adam_betas: [0.9, 0.95] + # Enable bfloat 16 precision + bf16: True + # Enable float 16 precision + fp16: False + # The strategy of evaluation, choosing form [epoch, steps] + eval_strategy: epoch + # The evaluation interval in step-wise evaluation case + eval_interval: 10 + # Whether to normalize the reward during RL training. + normalize_reward: False + # The number of repeated updates on a generated batch. + update_iters: 1 + # The number of groups for GRPO training + num_generations: 10 + # The hyper-parameter beta for GRPO training + beta: 0.01 +# The data configurations +data_cfgs: + # Datasets to use for training + train_datasets: null + # The format template for training + train_template: null + # The total number for training + train_size: null + # The split of train datasets + train_split: null + # The name of training datasets + train_name: null + # The training data files to be used + train_data_files: null + # The optional arguments for loading training datasets + train_optional_args: [] + # Datasets to use for evaluation + eval_datasets: null + # The format template for evaluation + eval_template: null + # The total number for evaluation + eval_size: null + # The split of evaluation datasets + eval_split: null + # The subset of evaluation datasets + eval_subset: null + # The evaluation data files to be used + eval_data_files: null + # The optional arguments for loading training evaluation datasets + eval_optional_args: [] + # Datasets to use for ptx loss + ptx_datasets: null + # The format template for ptx training + ptx_template: null + # The total number for ptx training + ptx_size: null + # The subset of datasets + ptx_subset: null + # The split of ptx datasets + ptx_split: null + # The ptx training data files to be used + ptx_data_files: null + # The optional arguments for loading ptx training datasets + ptx_optional_args: [] +# The logging configurations +logger_cfgs: + # Type of logging to use, choosing from [wandb, tensorboard] + log_type: wandb + # Project name for logging + log_project: align-anything + # Run name for logging + log_run_name: grpo + # Output directory name + output_dir: null + # The directory to cache the downloaded model + cache_dir: null + # The interval of saving models + save_total_limit: 1 +# The model configurations +model_cfgs: + # Pretrained model name or path for the actor model in RLHF + actor_model_name_or_path: null + # Pretrained model name or path for the reward model in RLHF + reward_model_name_or_path: null + # Pretrained model name or path for the critic model in RLHF + reward_critic_model_name_or_path: null + # The endpoint of the remote reward model + remote_rm_url: http://localhost:6000/get_reward + # The timeout for the remote reward model + remote_rm_timeout: 100 + # The retry times for the remote reward model + remote_rm_retry_times: 3 + # Whether to trust remote code + trust_remote_code: True + # The max token length + model_max_length: 18000 + # The value used to module the next token probabilities.' + temperature: 1.0 + # If set to float < 1, only the smallest set of most probable tokens with + # probabilities that add up to`top_p` or higher are kept for generation. + top_p: 1.0 + # The parameter for repetition penalty. 1.0 means no penalty. + repetition_penalty: 1.0 +# The LoRA configurations +lora_cfgs: + # Whether to use LoRA + use_lora: False + # Task type for LoRA configuration + task_type: TaskType.CAUSAL_LM + # Inference mode + inference_mode: False + # Rank of the low-rank adaptation matrices + r: 16 + # Alpha parameter for LoRA + lora_alpha: 16 + # Dropout rate for LoRA + lora_dropout: 0.1 + # Target modules for applying LoRA + target_modules: ["q_proj", "v_proj"] + # Whether to save the full model + save_full_model: True +# The QLoRA configurations +bnb_cfgs: + # Whether to use BNB(For QLoRA) + use_bnb: False + # Whether to use 4-bit quantization + load_in_4bit: True + # Whether to use 8-bit quantization + load_in_8bit: False + # The quantization type for 4-bit quantization + bnb_4bit_quant_type: nf4 + # Whether to use double quantization + bnb_4bit_use_double_quant: True + # The compute dtype for 4-bit quantization + bnb_4bit_compute_dtype: float16 +# Customized special tokens +special_tokens: null diff --git a/align_anything/trainers/text_to_text/grpo_remote_rm.py b/align_anything/trainers/text_to_text/grpo_remote_rm.py new file mode 100644 index 00000000..1a8ff216 --- /dev/null +++ b/align_anything/trainers/text_to_text/grpo_remote_rm.py @@ -0,0 +1,477 @@ +# Copyright 2025 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Trainer for GRPO Training""" + + +import argparse +import copy +import os +import sys + +import deepspeed +import torch +import torch.distributed as dist +from tqdm import tqdm +from transformers import GenerationConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +from align_anything.datasets.text_to_text import ( + PromptOnlyBatch, + PromptOnlyDataset, + SupervisedDataset, +) +from align_anything.models.pretrained_model import load_pretrained_models +from align_anything.models.remote_rm.remote_rm_client import RemoteRewardModel +from align_anything.trainers.text_to_text.grpo import GRPOTrainer +from align_anything.utils.device_utils import torch_set_device +from align_anything.utils.multi_process import ( + get_all_reduce_mean, + get_current_device, + is_main_process, +) +from align_anything.utils.tools import ( + custom_cfgs_to_dict, + dict_to_namedtuple, + prepare_ds_eval_cfgs, + prepare_ds_train_cfgs, + read_cfgs, + seed_everything, + update_dict, +) + + +class GRPOTrainerRemoteRM(GRPOTrainer): + + def __init__(self, cfgs, ds_cfgs) -> None: + self.cfgs = cfgs + self.ds_train_cfgs = prepare_ds_train_cfgs(custom_cfgs=cfgs.train_cfgs, raw_ds_cfgs=ds_cfgs) + self.ds_eval_cfgs = prepare_ds_eval_cfgs(custom_cfgs=cfgs.train_cfgs, raw_ds_cfgs=ds_cfgs) + self.global_step = 0 + + self.init_check() + dist.barrier() + self.infer_batch = lambda batch: {k: v for k, v in batch.items() if k != 'meta_info'} + dist.barrier() + self.init_models() + dist.barrier() + self.init_datasets() + dist.barrier() + self.init_engines() + dist.barrier() + self.init_logger() + + self.beta = self.cfgs.train_cfgs.beta # KL regularization coefficient + self.num_generations = ( + self.cfgs.train_cfgs.num_generations + ) # number of sequences generated for each prompt + + def init_check(self) -> None: + super().init_check() + if ( + self.cfgs.train_cfgs.per_device_prompt_batch_size + % self.cfgs.train_cfgs.per_device_train_batch_size + != 0 + ): + raise ValueError('Every prompt batch size must be divisible by the micro-batch size.') + + def init_models(self) -> None: + # DeepSpeed configuration, different from that in RLTrainerBase, we don't need critic model in GRPO + if self.ds_train_cfgs['zero_optimization']['stage'] == 3: + self.dstchf_train = HfDeepSpeedConfig(self.ds_train_cfgs) + if self.ds_eval_cfgs['zero_optimization']['stage'] == 3: + self.dsechf_eval = HfDeepSpeedConfig(self.ds_eval_cfgs) + + self.bnb_cfgs = self.cfgs.bnb_cfgs + self.lora_cfgs = self.cfgs.lora_cfgs + + self.actor_model, self.tokenizer, self.processor = load_pretrained_models( + self.cfgs.model_cfgs.actor_model_name_or_path, + model_max_length=self.cfgs.model_cfgs.model_max_length, + padding_side='left', + trust_remote_code=self.cfgs.model_cfgs.trust_remote_code, + bnb_cfgs=self.bnb_cfgs, + lora_cfgs=self.lora_cfgs, + processor_kwargs=self.cfgs.train_cfgs.processor_kwargs, + ) + + self.actor_reference_model, _, _ = load_pretrained_models( + self.cfgs.model_cfgs.actor_model_name_or_path, + model_max_length=self.cfgs.model_cfgs.model_max_length, + padding_side='left', + trust_remote_code=self.cfgs.model_cfgs.trust_remote_code, + bnb_cfgs=self.bnb_cfgs, + lora_cfgs=self.lora_cfgs, + processor_kwargs=self.cfgs.train_cfgs.processor_kwargs, + ) + + self.generation_config = GenerationConfig( + max_length=self.cfgs.model_cfgs.model_max_length, + temperature=self.cfgs.model_cfgs.temperature, + top_p=self.cfgs.model_cfgs.top_p, + repetition_penalty=self.cfgs.model_cfgs.repetition_penalty, + do_sample=True, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + # loading remote reward models # NOTE + + self.remote_rm_url = self.cfgs.model_cfgs.remote_rm_url + + if self.remote_rm_url is not None: + self.reward_model = None # NOTE for debug deepseed engine init + self.reward_tokenizer = ( + self.tokenizer + ) # NOTE the reward tokenizer can be set to be the same as the actor tokenizer + # NOTE using is_reward_model=True to load the reward critic model, and init score head for reward critic model + + self.remote_rm_client = RemoteRewardModel( + self.remote_rm_url, + timeout=self.cfgs.model_cfgs.remote_rm_timeout, + retry_times=self.cfgs.model_cfgs.remote_rm_retry_times, + ) + + else: + raise ValueError( + 'You are using remote reward model for training. But remote reward model endpoint is not provided.' + ) + + def init_datasets(self) -> None: + """Initialize training and evaluation datasets""" + self.prompt_only_dataloader, self.eval_dataloader, _ = self.get_dataloaders( + PromptOnlyDataset, PromptOnlyDataset, SupervisedDataset + ) + + def set_train(self, mode: bool = True) -> None: + """Set training mode for all models.""" + if mode: + self.actor_model.train() + if self.cfgs.train_cfgs.actor_gradient_checkpointing and not self.lora_enabled: + self.actor_model.gradient_checkpointing_enable() + else: + self.actor_model.eval() + if self.cfgs.train_cfgs.actor_gradient_checkpointing and not self.lora_enabled: + self.actor_model.gradient_checkpointing_disable() + return + + def init_engines(self) -> None: + """Initialize DeepSpeed engines.""" + # different from that in RLTrainerBase, we don't need critic model in GRPO + + self.total_training_steps: int = ( + len(self.prompt_only_dataloader) + * self.cfgs.train_cfgs.epochs + * self.cfgs.train_cfgs.update_iters + * self.cfgs.train_cfgs.per_device_prompt_batch_size + // self.cfgs.train_cfgs.per_device_train_batch_size + ) + # initialize the actor model engines + actor_ds_cfgs = copy.deepcopy(self.ds_train_cfgs) + actor_total_training_steps = self.total_training_steps + if self.use_ptx: + actor_ds_cfgs['train_batch_size'] *= 2 + actor_ds_cfgs['gradient_accumulation_steps'] *= 2 + actor_total_training_steps *= 2 + self.actor_model = self._init_train_deepspeed_engine( + model=self.actor_model, + weight_decay=self.cfgs.train_cfgs.actor_weight_decay, + lr=self.cfgs.train_cfgs.actor_lr, + lr_scheduler_type=self.cfgs.train_cfgs.actor_lr_scheduler_type, + lr_warmup_ratio=self.cfgs.train_cfgs.actor_lr_warmup_ratio, + total_training_steps=actor_total_training_steps, + ds_cfgs=actor_ds_cfgs, + ) + # initialize the actor reference model engines + self.actor_reference_model = self._init_eval_deepspeed_engine( + model=self.actor_reference_model, + ds_cfgs=self.ds_eval_cfgs, + ) + self.actor_reference_model.eval() + + # load the checkpoint if specified + if self.cfgs.train_cfgs.load_checkpoint: + self.actor_model.load_checkpoint(load_dir=self.cfgs.model_cfgs.actor_model_name_or_path) + # setup the gradient checkpointing + if self.cfgs.train_cfgs.actor_gradient_checkpointing and not self.lora_enabled: + self.actor_model.gradient_checkpointing_enable() + + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + """ + Compute the log-probabilities of the model on the specified tokens. + """ + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + logits = outputs.logits # shape: (B, L, V) + logits = logits[:, :-1, :] # (B, L-1, V) + logits = logits[:, -logits_to_keep:, :] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + target_ids = input_ids[:, -logits_to_keep:] + per_token_logps = log_probs.gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1) + return per_token_logps # shape: (B, logits_to_keep) + + def generate_completions(self, prompt_batch: dict) -> torch.Tensor: + """ + Generate multiple completions based on the given prompt. + Here, we set num_return_sequences = self.num_generations, which requires that each sample in the dataloader contains only one prompt, + and then generate multiple completions for comparison within the group. + """ + self.actor_model.eval() + with torch.no_grad(): + sequences = self.actor_model.module.generate( + **prompt_batch, + generation_config=self.generation_config, + num_return_sequences=self.cfgs.train_cfgs.num_generations, + synced_gpus=True, + do_sample=True, + ) + return sequences # shape: (B * num_generations, L_total) + + def decode_prompt_responses(self, actor_batch: PromptOnlyBatch) -> tuple[list[str], list[str]]: + """ + Decode the input_ids to get prompts and responses + + Args: + actor_batch: Batch containing input_ids + + Returns: + tuple[list[str], list[str]]: (prompts, responses) + """ + decoded = self.tokenizer.batch_decode(actor_batch, skip_special_tokens=True) + prompts = [] + responses = [] + + for text in decoded: + try: + if 'user\n\n' in text: + parts = text.split('user\n\n', 1) + elif 'user\n' in text: + parts = text.split('user\n', 1) + elif 'user' in text: + parts = text.split('user', 1) + else: + raise ValueError('No user marker found') + + user_part = parts[1] + if 'assistant\n\n' in user_part: + user_assistant = user_part.split('assistant\n\n', 1) + elif 'assistant\n' in user_part: + user_assistant = user_part.split('assistant\n', 1) + elif 'assistant' in user_part: + user_assistant = user_part.split('assistant', 1) + else: + raise ValueError('No assistant marker found') + + try: + user_text = eval(user_assistant[0].strip())[0]['text'] + except: + user_text = user_assistant[0].strip() + + assistant_response = user_assistant[1].strip() + + prompts.append(user_text) + responses.append(assistant_response) + + except Exception as e: + print(f'Error parsing text: {str(e)}') + print(f'Problematic text: {text}') + raise ValueError('Error parsing text') + + assert len(prompts) == len(responses), 'Prompts and responses must have same length' + + return prompts, responses + + def compute_rewards(self, sequences: torch.Tensor, prompt_length: int) -> torch.Tensor: + """ + Compute the rewards for the generated completions. + """ + + prompts, responses = self.decode_prompt_responses(sequences) + if prompts is None: + raise ValueError('prompt is not found in the actor_batch') + reward_tensor = self.remote_rm_client.score(prompts, responses) + return reward_tensor + + def train_step(self, prompt_batch: dict) -> dict[str, float]: + """Single training step""" + device = self.actor_model.module.parameters().__next__().device + prompt_batch = {k: v.to(device) for k, v in prompt_batch.items()} + + # record the original prompt length + prompt_length = prompt_batch['input_ids'].size(1) + + # generate multiple completions (each prompt generates num_generations sequences) + sequences = self.generate_completions(prompt_batch) # shape: (B * num_generations, L_total) + # restore train mode + self.actor_model.train() + + # compute rewards + rewards = self.compute_rewards(sequences, prompt_length) # shape: (B * num_generations,) + if isinstance(rewards, torch.Tensor): + rewards = rewards.clone().detach().to(prompt_batch['input_ids'].device) + else: + rewards = torch.tensor(rewards, device=prompt_batch['input_ids'].device) + + + + B = prompt_batch['input_ids'].size(0) + G = self.num_generations + rewards = rewards.view(B, G) + group_mean = rewards.mean(dim=1, keepdim=True) + group_std = rewards.std(dim=1, keepdim=True) + 1e-4 + advantages = (rewards - group_mean) / group_std # shape: (B, G) + advantages = advantages.view(-1, 1) + + # compute the attention mask of the generated sequences + attention_mask = (sequences != self.tokenizer.pad_token_id).long() + logits_to_keep = sequences.size(1) - prompt_length + + # compute the per-token log-probabilities of the actor_model on the generated sequences + per_token_logps = self._get_per_token_logps( + self.actor_model, sequences, attention_mask, logits_to_keep + ) + # the log-probabilities of the reference model (no gradient) + with torch.no_grad(): + ref_per_token_logps = self._get_per_token_logps( + self.actor_reference_model, sequences, attention_mask, logits_to_keep + ) + + # compute the per-token KL divergence: KL = exp(ref - logp) - (ref - logp) - 1 + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) + - (ref_per_token_logps - per_token_logps) + - 1 + ) + + # expand the advantages to each token (assume the same advantages for all completion tokens) + advantages_expanded = advantages.expand(-1, logits_to_keep) + + # formula: loss = - ( exp(logp - detach(logp)) * advantage - beta * KL ) + per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages_expanded + per_token_loss = -(per_token_loss - self.beta * per_token_kl) + + # construct the completion mask: only count the loss for the valid tokens (not truncated by eos) + completion_tokens = sequences[:, prompt_length:] + eos_token_id = self.tokenizer.eos_token_id + completion_mask = torch.ones_like(completion_tokens) + for i in range(completion_tokens.size(0)): + eos_positions = (completion_tokens[i] == eos_token_id).nonzero(as_tuple=False) + if eos_positions.numel() > 0: + first_eos = eos_positions[0].item() + completion_mask[i, first_eos + 1 :] = 0 + completion_mask = completion_mask.to(per_token_loss.dtype) + + # compute the total loss + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() + + self.actor_model.zero_grad() + self.actor_model.backward(loss) + self.actor_model.step() + + loss_val = get_all_reduce_mean(loss).item() + avg_reward = get_all_reduce_mean(rewards.mean()).item() + + return {'train/loss': loss_val, 'train/reward': avg_reward} + + def train(self) -> None: + """Training main loop""" + self.logger.print('***** Running GRPO training *****') + + total_training_steps = self.total_training_steps + progress_bar = tqdm( + total=total_training_steps, + desc=f'Training 1/{self.cfgs.train_cfgs.epochs} epoch', + position=0, + leave=True, + disable=not is_main_process(), + ) + progress_bar.update(self.global_step) + + if self.cfgs.data_cfgs.eval_datasets: + self.eval() + + remain_epoch = self.cfgs.train_cfgs.epochs - ( + self.global_step // len(self.prompt_only_dataloader) + ) + + start_batch_idx = self.global_step % len(self.prompt_only_dataloader) + + for epoch in range(int(remain_epoch)): + for batch_idx, prompt_batch in enumerate(self.prompt_only_dataloader): + if epoch == 0 and batch_idx < start_batch_idx: + continue + + train_info = self.train_step(prompt_batch) + self.global_step += 1 + + self.logger.log(train_info, step=self.global_step) + progress_bar.set_description( + f"Epoch {epoch + 1}/{self.cfgs.train_cfgs.epochs} (reward {train_info['train/reward']:.4f})" + ) + progress_bar.update(1) + + save_interval = ( + self.cfgs.train_cfgs.epochs + * len(self.prompt_only_dataloader) + // self.cfgs.logger_cfgs.save_total_limit + ) + if self.global_step % save_interval == 0: + self.logger.print(f'Saving checkpoint at step {self.global_step} ...') + self.save(tag=self.global_step) + self.logger.print('Checkpoint saved.') + + if ( + self.cfgs.data_cfgs.eval_datasets + and self.cfgs.train_cfgs.eval_strategy == 'steps' + and self.global_step % self.cfgs.train_cfgs.eval_interval == 0 + ): + self.logger.print(f'\n***** Evaluating at step {self.global_step} *****') + self.eval() + + self.save() + + def save(self, model: deepspeed.DeepSpeedEngine | None = None, tag: int | None = None) -> None: + self.save_transformers(model=model, tag=tag) + + +def main(): + # initialize distributed training + deepspeed.init_distributed() + current_device = get_current_device() + torch_set_device(current_device) + + # read the default configuration from the yaml file + task = os.path.join('text_to_text', 'grpo_remote_rm') + dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task) + + # read the custom configuration from the command line + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + _, unparsed_args = parser.parse_known_args() + keys = [k[2:] for k in unparsed_args[1::2]] + values = list(unparsed_args[2::2]) + unparsed_args = dict(zip(keys, values)) + for k, v in unparsed_args.items(): + dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v)) + + cfgs = dict_to_namedtuple(dict_cfgs) + seed_everything(cfgs.train_cfgs.seed) + + # initialize and start training the GRPO model + trainer = GRPOTrainerRemoteRM(cfgs=cfgs, ds_cfgs=ds_cfgs) + trainer.train() + trainer.save() + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/cookbooks/en/text_to_text_grpo.ipynb b/cookbooks/en/text_to_text_grpo.ipynb new file mode 100644 index 00000000..9e489c23 --- /dev/null +++ b/cookbooks/en/text_to_text_grpo.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tuning Models with GRPO and RLVR Algorithms\n", + "\n", + "This tutorial demonstrates how to fine-tune large language models (using **Llama-3.1-8B-Instruct** as an example) using the **Group Relative Policy Optimization (GRPO)** algorithm. Through this tutorial, you will learn how to customize reward functions for your tasks under the **Align Anything** framework, and combine them with **Reinforcement Learning with Verifiable Rewards (RLVR)** to further improve model performance on specific tasks.\n", + "\n", + "## 1.1 What is GRPO?\n", + "\n", + "**Group Relative Policy Optimization (GRPO)** is a reinforcement learning algorithm designed to enhance model reasoning capabilities through grouping and relative reward mechanisms. GRPO was first introduced in the paper *DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models* and was successfully applied in the post-training phase of DeepSeek-R1.\n", + "\n", + "GRPO aims to optimize model behavior through relative comparison policies rather than absolute rewards. Specifically, GRPO groups multiple model outputs and calculates reward values based on their relative performance. This approach helps mitigate issues in traditional reinforcement learning where absolute rewards are difficult to define or lack precision, making it particularly suitable for complex reasoning tasks.\n", + "\n", + "## 1.2 What is RLVR?\n", + "\n", + "**Reinforcement Learning with Verifiable Rewards (RLVR)** is a novel language model training method designed for tasks with verifiable outcomes (such as mathematical problem-solving and instruction following). RLVR uses existing reinforcement learning reward mechanisms (like RLHF) but replaces traditional reward models with a verification function.\n", + "\n", + "Unlike traditional methods, RLVR trains models using binary signals through answer matching or constraint verification (e.g., whether an answer is correct). When applied to mathematical domains or other verifiable tasks, RLVR not only improves performance on specific benchmarks (like GSM8K) but also maintains stable performance across other tasks.\n", + "\n", + "RLVR can be viewed as a simplified version of existing methods, such as RL with execution feedback or bootstrapping methods for language model reasoning. Its core idea is to use verifiable signals as direct rewards, avoiding the complex process of building sophisticated reward models.\n", + "\n", + "## 2. Environment Setup\n", + "\n", + "Before starting, please make sure you have installed the ``align-anything`` package.\n", + "\n", + "```bash\n", + "# Clone the repository\n", + "git clone git@github.com:PKU-Alignment/align-anything.git\n", + "cd align-anything\n", + "\n", + "# Create a virtual environment using conda\n", + "conda create -n align-anything python==3.11\n", + "conda activate align-anything\n", + "```\n", + "\n", + "- **`[Optional]`** We recommend installing [CUDA](https://anaconda.org/nvidia/cuda) in the conda environment and set the environment variable.\n", + "\n", + "```bash\n", + "# We have tested this version of CUDA on the H800 computing cluster and it worked well.\n", + "# You can adjust this version according to your actual computing cluster.\n", + "\n", + "conda install nvidia/label/cuda-12.2.0::cuda\n", + "export CUDA_HOME=$CONDA_PREFIX\n", + "```\n", + "\n", + "> If your CUDA is installed in a different location, such as `/usr/local/cuda/bin/nvcc`, you can set the environment variable as follows:\n", + "\n", + "```bash\n", + "export CUDA_HOME=\"/usr/local/cuda\"\n", + "```\n", + "\n", + "Finally, install `align-anything` using the following command:\n", + "\n", + "```bash\n", + "# We have prepared a quick installation for training and evaluation.\n", + "# If you only need to use the training or evaluation module,\n", + "# you can install the corresponding dependencies.\n", + "pip install -e .[train] # Install training dependencies\n", + "pip install -e .[evaluate] # Install evaluation dependencies\n", + "\n", + "# If you need to install all dependencies, you can use the following command:\n", + "pip install -e .[all]\n", + "```\n", + "\n", + "At last, according to https://github.com/PKU-Alignment/align-anything/tree/main/align_anything/models/remote_rm\n", + "\n", + "You should \n", + "```bash\n", + "pip install Levenshtein flask latex2sympy2_extended math_verify\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Llama-3.1-8B-Instruct Model Output Example\n", + "Next, let's first test the zero-shot capability of the Llama-3.1-8B-Instruct model.\n", + "\n", + "### 3.1 Import Required Libraries\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1742778596.498488] [dsw-519274-66f65ff576-678dh:4051137:f] vfs_fuse.c:281 UCX ERROR inotify_add_watch(/tmp) failed: No space left on device\n" + ] + } + ], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "import os\n", + "import torch\n", + "\n", + "os.environ[\"TRANSFORMERS_OFFLINE\"] = \"1\"\n", + "os.environ[\"HF_DATASETS_OFFLINE\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Load the Original Llama Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00, 2.29it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128256, 4096)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\" # Set device to \"cuda\" to use GPU\n", + "model_path = (\n", + " \"/PATH/TO/YOUR/Meta-Llama-3.1-8B-Instruct\" # Please replace with your actual model path\n", + ")\n", + "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n", + "\n", + "# Set the model to evaluation mode\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 Test the Performance of the Original Model\n", + "\n", + "Let's test the Llama-3.1-8B-Instruct model with a sample question." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Generated Text: The sequence of square roots of the positive integers is increasing. The largest term of the sequence that is less than or equal to 20 is $\\sqrt{19}$, the square root of 16. Therefore, 16 terms of the sequence are less than or equal to 20. The sequence of 16 terms is\n", + "\n", + "$\\sqrt{1},\\sqrt{2},\\sqrt{3},\\sqrt{4},\\sqrt{5},\\sqrt{6},\\sqrt{7},\\sqrt{8},\\sqrt{9},\\sqrt{10},\\sqrt{11},\\sqrt{12},\\sqrt{13},\\sqrt{14},\\sqrt{15},\\sqrt{16}$\n" + ] + } + ], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have?\",\n", + " },\n", + "]\n", + "\n", + "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n", + "\n", + "# the model generate new tokens\n", + "with torch.no_grad():\n", + " output = model.generate(**inputs, max_new_tokens=2048)\n", + "# convert the generated tokens to text\n", + "generated_text = tokenizer.decode(\n", + " output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n", + ")\n", + "print(\"\\nGenerated Text:\", generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As the correct answer is 400, this demonstrates that there is still room for improvement in Llama 3.1's mathematical capabilities." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Training the Model Using the GRPO Algorithm\n", + "\n", + "**Note**: If you cannot access huggingface.co, set the Hugging Face endpoint to hf-mirror.com. You can do this with the following command:\n", + "\n", + "`export HF_ENDPOINT=\"https://hf-mirror.com\"`\n", + "\n", + "Here, we take the PKU-SafeRLHF series dataset as an example. The PKU-SafeRLHF dataset is a preference dataset focused on safety alignment. Each data entry in this dataset contains two responses to the same question, along with their corresponding safety meta-tags and preference annotations.\n", + "\n", + "You can refer to the training script below:\n", + "\n", + "```bash\n", + "# NOTE need to start the remote rm server first\n", + "bash start_remote_rm.sh\n", + "\n", + "# NOTE need to change the model path\n", + "ACTOR_MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # actor model path\n", + "\n", + "TRAIN_DATASETS=\"../align_anything/models/remote_rm/math_verify_dataset/mathvl_345_example.json\" # dataset path\n", + "TRAIN_TEMPLATE=\"Math-Zero-RL\" # math zero rlhf dataset template, note that for math zero rl, you are recommended to expand token length to longer length such as 18000\n", + "TRAIN_SPLIT=\"train\" # split the input dataset\n", + "\n", + "OUTPUT_DIR=\"../output/llama_grpo_remote_rm\" # output dir\n", + "# For wandb online logging\n", + "export WANDB_API_KEY=\"\"\n", + "\n", + "export REMOTE_RM_URL=\"http://127.0.0.1:6000/get_reward\"\n", + "# Source the setup script\n", + "source ./setup.sh\n", + "\n", + "# Execute deepspeed command\n", + "deepspeed \\\n", + " --master_port ${MASTER_PORT} \\\n", + " --module align_anything.trainers.text_to_text.grpo_remote_rm \\\n", + " --actor_model_name_or_path ${ACTOR_MODEL_NAME_OR_PATH} \\\n", + " --remote_rm_url ${REMOTE_RM_URL} \\\n", + " --train_datasets ${TRAIN_DATASETS} \\\n", + " --train_split ${TRAIN_SPLIT} \\\n", + " --train_template ${TRAIN_TEMPLATE} \\\n", + " --output_dir ${OUTPUT_DIR}\n", + "```\n", + "\n", + "After training is completed, you can find the trained model weights under the `OUTPUT_DIR`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Test the Performance of the Model After GRPO Training\n", + "\n", + "After the training is complete, we try to test whether the math of the trained model has improved.\n", + "\n", + "### 5.1 Load the New Model Weights\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128257, 4096, padding_idx=128256)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128257, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_path = \"/PATH/TO/YOUR/TRAINED_MODEL\" # Please replace with your actual model path\n", + "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n", + "\n", + "# Set the model to evaluation mode\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.2 测试新模型的性能" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Generated Text: To find out how many terms are less than or equal to $20$, we can find out which term is greater than $20$, and then subtract $1$ to find the answer.\n", + "\n", + "Recognize that $\\sqrt{400} = 20$.\n", + "\n", + "The sequence goes by consecutive integers (1, 2, 3, 4, ect), so $\\sqrt{400}$ will be the 400th term.\n", + "\n", + "Thus, we can say every term up to the 400th term is less than or equal to $20$, except $\\sqrt{400}$.\n" + ] + } + ], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have?\",\n", + " },\n", + "]\n", + "\n", + "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n", + "\n", + "# the model generate new tokens\n", + "with torch.no_grad():\n", + " output = model.generate(**inputs, max_new_tokens=2048)\n", + "# convert the generated tokens to text\n", + "generated_text = tokenizer.decode(\n", + " output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n", + ")\n", + "print(\"\\nGenerated Text:\", generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This shows that the fine-tuned model did indeed solve the problem correctly.\n", + "\n", + "(Strictly speaking, the test question was from the training dataset, so this is an in-distribution test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Customizing Reward Functions\n", + "In this section, we will learn how to customize reward functions, allowing you to design specific scoring mechanisms based on your task requirements." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.1 Creating Reward Function Files\n", + "First, create a new reward function file in the reward_functions directory of the project:\n", + "```bash\n", + "cd align-anything/align_anything/models/remote_rm/reward_functions/\n", + "touch my_verifier.py\n", + "```\n", + "We can refer to the examples in examples.py to implement our own reward function. In this example, we'll implement a simple format verification reward function that focuses on whether the answer format is correct, without considering the accuracy of the answer.\n", + "Here's the specific implementation code\n", + "```python\n", + "# align_anything/models/remote_rm/reward_functions/my_verifier.py\n", + "import random\n", + "import re\n", + "from typing import List, Optional\n", + "\n", + "from flask import jsonify\n", + "\n", + "format_pattern = r'^(?:(?!).)*(?:(?!).)*\\Z'\n", + "\n", + "\n", + "def verify_format(content):\n", + " \"\"\"\n", + " Verify if the string meets the format requirements:\n", + " - Must start with and end with \n", + " - Must contain exactly one pair of ... and ... tags\n", + " - No extra characters allowed between and tags\n", + " \"\"\"\n", + " think_count = content.count('')\n", + " answer_count = content.count('')\n", + " return (\n", + " bool(re.match(format_pattern, content, re.DOTALL))\n", + " and think_count == 1\n", + " and answer_count == 1\n", + " )\n", + "\n", + "def my_verifier_reward_function(\n", + " prompts: List[str], responses: List[str], golden_responses: Optional[List[str]] = None\n", + ") -> List[float]:\n", + " \"\"\"\n", + " Math verifier reward function, evaluate the accuracy of the answer\n", + "\n", + " Args:\n", + " prompts: List of math problems\n", + " responses: List of model answers\n", + " golden_responses: Optional list of golden responses\n", + " Returns:\n", + " List of reward scores for each (prompt, response) pair\n", + " \"\"\"\n", + " rewards = []\n", + " format_rewards = []\n", + " for prompt, response, golden_response in zip(prompts, responses, golden_responses):\n", + " if prompt is None:\n", + " return jsonify({'error': f'problem not found from {prompt}'}), 400\n", + " if golden_response is None:\n", + " return jsonify({'error': f'golden response not found from {prompt}'}), 400\n", + " # TODO: processing the error code 400\n", + "\n", + " format_reward = float(verify_format(response))\n", + " rewards.append(format_reward)\n", + " format_rewards.append(format_reward)\n", + "\n", + " do_print = random.randint(1, 10) == 1\n", + " if do_print:\n", + " info = f'Query: {prompt}\\n\\nAnswer: {golden_response}\\n\\nResponse: {response}\\n\\nFormat Reward: {format_reward}\\n\\n'\n", + " info = re.sub(r'<\\|.*?\\|>', '', info)\n", + " print(info)\n", + " return rewards\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.2 Registering Custom Reward Functions\n", + "\n", + "After implementing the reward function, you need to register it in the framework:\n", + "\n", + "1. Add the following to `align_anything/models/remote_rm/reward_functions/__init__.py`:\n", + "```python\n", + "from .my_verifier import *\n", + "```\n", + "\n", + "2. Register the function in `align_anything/models/remote_rm/run_reward_server.py`:\n", + "```python\n", + "reward_functions = {\n", + " 'example_math': example_math_reward_function,\n", + " 'example_coding': example_coding_reward_function,\n", + " 'example_safety': example_safety_reward_function,\n", + " 'math_verifier': math_verifier_reward_function,\n", + " 'my_verifier': my_verifier_reward_function,\n", + "}\n", + "```\n", + "\n", + "3. Modify the configuration in `scripts/start_remote_rm.sh`:\n", + "```bash\n", + "export REWARD_TYPE=\"my_verifier\"\n", + "```\n", + "\n", + "With this, the custom reward function configuration is complete." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.3 Training with Custom Reward Functions\n", + "We use the same training command, but now our custom reward function is calculating the rewards behind the scenes\n", + "\n", + "```bash\n", + "# NOTE need to start the remote rm server first\n", + "bash start_remote_rm.sh\n", + "\n", + "# NOTE need to change the model path\n", + "ACTOR_MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # actor model path\n", + "\n", + "TRAIN_DATASETS=\"../align_anything/models/remote_rm/math_verify_dataset/mathvl_345_example.json\" # dataset path\n", + "TRAIN_TEMPLATE=\"Math-Zero-RL\" # math zero rlhf dataset template, note that for math zero rl, you are recommended to expand token length to longer length such as 18000\n", + "TRAIN_SPLIT=\"train\" # split the input dataset\n", + "\n", + "OUTPUT_DIR=\"../output/llama_grpo_remote_rm\" # output dir\n", + "# For wandb online logging\n", + "export WANDB_API_KEY=\"\"\n", + "\n", + "export REMOTE_RM_URL=\"http://127.0.0.1:6000/get_reward\"\n", + "# Source the setup script\n", + "source ./setup.sh\n", + "\n", + "# Execute deepspeed command\n", + "deepspeed \\\n", + " --master_port ${MASTER_PORT} \\\n", + " --module align_anything.trainers.text_to_text.grpo_remote_rm \\\n", + " --actor_model_name_or_path ${ACTOR_MODEL_NAME_OR_PATH} \\\n", + " --remote_rm_url ${REMOTE_RM_URL} \\\n", + " --train_datasets ${TRAIN_DATASETS} \\\n", + " --train_split ${TRAIN_SPLIT} \\\n", + " --train_template ${TRAIN_TEMPLATE} \\\n", + " --output_dir ${OUTPUT_DIR}\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.4 Checking Reward Outputs\n", + "\n", + "To prevent reward hacking (where the model exploits loopholes in the reward function), we need to verify if the model's behavior meets expectations:\n", + "\n", + "1. Check the reward server logs:\n", + "```bash\n", + "tail -f align-anything/debug_logs/reward_server.log\n", + "```\n", + "\n", + "If any anomalies are detected, adjust the reward function's evaluation logic promptly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Acknowledgements\n", + "\n", + "- [Hugging Face Transformers 文档](https://huggingface.co/docs/transformers/index)\n", + "- [GRPO Paper](https://arxiv.org/pdf/2402.03300)\n", + "- [DeepSeek-R1 Paper](https://arxiv.org/abs/2501.12948)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jy-align", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/cookbooks/zh/text_to_text_grpo.ipynb b/cookbooks/zh/text_to_text_grpo.ipynb new file mode 100644 index 00000000..c230f95d --- /dev/null +++ b/cookbooks/zh/text_to_text_grpo.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 使用 GRPO、RLVR 算法微调模型\n", + "\n", + "本教程将演示如何使用 **Group Relative Policy Optimization (GRPO)** 算法微调大型语言模型(以 **Llama-3.1-8B-Instruct** 模型为例)。通过本教程,你将学习如何在 **Align Anything** 框架下,为你的任务自定义奖励函数(reward function),并结合 **Reinforcement Learning with Verifiable Rewards (RLVR)** 方法,进一步提升模型在特定任务上的性能。\n", + "\n", + "\n", + "\n", + "## 1.1 什么是 GRPO 算法?\n", + "\n", + "**Group Relative Policy Optimization (GRPO)** 是一种强化学习算法,旨在通过分组和相对奖励机制提升模型的推理能力。GRPO 最早在论文 *DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models* 中提出,并在 DeepSeek-R1 的后训练阶段中被成功应用。\n", + "\n", + "GRPO 的目标是通过相对比较策略(而非绝对奖励)来优化模型的行为。具体而言,GRPO 会将多个模型输出分组,并根据它们的相对表现计算奖励值。这种方法能够缓解传统强化学习中绝对奖励难以定义或不够精确的问题,尤其适用于复杂推理任务。\n", + "\n", + "\n", + "## 1.2 什么是 RLVR?\n", + "\n", + "**Reinforcement Learning with Verifiable Rewards (RLVR)** 是一种新颖的语言模型训练方法,专为具有可验证结果的任务(如数学问题求解和指令跟随)设计。RLVR 使用现有的强化学习奖励机制(如 RLHF),但用一种验证函数替代传统的奖励模型。\n", + "\n", + "与传统方法不同,RLVR 通过答案匹配或约束验证(例如答案是否正确)作为二元信号训练模型。当应用于数学领域或其他可验证任务时,RLVR 不仅能够提升特定基准(如 GSM8K)的性能,还能在其他任务中保持稳定表现。\n", + "\n", + "可以将 RLVR 看作是现有方法的简化版本,比如基于执行反馈的强化学习(RL with execution feedback)或语言模型推理的自举方法。它的核心思想是利用可验证信号作为直接奖励,避免了构建复杂奖励模型的繁琐过程。\n", + "\n", + "\n", + "## 2. 环境配置\n", + "\n", + "在开始之前,请确保您已安装 ``align-anything`` 包。\n", + "\n", + "```bash\n", + "# 克隆仓库\n", + "git clone git@github.com:PKU-Alignment/align-anything.git\n", + "cd align-anything\n", + "\n", + "# 使用conda创建虚拟环境\n", + "conda create -n align-anything python==3.11\n", + "conda activate align-anything\n", + "```\n", + "\n", + "- **`[Optional]`** We recommend installing [CUDA](https://anaconda.org/nvidia/cuda) in the conda environment and set the environment variable.\n", + "\n", + "```bash\n", + "# 我们在 H800 计算集群上测试过,这个版本的 CUDA 效果很好。\n", + "# 您可以根据计算集群的实际情况调整此版本。\n", + "\n", + "conda install nvidia/label/cuda-12.2.0::cuda\n", + "export CUDA_HOME=$CONDA_PREFIX\n", + "```\n", + "\n", + "> 如果您的 CUDA 安装在不同的位置,例如 `/usr/local/cuda/bin/nvcc`,您可以按如下方式设置环境变量:\n", + "\n", + "```bash\n", + "export CUDA_HOME=\"/usr/local/cuda\"\n", + "```\n", + "\n", + "接着通过以下命令安装 `align-anything`:\n", + "\n", + "```bash\n", + "# 我们为训练和评估准备了快速安装。\n", + "# 如果您只需要使用训练或评估模块,\n", + "# 您可以安装相应的依赖项。\n", + "pip install -e .[train] # 安装训练依赖项\n", + "pip install -e .[evaluate] # 安装评估依赖项\n", + "\n", + "# 如果您需要安装所有依赖项,可以使用以下命令:\n", + "pip install -e .[all]\n", + "```\n", + "\n", + "最后, 参照 https://github.com/PKU-Alignment/align-anything/tree/main/align_anything/models/remote_rm\n", + "\n", + "您还需要:\n", + "```bash\n", + "pip install Levenshtein flask latex2sympy2_extended math_verify\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Llama-3.1-8B-Instruct模型输出示例\n", + "下面,让我们首先测试Llama-3.1-8B-Instruct模型的zero-shot能力。\n", + "### 3.1 导入所需的库" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1742778596.498488] [dsw-519274-66f65ff576-678dh:4051137:f] vfs_fuse.c:281 UCX ERROR inotify_add_watch(/tmp) failed: No space left on device\n" + ] + } + ], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "import os\n", + "import torch\n", + "\n", + "os.environ[\"TRANSFORMERS_OFFLINE\"] = \"1\"\n", + "os.environ[\"HF_DATASETS_OFFLINE\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 加载原始的Llama 模型" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00, 2.29it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128256, 4096)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\" # 将device设置为\"cuda\"以使用GPU\n", + "model_path = \"/PATH/TO/YOUR/Llama-3.1-8B-Instruct\" # 请更换为实际的模型路径\n", + "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n", + "\n", + "# 将模型设置为eval模式\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 测试原始模型的性能\n", + "\n", + "让我们用一个示例问题测试 Llama-3.1-8B-Instruct 模型。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Generated Text: The sequence of square roots of the positive integers is increasing. The largest term of the sequence that is less than or equal to 20 is $\\sqrt{19}$, the square root of 16. Therefore, 16 terms of the sequence are less than or equal to 20. The sequence of 16 terms is\n", + "\n", + "$\\sqrt{1},\\sqrt{2},\\sqrt{3},\\sqrt{4},\\sqrt{5},\\sqrt{6},\\sqrt{7},\\sqrt{8},\\sqrt{9},\\sqrt{10},\\sqrt{11},\\sqrt{12},\\sqrt{13},\\sqrt{14},\\sqrt{15},\\sqrt{16}$\n" + ] + } + ], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have?\",\n", + " },\n", + "]\n", + "\n", + "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n", + "\n", + "# the model generate new tokens\n", + "with torch.no_grad():\n", + " output = model.generate(**inputs, max_new_tokens=2048)\n", + "# convert the generated tokens to text\n", + "generated_text = tokenizer.decode(\n", + " output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n", + ")\n", + "print(\"\\nGenerated Text:\", generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "而正确答案是400, 由此可见,llama 3.1的数学能力仍有提升的空间" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. 使用GRPO算法对齐模型\n", + "\n", + "**注意**:如果您无法访问huggingface.co,请将huggingface的endpoint设置为hf-mirror.com。您可以进行以下操作:\n", + "\n", + "`export HF_ENDPOINT=\"https://hf-mirror.com\"`\n", + "\n", + "在这里,我们以 **Align Anything** 框架自带的示例数据集 mathvl_345_example.json 为例。mathvl_345_example 是一个简单数学数据集, 包含了10个数学问题和答案\n", + "\n", + "可以参考如下的训练脚本:\n", + "\n", + "```bash\n", + "# NOTE need to start the remote rm server first\n", + "bash start_remote_rm.sh\n", + "\n", + "# NOTE need to change the model path\n", + "ACTOR_MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # actor model path\n", + "\n", + "TRAIN_DATASETS=\"../align_anything/models/remote_rm/math_verify_dataset/mathvl_345_example.json\" # dataset path\n", + "TRAIN_TEMPLATE=\"Math-Zero-RL\" # math zero rlhf dataset template, note that for math zero rl, you are recommended to expand token length to longer length such as 18000\n", + "TRAIN_SPLIT=\"train\" # split the input dataset\n", + "\n", + "OUTPUT_DIR=\"../output/llama_grpo_remote_rm\" # output dir\n", + "# For wandb online logging\n", + "export WANDB_API_KEY=\"\"\n", + "\n", + "export REMOTE_RM_URL=\"http://127.0.0.1:6000/get_reward\"\n", + "# Source the setup script\n", + "source ./setup.sh\n", + "\n", + "# Execute deepspeed command\n", + "deepspeed \\\n", + " --master_port ${MASTER_PORT} \\\n", + " --module align_anything.trainers.text_to_text.grpo_remote_rm \\\n", + " --actor_model_name_or_path ${ACTOR_MODEL_NAME_OR_PATH} \\\n", + " --remote_rm_url ${REMOTE_RM_URL} \\\n", + " --train_datasets ${TRAIN_DATASETS} \\\n", + " --train_split ${TRAIN_SPLIT} \\\n", + " --train_template ${TRAIN_TEMPLATE} \\\n", + " --output_dir ${OUTPUT_DIR}\n", + "```\n", + "\n", + "训练完成后,您可以在`OUTPUT_DIR`下找到训练的模型权重。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. 测试GRPO训练后的模型性能\n", + "\n", + "在训练结束后,我们试图测试训练后的模型数学能力是否有所改观。\n", + "\n", + "### 5.1 加载新的模型权重\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128257, 4096, padding_idx=128256)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128257, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_path = \"/PATH/TO/YOUR/TRAINED_MODEL\" # 请更换为实际的模型路径\n", + "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n", + "\n", + "# 将模型设置为eval模式\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.2 测试新模型的性能" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Generated Text: To find out how many terms are less than or equal to $20$, we can find out which term is greater than $20$, and then subtract $1$ to find the answer.\n", + "\n", + "Recognize that $\\sqrt{400} = 20$.\n", + "\n", + "The sequence goes by consecutive integers (1, 2, 3, 4, ect), so $\\sqrt{400}$ will be the 400th term.\n", + "\n", + "Thus, we can say every term up to the 400th term is less than or equal to $20$, except $\\sqrt{400}$.\n" + ] + } + ], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have?\",\n", + " },\n", + "]\n", + "\n", + "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n", + "\n", + "# the model generate new tokens\n", + "with torch.no_grad():\n", + " output = model.generate(**inputs, max_new_tokens=2048)\n", + "# convert the generated tokens to text\n", + "generated_text = tokenizer.decode(\n", + " output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n", + ")\n", + "print(\"\\nGenerated Text:\", generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "由此可见,训练后的模型确实答对了问题. \n", + "\n", + "(当然严格来说, 测试的题目是在训练数据集里的, 这是 in distribution test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. 自定义奖励函数\n", + "在本节中,我们将学习如何自定义奖励函数,以便您可以根据具体任务需求设计专属的评分机制。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.1 创建奖励函数文件\n", + "首先需要在项目的reward_functions目录下创建新的奖励函数文件:\n", + "```bash\n", + "cd align-anything/align_anything/models/remote_rm/reward_functions/\n", + "touch my_verifier.py\n", + "```\n", + "我们可以参考examples.py中的示例来实现自己的奖励函数。本示例中,我们实现了一个简单的格式验证奖励函数,主要关注回答的格式是否正确,暂不考虑答案的准确性。\n", + "下面是具体实现代码\n", + "```python\n", + "# align_anything/models/remote_rm/reward_functions/my_verifier.py\n", + "import random\n", + "import re\n", + "from typing import List, Optional\n", + "\n", + "from flask import jsonify\n", + "\n", + "format_pattern = r'^(?:(?!).)*(?:(?!).)*\\Z'\n", + "\n", + "\n", + "def verify_format(content):\n", + " \"\"\"\n", + " Verify if the string meets the format requirements:\n", + " - Must start with and end with \n", + " - Must contain exactly one pair of ... and ... tags\n", + " - No extra characters allowed between and tags\n", + " \"\"\"\n", + " think_count = content.count('')\n", + " answer_count = content.count('')\n", + " return (\n", + " bool(re.match(format_pattern, content, re.DOTALL))\n", + " and think_count == 1\n", + " and answer_count == 1\n", + " )\n", + "\n", + "def my_verifier_reward_function(\n", + " prompts: List[str], responses: List[str], golden_responses: Optional[List[str]] = None\n", + ") -> List[float]:\n", + " \"\"\"\n", + " Math verifier reward function, evaluate the accuracy of the answer\n", + "\n", + " Args:\n", + " prompts: List of math problems\n", + " responses: List of model answers\n", + " golden_responses: Optional list of golden responses\n", + " Returns:\n", + " List of reward scores for each (prompt, response) pair\n", + " \"\"\"\n", + " rewards = []\n", + " format_rewards = []\n", + " for prompt, response, golden_response in zip(prompts, responses, golden_responses):\n", + " if prompt is None:\n", + " return jsonify({'error': f'problem not found from {prompt}'}), 400\n", + " if golden_response is None:\n", + " return jsonify({'error': f'golden response not found from {prompt}'}), 400\n", + " # TODO: processing the error code 400\n", + "\n", + " format_reward = float(verify_format(response))\n", + " rewards.append(format_reward)\n", + " format_rewards.append(format_reward)\n", + "\n", + " do_print = random.randint(1, 10) == 1\n", + " if do_print:\n", + " info = f'Query: {prompt}\\n\\nAnswer: {golden_response}\\n\\nResponse: {response}\\n\\nFormat Reward: {format_reward}\\n\\n'\n", + " info = re.sub(r'<\\|.*?\\|>', '', info)\n", + " print(info)\n", + " return rewards\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.2 注册自定义奖励函数\n", + "\n", + "完成奖励函数实现后,需要将其注册到框架中:\n", + "\n", + "1. 在`align_anything/models/remote_rm/reward_functions/__init__.py`添加:\n", + "```python\n", + "from .my_verifier import *\n", + "```\n", + "\n", + "2. 在`align_anything/models/remote_rm/run_reward_server.py`中注册函数:\n", + "```python\n", + "reward_functions = {\n", + " 'example_math': example_math_reward_function,\n", + " 'example_coding': example_coding_reward_function,\n", + " 'example_safety': example_safety_reward_function,\n", + " 'math_verifier': math_verifier_reward_function,\n", + " 'my_verifier': my_verifier_reward_function,\n", + "}\n", + "```\n", + "\n", + "3. 修改`scripts/start_remote_rm.sh`中的配置:\n", + "```bash\n", + "export REWARD_TYPE=\"my_verifier\"\n", + "```\n", + "\n", + "至此,自定义奖励函数配置完成。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.3 在自定义奖励函数后训练\n", + "我们通过同样的命令进行训练, 但背后是我们的自定义奖励函数在计算reward\n", + "\n", + "```bash\n", + "# NOTE need to start the remote rm server first\n", + "bash start_remote_rm.sh\n", + "\n", + "# NOTE need to change the model path\n", + "ACTOR_MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # actor model path\n", + "\n", + "TRAIN_DATASETS=\"../align_anything/models/remote_rm/math_verify_dataset/mathvl_345_example.json\" # dataset path\n", + "TRAIN_TEMPLATE=\"Math-Zero-RL\" # math zero rlhf dataset template, note that for math zero rl, you are recommended to expand token length to longer length such as 18000\n", + "TRAIN_SPLIT=\"train\" # split the input dataset\n", + "\n", + "OUTPUT_DIR=\"../output/llama_grpo_remote_rm\" # output dir\n", + "# For wandb online logging\n", + "export WANDB_API_KEY=\"\"\n", + "\n", + "export REMOTE_RM_URL=\"http://127.0.0.1:6000/get_reward\"\n", + "# Source the setup script\n", + "source ./setup.sh\n", + "\n", + "# Execute deepspeed command\n", + "deepspeed \\\n", + " --master_port ${MASTER_PORT} \\\n", + " --module align_anything.trainers.text_to_text.grpo_remote_rm \\\n", + " --actor_model_name_or_path ${ACTOR_MODEL_NAME_OR_PATH} \\\n", + " --remote_rm_url ${REMOTE_RM_URL} \\\n", + " --train_datasets ${TRAIN_DATASETS} \\\n", + " --train_split ${TRAIN_SPLIT} \\\n", + " --train_template ${TRAIN_TEMPLATE} \\\n", + " --output_dir ${OUTPUT_DIR}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.4 检查奖励输出\n", + "\n", + "为防止reward hacking(奖励函数被模型钻空子),需要检查模型行为是否符合预期:\n", + "\n", + "1. 查看reward server日志:\n", + "```bash\n", + "tail -f align-anything/debug_logs/reward_server.log\n", + "```\n", + "\n", + "如发现异常,及时调整奖励函数的判定逻辑。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. 致谢\n", + "\n", + "- [Hugging Face Transformers 文档](https://huggingface.co/docs/transformers/index)\n", + "- [GRPO 论文](https://arxiv.org/pdf/2402.03300)\n", + "- [DeepSeek-R1 论文](https://arxiv.org/abs/2501.12948)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jy-align", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/scripts/llama/llama_grpo_remote_rm.sh b/scripts/llama/llama_grpo_remote_rm.sh new file mode 100755 index 00000000..96d9612e --- /dev/null +++ b/scripts/llama/llama_grpo_remote_rm.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# +# Copyright 2025 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# NOTE need to start the remote rm server first +bash start_remote_rm.sh + +# NOTE need to change the model path +ACTOR_MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B-Instruct" # actor model path + +TRAIN_DATASETS="../align_anything/models/remote_rm/math_verify_dataset/mathvl_345_example.json" # dataset path +TRAIN_TEMPLATE="Math-Zero-RL" # math zero rlhf dataset template, note that for math zero rl, you are recommended to expand token length to longer length such as 18000 +TRAIN_SPLIT="train" # split the input dataset + +OUTPUT_DIR="../output/llama_grpo_remote_rm" # output dir +# For wandb online logging +export WANDB_API_KEY="" + +export REMOTE_RM_URL="http://127.0.0.1:6000/get_reward" +# Source the setup script +source ./setup.sh + +# Execute deepspeed command +deepspeed \ + --master_port ${MASTER_PORT} \ + --module align_anything.trainers.text_to_text.grpo_remote_rm \ + --actor_model_name_or_path ${ACTOR_MODEL_NAME_OR_PATH} \ + --remote_rm_url ${REMOTE_RM_URL} \ + --train_datasets ${TRAIN_DATASETS} \ + --train_split ${TRAIN_SPLIT} \ + --train_template ${TRAIN_TEMPLATE} \ + --output_dir ${OUTPUT_DIR}