-
Notifications
You must be signed in to change notification settings - Fork 106
Log chosen/rejected entropy #1159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: online_training
Are you sure you want to change the base?
Conversation
@@ -221,7 +221,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None): | |||
|
|||
return outputs | |||
|
|||
def reward_from_model(self, prompt_list, batch_size=64): | |||
def reward_from_model(self, prompt_list, batch_size=16): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was getting some vllm cuda OOM with batch_size=64
because we have custom vllm model
register("chosen_logit_entropy", "Chosen Logit Entropy", 51, format_as_float) | ||
register("rejected_logit_entropy","Rejected Logit Entropy", 51, format_as_float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are the only two added. rest are formatted
per_seq_loss = ( | ||
(per_token_loss * target_mask).sum(dim=-1) | ||
).mean(dim=1) | ||
per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting
) # [Batch x Rollouts, 1] | ||
|
||
# entropy for all N rollouts | ||
logit_entropy = self.get_all_rollouts_entropy(rollouts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we want this. previously logit_entropy
was computed for the chosen sequences. here i'm computing it for all rollouts
What does this PR do? Please describe:
Check list: