Skip to content

Commit b206de0

Browse files
authored
Merge pull request #6 from gaodeng/main
2 parents 72a5032 + 9d68c5c commit b206de0

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

src/apple-ai.swift

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ private func describeTranscriptEntry(_ entry: Transcript.Entry) -> String {
171171
}
172172

173173
struct Guardrails {
174-
static var developerProvided: LanguageModelSession.Guardrails {
175-
var guardrails = LanguageModelSession.Guardrails.default
174+
static var developerProvided: SystemLanguageModel.Guardrails {
175+
var guardrails = SystemLanguageModel.Guardrails.default
176176

177177
withUnsafeMutablePointer(to: &guardrails) { ptr in
178178
let rawPtr = UnsafeMutableRawPointer(ptr)
@@ -929,28 +929,32 @@ private func convertJSONSchemaToDynamic(_ dict: [String: Any], name: String? = n
929929

930930
@available(macOS 26.0, *)
931931
private func generatedContentToJSON(_ content: GeneratedContent) -> Any {
932-
// Try object
933-
if let dict = try? content.properties() {
932+
switch content.kind {
933+
case .structure(let properties, _):
934934
var result: [String: Any] = [:]
935-
for (k, v) in dict {
936-
result[k] = generatedContentToJSON(v)
935+
for (key, value) in properties {
936+
result[key] = generatedContentToJSON(value)
937937
}
938938
return result
939+
940+
case .array(let elements):
941+
return elements.map { generatedContentToJSON($0) }
942+
943+
case .string(let stringValue):
944+
return stringValue
945+
946+
case .number(let numberValue):
947+
return numberValue
948+
949+
case .bool(let boolValue):
950+
return boolValue
951+
952+
case .null:
953+
return NSNull()
954+
955+
@unknown default:
956+
return content.jsonString
939957
}
940-
941-
// Try array
942-
if let arr = try? content.elements() {
943-
return arr.map { generatedContentToJSON($0) }
944-
}
945-
946-
// Try basic scalar types
947-
if let str = try? content.value(String.self) { return str }
948-
if let intVal = try? content.value(Int.self) { return intVal }
949-
if let dbl = try? content.value(Double.self) { return dbl }
950-
if let boolVal = try? content.value(Bool.self) { return boolVal }
951-
952-
// Fallback to description
953-
return String(describing: content)
954958
}
955959

956960
@available(macOS 26.0, *)
@@ -1191,8 +1195,9 @@ public func appleAIGenerateUnified(
11911195
private func handleBasicMode(context: ConversationContext) async throws -> String {
11921196
let transcript = Transcript(entries: context.transcriptEntries)
11931197
debugPrintTranscript(transcript, prompt: context.currentPrompt)
1198+
let model = SystemLanguageModel(guardrails: Guardrails.developerProvided)
11941199
let session = LanguageModelSession(
1195-
guardrails: Guardrails.developerProvided, transcript: transcript)
1200+
model: model, transcript: transcript)
11961201
let response = try await session.respond(to: context.currentPrompt, options: context.options)
11971202

11981203
// Return as JSON for consistency
@@ -1208,15 +1213,16 @@ private func handleBasicModeStream(
12081213
) async throws {
12091214
let transcript = Transcript(entries: context.transcriptEntries)
12101215
debugPrintTranscript(transcript, prompt: context.currentPrompt)
1216+
let model = SystemLanguageModel(guardrails: Guardrails.developerProvided)
12111217
let session = LanguageModelSession(
1212-
guardrails: Guardrails.developerProvided, transcript: transcript)
1218+
model: model, transcript: transcript)
12131219

12141220
var prev = ""
12151221
for try await cumulative in session.streamResponse(
12161222
to: context.currentPrompt, options: context.options)
12171223
{
1218-
let delta = String(cumulative.dropFirst(prev.count))
1219-
prev = cumulative
1224+
let delta = String(cumulative.content.dropFirst(prev.count))
1225+
prev = cumulative.content
12201226
guard !delta.isEmpty else { continue }
12211227

12221228
delta.withCString { cStr in
@@ -1245,8 +1251,9 @@ private func handleStructuredMode(
12451251
// Create session without tools (structured generation doesn't use tools constructor)
12461252
let transcript = Transcript(entries: context.transcriptEntries)
12471253
debugPrintTranscript(transcript, prompt: context.currentPrompt)
1254+
let model = SystemLanguageModel(guardrails: Guardrails.developerProvided)
12481255
let session = LanguageModelSession(
1249-
guardrails: Guardrails.developerProvided, transcript: transcript)
1256+
model: model, transcript: transcript)
12501257

12511258
// Generate structured response
12521259
let response = try await session.respond(
@@ -1338,8 +1345,9 @@ private func handleToolsMode(
13381345

13391346
let transcript = Transcript(entries: finalEntries)
13401347
debugPrintTranscript(transcript, prompt: context.currentPrompt)
1348+
let model = SystemLanguageModel(guardrails: Guardrails.developerProvided)
13411349
let session = LanguageModelSession(
1342-
guardrails: Guardrails.developerProvided, tools: tools, transcript: transcript)
1350+
model: model, tools: tools, transcript: transcript)
13431351

13441352
// Reset tool call collection
13451353
ToolCallCollector.shared.reset()
@@ -1401,8 +1409,8 @@ private func handleToolsMode(
14011409
}
14021410
}
14031411

1404-
let delta = String(cumulative.dropFirst(prev.count))
1405-
prev = cumulative
1412+
let delta = String(cumulative.content.dropFirst(prev.count))
1413+
prev = cumulative.content
14061414
guard !delta.isEmpty else { continue }
14071415

14081416
delta.withCString { cStr in

0 commit comments

Comments
 (0)