-
Notifications
You must be signed in to change notification settings - Fork 377
Expand file tree
/
Copy pathChatModel.swift
More file actions
113 lines (92 loc) · 2.95 KB
/
ChatModel.swift
File metadata and controls
113 lines (92 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright © 2025 Apple Inc.
import MLXLMHuggingFace
import MLXLLM
import MLXLMCommon
import MLXLMTransformers
import SwiftUI
/// which model to load
private let modelConfiguration = LLMRegistry.gemma3_1B_qat_4bit
/// instructions for the model (the system prompt)
private let instructions =
"""
You are a friendly and helpful chatbot.
"""
/// parameters controlling generation
private let generateParameters = GenerateParameters(temperature: 0.5)
/// Downloads and loads the weights for the model -- we have one of these in the process
@MainActor @Observable public class ModelLoader {
enum State {
case idle
case loading(Task<ModelContainer, Error>)
case loaded(ModelContainer)
}
public var progress = 0.0
public var isLoaded: Bool {
switch state {
case .idle, .loading: false
case .loaded: true
}
}
private var state = State.idle
public func model() async throws -> ModelContainer {
switch self.state {
case .idle:
let task = Task {
// download and report progress
try await loadModelContainer(
from: HubClient.default,
configuration: modelConfiguration
) { value in
Task { @MainActor in
self.progress = value.fractionCompleted
}
}
}
self.state = .loading(task)
let model = try await task.value
self.state = .loaded(model)
return model
case .loading(let task):
return try await task.value
case .loaded(let model):
return model
}
}
}
/// View model for the ChatSession
@MainActor @Observable public class ChatModel {
private let session: ChatSession
/// back and forth conversation between the user and LLM
public var messages = [Chat.Message]()
private var task: Task<Void, Error>?
public var isBusy: Bool {
task != nil
}
public init(model: ModelContainer) {
self.session = ChatSession(
model,
instructions: instructions,
generateParameters: generateParameters)
}
public func cancel() {
task?.cancel()
}
public func respond(_ message: String) {
guard task == nil else { return }
self.messages.append(.init(role: .user, content: message))
self.messages.append(.init(role: .assistant, content: "..."))
let lastIndex = self.messages.count - 1
self.task = Task {
var first = true
for try await item in session.streamResponse(to: message) {
if first {
self.messages[lastIndex].content = item
first = false
} else {
self.messages[lastIndex].content += item
}
}
self.task = nil
}
}
}