Skip to content

Commit 3847516

Browse files
authored
Fix decoding error on assistants response (#31)
* Fix decoding error on assistants response * Fix API reference links
1 parent 4b8caa7 commit 3847516

File tree

3 files changed

+169
-32
lines changed

3 files changed

+169
-32
lines changed

Sources/OpenAI/Public/Shared/ResponseFormat.swift

+39-13
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,44 @@
88
import Foundation
99

1010

11-
public struct ResponseFormat: Codable {
12-
13-
/// Defaults to text
14-
/// Setting to `json_object` enables JSON mode. This guarantees that the message the model generates is valid JSON.
15-
/// Note that your system prompt must still instruct the model to produce JSON, and to help ensure you don't forget, the API will throw an error if the string JSON does not appear in your system message.
16-
/// Also note that the message content may be partial (i.e. cut off) if `finish_reason="length"`, which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.
17-
/// Must be one of `text `or `json_object`.
18-
public var type: String?
19-
20-
public init(
21-
type: String?)
22-
{
23-
self.type = type
11+
/// Defaults to text
12+
/// Setting to `json_object` enables JSON mode. This guarantees that the message the model generates is valid JSON.
13+
/// Note that your system prompt must still instruct the model to produce JSON, and to help ensure you don't forget, the API will throw an error if the string JSON does not appear in your system message.
14+
/// Also note that the message content may be partial (i.e. cut off) if `finish_reason="length"`, which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.
15+
/// Must be one of `text `or `json_object`.
16+
public enum ResponseFormat: Codable, Equatable {
17+
case auto
18+
case type(String)
19+
20+
enum CodingKeys: String, CodingKey {
21+
case type = "type"
22+
}
23+
24+
public func encode(to encoder: Encoder) throws {
25+
var container = encoder.container(keyedBy: CodingKeys.self)
26+
switch self {
27+
case .auto:
28+
try container.encode("text", forKey: .type)
29+
case .type(let responseType):
30+
try container.encode(responseType, forKey: .type)
31+
}
32+
}
33+
34+
public init(from decoder: Decoder) throws {
35+
// Handle the 'type' case:
36+
if let container = try? decoder.container(keyedBy: CodingKeys.self),
37+
let responseType = try? container.decode(String.self, forKey: .type) {
38+
self = .type(responseType)
39+
return
40+
}
41+
42+
// Handle the 'auto' case:
43+
let container = try decoder.singleValueContainer()
44+
switch try container.decode(String.self) {
45+
case "auto":
46+
self = .auto
47+
default:
48+
throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid response_format structure")
49+
}
2450
}
2551
}

Sources/OpenAI/Public/Shared/ToolChoice.swift

+18-13
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,23 @@ public enum ToolChoice: Codable, Equatable {
4343
}
4444

4545
public init(from decoder: Decoder) throws {
46-
let container = try decoder.container(keyedBy: CodingKeys.self)
47-
if let _ = try? container.decode(String.self, forKey: .none) {
48-
self = .none
49-
return
50-
}
51-
if let _ = try? container.decode(String.self, forKey: .auto) {
52-
self = .auto
53-
return
54-
}
55-
let functionContainer = try container.nestedContainer(keyedBy: FunctionCodingKeys.self, forKey: .function)
56-
let name = try functionContainer.decode(String.self, forKey: .name)
57-
// Assuming the type is always "function" as default if decoding this case.
58-
self = .function(type: "function", name: name)
46+
// Handle the 'function' case:
47+
if let container = try? decoder.container(keyedBy: CodingKeys.self),
48+
let functionContainer = try? container.nestedContainer(keyedBy: FunctionCodingKeys.self, forKey: .function) {
49+
let name = try functionContainer.decode(String.self, forKey: .name)
50+
self = .function(type: "function", name: name)
51+
return
52+
}
53+
54+
// Handle the 'auto' and 'none' cases
55+
let container = try decoder.singleValueContainer()
56+
switch try container.decode(String.self) {
57+
case "none":
58+
self = .none
59+
case "auto":
60+
self = .auto
61+
default:
62+
throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid tool_choice structure")
63+
}
5964
}
6065
}

Tests/OpenAITests/OpenAITests.swift

+112-6
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,117 @@ import XCTest
22
@testable import SwiftOpenAI
33

44
final class OpenAITests: XCTestCase {
5-
func testExample() throws {
6-
// XCTest Documentation
7-
// https://developer.apple.com/documentation/xctest
85

9-
// Defining Test Cases and Test Methods
10-
// https://developer.apple.com/documentation/xctest/defining_test_cases_and_test_methods
11-
}
6+
// OpenAI is loose with their API contract, unfortunately.
7+
// Here we test that `tool_choice` is decodable from a string OR an object,
8+
// which is required for deserializing responses from assistants:
9+
// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tool_choice
10+
func testToolChoiceIsDecodableFromStringOrObject() throws {
11+
let expectedResponseMappings: [(String, ToolChoice)] = [
12+
("\"auto\"", .auto),
13+
("\"none\"", .none),
14+
("{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}", .function(type: "function", name: "my_function"))
15+
]
16+
let decoder = JSONDecoder()
17+
for (response, expectedToolChoice) in expectedResponseMappings {
18+
print(response)
19+
guard let jsonData = response.data(using: .utf8) else {
20+
XCTFail("Could not create json from sample response")
21+
return
22+
}
23+
let toolChoice = try decoder.decode(ToolChoice.self, from: jsonData)
24+
XCTAssertEqual(toolChoice, expectedToolChoice, "Mapping from \(response) did not yield expected result")
25+
}
26+
}
27+
28+
// Here we test that `response_format` is decodable from a string OR an object,
29+
// which is required for deserializing responses from assistants:
30+
// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format
31+
func testResponseFormatIsDecodableFromStringOrObject() throws {
32+
let expectedResponseMappings: [(String, ResponseFormat)] = [
33+
("\"auto\"", .auto),
34+
("{\"type\": \"json_object\"}", .type("json_object")),
35+
("{\"type\": \"text\"}", .type("text"))
36+
]
37+
let decoder = JSONDecoder()
38+
for (response, expectedResponseFormat) in expectedResponseMappings {
39+
print(response)
40+
guard let jsonData = response.data(using: .utf8) else {
41+
XCTFail("Could not create json from sample response")
42+
return
43+
}
44+
let responseFormat = try decoder.decode(ResponseFormat.self, from: jsonData)
45+
XCTAssertEqual(responseFormat, expectedResponseFormat, "Mapping from \(response) did not yield expected result")
46+
}
47+
}
48+
49+
// ResponseFormat is used in other places, and in those places it can *only* be populated with an object.
50+
// OpenAI really suffers in API consistency.
51+
// If a client sets the ResponseFormat to `auto` (which is now a valid case in the codebase), we
52+
// encode to {"type": "text"} to satisfy when response_format can only be an object, such as:
53+
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
54+
func testAutoResponseFormatEncodesToText() throws {
55+
let jsonData = try JSONEncoder().encode(ResponseFormat.auto)
56+
XCTAssertEqual(String(data: jsonData, encoding: .utf8), "{\"type\":\"text\"}")
57+
}
58+
59+
// Verifies that our custom encoding of ResponseFormat supports the 'text' type:
60+
func testTextResponseFormatIsEncodable() throws {
61+
let jsonData = try JSONEncoder().encode(ResponseFormat.type("text"))
62+
XCTAssertEqual(String(data: jsonData, encoding: .utf8), "{\"type\":\"text\"}")
63+
64+
}
65+
66+
// Verifies that our custom encoding of ResponseFormat supports the 'json_object' type:
67+
func testJSONResponseFormatIsEncodable() throws {
68+
let jsonData = try JSONEncoder().encode(ResponseFormat.type("json_object"))
69+
XCTAssertEqual(String(data: jsonData, encoding: .utf8), "{\"type\":\"json_object\"}")
70+
}
71+
72+
// Regression test for decoding assistant runs. Thank you to Martin Brian for the repro:
73+
// https://gist.github.com/mbrian23/6863ffa705ccbb5097bd07efb2355a30
74+
func testThreadRunResponseIsDecodable() throws {
75+
let response = """
76+
{
77+
"id": "run_ZWntP0jJr391lwVu3JqFZbKV",
78+
"object": "thread.run",
79+
"created_at": 1713979538,
80+
"assistant_id": "asst_qxhQxXsecIjqw9cBjFTB6yvd",
81+
"thread_id": "thread_CT4hxsN5N0A5vXg4FeR4pOPD",
82+
"status": "queued",
83+
"started_at": null,
84+
"expires_at": 1713980138,
85+
"cancelled_at": null,
86+
"failed_at": null,
87+
"completed_at": null,
88+
"required_action": null,
89+
"last_error": null,
90+
"model": "gpt-4-1106-preview",
91+
"instructions": "You answer ever question with ‘hello world’",
92+
"tools": [],
93+
"file_ids": [],
94+
"metadata": {},
95+
"temperature": 1.0,
96+
"top_p": 1.0,
97+
"max_completion_tokens": null,
98+
"max_prompt_tokens": null,
99+
"truncation_strategy": {
100+
"type": "auto",
101+
"last_messages": null
102+
},
103+
"incomplete_details": null,
104+
"usage": null,
105+
"response_format": "auto",
106+
"tool_choice": "auto"
107+
}
108+
"""
109+
110+
guard let jsonData = response.data(using: .utf8) else {
111+
XCTFail("Could not create json from sample response")
112+
return
113+
}
114+
let decoder = JSONDecoder()
115+
let runObject = try decoder.decode(RunObject.self, from: jsonData)
116+
XCTAssertEqual(runObject.id, "run_ZWntP0jJr391lwVu3JqFZbKV")
117+
}
12118
}

0 commit comments

Comments
 (0)