Preserve hidden state fields in chat completions#679
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces return_hidden_states and layers fields to ChatCompletionParams to support requesting per-layer hidden state activations from backends like sglang. It also adds an extra map to ChatCompletionChunk and ChatCompletionResponse to preserve provider-specific fields during serialization and deserialization. The review feedback correctly identifies that the newly added fields are hardcoded to None during request conversions and service mappings, which prevents them from being populated from the incoming request's extra map. Actionable code suggestions are provided to extract these fields dynamically.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the chat-completions normalization layer to support hidden-state passthrough end-to-end, ensuring provider-specific response fields (e.g., hidden_states) are preserved when cloud-api deserializes and reserializes both full responses and streaming chunks.
Changes:
- Added
return_hidden_statesandlayerstoChatCompletionParamsto explicitly request hidden states from backends that support it. - Added flattened
extramaps toChatCompletionChunkandChatCompletionResponseto round-trip unknown/provider-specific response fields. - Updated struct literals and added regression tests to ensure unknown fields survive deserialize/reserialize for both streaming and non-streaming responses.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| crates/services/src/inference_provider_pool/mod.rs | Updates test params to include new ChatCompletionParams fields. |
| crates/services/src/completions/mod.rs | Extracts return_hidden_states / layers from request extra when building provider params. |
| crates/inference_providers/tests/integration_tests.rs | Updates integration tests to populate new request fields. |
| crates/inference_providers/src/vllm/mod.rs | Updates vLLM tests to include new request fields. |
| crates/inference_providers/src/models.rs | Adds new request fields, adds extra passthrough on chunk/response, and adds round-trip regression tests; fixes doctest import path. |
| crates/inference_providers/src/mock.rs | Updates mock response builders to initialize new extra field. |
| crates/inference_providers/src/external/openai_compatible.rs | Updates tests to include new request fields. |
| crates/inference_providers/src/external/gemini/mod.rs | Ensures synthesized responses initialize new extra field. |
| crates/inference_providers/src/external/anthropic/mod.rs | Ensures synthesized responses initialize new extra field; updates tests to include new request fields. |
| crates/inference_providers/src/chunk_builder.rs | Ensures built chunks initialize new extra field. |
| crates/api/src/routes/completions.rs | Ensures SSE flush chunks initialize new extra field. |
| crates/api/src/conversions.rs | Extracts return_hidden_states / layers from API request extra when converting to provider params. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return_hidden_states: extra.remove("return_hidden_states").and_then(|v| v.as_bool()), | ||
| layers: extra.remove("layers").and_then(|v| serde_json::from_value(v).ok()), |
| return_hidden_states: extra.remove("return_hidden_states").and_then(|v| v.as_bool()), | ||
| layers: extra.remove("layers").and_then(|v| serde_json::from_value(v).ok()), |
| return_hidden_states: extra.remove("return_hidden_states").and_then(|v| v.as_bool()), | ||
| layers: extra.remove("layers").and_then(|v| serde_json::from_value(v).ok()), |
|
Thanks for the PR, no backend has --enable-return-hidden-states currently so this would be a no-op. May I ask what's your use case? |
|
Hey mate, thanks for the prompt review and reply. I'm doing mechanistic interpretability research that requires per-layer activation from transformer models. gland and vLLM both support returning hidden stated natively, so the main gap is the proxy layer stripping them them before they reach the client(?) I understand that this may be currently a no-op, but this was actually suggested to be by Illia, so hope we can find a way to enable it. Happy to have a chat on telegram or call @Evrard-Nil |
Summary
Adds support for hidden-state passthrough in chat completions.
return_hidden_statesandlayersrequest fields toChatCompletionParamsextracatch-all fields toChatCompletionChunkandChatCompletionResponsehidden_statesduring deserialize/reserializeextravaluesTesting
cargo fmt --checkcargo test -p inference_providerscargo checkgit diff --checkThis mirrors the existing
ChatDeltacatch-all behavior so unknown upstream fields are not silently dropped.