Skip to content

Commit d1053e7

Browse files
authored
fix(openai): tolerate object-form tool-call arguments in streaming (0xPlaygrounds#1822)
* fix(openai): accept object-form tool-call arguments in streaming * refactor(openai): reuse shared json_utils deserializer for tool-call arguments
1 parent 9ac2c67 commit d1053e7

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

crates/rig-core/src/json_utils.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ pub fn value_to_json_string(value: &serde_json::Value) -> String {
3535
}
3636
}
3737

38+
/// Deserialize a field that may arrive as either a JSON-encoded string or any other
39+
/// JSON value, into `Option<String>`.
40+
///
41+
/// - A string is taken verbatim.
42+
/// - Any other JSON value is re-serialized to its compact JSON-string form (via
43+
/// [`value_to_json_string`]). Object key order is not preserved, which is fine
44+
/// because callers re-parse the string.
45+
/// - `null` or a missing field becomes `None`.
46+
///
47+
/// Tolerates OpenAI-compatible gateways that stream `tool_calls[].function.arguments`
48+
/// as an object (e.g. `{}`) instead of the spec-mandated JSON string (`"{}"`).
49+
pub fn deserialize_json_string_or_value<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
50+
where
51+
D: Deserializer<'de>,
52+
{
53+
let value = Option::<serde_json::Value>::deserialize(deserializer)?;
54+
Ok(match value {
55+
None | Some(serde_json::Value::Null) => None,
56+
Some(v) => Some(value_to_json_string(&v)),
57+
})
58+
}
59+
3860
/// Parse tool arguments from a streamed string payload.
3961
/// Some providers emit an empty string for parameterless tool calls; normalize that to `{}`.
4062
pub fn parse_tool_arguments(arguments: &str) -> serde_json::Result<serde_json::Value> {
@@ -207,6 +229,65 @@ mod tests {
207229
data: serde_json::Value,
208230
}
209231

232+
#[derive(serde::Deserialize)]
233+
struct ArgWrapper {
234+
#[serde(default, deserialize_with = "deserialize_json_string_or_value")]
235+
arguments: Option<String>,
236+
}
237+
238+
/// Spec-compliant case: `arguments` is already a JSON-encoded string, taken verbatim.
239+
#[test]
240+
fn json_string_or_value_string_passthrough() {
241+
let w: ArgWrapper = serde_json::from_str(r#"{"arguments":"{\"a\":1}"}"#).unwrap();
242+
assert_eq!(w.arguments.as_deref(), Some(r#"{"a":1}"#));
243+
}
244+
245+
/// Non-compliant gateway: an empty object `{}` must serialize to the string `"{}"`,
246+
/// not be treated as absent (None).
247+
#[test]
248+
fn json_string_or_value_empty_object() {
249+
let w: ArgWrapper = serde_json::from_str(r#"{"arguments":{}}"#).unwrap();
250+
assert_eq!(w.arguments.as_deref(), Some("{}"));
251+
}
252+
253+
/// Non-compliant gateway: a nested object is re-serialized to a string.
254+
#[test]
255+
fn json_string_or_value_nested_object() {
256+
let w: ArgWrapper =
257+
serde_json::from_str(r#"{"arguments":{"path":"/tmp","depth":2}}"#).unwrap();
258+
// `arguments` is re-serialized from a Value; object key order is not guaranteed
259+
// (depends on serde_json's `preserve_order` feature), so re-parse and compare
260+
// values rather than the raw string.
261+
let parsed: serde_json::Value =
262+
serde_json::from_str(w.arguments.as_deref().unwrap()).unwrap();
263+
assert_eq!(parsed["path"], "/tmp");
264+
assert_eq!(parsed["depth"], 2);
265+
}
266+
267+
/// Non-compliant gateway: an array is also "any other JSON value" and serializes to a
268+
/// string. Array order is meaningful and preserved by serde_json, so compare the string
269+
/// directly.
270+
#[test]
271+
fn json_string_or_value_array() {
272+
let w: ArgWrapper = serde_json::from_str(r#"{"arguments":[1,2,3]}"#).unwrap();
273+
assert_eq!(w.arguments.as_deref(), Some("[1,2,3]"));
274+
}
275+
276+
/// Regression test: JSON null must collapse to None (not the string "null").
277+
/// Removing `.filter(|v| !v.is_null())` from the deserializer would fail this test.
278+
#[test]
279+
fn json_string_or_value_null_is_none() {
280+
let w: ArgWrapper = serde_json::from_str(r#"{"arguments":null}"#).unwrap();
281+
assert!(w.arguments.is_none());
282+
}
283+
284+
/// A missing field is likewise None (relies on `#[serde(default)]`).
285+
#[test]
286+
fn json_string_or_value_missing_is_none() {
287+
let w: ArgWrapper = serde_json::from_str(r#"{}"#).unwrap();
288+
assert!(w.arguments.is_none());
289+
}
290+
210291
#[test]
211292
fn test_merge() {
212293
let a = serde_json::json!({"key1": "value1"});

crates/rig-core/src/providers/openai/completion/streaming.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ use crate::streaming;
1919
#[derive(Default, Deserialize, Debug)]
2020
pub(crate) struct StreamingFunction {
2121
pub(crate) name: Option<String>,
22+
#[serde(
23+
default,
24+
deserialize_with = "crate::json_utils::deserialize_json_string_or_value"
25+
)]
2226
pub(crate) arguments: Option<String>,
2327
}
2428

@@ -237,6 +241,32 @@ mod tests {
237241
);
238242
}
239243

244+
#[test]
245+
fn test_streaming_function_object_arguments() {
246+
// Some OpenAI-compatible gateways send `arguments` as a JSON object
247+
// instead of the spec-mandated JSON-encoded string. Accept it by
248+
// re-serializing to the string form rather than dropping the chunk.
249+
let json = r#"{"name": "list_dir", "arguments": {}}"#;
250+
let function: StreamingFunction = serde_json::from_str(json).unwrap();
251+
assert_eq!(function.name, Some("list_dir".to_string()));
252+
assert_eq!(function.arguments.as_ref().unwrap(), "{}");
253+
254+
let json = r#"{"name": "get_weather", "arguments": {"city": "London"}}"#;
255+
let function: StreamingFunction = serde_json::from_str(json).unwrap();
256+
assert_eq!(function.arguments.as_ref().unwrap(), r#"{"city":"London"}"#);
257+
}
258+
259+
#[test]
260+
fn test_streaming_function_null_arguments() {
261+
let json = r#"{"name": "list_dir", "arguments": null}"#;
262+
let function: StreamingFunction = serde_json::from_str(json).unwrap();
263+
assert!(function.arguments.is_none());
264+
265+
let json = r#"{"name": "list_dir"}"#;
266+
let function: StreamingFunction = serde_json::from_str(json).unwrap();
267+
assert!(function.arguments.is_none());
268+
}
269+
240270
#[test]
241271
fn test_streaming_tool_call_deserialization() {
242272
let json = r#"{

0 commit comments

Comments
 (0)