diff --git a/src/fairseq2/metrics/recorders/_wandb.py b/src/fairseq2/metrics/recorders/_wandb.py index 60a1c574e..e3c56d897 100644 --- a/src/fairseq2/metrics/recorders/_wandb.py +++ b/src/fairseq2/metrics/recorders/_wandb.py @@ -45,6 +45,7 @@ def __init__( name: str, output_dir: Path, metric_descriptors: Provider[MetricDescriptor], + id: str | None = None, ) -> None: """ :param project: The W&B project name. @@ -60,7 +61,7 @@ def __init__( self._run = None else: self._run = wandb.init( - project=project, name=name, dir=output_dir.parent, resume="allow" + project=project, name=name, id=id, dir=output_dir.parent, resume="allow" ) self._metric_descriptors = metric_descriptors @@ -77,6 +78,12 @@ def record_metrics( if self._run is None: return + # try: + # self._run.log({"_step": step_nr}) # Log to the specific step + # self._run.step = step_nr # Directly update the internal step counter + # except: + # ... + for name, value in values.items(): try: descriptor = self._metric_descriptors.get(name) @@ -88,6 +95,8 @@ def record_metrics( else: display_name = descriptor.display_name + display_name = run + "/" + display_name + try: self._run.log({display_name: value}, step=step_nr) except RuntimeError as ex: @@ -112,6 +121,8 @@ class WandbRecorderConfig: run: str | None = None + id: str | None = None + def validate(self) -> None: result = ValidationResult() @@ -151,7 +162,11 @@ def create(self, output_dir: Path, config: object) -> MetricRecorder: wandb_dir = output_dir.joinpath("wandb") return WandbRecorder( - config.project, config.run, wandb_dir, self._metric_descriptors + config.project, + config.run, + wandb_dir, + self._metric_descriptors, + id=config.id, ) @property diff --git a/src/fairseq2/recipes/common/_trainer.py b/src/fairseq2/recipes/common/_trainer.py index 5611cde32..8754c1020 100644 --- a/src/fairseq2/recipes/common/_trainer.py +++ b/src/fairseq2/recipes/common/_trainer.py @@ -96,6 +96,7 @@ def create_trainer( valid_data_readers=valid_data_readers, validate_after_n_steps=regime_section.validate_after_n_steps, validate_every_n_steps=regime_section.validate_every_n_steps, + validate_step_0=regime_section.validate_step_0, validate_after_n_data_epochs=regime_section.validate_after_n_data_epochs, validate_every_n_data_epochs=regime_section.validate_every_n_data_epochs, checkpoint_manager=checkpoint_manager, diff --git a/src/fairseq2/recipes/config.py b/src/fairseq2/recipes/config.py index dfb4d66ff..e35d35b9c 100644 --- a/src/fairseq2/recipes/config.py +++ b/src/fairseq2/recipes/config.py @@ -219,6 +219,9 @@ class RegimeSection: validate_every_n_steps: int | None = None """The step interval at which to validate the model.""" + validate_step_0: bool = False + """Validate before training""" + validate_after_n_data_epochs: int = 0 validate_every_n_data_epochs: int | None = None diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index fcb9d24d5..0cd185804 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -161,6 +161,14 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( SkyworkVerifierHandler as SkyworkVerifierHandler, ) + +from fairseq2.recipes.lm._online_finetune._rewards import ( + AtheneVerifier as AtheneVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + AtheneVerifierHandler as AtheneVerifierHandler, +) + from fairseq2.recipes.lm._online_finetune._rewards import ( NuminaMathVerifier as NuminaMathVerifier, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py b/src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py new file mode 100644 index 000000000..6b8700a6d --- /dev/null +++ b/src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py @@ -0,0 +1,100 @@ +import string as string_lib +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +import gzip +import torch + + +def get_compression_ratio(strings): + + flattened_generation = " ".join(strings) + original_byte_size = len(bytes(flattened_generation, "UTF-8")) + compressed_bytes_size = len(gzip.compress(bytes(flattened_generation, "UTF-8"))) + + cr = compressed_bytes_size / original_byte_size + cr_tensor = torch.Tensor([cr]) + return cr_tensor + + +def get_self_bleu_score(strings): + # Create a translation table to remove punctuation + translator = str.maketrans("", "", string_lib.punctuation) + + # Preprocess the strings: convert to lowercase and remove punctuation + cleaned_strings = [s.lower().translate(translator) for s in strings] + + # Tokenize the cleaned strings into lists of words + tokenized_strings = [s.split() for s in cleaned_strings] + + # Initialize a dictionary to store BLEU scores + bleu_scores = [] + + # Calculate BLEU scores for all pairs of strings + for i in range(len(tokenized_strings)): + for j in range(i + 1, len(tokenized_strings)): + # Use smoothing to handle cases where there are no n-grams in common + smoothie = SmoothingFunction().method4 + bleu = sentence_bleu( + [tokenized_strings[i]], + tokenized_strings[j], + smoothing_function=smoothie, + ) + + # Store the BLEU score + bleu_scores.append(bleu) + + mean_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0 + mean_bleu_score_tensor = torch.Tensor([mean_bleu_score]) + return mean_bleu_score_tensor + + +def get_unique_1grams(strings): + + # Initialize an empty set to store unique 1-grams + unique_words = set() + total_words = 0 + + # Create a translation table to remove punctuation + translator = str.maketrans("", "", string_lib.punctuation) + + # Iterate over each string in the list + for string in strings: + # Convert the string to lowercase and remove punctuation + cleaned_string = string.lower().translate(translator) + + # Split the cleaned string into words (1-grams) and update the set + words = cleaned_string.split() + total_words += len(words) + unique_words.update(words) + + # Return the set of unique 1-grams + num_unique_1grams = len(unique_words) + num_unique_1grams_norm = len(unique_words) / total_words if total_words > 0 else 0 + num_unique_1grams_tensor = torch.Tensor([num_unique_1grams]) + num_unique_1grams_norm = torch.Tensor([num_unique_1grams_norm]) + return num_unique_1grams_tensor, num_unique_1grams_norm + + +def extract_logprobs(data): + logprobs = [] + for item in data: + for key, logprob in item.items(): + logprobs.append(logprob.logprob) + return logprobs + + +def get_entropy(rollouts): + batch_sum_logprobs = [] + batch_sum_logprobs_per_tok = [] + for rollout_idx in range(len(rollouts[0].outputs)): + logprobs = extract_logprobs(rollouts[0].outputs[rollout_idx].logprobs) + + sum_logprobs = -sum(logprobs) + sum_logprobs_per_tok = -sum(logprobs) / len(logprobs) + + batch_sum_logprobs.append(sum_logprobs) + batch_sum_logprobs_per_tok.append(sum_logprobs_per_tok) + + entropy = sum(batch_sum_logprobs) / len(batch_sum_logprobs) + entropy_norm = sum(batch_sum_logprobs_per_tok) / len(batch_sum_logprobs_per_tok) + + return entropy, entropy_norm diff --git a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py index cf56e636b..5fc491cde 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py @@ -17,6 +17,8 @@ from torch import Tensor from torch.nn import Module from torcheval.metrics import Mean + +# from fairseq2.metrics import String from typing_extensions import override from vllm import SamplingParams from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -53,10 +55,7 @@ generate_rollouts, ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler -from fairseq2.recipes.lm._online_finetune._remote_vllm import ( - RemoteVllmModel, - RemoteVllmModelHandler, -) +from fairseq2.recipes.lm._online_finetune._remote_vllm import RemoteVllmModel from fairseq2.recipes.lm._online_finetune._rewards import ( RewardSection, VLLMOutputReward, @@ -74,6 +73,13 @@ from fairseq2.utils.structured import structure from fairseq2.utils.validation import validate +from fairseq2.recipes.lm._online_finetune._diversity_metrics import ( + get_compression_ratio, + get_self_bleu_score, + get_unique_1grams, + get_entropy, +) + @final class OnlineDpoFinetuneUnit(TrainUnit[SequenceBatch]): @@ -89,6 +95,7 @@ class OnlineDpoFinetuneUnit(TrainUnit[SequenceBatch]): _sync_ref_model_every_n_steps: int _display_name: str _reward: VLLMOutputReward + _valid_reward: VLLMOutputReward | None _reference_offload: bool def __init__( @@ -99,6 +106,7 @@ def __init__( vllm_model: RemoteVllmModel, vllm_actors: List[RemoteVllmModel], reward, + valid_reward, gangs: Gangs, loss_config: DpoLossConfig, sync_vllm_model_every_n_steps: int = 1, @@ -115,8 +123,8 @@ def __init__( self._sync_vllm_model_every_n_steps = sync_vllm_model_every_n_steps self._sync_ref_model_every_n_steps = sync_ref_model_every_n_step self._reward = reward + self._valid_reward = valid_reward self._metric_bag = OnlineDpoFinetuneMetricBag(gangs.dp) - self._display_name = "online_dpo" @property @@ -124,12 +132,12 @@ def __init__( def display_name(self) -> str | None: return self._display_name - def maybe_sync_models(self): + def maybe_sync_models(self, force_sync=False): if ( self._sync_vllm_model_every_n_steps > 0 and self._step_nr % self._sync_vllm_model_every_n_steps == 0 - ): + ) or force_sync: with self._model.summon_full_parameters(): if self._gangs.root.rank == 0: self._vllm_model.sync_weights_with_vllm(train_model=self._model) @@ -155,22 +163,53 @@ def maybe_sync_models(self): self._gangs.root.barrier() broadcast_model(self._reference_model, self._gangs) + def maybe_log_rollouts(self, prompt_batch: PromptBatch, rollouts, split_name): + if self._loss_config.log_rollouts: + prompt0 = prompt_batch.meta_info.get("prompt_raw")[0] + rollout0 = rollouts[0].outputs[0].text + log.info(f"{split_name} Prompt: {prompt0}") + log.info(f"{split_name} Rollout: {rollout0}") + def validate_reward(self, prompt_batch: PromptBatch) -> tuple[Tensor, int]: if self._gangs.dp.rank == 0: policy_sampling_params = copy(self._vllm_model.sampling_params) - policy_sampling_params.n = 1 + policy_sampling_params.n = self._vllm_model.valid_n + policy_sampling_params.temperature = 0.6 # FIXME add to config + policy_sampling_params.top_p = 0.9 # FIXME add to config else: policy_sampling_params = None + rollouts = generate_rollouts( prompt_batch.prompts, dp_gang=self._gangs.dp, vllm_model=self._vllm_model, sampling_params=policy_sampling_params, ) - reward_output = self._reward.process_rollouts(rollouts, prompt_batch) - avg_reward = torch.tensor(reward_output["rewards"]).float().mean() - self._metric_bag.update_avg_reward(avg_reward) + self.maybe_log_rollouts(prompt_batch, rollouts, "Valid") + + if self._valid_reward is None: + reward_output = self._reward.process_rollouts(rollouts, prompt_batch) + else: + reward_output = self._valid_reward.process_rollouts(rollouts, prompt_batch) + + self._metric_bag.update_batch_metrics(prompt_batch) + total_reward = torch.tensor(reward_output["rewards"]).float().mean() + self._metric_bag.update_avg_reward(total_reward) + + # Diversity metrics + unique_1grams, unique_1grams_norm = get_unique_1grams(reward_output["text"][0]) + self_bleu_score = get_self_bleu_score(reward_output["text"][0]) + compression_ratio = get_compression_ratio(reward_output["text"][0]) + entropy, entropy_norm = get_entropy(rollouts) + self._metric_bag.update_diversity_metrics( + unique_1grams, + unique_1grams_norm, + self_bleu_score, + compression_ratio, + entropy, + entropy_norm, + ) # returning dummy loss since trainer expects it return torch.tensor(0.0, device=self._gangs.dp.device), prompt_batch.batch_size @@ -216,10 +255,18 @@ def __call__(self, prompt_batch: PromptBatch) -> tuple[Tensor, int]: prompt_batch.prompts, dp_gang=self._gangs.dp, vllm_model=self._vllm_model ) + self.maybe_log_rollouts(prompt_batch, rollouts, "Train") + batch: PreferenceBatch batch, is_bad_batch, reward_output = self._reward.prepare_preference_batch( - prompt_batch, rollouts + prompt_batch, rollouts, divpo_p=self._loss_config.divpo_p ) # loss_zeroer is used when entire batch has no valid prefrence pair + + unique_1grams, unique_1grams_norm = get_unique_1grams(reward_output["text"][0]) + self_bleu_score = get_self_bleu_score(reward_output["text"][0]) + compression_ratio = get_compression_ratio(reward_output["text"][0]) + entropy, entropy_norm = get_entropy(rollouts) + if is_bad_batch: loss_zeroer = 0.0 else: @@ -255,14 +302,38 @@ def __call__(self, prompt_batch: PromptBatch) -> tuple[Tensor, int]: rejected_logps, average_rejected_logps = _gather_lprobs_avg( rejected_output, rejected_target_batch ) - tgt_logit_entropy = compute_token_level_entropy( + chosen_tgt_logit_entropy = compute_token_level_entropy( chosen_output.logits, chosen_target_batch.target_mask ) # [Batch x Rollouts, 1] + rejected_tgt_logit_entropy = compute_token_level_entropy( + rejected_output.logits, rejected_target_batch.target_mask + ) # [Batch x Rollouts, 1] + + all_entropy = [] + all_entropy_first100 = [] + # FIXME better way to get entropy from all rollouts? + for rollout_idx in range(len(rollouts[0].outputs)): + logprobs = rollouts[0].outputs[rollout_idx].logprobs + logprobs = [next(iter(x.values())).logprob for x in logprobs] + entropy = sum(logprobs) / len(logprobs) + entropy_first100 = sum(logprobs[0:100]) / len(logprobs[0:100]) + all_entropy.append(entropy) + all_entropy_first100.append(entropy_first100) + total_logit_entropy = torch.tensor(all_entropy, device=self._gangs.dp.device) + total_logit_entropy_first100 = torch.tensor( + all_entropy_first100, device=self._gangs.dp.device + ) max_entropy_regularizer = ( - -tgt_logit_entropy.sum() * self._loss_config.entropy_regularizer_scale + -chosen_tgt_logit_entropy.sum() + * self._loss_config.entropy_regularizer_scale + ) + self.metric_bag.update_chosen_logit_entropy(chosen_tgt_logit_entropy) + self.metric_bag.update_rejected_logit_entropy(rejected_tgt_logit_entropy) + self.metric_bag.update_total_logit_entropy(total_logit_entropy) + self.metric_bag.update_total_logit_entropy_first100( + total_logit_entropy_first100 ) - self.metric_bag.update_logit_entropy(tgt_logit_entropy) if self._reference_offload: token_ref_chosen_logps = self.compute_reference_logps(batch.chosen) @@ -315,6 +386,15 @@ def __call__(self, prompt_batch: PromptBatch) -> tuple[Tensor, int]: self._metric_bag.update_batch_metrics(batch.chosen) + self._metric_bag.update_diversity_metrics( + unique_1grams, + unique_1grams_norm, + self_bleu_score, + compression_ratio, + entropy, + entropy_norm, + ) + avg_reward = torch.tensor(reward_output["rewards"]).float().mean() self._metric_bag.update_avg_reward(avg_reward) @@ -371,6 +451,10 @@ def _compute_dpo_loss( def set_step_nr(self, step_nr: int) -> None: self._step_nr = step_nr + @override + def set_data_epoch_nr(self, data_epoch_nr: int) -> None: + self._data_epoch_nr = data_epoch_nr + @property @override def model(self) -> Model: @@ -389,7 +473,16 @@ class OnlineDpoFinetuneMetricBag(POFinetuneMetricBag): num_dummy_batches: Mean avg_reward: Mean avg_zeroed_loss: Mean - logit_entropy: Mean + unique_1grams: Mean + unique_1grams_norm: Mean + self_bleu_score: Mean + compression_ratio: Mean + entropy: Mean + entropy_norm: Mean + chosen_logit_entropy: Mean + rejected_logit_entropy: Mean + total_logit_entropy: Mean + total_logit_entropy_first100: Mean def __init__(self, gang: Gang) -> None: super().__init__(gang) @@ -403,14 +496,63 @@ def __init__(self, gang: Gang) -> None: "avg_zeroed_loss", Mean(device=gang.device), persistent=False ) self.register_metric( - "logit_entropy", Mean(device=gang.device), persistent=False + "unique_1grams", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "unique_1grams_norm", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "self_bleu_score", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "compression_ratio", Mean(device=gang.device), persistent=False + ) + self.register_metric("entropy", Mean(device=gang.device), persistent=False) + self.register_metric("entropy_norm", Mean(device=gang.device), persistent=False) + self.register_metric( + "chosen_logit_entropy", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "rejected_logit_entropy", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "total_logit_entropy", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "total_logit_entropy_first100", Mean(device=gang.device), persistent=False + ) + + @torch.inference_mode() + def update_chosen_logit_entropy(self, logit_entropy: Tensor): + # logit_entropy is expected to contain token-level entropy for every sequence in the current batch + batch_size = logit_entropy.size(0) + self.chosen_logit_entropy.update( + logit_entropy.sum() / batch_size, weight=batch_size + ) + + @torch.inference_mode() + def update_rejected_logit_entropy(self, logit_entropy: Tensor): + # logit_entropy is expected to contain token-level entropy for every sequence in the current batch + batch_size = logit_entropy.size(0) + self.rejected_logit_entropy.update( + logit_entropy.sum() / batch_size, weight=batch_size + ) + + @torch.inference_mode() + def update_total_logit_entropy(self, logit_entropy: Tensor): + # logit_entropy is expected to contain token-level entropy for every sequence in the current batch + batch_size = logit_entropy.size(0) + self.total_logit_entropy.update( + logit_entropy.sum() / batch_size, weight=batch_size ) @torch.inference_mode() - def update_logit_entropy(self, logit_entropy: Tensor): + def update_total_logit_entropy_first100(self, logit_entropy: Tensor): # logit_entropy is expected to contain token-level entropy for every sequence in the current batch batch_size = logit_entropy.size(0) - self.logit_entropy.update(logit_entropy.sum() / batch_size, weight=batch_size) + self.total_logit_entropy_first100.update( + logit_entropy.sum() / batch_size, weight=batch_size + ) @torch.inference_mode() def update_dpo_loss(self, batch: PreferenceBatch, loss: Tensor) -> None: @@ -435,10 +577,43 @@ def update_num_dummy_batches(self, batch: PreferenceBatch, num_dummy_batches: in def update_avg_reward(self, avg_reward): self.avg_reward.update(avg_reward, weight=1) + @torch.inference_mode() + def update_batch_metrics(self, batch: PreferenceBatch): + # if self._gang.rank == 0: + # breakpoint() + + num_examples = batch.batch_size + self.num_examples.update(num_examples) + if self._train: + assert self.total_num_examples is not None + self.total_num_examples.update(num_examples) + @torch.inference_mode() def update_avg_zeroed_loss(self, avg_zeroed_loss): self.avg_zeroed_loss.update(avg_zeroed_loss, weight=1) + @torch.inference_mode() + def update_diversity_metrics( + self, + unique_1grams, + unique_1grams_norm, + self_bleu_score, + compression_ratio, + entropy, + entropy_norm, + ): + self.unique_1grams.update(unique_1grams, weight=1) + self.unique_1grams_norm.update(unique_1grams_norm, weight=1) + self.self_bleu_score.update(self_bleu_score, weight=1) + self.compression_ratio.update(compression_ratio, weight=1) + + self.entropy.update(torch.Tensor([entropy]), weight=1) + self.entropy_norm.update(torch.Tensor([entropy_norm]), weight=1) + + # @torch.inference_mode() + # def update_rollouts(self, rollouts): + # self.rollouts.update(rollouts) + ONLINE_DPO_FINETUNE_UNIT: Final = "online_dpo" @@ -455,6 +630,11 @@ class DpoLossConfig: length_normalization: bool = False """Use length normalized DPO, which uses the average log probability of a sequence as the implicit reward.""" + divpo_p: float = 0.0 + """Use diverse preference optimization.""" + + log_rollouts: bool = True + """Add prompts/rollouts to the logs""" entropy_regularizer_scale: float = 0.0 @@ -474,12 +654,15 @@ class OnlineDpoFinetuneConfig: loss_config: DpoLossConfig = field(default_factory=lambda: DpoLossConfig()) ray_policy_actor_name: str = "vllm_policy" - vllm_reward_model_name: str = None + vllm_reward_model_name: str = None reward: RewardSection = field( default_factory=lambda: RewardSection(name="gsm8k_verifier") ) + vllm_valid_reward_model_name: str = None + valid_reward: RewardSection | None = None + sync_ref_model_every_n_steps: int = -1 sync_vllm_model_every_n_steps: int = -1 @@ -504,7 +687,7 @@ def create( validate(config) if isinstance(config.reference_model, ReferenceModelSection): - log.info("Setting up GRPO with reference model.") + log.info("Setting up Online DPO with reference model.") trainer_section = get_config_section( recipe_config, "trainer", TrainerSection @@ -548,6 +731,22 @@ def create( gangs=gangs, ) + # VALID REWARD MODEL + if config.vllm_valid_reward_model_name is not None: + vllm_valid_reward_model = vllm_actors.get( + config.vllm_valid_reward_model_name, None + ) + reward_registry = self._context.get_registry(VLLMOutputRewardHandler) + reward_handler = reward_registry.get(config.valid_reward.name) + valid_reward = reward_handler.create( + reward_model=vllm_valid_reward_model, + reward_config=config.valid_reward.config, + gangs=gangs, + ) + log.info("Setting up Online DPO with valid reward model.") + else: + valid_reward = None + return OnlineDpoFinetuneUnit( model, reference_model, @@ -555,6 +754,7 @@ def create( vllm_model, vllm_actors, reward, + valid_reward, gangs, config.loss_config, config.sync_vllm_model_every_n_steps, diff --git a/src/fairseq2/recipes/lm/_online_finetune/_recipe.py b/src/fairseq2/recipes/lm/_online_finetune/_recipe.py index a05977627..ee3d0a8ca 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_recipe.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_recipe.py @@ -67,18 +67,13 @@ OnlineFinetuneUnitHandler, UnknownOnlineFinetuneUnitError, ) -from fairseq2.recipes.lm._online_finetune._online_dpo import ( # ONLINE_DPO_FINETUNE_UNIT, - OnlineDpoFinetuneConfig, -) -from fairseq2.recipes.lm._online_finetune._grpo import ( - GrpoFinetuneConfig, -) - from fairseq2.recipes.lm._online_finetune._remote_vllm import ( RemoteVllmModelHandler, - VllmEngineArgs, VllmRayActorConfig, ) +from fairseq2.recipes.lm._online_finetune._grpo import ( + GrpoFinetuneConfig, +) from fairseq2.recipes.trainer import Trainer from fairseq2.typing import CPU from fairseq2.utils.rng import manual_seed diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_vllm.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_vllm.py index 7e9b5eeb8..ecbb24bd9 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_vllm.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_vllm.py @@ -32,18 +32,15 @@ class RemoteModelHandler(ABC): @abstractmethod - def create(self, gangs: Gangs, unit_config: object) -> RemoteVllmModel: - ... + def create(self, gangs: Gangs, unit_config: object) -> RemoteVllmModel: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @property @abstractmethod - def config_kls(self) -> type[object]: - ... + def config_kls(self) -> type[object]: ... @dataclass(kw_only=True) @@ -51,19 +48,21 @@ class VllmEngineArgs: model: str = "/checkpoint/ram/kulikov/gsm8k_8b_sft/checkpoints/step_20" tokenizer: str = "/datasets/pretrained-llms/Llama-3.1-8B-Instruct" task: str = "generate" - tensor_parallel_size: int = 4 + trust_remote_code: bool = False + model_impl: str = "auto" enforce_eager: bool = True + tensor_parallel_size: int = 4 hf_overrides: object = None + dtype: str = "auto" override_pooler_config: PoolerConfig = field(default_factory=lambda: PoolerConfig()) + valid_n: int = 1 @dataclass(kw_only=True) class VllmRayActorConfig: ray_actor_name: str = "dummy" vllm_engine_args: VllmEngineArgs = field(default_factory=lambda: VllmEngineArgs()) - vllm_sampling_params: Dict[str, Any] = field( - default_factory=lambda: {} - ) + vllm_sampling_params: Dict[str, Any] = field(default_factory=lambda: {}) init_update_process_group: bool = False @@ -114,11 +113,10 @@ def __init__( self.vllm_model = self.setup_vllm_worker( ray_actor_name, vllm_engine_args, gangs ) - + self.valid_n = vllm_engine_args.valid_n + # populate sampling params using all values that were passed in the config - self.sampling_params = SamplingParams( - **sampling_params - ) + self.sampling_params = SamplingParams(**sampling_params) if init_update_process_group: self.update_process_group = self.setup_process_group_for_model_sync( @@ -159,8 +157,11 @@ def setup_vllm_worker( worker_cls=MyWorker, tensor_parallel_size=vllm_engine_args.tensor_parallel_size, task=vllm_engine_args.task, + trust_remote_code=vllm_engine_args.trust_remote_code, + model_impl=vllm_engine_args.model_impl, hf_overrides=vllm_engine_args.hf_overrides, override_pooler_config=vllm_engine_args.override_pooler_config, + dtype=vllm_engine_args.dtype, distributed_executor_backend="ray", ) @@ -222,7 +223,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None): return outputs - def reward_from_model(self, prompt_list, batch_size=64): + def reward_from_model(self, prompt_list, batch_size=16): # NOTE: need to batch inputs to vllm.encode model for current models that aren't supported by vllm rewards = [] for i in range(0, len(prompt_list), batch_size): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 23bc742bf..776d99c07 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -34,6 +34,8 @@ ) from fairseq2.recipes.model import Model from fairseq2.recipes.trainer import TrainUnit +import numpy as np +from fairseq2.logging import log @dataclass(kw_only=True) @@ -240,22 +242,11 @@ def __init__(self, gangs, reward_model, answer_key, prompt_key): "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2" ) - def extract_text_from_llama3_wrapper(self, input_string): - start_pattern = r"<\|start_header_id\|>user<\|end_header_id\|>" - end_pattern = r"<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>" - start_index = re.search(start_pattern, input_string).end() - end_index = re.search(end_pattern, input_string).start() - # Extract the text between the start and end indices - extracted_text = input_string[start_index:end_index].strip() - return extracted_text - def wrap_text(self, prompt_text, rollout_text): wrapped_text = [ {"role": "user", "content": prompt_text}, {"role": "assistant", "content": rollout_text}, ] - # templated_text = self.tokenizer.apply_chat_template(wrapped_text, tokenize=True) - # tokens_prompt = TokensPrompt(prompt_token_ids=templated_text) chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False) chat_str = chat_str.replace("<|begin_of_text|>", "") @@ -299,8 +290,48 @@ def process_rollouts( return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + def get_divpo_indices(self, rewards, rollouts, p=0.10): + cumulative_logprobs_norm = [] + for rollout_idx in range(len(rollouts[0].outputs)): + logprobs = self.extract_logprobs(rollouts[0].outputs[rollout_idx].logprobs) + cumulative_logprob_norm = sum(logprobs) / len(logprobs) + cumulative_logprobs_norm.append(cumulative_logprob_norm) + + assert len(rewards) == len( + cumulative_logprobs_norm + ), "Rewards and logprobs must have the same length" + + # Convert the list to a numpy array + max_val = max(rewards) + min_val = min(rewards) + + diff = max_val - min_val + thresh_offset = diff * p + top_thresh = max_val - thresh_offset + bot_thresh = min_val + thresh_offset + + chosen_set = [idx for idx, val in enumerate(rewards) if val >= top_thresh] + rejected_set = [idx for idx, val in enumerate(rewards) if val <= bot_thresh] + + # Debugging output + # log.info(f"rewards: {rewards}") + # log.info(f"top_thresh: {top_thresh}, bot_thresh: {bot_thresh}") + # log.info(f"chosen_set: {chosen_set}, rejected_set: {rejected_set}") + + max_reward_idx = min(chosen_set, key=lambda i: cumulative_logprobs_norm[i]) + min_reward_idx = max(rejected_set, key=lambda i: cumulative_logprobs_norm[i]) + + return max_reward_idx, min_reward_idx + + def extract_logprobs(self, data): + logprobs = [] + for item in data: + for key, logprob in item.items(): + logprobs.append(logprob.logprob) + return logprobs + def prepare_preference_batch( - self, prompt_batch: PromptBatch, rollouts + self, prompt_batch: PromptBatch, rollouts, divpo_p=0 ) -> PreferenceBatch: reward_output = self.process_rollouts(rollouts, prompt_batch) @@ -314,8 +345,260 @@ def prepare_preference_batch( for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( zip(reward_output["rewards"], reward_output["tokens"]) ): - chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) - rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + # if self._gangs.root.rank == 0: + # breakpoint() + + if divpo_p > 0: + chosen_rollout_position, rejected_rollout_position = ( + self.get_divpo_indices(i_batch_rewards, rollouts, divpo_p) + ) + else: + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output + + def prepare_grpo_batch(self, prompt_batch: PromptBatch, rollouts): + + prompt_rollouts = [] + prompt_lens = [] + rewards = [] + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + prompt = prompt_batch.prompts[i_batch] + rollout_tokens = [ + torch.tensor(prompt + list(c), device=self._gangs.dp.device) + for c in i_batch_tokens + ] + prompt_rollouts.extend(rollout_tokens) + + prompt_lens.extend([len(prompt)] * len(rollout_tokens)) + + rewards.append(i_batch_rewards) + + prompt_rollout_batch = collate_with_target_mask( + prompt_rollouts, prompt_lens, device=self._gangs.dp.device + ) + + rewards = torch.tensor( + rewards, device=self._gangs.dp.device + ).float() # [Batch, Rollouts] + rewards_normalized = (rewards - rewards.mean(dim=1, keepdim=True)) / ( + rewards.std(dim=1, keepdim=True) + 1e-6 + ) # small epsilon to compensate 0 std + + grpo_batch = GRPOBatch( + prompt_rollouts=prompt_rollout_batch, rewards=rewards_normalized + ) + + return grpo_batch, reward_output + + +class AtheneVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_config, gangs): + return AtheneVerifier( + gangs, + reward_model, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + ) + + @property + @override + def name(self): + return "athene_verifier" + + @property + @override + def config_kls(self): + return None + + +class AtheneVerifier(VLLMOutputReward): + def __init__(self, gangs, reward_model, answer_key, prompt_key): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self.reward_model = reward_model + self.tokenizer = AutoTokenizer.from_pretrained( + "/checkpoint/ram/shared/Athene-RM-8B" # FIXME move to configs + ) + + def wrap_text(self, prompt_text, rollout_text): + wrapped_text = [ + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": rollout_text}, + ] + chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False) + chat_str += "<|reserved_special_token_1|>" + + # if self._gangs.root.rank == 0: + # breakpoint() + + return chat_str + + @override + def process_rollouts( + self, vllm_outputs: List[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + + rollouts_text = [] + rollouts_tokens = [] + for rollout_output in i_batch_request_output.outputs: + rollout_text = rollout_output.text + vllm_input = self.wrap_text(prompt_text, rollout_text) + vllm_inputs.append(vllm_input) + rollouts_text.append(rollout_output.text) + rollouts_tokens.append(rollout_output.token_ids) + + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + batch_rewards = generate_rewards( + vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model + ) + + # reshape batch_rewards to [Batch, Rollouts] + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)] + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def get_divpo_indices(self, rewards, rollouts, p=0.10): + cumulative_logprobs_norm = [] + for rollout_idx in range(len(rollouts[0].outputs)): + logprobs = self.extract_logprobs(rollouts[0].outputs[rollout_idx].logprobs) + cumulative_logprob_norm = sum(logprobs) / len(logprobs) + cumulative_logprobs_norm.append(cumulative_logprob_norm) + + assert len(rewards) == len( + cumulative_logprobs_norm + ), "Rewards and logprobs must have the same length" + + # Convert the list to a numpy array + max_val = max(rewards) + min_val = min(rewards) + + diff = max_val - min_val + thresh_offset = diff * p + top_thresh = max_val - thresh_offset + bot_thresh = min_val + thresh_offset + + chosen_set = [idx for idx, val in enumerate(rewards) if val >= top_thresh] + rejected_set = [idx for idx, val in enumerate(rewards) if val <= bot_thresh] + + max_reward_idx = min(chosen_set, key=lambda i: cumulative_logprobs_norm[i]) + min_reward_idx = max(rejected_set, key=lambda i: cumulative_logprobs_norm[i]) + + return max_reward_idx, min_reward_idx + + def extract_logprobs(self, data): + logprobs = [] + for item in data: + for key, logprob in item.items(): + logprobs.append(logprob.logprob) + return logprobs + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts, divpo_p=0 + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + + if divpo_p > 0: + chosen_rollout_position, rejected_rollout_position = ( + self.get_divpo_indices(i_batch_rewards, rollouts, divpo_p) + ) + else: + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) if chosen_rollout_position == rejected_rollout_position: # cant form preference pair when we dont have such rollouts diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index e0e7a6a13..0530df572 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -118,6 +118,7 @@ class Trainer(StatefulObjectBag, Generic[BatchT]): _valid_data_readers: Sequence[DataReader[BatchT]] _validate_after_n_steps: int _validate_every_n_steps: int | None + _validate_step_0: bool _validate_after_n_data_epochs: int _validate_every_n_data_epochs: int | None _checkpoint_manager: CheckpointManager @@ -177,6 +178,7 @@ def __init__( valid_data_readers: Sequence[DataReader[BatchT]] | None = None, validate_after_n_steps: int = 0, validate_every_n_steps: int | None = None, + validate_step_0: bool = False, validate_after_n_data_epochs: int = 0, validate_every_n_data_epochs: int | None = None, checkpoint_after_n_steps: int = 0, @@ -235,6 +237,8 @@ def __init__( The number of steps after which to start validating the model. :param validate_every_n_steps: The step interval at which to validate the model. + :param validate_step_0: + Validate before training :param validate_after_n_data_epochs: The number of data epochs after which to start validating the model. :param validate_every_n_data_epochs: @@ -390,6 +394,7 @@ def __init__( self._validate_after_n_data_epochs = validate_after_n_data_epochs self._validate_every_n_data_epochs = validate_every_n_data_epochs + self._validate_step_0 = validate_step_0 self._checkpoint_manager = checkpoint_manager @@ -580,6 +585,12 @@ def _maybe_restore_state(self) -> None: log.info("Training restored. Resuming.") + try: + self._unit.set_step_nr(step_nr) + self._unit.maybe_sync_models(force_sync=True) + except: + pass + def _do_run(self) -> None: self._model.module.train() @@ -590,11 +601,22 @@ def _do_run(self) -> None: "train", total=self._max_num_steps, completed=self._step_nr ) + + self._device_stat_tracker.reset() first_iter = True + if self._validate_step_0: + self._validate() + while self._should_run_step(): + + try: + self._unit.set_data_epoch_nr(self._data_epoch_nr) + except: + pass + self._maybe_advance_data_epoch() self._step_nr += 1 @@ -614,7 +636,6 @@ def _do_run(self) -> None: if self._should_validate(): self._validate() - self._maybe_request_early_stop() if self._should_checkpoint(): diff --git a/src/fairseq2/setup/_metrics.py b/src/fairseq2/setup/_metrics.py index 0e132c160..842169152 100644 --- a/src/fairseq2/setup/_metrics.py +++ b/src/fairseq2/setup/_metrics.py @@ -69,18 +69,22 @@ def register(name: str, *args: Any) -> None: register("generator_cache_capacity", "Generator/Cache Capacity", 904, format_as_byte_size) # Preference Optimization - register("cpo_loss", "CPO Loss", 0, format_as_float) - register("dpo_loss", "DPO Loss", 0, format_as_float) - register("orpo_loss", "ORPO Loss", 0, format_as_float) - register("simpo_loss", "SimPO Loss", 0, format_as_float) - register("grpo_loss", "GRPO Loss", 0, format_as_float) - register("avg_reward", "Reward", 1, format_as_float) - register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float) - register("rejected_logps", "Rejected Sequence Log Probabilities", 50, format_as_float) - register("logit_entropy", "Logit Entropy", 51, format_as_float) - register("rollout_lengths", "Rollout Length", 70, format_as_float) - register("chosen_lengths", "Chosen Sequence Length", 70, format_as_float) - register("rejected_lengths", "Rejected Sequence Length", 70, format_as_float) + register("cpo_loss", "CPO Loss", 0, format_as_float) + register("dpo_loss", "DPO Loss", 0, format_as_float) + register("orpo_loss", "ORPO Loss", 0, format_as_float) + register("simpo_loss", "SimPO Loss", 0, format_as_float) + register("grpo_loss", "GRPO Loss", 0, format_as_float) + register("avg_reward", "Reward", 1, format_as_float) + register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float) + register("rejected_logps", "Rejected Sequence Log Probabilities", 50, format_as_float) + register("logit_entropy", "Logit Entropy", 51, format_as_float) + register("chosen_logit_entropy", "Chosen Logit Entropy", 51, format_as_float) + register("rejected_logit_entropy","Rejected Logit Entropy", 51, format_as_float) + register("total_logit_entropy", "Total Logit Entropy", 51, format_as_float) + register("total_logit_entropy_first100", "Total Logit Entropy (first 100)",51, format_as_float) + register("rollout_lengths", "Rollout Length", 70, format_as_float) + register("chosen_lengths", "Chosen Sequence Length", 70, format_as_float) + register("rejected_lengths", "Rejected Sequence Length", 70, format_as_float) # Memory register("peak_active_mem", "Peak Active Device Memory", 920, format_as_byte_size) diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index 3f2b76350..c7caeb10d 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -19,6 +19,7 @@ GSM8kVerifierHandler, NuminaMathVerifierHandler, SkyworkVerifierHandler, + AtheneVerifierHandler, VLLMOutputRewardHandler, ) @@ -72,9 +73,14 @@ def register_online_finetune_units(context: RuntimeContext) -> None: handler = GSM8kVerifierHandler() registry.register(handler.name, handler) + # SkyworkVerifier handler = SkyworkVerifierHandler() registry.register(handler.name, handler) + # AtheneVerifier + handler = AtheneVerifierHandler() + registry.register(handler.name, handler) + # NuminaMathVerifier handler = NuminaMathVerifierHandler() registry.register(handler.name, handler)