Skip to content

feat: support logit bias in chat request #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
26 changes: 0 additions & 26 deletions .github/workflows/client-tests.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
Expand Down Expand Up @@ -181,6 +182,7 @@ impl Client {
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens,
Expand Down
2 changes: 2 additions & 0 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
Expand Down Expand Up @@ -244,6 +245,7 @@ impl Health for ShardedClient {
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
Expand Down
1 change: 1 addition & 0 deletions backends/v2/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ mod tests {
frequency_penalty: 0.0,
watermark: false,
grammar: None,
logit_bias: None,
},
stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false,
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
Expand Down Expand Up @@ -181,6 +182,7 @@ impl Client {
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens,
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::client::{
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri;
use tracing::instrument;

Expand Down Expand Up @@ -232,6 +233,7 @@ impl Health for ShardedClient {
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
Expand Down
10 changes: 10 additions & 0 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::client::{
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max;
use std::collections::HashMap;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
Expand Down Expand Up @@ -522,6 +523,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
logit_bias: value
.logit_bias
.map(|bias| {
bias.into_iter()
.map(|(token, bias)| (token.to_string(), bias as i32))
.collect::<HashMap<String, i32>>()
})
.unwrap_or_default(),
}
}
}
Expand Down Expand Up @@ -568,6 +577,7 @@ mod tests {
frequency_penalty: 0.0,
watermark: false,
grammar: None,
logit_bias: None,
},
stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false,
Expand Down
1 change: 1 addition & 0 deletions benchmark/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub async fn run(
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
logit_bias: std::collections::HashMap::new(),
};

// Initialize terminal properties
Expand Down
4 changes: 2 additions & 2 deletions clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any
from typing import Optional, List, Union, Any, Dict

from text_generation.errors import ValidationError

Expand Down Expand Up @@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Bias values for token selection
logit_bias: Optional[List[float]] = None
logit_bias: Optional[Dict[str, int]] = None
# Whether to return log probabilities
logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position
Expand Down
21 changes: 16 additions & 5 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -995,12 +995,12 @@
"nullable": true
},
"logit_bias": {
"type": "array",
"items": {
"type": "number",
"format": "float"
"type": "object",
"description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"additionalProperties": {
"type": "integer",
"format": "int32"
},
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"nullable": true
},
"logprobs": {
Expand Down Expand Up @@ -1589,6 +1589,17 @@
"default": "null",
"nullable": true
},
"logit_bias": {
"type": "object",
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
"default": "null",
"additionalProperties": {
"type": "integer",
"format": "int32"
},
"example": "{\"1923\": 100, \"1924\": -100}",
"nullable": true
},
"max_new_tokens": {
"type": "integer",
"format": "int32",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Hello! How can I help you today?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1745337495,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "¡Hola! ¿Cómo puedo ayudarte?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Chat!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 3,
"prompt_tokens": 25,
"total_tokens": 28
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.2.3-dev0-native",
"usage": null
}
Loading
Loading