Skip to content

Log chosen/rejected entropy #1159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: online_training
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/fairseq2/recipes/lm/_online_finetune/_grpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -343,9 +343,7 @@
(per_token_loss * target_mask).sum(dim=-1) / target_mask.sum(dim=-1)
).mean(dim=1)
else:
per_seq_loss = (
(per_token_loss * target_mask).sum(dim=-1)
).mean(dim=1)
per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting


# if self._gangs.root.rank == 0:
# from pudb.remote import set_trace
Expand Down
37 changes: 34 additions & 3 deletions src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -263,14 +263,21 @@
rejected_logps, average_rejected_logps = _gather_lprobs_avg(
rejected_output, rejected_target_batch
)
tgt_logit_entropy = compute_token_level_entropy(
chosen_tgt_logit_entropy = compute_token_level_entropy(
chosen_output.logits, chosen_target_batch.target_mask
) # [Batch x Rollouts, 1]
rejected_tgt_logit_entropy = compute_token_level_entropy(
rejected_output.logits, rejected_target_batch.target_mask
) # [Batch x Rollouts, 1]

max_entropy_regularizer = (
-tgt_logit_entropy.sum() * self._loss_config.entropy_regularizer_scale
-chosen_tgt_logit_entropy.sum()
* self._loss_config.entropy_regularizer_scale
)
self.metric_bag.update_logit_entropy(tgt_logit_entropy)

self.metric_bag.update_logit_entropy(chosen_tgt_logit_entropy)
self.metric_bag.update_chosen_logit_entropy(chosen_tgt_logit_entropy)
self.metric_bag.update_rejected_logit_entropy(rejected_tgt_logit_entropy)

if self._reference_offload:
token_ref_chosen_logps = self.compute_reference_logps(batch.chosen)
Expand Down Expand Up @@ -399,6 +406,8 @@
num_dummy_batches: Mean
avg_reward: Mean
avg_loss_zeroer: Mean
chosen_logit_entropy: Mean
rejected_logit_entropy: Mean
logit_entropy: Mean

def __init__(self, gang: Gang) -> None:
Expand All @@ -415,6 +424,28 @@
self.register_metric(
"logit_entropy", Mean(device=gang.device), persistent=False
)
self.register_metric(
"chosen_logit_entropy", Mean(device=gang.device), persistent=False
)
self.register_metric(
"rejected_logit_entropy", Mean(device=gang.device), persistent=False
)

@torch.inference_mode()
def update_chosen_logit_entropy(self, logit_entropy: Tensor):
# logit_entropy for chosen sequences
batch_size = logit_entropy.size(0)
self.chosen_logit_entropy.update(
logit_entropy.sum() / batch_size, weight=batch_size
)

@torch.inference_mode()
def update_rejected_logit_entropy(self, logit_entropy: Tensor):
# logit_entropy for rejected sequences
batch_size = logit_entropy.size(0)
self.rejected_logit_entropy.update(
logit_entropy.sum() / batch_size, weight=batch_size
)

@torch.inference_mode()
def update_logit_entropy(self, logit_entropy: Tensor):
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/lm/_online_finetune/_remote_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None):

return outputs

def reward_from_model(self, prompt_list, batch_size=64):
def reward_from_model(self, prompt_list, batch_size=16):
Copy link
Contributor Author

@jacklanchantin jacklanchantin May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was getting some vllm cuda OOM with batch_size=64 because we have custom vllm model

# NOTE: need to batch inputs to vllm.encode model for current models that aren't supported by vllm
rewards = []
for i in range(0, len(prompt_list), batch_size):
Expand Down
11 changes: 0 additions & 11 deletions src/fairseq2/recipes/lm/_online_finetune/_rewards.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -365,22 +365,11 @@
"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
)

def extract_text_from_llama3_wrapper(self, input_string):
start_pattern = r"<\|start_header_id\|>user<\|end_header_id\|>"
end_pattern = r"<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>"
start_index = re.search(start_pattern, input_string).end()
end_index = re.search(end_pattern, input_string).start()
# Extract the text between the start and end indices
extracted_text = input_string[start_index:end_index].strip()
return extracted_text

def wrap_text(self, prompt_text, rollout_text):
wrapped_text = [
{"role": "user", "content": prompt_text},
{"role": "assistant", "content": rollout_text},
]
# templated_text = self.tokenizer.apply_chat_template(wrapped_text, tokenize=True)
# tokens_prompt = TokensPrompt(prompt_token_ids=templated_text)
chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False)
chat_str = chat_str.replace("<|begin_of_text|>", "")

Expand Down
26 changes: 14 additions & 12 deletions src/fairseq2/setup/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,20 @@ def register(name: str, *args: Any) -> None:
register("generator_cache_capacity", "Generator/Cache Capacity", 904, format_as_byte_size)

# Preference Optimization
register("cpo_loss", "CPO Loss", 0, format_as_float)
register("dpo_loss", "DPO Loss", 0, format_as_float)
register("orpo_loss", "ORPO Loss", 0, format_as_float)
register("simpo_loss", "SimPO Loss", 0, format_as_float)
register("grpo_loss", "GRPO Loss", 0, format_as_float)
register("avg_reward", "Reward", 1, format_as_float)
register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float)
register("rejected_logps", "Rejected Sequence Log Probabilities", 50, format_as_float)
register("logit_entropy", "Logit Entropy", 51, format_as_float)
register("rollout_lengths", "Rollout Length", 70, format_as_float)
register("chosen_lengths", "Chosen Sequence Length", 70, format_as_float)
register("rejected_lengths", "Rejected Sequence Length", 70, format_as_float)
register("cpo_loss", "CPO Loss", 0, format_as_float)
register("dpo_loss", "DPO Loss", 0, format_as_float)
register("orpo_loss", "ORPO Loss", 0, format_as_float)
register("simpo_loss", "SimPO Loss", 0, format_as_float)
register("grpo_loss", "GRPO Loss", 0, format_as_float)
register("avg_reward", "Reward", 1, format_as_float)
register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float)
register("rejected_logps", "Rejected Sequence Log Probabilities", 50, format_as_float)
register("logit_entropy", "Logit Entropy", 51, format_as_float)
register("chosen_logit_entropy", "Chosen Logit Entropy", 51, format_as_float)
register("rejected_logit_entropy","Rejected Logit Entropy", 51, format_as_float)
Comment on lines +81 to +82
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the only two added. rest are formatted

register("rollout_lengths", "Rollout Length", 70, format_as_float)
register("chosen_lengths", "Chosen Sequence Length", 70, format_as_float)
register("rejected_lengths", "Rejected Sequence Length", 70, format_as_float)

# Memory
register("peak_active_mem", "Peak Active Device Memory", 920, format_as_byte_size)
Expand Down
Loading