-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathChat.swift
More file actions
156 lines (135 loc) · 6.62 KB
/
Chat.swift
File metadata and controls
156 lines (135 loc) · 6.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
/// the context in memory between each message sent.
public final class Chat: Sendable {
private let model: GenerativeModel
private let _history: History
init(model: GenerativeModel, history: [ModelContent]) {
self.model = model
_history = History(history: history)
}
/// The previous content from the chat that has been successfully sent and received from the
/// model. This will be provided to the model for each message sent as context for the discussion.
public var history: [ModelContent] {
get {
return _history.history
}
set {
_history.history = newValue
}
}
var generationConfig: GenerationConfig? { model.generationConfig }
/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ parts: any PartsRepresentable...) async throws
-> GenerateContentResponse {
return try await sendMessage([ModelContent(parts: parts)])
}
/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ content: [ModelContent]) async throws -> GenerateContentResponse {
return try await sendMessage(content, generationConfig: generationConfig)
}
/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, watchOS 8.0, *)
public func sendMessageStream(_ parts: any PartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream([ModelContent(parts: parts)])
}
/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, watchOS 8.0, *)
public func sendMessageStream(_ content: [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream(content, generationConfig: generationConfig)
}
// MARK: - Internal
func sendMessage(_ content: [ModelContent],
generationConfig: GenerationConfig?) async throws -> GenerateContentResponse {
// Ensure that the new content has the role set.
let newContent = content.map(populateContentRole(_:))
// Send the history alongside the new message as context.
let request = history + newContent
let result = try await model.generateContent(request, generationConfig: generationConfig)
guard let reply = result.candidates.first?.content else {
let error = NSError(domain: "com.google.generative-ai",
code: -1,
userInfo: [
NSLocalizedDescriptionKey: "No candidates with content available.",
])
throw GenerateContentError.internalError(underlying: error)
}
// Make sure we inject the role into the content received.
let toAdd = ModelContent(role: "model", parts: reply.parts)
// Append the request and successful result to history, then return the value.
_history.append(contentsOf: newContent)
_history.append(toAdd)
return result
}
@available(macOS 12.0, watchOS 8.0, *)
func sendMessageStream(_ content: [ModelContent], generationConfig: GenerationConfig?) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))
// Send the history alongside the new message as context.
let request = history + newContent
let stream = try model.generateContentStream(request, generationConfig: generationConfig)
return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []
do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
if let chunkContent = chunk.candidates.first?.content {
aggregatedContent.append(chunkContent)
}
// Pass along the chunk.
continuation.yield(chunk)
}
} catch {
// Rethrow the error that the underlying stream threw. Don't add anything to history.
continuation.finish(throwing: error)
return
}
// Save the request.
_history.append(contentsOf: newContent)
// Aggregate the content to add it to the history before we finish.
let aggregated = self._history.aggregatedChunks(aggregatedContent)
self._history.append(aggregated)
continuation.finish()
}
}
}
/// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions.
private func populateContentRole(_ content: ModelContent) -> ModelContent {
if content.role != nil {
return content
} else {
return ModelContent(role: "user", parts: content.parts)
}
}
}