Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/api/src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ impl From<ChatCompletionRequest> for ChatCompletionParams {
store: None,
stream_options: None,
modalities,
return_hidden_states: None,
layers: None,
Comment thread
satojandro marked this conversation as resolved.
Outdated
extra,
}
}
Expand Down
5 changes: 5 additions & 0 deletions crates/api/src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ fn build_flush_chunks(states: &mut StreamUnredactStates, template: &ChunkTemplat
usage: None,
prompt_token_ids: None,
modality: None,
extra: Default::default(),
};
if let Ok(s) = serde_json::to_string(&chunk) {
out.push(Bytes::from(format!("data: {s}\n\n")));
Expand Down Expand Up @@ -780,6 +781,7 @@ fn build_flush_chunks(states: &mut StreamUnredactStates, template: &ChunkTemplat
usage: None,
prompt_token_ids: None,
modality: None,
extra: Default::default(),
};
if let Ok(s) = serde_json::to_string(&chunk) {
out.push(Bytes::from(format!("data: {s}\n\n")));
Expand Down Expand Up @@ -1509,6 +1511,7 @@ mod tests {
usage: None,
prompt_token_ids: None,
modality: None,
extra: Default::default(),
})
}

Expand Down Expand Up @@ -1590,6 +1593,7 @@ mod tests {
usage: None,
prompt_token_ids: None,
modality: None,
extra: Default::default(),
})
}

Expand Down Expand Up @@ -1810,6 +1814,7 @@ mod tests {
usage: Some(inference_providers::models::TokenUsage::new(10, 5)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
};

let stream_chunk = inference_providers::StreamChunk::Chat(chunk.clone());
Expand Down
1 change: 1 addition & 0 deletions crates/inference_providers/src/chunk_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl ChunkContext {
usage,
prompt_token_ids: None,
modality: None,
extra: Default::default(),
}
}

Expand Down
3 changes: 3 additions & 0 deletions crates/inference_providers/src/external/anthropic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ impl ExternalBackend for AnthropicBackend {
prompt_logprobs: None,
prompt_token_ids: None,
kv_transfer_params: None,
extra: Default::default(),
};

// Serialize our normalized response. We intentionally overwrite fields
Expand Down Expand Up @@ -344,6 +345,8 @@ mod tests {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/inference_providers/src/external/gemini/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ fn convert_to_openai_response(
prompt_logprobs: None,
prompt_token_ids: None,
kv_transfer_params: None,
extra: Default::default(),
})
}

Expand Down
2 changes: 2 additions & 0 deletions crates/inference_providers/src/external/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ mod tests {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra,
}
}
Expand Down
6 changes: 6 additions & 0 deletions crates/inference_providers/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ impl ResponseTemplate {
prompt_logprobs: None,
prompt_token_ids: None,
kv_transfer_params: None,
extra: Default::default(),
}
}

Expand Down Expand Up @@ -380,6 +381,7 @@ impl ResponseTemplate {
usage: Some(self.token_usage(input_tokens, output_token_count)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
});
}
}
Expand Down Expand Up @@ -425,6 +427,7 @@ impl ResponseTemplate {
usage: Some(self.token_usage(input_tokens, output_token_count)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
});
}
}
Expand Down Expand Up @@ -471,6 +474,7 @@ impl ResponseTemplate {
usage: Some(self.token_usage(input_tokens, output_token_count)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
});

// Stream arguments split by spaces (like content)
Expand Down Expand Up @@ -522,6 +526,7 @@ impl ResponseTemplate {
usage: Some(self.token_usage(input_tokens, output_token_count)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
});
}
}
Expand All @@ -538,6 +543,7 @@ impl ResponseTemplate {
usage: Some(self.token_usage(input_tokens, output_token_count)),
prompt_token_ids: None,
modality: None,
extra: Default::default(),
});

chunks
Expand Down
77 changes: 76 additions & 1 deletion crates/inference_providers/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@ pub struct ChatCompletionParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,

/// Request per-layer hidden state activations from the backend.
/// Supported by sglang; may not be supported by all backends.
#[serde(skip_serializing_if = "Option::is_none")]
pub return_hidden_states: Option<bool>,

/// Optional subset of layers to return (0-indexed). When empty or
/// omitted while return_hidden_states is true, all layers are returned.
#[serde(skip_serializing_if = "Option::is_none")]
pub layers: Option<Vec<i64>>,

#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
Expand Down Expand Up @@ -436,6 +446,10 @@ pub struct ChatCompletionChunk {
/// Modality indicator for Qwen3-Omni streaming ("text" or "audio")
#[serde(skip_serializing_if = "Option::is_none")]
pub modality: Option<String>,

/// Additional provider-specific response fields.
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}

/// Text completion streaming chunk (matches OpenAI legacy format)
Expand Down Expand Up @@ -606,6 +620,10 @@ pub struct ChatCompletionResponse {
/// KV cache transfer parameters
#[serde(skip_serializing_if = "Option::is_none")]
pub kv_transfer_params: Option<serde_json::Value>,

/// Additional provider-specific response fields.
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}

