Skip to content

feat: support logit bias in chat request #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Apr 22, 2025

This PR adds support for the previously unused logit_bias parameter in a chat request.

Example request without logit_bias using Qwen/Qwen2-VL-2B-Instruct

curl http://localhost:3000/v1/chat/completions -X POST \
-H 'Content-Type: application/json' \
-d '{
    "model": "tgi",
    "seed": 42,
    "max_tokens": 10,
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "say Hello"
                }
            ]
        }
    ]
}'

response

{
    "object": "chat.completion",
    "id": "",
    "created": 1745338432,
    "model": "Qwen/Qwen2-VL-2B-Instruct",
    "system_fingerprint": "3.2.3-dev0-native",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "Hello! How can I help you today?"
            },
            "logprobs": null,
            "finish_reason": "length"
        }
    ],
    "usage": {
        "prompt_tokens": 21,
        "completion_tokens": 10,
        "total_tokens": 31
    }
}

with logit_bias specified (specifically Hello with a large negative bias to avoid generating it

curl http://localhost:3000/v1/chat/completions -X POST \
-H 'Content-Type: application/json' \
-d '{
    "model": "tgi",
    "seed": 42,
    "max_tokens": 10,
    "logit_bias": {
        "9707": -100
    },
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "say Hello"
                }
            ]
        }
    ]
}'

it returns a different response; which happens to be in nice greeting in Spanish

{
    "object": "chat.completion",
    "id": "",
    "created": 1745338592,
    "model": "Qwen/Qwen2-VL-2B-Instruct",
    "system_fingerprint": "3.2.3-dev0-native",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "¡Hola! ¿Cómo puedo ayudarte?"
            },
            "logprobs": null,
            "finish_reason": "length"
        }
    ],
    "usage": {
        "prompt_tokens": 21,
        "completion_tokens": 10,
        "total_tokens": 31
    }
}

Important

This PR contains breaking changes as the logit_bias type has changed from a list to a map

@drbh
Copy link
Collaborator Author

drbh commented Apr 24, 2025

pinging @hanouticelina regarding changes to the hub library huggingface/huggingface_hub#2724 once this is merged

Comment on lines 267 to 270
# Initialize with empty logit biases if none provided
if logit_biases is None:
logit_biases = [None] * len(do_sample)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do like other arguments, and just send everything initialized instead of None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that makes sense, updated in latest commit. Thanks!

self.tokenizer = tokenizer
self.logit_bias = logit_bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary, once we have the processor we should let go of the other object (Think of this as logit_bias taking ownership)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oo yea removed 🙏

Comment on lines 646 to 651
for token_str, bias_value in self.logit_biases.items():
# Get token ID, either from cache or by computing it
if token_str not in self.token_id_mapping:
if token_str.isdigit():
# If the token string is already a numeric ID
token_id = int(token_str)
else:
# Otherwise, use the tokenizer to get the ID
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
token_id = tokens[0] if tokens else -1 # Use -1 for not found

self.token_id_mapping[token_str] = token_id

token_id = self.token_id_mapping[token_str]

# Apply bias if token ID is valid
if 0 <= token_id < scores.size(-1):
scores[:, token_id] += bias_value

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too slow implementation, the logit_bias must be a tensor precalculated, that we just need to add to the scores.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated along with a much better LogitBiasProcessor. The bias tensor is now created in init and simply added via add_ in __call__. Thanks 🙏

Comment on lines 642 to 643
if not self.logit_biases:
return scores
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never be in that case since we're only adding the processor when it's not empty, I'd happily switch to an assert here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, updated along with logit bias processor changes.

