-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: support logit bias in chat request #3186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
pinging @hanouticelina regarding changes to the hub library huggingface/huggingface_hub#2724 once this is merged |
# Initialize with empty logit biases if none provided | ||
if logit_biases is None: | ||
logit_biases = [None] * len(do_sample) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do like other arguments, and just send everything initialized instead of None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that makes sense, updated in latest commit. Thanks!
self.tokenizer = tokenizer | ||
self.logit_bias = logit_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary, once we have the processor we should let go of the other object (Think of this as logit_bias taking ownership)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oo yea removed 🙏
for token_str, bias_value in self.logit_biases.items(): | ||
# Get token ID, either from cache or by computing it | ||
if token_str not in self.token_id_mapping: | ||
if token_str.isdigit(): | ||
# If the token string is already a numeric ID | ||
token_id = int(token_str) | ||
else: | ||
# Otherwise, use the tokenizer to get the ID | ||
tokens = self.tokenizer.encode(token_str, add_special_tokens=False) | ||
token_id = tokens[0] if tokens else -1 # Use -1 for not found | ||
|
||
self.token_id_mapping[token_str] = token_id | ||
|
||
token_id = self.token_id_mapping[token_str] | ||
|
||
# Apply bias if token ID is valid | ||
if 0 <= token_id < scores.size(-1): | ||
scores[:, token_id] += bias_value | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too slow implementation, the logit_bias must be a tensor precalculated, that we just need to add to the scores.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated along with a much better LogitBiasProcessor
. The bias tensor is now created in init
and simply added via add_
in __call__
. Thanks 🙏
if not self.logit_biases: | ||
return scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should never be in that case since we're only adding the processor when it's not empty, I'd happily switch to an assert here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, updated along with logit bias processor changes.
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase | ||
): | ||
self.tokenizer = tokenizer | ||
self.logit_biases = logit_biases or {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.logit_biases = logit_biases or {} | |
self.logit_biases = logit_biases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated along with logit bias processor changes
self.logit_biases = logit_biases or {} | ||
|
||
# Pre-compute token IDs for each token string | ||
self.token_id_mapping = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the pre-computing ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated along with logit bias processor changes
if 0 <= token_id < scores.size(-1): | ||
scores[i, token_id] += bias_value | ||
|
||
return scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same no forloop, single tensor addition should be doable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated along with logit bias processor changes
@@ -125,6 +136,7 @@ def from_pb( | |||
tokenizer=tokenizer, | |||
grammar=pb.grammar, | |||
grammar_type=pb.grammar_type, | |||
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is pb.logit_bias
not possible ? It would maintain consistency better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good catch, simplified in the latest commit
@@ -500,6 +530,9 @@ def from_pb( | |||
fsm_grammar_states=( | |||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) | |||
), | |||
logit_biases=[ | |||
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified in the latest commit
if token_str.isdigit(): | ||
# If the token string is already a numeric ID | ||
token_id = int(token_str) | ||
else: | ||
# Otherwise, use the tokenizer to get the ID | ||
tokens = self.tokenizer.encode( | ||
token_str, add_special_tokens=False | ||
) | ||
token_id = tokens[0] if tokens else -1 # Use -1 for not found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do the sanitation much earlier, way up in the rust code, we also have the tokenizer there, and we can reject requests that contain invalid logit_bias early without having to encode or fail here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point, I've removed the extra checking logic from the python side and reject in validate
if any values are invalid (token id not in vocab range)
88010ba
to
2b996b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better, I think I found a flaw in the actual computation, but other than that it looks great.
router/src/validation.rs
Outdated
.collect(), | ||
) | ||
} | ||
_ => None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ => None, | |
None => None, |
For readability (having all explicit variants makes it easier to know there's no shenanigans if change the values)
Using parameters.logit_bias.map(|bias| {....})
is another option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great points and totally agree that a map is cleaner here. Updated in the latest changes for a more simple solution with map
router/src/validation.rs
Outdated
Some(bias) if !bias.is_empty() => { | ||
for (token_str, _) in bias.iter() { | ||
let token_id = token_str.parse::<u32>().map_err(|_| { | ||
ValidationError::LogitBiasInvalid(format!( | ||
"Token ID {} is not a valid number.", | ||
token_str | ||
)) | ||
})?; | ||
|
||
if token_id >= self.vocab_size { | ||
return Err(ValidationError::LogitBiasInvalid(format!( | ||
"Token ID {} is out of range. Must be between 0 and {}.", | ||
token_id, | ||
self.vocab_size - 1 | ||
))); | ||
} | ||
} | ||
|
||
// Transform into the required format | ||
Some( | ||
bias.iter() | ||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32)) | ||
.collect(), | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't you accept actual strings that are tokens before ? I'm fine with this version, but it seems different than before, just making sure it wasn't lost in translation.
I think the code can be simplified a bit.
let logit_bias = request.parameters.logit_bias.map(|bias|{
let bias: Result<Vec<_>, _> = bias.into_iter().map(|(token_str, value)| {
let token_id: u32 = token_str.parse().map_err(...)?;
if token_id > self.vocab_size{
....
}
Ok((token_id, value))
})
})
Current code is fine, but we could remove some unwrapping + double looping most likely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention is to accept token ids as string and not the token value.
I believe that the early changes included some token encoding decoding that is now remove which is an improvement.
Additionally I've updated the logic to be more simple and avoid the unwrap/extra loop in the latest changes
self.bias_matrix = torch.nn.functional.pad( | ||
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the padding, I'm surprised here.
It seems to me that self.bias_matrix is (BS, VOCAB), while scores is (SEQ_LENTHS, VOCAB).
In the most common scenario (decode), it's easy, as BS == SEQ_LENGTHS.
But in prefill, and mixed prefill + decode, by using pad, you're effectively spilling the bias_matrix onto other users, no ?
It seems to me we're needing cu_seqlengths (or whatever we currently have) in order to expand (probably using https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave ).
Again ideally we do this at init time, not at call time (so the call function is literally just an add).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was running into a error that appears to be a bug with the configuration of Qwen/Qwen2-VL-2B-Instruct
(used in the test) where the vocab size returned on Qwen2TokenizerFast
is not the correct size (151643 instead of 151936) https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/blob/main/config.json.
The correct vocab size is set downstream within the custom modeling code but this is not accessible by the logit_processor so I've added a hacky patch that sets _vocab_size
on the tokenizer if the vocab_size post loading does not match the tokenizer.vocab_size.
This solution as it feels hacky and obtuse yet reliably resolves the issue.. Any ideas on a cleaner approach?
Aside from this I've removed the padding step and now the forward is simply an add_
a174f63
to
7659925
Compare
This PR adds support for the previously unused logit_bias parameter in a chat request.
Example request without logit_bias using
Qwen/Qwen2-VL-2B-Instruct
response
with logit_bias specified (specifically
Hello
with a large negative bias to avoid generating itit returns a different response; which happens to be in nice greeting in Spanish
Important
This PR contains breaking changes as the
logit_bias
type has changed from a list to a map