/// Wrapper for chat completion response that includes raw bytes.
Expand Down Expand Up @@ -1083,7 +1101,7 @@ pub enum AudioTranscriptionError {
/// # Examples
///
/// ```
/// # use inference_providers::detect_audio_content_type;
/// # use inference_providers::models::detect_audio_content_type;
/// assert_eq!(detect_audio_content_type("speech.mp3"), "audio/mpeg");
/// assert_eq!(detect_audio_content_type("recording.wav"), "audio/wav");
/// assert_eq!(detect_audio_content_type("unknown.xyz"), "application/octet-stream");
Expand Down Expand Up @@ -1256,6 +1274,63 @@ mod tests {
assert!(reserialized.contains("\"web_context_search\""));
assert!(reserialized.contains("\"call_1\""));
}

#[test]
fn chat_completion_chunk_preserves_unknown_fields_round_trip() {
let json_chunk = r#"{
"id": "chatcmpl-hidden",
"object": "chat.completion.chunk",
"created": 1,
"model": "sglang-model",
"choices": [],
"hidden_states": {
"0": [[0.1, 0.2]],
"1": [[0.3, 0.4]]
}
}"#;

let chunk: ChatCompletionChunk = serde_json::from_str(json_chunk).unwrap();
let hidden_states = chunk
.extra
.get("hidden_states")
.expect("hidden_states must survive chunk deserialization");
assert_eq!(hidden_states["0"][0][0], 0.1);

let reserialized = serde_json::to_string(&chunk).unwrap();
assert!(reserialized.contains("\"hidden_states\""));
assert!(reserialized.contains("\"0\""));
}

#[test]
fn chat_completion_response_preserves_unknown_fields_round_trip() {
let json_response = r#"{
"id": "chatcmpl-hidden",
"object": "chat.completion",
"created": 1,
"model": "sglang-model",
"choices": [],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2
},
"hidden_states": {
"0": [[0.1, 0.2]],
"1": [[0.3, 0.4]]
}
}"#;

let response: ChatCompletionResponse = serde_json::from_str(json_response).unwrap();
let hidden_states = response
.extra
.get("hidden_states")
.expect("hidden_states must survive response deserialization");
assert_eq!(hidden_states["1"][0][1], 0.4);

let reserialized = serde_json::to_string(&response).unwrap();
assert!(reserialized.contains("\"hidden_states\""));
assert!(reserialized.contains("\"1\""));
}
}

// Score models for text similarity endpoint
Expand Down
2 changes: 2 additions & 0 deletions crates/inference_providers/src/vllm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3010,6 +3010,8 @@ mod tests {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down
8 changes: 8 additions & 0 deletions crates/inference_providers/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ async fn test_chat_completion_streaming() {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down Expand Up @@ -353,6 +355,8 @@ async fn test_error_handling() {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down Expand Up @@ -443,6 +447,8 @@ async fn test_chat_completion_streaming_with_tool_calls() {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down Expand Up @@ -636,6 +642,8 @@ async fn test_reasoning_content() {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down
10 changes: 10 additions & 0 deletions crates/services/src/completions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ impl ports::CompletionServiceTrait for CompletionServiceImpl {
store: request.store,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
Comment thread
satojandro marked this conversation as resolved.
Outdated
extra,
};

Expand Down Expand Up @@ -1294,6 +1296,8 @@ impl ports::CompletionServiceTrait for CompletionServiceImpl {
store: request.store,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
Comment thread
satojandro marked this conversation as resolved.
Outdated
extra,
};

Expand Down Expand Up @@ -1808,6 +1812,7 @@ mod tests {
prompt_token_ids: None,
system_fingerprint: None,
modality: None,
extra: Default::default(),
}),
};

Expand All @@ -1834,6 +1839,7 @@ mod tests {
prompt_token_ids: None,
system_fingerprint: None,
modality: None,
extra: Default::default(),
}),
};

Expand Down Expand Up @@ -1946,6 +1952,7 @@ mod tests {
prompt_token_ids: None,
system_fingerprint: None,
modality: None,
extra: Default::default(),
}),
};

Expand All @@ -1961,6 +1968,7 @@ mod tests {
prompt_token_ids: None,
system_fingerprint: None,
modality: None,
extra: Default::default(),
}),
};

Expand All @@ -1987,6 +1995,7 @@ mod tests {
prompt_token_ids: None,
modality: None,
system_fingerprint: None,
extra: Default::default(),
}),
};

Expand Down Expand Up @@ -2110,6 +2119,7 @@ mod tests {
prompt_token_ids: None,
modality: None,
system_fingerprint: None,
extra: Default::default(),
}),
};

Expand Down
2 changes: 2 additions & 0 deletions crates/services/src/inference_provider_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,8 @@ mod tests {
store: None,
stream_options: None,
modalities: None,
return_hidden_states: None,
layers: None,
extra: std::collections::HashMap::new(),
};

Expand Down