self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
):
self.tokenizer = tokenizer
self.logit_biases = logit_biases or {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.logit_biases = logit_biases or {}
self.logit_biases = logit_biases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated along with logit bias processor changes

self.logit_biases = logit_biases or {}

# Pre-compute token IDs for each token string
self.token_id_mapping = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the pre-computing ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated along with logit bias processor changes

if 0 <= token_id < scores.size(-1):
scores[i, token_id] += bias_value

return scores
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same no forloop, single tensor addition should be doable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated along with logit bias processor changes

@@ -125,6 +136,7 @@ def from_pb(
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is pb.logit_bias not possible ? It would maintain consistency better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good catch, simplified in the latest commit

@@ -500,6 +530,9 @@ def from_pb(
fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
),
logit_biases=[
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified in the latest commit

Comment on lines 709 to 717
if token_str.isdigit():
# If the token string is already a numeric ID
token_id = int(token_str)
else:
# Otherwise, use the tokenizer to get the ID
tokens = self.tokenizer.encode(
token_str, add_special_tokens=False
)
token_id = tokens[0] if tokens else -1 # Use -1 for not found
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do the sanitation much earlier, way up in the rust code, we also have the tokenizer there, and we can reject requests that contain invalid logit_bias early without having to encode or fail here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, I've removed the extra checking logic from the python side and reject in validate if any values are invalid (token id not in vocab range)

@drbh drbh requested a review from Narsil April 29, 2025 15:35
@drbh drbh force-pushed the support-logit-bias-in-chat branch from 88010ba to 2b996b0 Compare April 30, 2025 14:47
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better, I think I found a flaw in the actual computation, but other than that it looks great.

.collect(),
)
}
_ => None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_ => None,
None => None,

For readability (having all explicit variants makes it easier to know there's no shenanigans if change the values)

Using parameters.logit_bias.map(|bias| {....}) is another option.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great points and totally agree that a map is cleaner here. Updated in the latest changes for a more simple solution with map

Comment on lines 406 to 430
Some(bias) if !bias.is_empty() => {
for (token_str, _) in bias.iter() {
let token_id = token_str.parse::<u32>().map_err(|_| {
ValidationError::LogitBiasInvalid(format!(
"Token ID {} is not a valid number.",
token_str
))
})?;

if token_id >= self.vocab_size {
return Err(ValidationError::LogitBiasInvalid(format!(
"Token ID {} is out of range. Must be between 0 and {}.",
token_id,
self.vocab_size - 1
)));
}
}

// Transform into the required format
Some(
bias.iter()
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
.collect(),
)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you accept actual strings that are tokens before ? I'm fine with this version, but it seems different than before, just making sure it wasn't lost in translation.

I think the code can be simplified a bit.

let logit_bias = request.parameters.logit_bias.map(|bias|{
     let bias: Result<Vec<_>, _> = bias.into_iter().map(|(token_str, value)| {
                  let token_id: u32 = token_str.parse().map_err(...)?;
                  if token_id > self.vocab_size{
                  ....
                  }
                  Ok((token_id, value))  
     })
})

Current code is fine, but we could remove some unwrapping + double looping most likely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to accept token ids as string and not the token value.

I believe that the early changes included some token encoding decoding that is now remove which is an improvement.

Additionally I've updated the logic to be more simple and avoid the unwrap/extra loop in the latest changes

Comment on lines 703 to 705
self.bias_matrix = torch.nn.functional.pad(
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the padding, I'm surprised here.

It seems to me that self.bias_matrix is (BS, VOCAB), while scores is (SEQ_LENTHS, VOCAB).

In the most common scenario (decode), it's easy, as BS == SEQ_LENGTHS.

But in prefill, and mixed prefill + decode, by using pad, you're effectively spilling the bias_matrix onto other users, no ?
It seems to me we're needing cu_seqlengths (or whatever we currently have) in order to expand (probably using https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave ).

Again ideally we do this at init time, not at call time (so the call function is literally just an add).

Copy link
Collaborator Author

@drbh drbh May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running into a error that appears to be a bug with the configuration of Qwen/Qwen2-VL-2B-Instruct (used in the test) where the vocab size returned on Qwen2TokenizerFast is not the correct size (151643 instead of 151936) https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/blob/main/config.json.

The correct vocab size is set downstream within the custom modeling code but this is not accessible by the logit_processor so I've added a hacky patch that sets _vocab_size on the tokenizer if the vocab_size post loading does not match the tokenizer.vocab_size.

This solution as it feels hacky and obtuse yet reliably resolves the issue.. Any ideas on a cleaner approach?

Aside from this I've removed the padding step and now the forward is simply an add_

@drbh drbh force-pushed the support-logit-bias-in-chat branch from a174f63 to 7659925 Compare May 5, 2025 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants