Skip to content

Commit 3fee63c

Browse files
committed
chat: refactor chat buffer
1 parent def3cc2 commit 3fee63c

File tree

6 files changed

+165
-220
lines changed

6 files changed

+165
-220
lines changed

niinii/src/app.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,20 +308,20 @@ impl App {
308308
ui.dockspace_over_viewport();
309309
};
310310

311-
let nx_state = get_scroll_lock();
311+
let no_inputs = get_scroll_lock();
312312
let mut niinii = ui
313313
.window("niinii")
314314
.opened(run)
315315
.menu_bar(true)
316316
.draw_background(!self.settings().transparent);
317-
if nx_state {
317+
if no_inputs {
318318
niinii = niinii.no_inputs().draw_background(false);
319319
}
320320
niinii.build(|| {
321321
self.show_menu(ctx, ui);
322322
self.show_error_modal(ctx, ui);
323323

324-
if nx_state {
324+
if no_inputs {
325325
stroke_text_with_highlight(
326326
ui,
327327
&ui.get_window_draw_list(),

niinii/src/translator/chat.rs

Lines changed: 47 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
33
use async_trait::async_trait;
44
use enclose::enclose;
55
use openai_chat::{
6-
chat::{self, ChatBuffer, Message, Model},
6+
chat::{self, ChatBuffer, Exchange, Message, Model},
77
ConnectionPolicy,
88
};
99
use tokio::sync::{Mutex, Semaphore};
@@ -53,9 +53,9 @@ impl Translator for ChatTranslator {
5353
let chatgpt = &settings.chat;
5454

5555
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
56-
let chat_request = {
56+
let mut exchange = {
5757
let mut chat = self.chat.lock().await;
58-
chat.begin_exchange(
58+
chat.start_exchange(
5959
Message {
6060
role: chat::Role::System,
6161
content: Some(chatgpt.system_prompt.clone()),
@@ -66,68 +66,59 @@ impl Translator for ChatTranslator {
6666
content: Some(text.clone()),
6767
..Default::default()
6868
},
69-
);
70-
chat.enforce_context_limit(chatgpt.max_context_tokens);
71-
72-
chat::Request {
73-
model: chatgpt.model,
74-
messages: chat.prompt(),
75-
temperature: chatgpt.temperature,
76-
top_p: chatgpt.top_p,
77-
max_tokens: chatgpt.max_tokens,
78-
presence_penalty: chatgpt.presence_penalty,
79-
..Default::default()
80-
}
69+
)
8170
};
8271

83-
let stream = self.client.stream(chat_request).await;
84-
let mut stream = match stream {
85-
Ok(stream) => stream,
86-
Err(err) => {
87-
let mut chat = self.chat.lock().await;
88-
chat.cancel_exchange();
89-
return Err(err.into());
90-
}
72+
let chat_request = chat::Request {
73+
model: chatgpt.model,
74+
messages: exchange.prompt(),
75+
temperature: chatgpt.temperature,
76+
top_p: chatgpt.top_p,
77+
max_tokens: chatgpt.max_tokens,
78+
presence_penalty: chatgpt.presence_penalty,
79+
..Default::default()
9180
};
9281

82+
let exchange = Arc::new(Mutex::new(exchange));
83+
let mut stream = self.client.stream(chat_request).await?;
9384
let token = CancellationToken::new();
9485
let chat = &self.chat;
95-
tokio::spawn(enclose! { (chat, token) async move {
96-
// Hold permit: We are not allowed to begin another translation
97-
// request until this one is complete.
98-
let _permit = permit;
99-
loop {
100-
tokio::select! {
101-
msg = stream.next() => match msg {
102-
Some(Ok(completion)) => {
103-
let mut chat = chat.lock().await;
104-
let message = &completion.choices.first().unwrap().delta;
105-
chat.append_partial_response(message)
86+
tokio::spawn(
87+
enclose! { (chat, token, exchange, chatgpt.max_context_tokens => max_context_tokens) async move {
88+
// Hold permit: We are not allowed to begin another translation
89+
// request until this one is complete.
90+
let _permit = permit;
91+
loop {
92+
tokio::select! {
93+
msg = stream.next() => match msg {
94+
Some(Ok(completion)) => {
95+
let mut exchange = exchange.lock().await;
96+
let message = &completion.choices.first().unwrap().delta;
97+
exchange.append(message)
98+
},
99+
Some(Err(err)) => {
100+
tracing::error!(%err, "stream");
101+
break
102+
},
103+
None => {
104+
let mut chat = chat.lock().await;
105+
let mut exchange = exchange.lock().await;
106+
chat.commit(&mut exchange);
107+
chat.enforce_context_limit(max_context_tokens);
108+
break
109+
}
106110
},
107-
Some(Err(err)) => {
108-
tracing::error!(%err, "stream");
109-
let mut chat = chat.lock().await;
110-
chat.cancel_exchange();
111-
break
112-
},
113-
None => {
114-
let mut chat = chat.lock().await;
115-
chat.end_exchange();
111+
_ = token.cancelled() => {
116112
break
117113
}
118-
},
119-
_ = token.cancelled() => {
120-
let mut chat = chat.lock().await;
121-
chat.cancel_exchange();
122-
break
123114
}
124115
}
125-
}
126-
}.instrument(tracing::Span::current())});
116+
}.instrument(tracing::Span::current())},
117+
);
127118

128-
Ok(Box::new(ChatTranslation::Translated {
119+
Ok(Box::new(ChatTranslation {
129120
model: chatgpt.model,
130-
chat: chat.clone(),
121+
exchange,
131122
_guard: token.drop_guard(),
132123
}))
133124
}
@@ -137,12 +128,10 @@ impl Translator for ChatTranslator {
137128
}
138129
}
139130

140-
pub enum ChatTranslation {
141-
Translated {
142-
model: Model,
143-
chat: Arc<Mutex<ChatBuffer>>,
144-
_guard: DropGuard,
145-
},
131+
pub struct ChatTranslation {
132+
pub model: Model,
133+
pub exchange: Arc<Mutex<Exchange>>,
134+
_guard: DropGuard,
146135
}
147136
impl Translation for ChatTranslation {
148137
fn view(&self) -> Box<dyn View + '_> {

niinii/src/view/translator/chat.rs

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ impl View for ViewChatTranslator<'_> {
2525
// ui.menu("Settings", || {
2626
// });
2727
// ui.separator();
28-
ui.disabled(chat.pending_response(), || {
29-
if ui.menu_item("Clear") {
30-
chat.clear();
31-
}
32-
});
28+
if ui.menu_item("Clear") {
29+
chat.clear();
30+
}
3331
});
3432
if ui.collapsing_header("Tuning", TreeNodeFlags::DEFAULT_OPEN) {
3533
if let Some(_token) = ui.begin_table("##", 2) {
@@ -173,26 +171,6 @@ impl View for ViewChatTranslator<'_> {
173171
_ => {}
174172
}
175173

176-
ui.disabled(true, || {
177-
for message in chat.response_mut() {
178-
let _id = ui.push_id_ptr(message);
179-
ui.table_next_column();
180-
ui.table_next_column();
181-
ui.table_next_column();
182-
ui.table_next_column();
183-
ui.set_next_item_width(ui.current_font_size() * 6.0);
184-
let mut system_role = openai_chat::chat::Role::Assistant;
185-
combo_enum(ui, "##role", &mut system_role);
186-
ui.table_next_column();
187-
ui.disabled(true, || {
188-
if let Some(content) = &mut message.content {
189-
ui.set_next_item_width(ui.content_region_avail()[0]);
190-
ui.input_text("##content", content).build();
191-
}
192-
});
193-
}
194-
});
195-
196174
ui.table_next_column();
197175
ui.table_next_column();
198176
if ui.button_with_size("+", [ui.frame_height(), 0.0]) {
@@ -216,52 +194,51 @@ impl View for ViewChatTranslation<'_> {
216194
let _wrap_token = ui.push_text_wrap_pos_with_pos(0.0);
217195
ui.text(""); // anchor for line wrapping
218196
ui.same_line();
219-
match self.0 {
220-
ChatTranslation::Translated { chat, .. } => {
221-
let chat = chat.blocking_lock();
222-
let draw_list = ui.get_window_draw_list();
223-
stroke_text_with_highlight(
224-
ui,
225-
&draw_list,
226-
"[ChatGPT]",
227-
1.0,
228-
Some(StyleColor::NavHighlight),
229-
);
230-
for content in chat.response().iter().flat_map(|c| c.content.as_ref()) {
231-
ui.same_line();
232-
stroke_text_with_highlight(
233-
ui,
234-
&draw_list,
235-
content,
236-
1.0,
237-
Some(StyleColor::TextSelectedBg),
238-
);
239-
}
240-
if chat.pending_response() {
241-
if chat.response().is_empty() {
242-
ui.same_line();
243-
} else {
244-
ui.same_line_with_spacing(0.0, 0.0);
245-
}
246-
stroke_text_with_highlight(
247-
ui,
248-
&draw_list,
249-
ellipses(ui),
250-
1.0,
251-
Some(StyleColor::TextSelectedBg),
252-
);
253-
}
197+
let ChatTranslation { exchange, .. } = self.0;
198+
let exchange = exchange.blocking_lock();
199+
let draw_list = ui.get_window_draw_list();
200+
stroke_text_with_highlight(
201+
ui,
202+
&draw_list,
203+
"[ChatGPT]",
204+
1.0,
205+
Some(StyleColor::NavHighlight),
206+
);
207+
for content in exchange.response().iter().flat_map(|c| c.content.as_ref()) {
208+
ui.same_line();
209+
stroke_text_with_highlight(
210+
ui,
211+
&draw_list,
212+
content,
213+
1.0,
214+
Some(StyleColor::TextSelectedBg),
215+
);
216+
}
217+
if exchange.usage().is_none() {
218+
if exchange.response().is_none() {
219+
ui.same_line();
220+
} else {
221+
ui.same_line_with_spacing(0.0, 0.0);
254222
}
223+
stroke_text_with_highlight(
224+
ui,
225+
&draw_list,
226+
ellipses(ui),
227+
1.0,
228+
Some(StyleColor::TextSelectedBg),
229+
);
255230
}
256231
}
257232
}
258233

259234
pub struct ViewChatTranslationUsage<'a>(pub &'a ChatTranslation);
260235
impl View for ViewChatTranslationUsage<'_> {
261236
fn ui(&mut self, ui: &imgui::Ui) {
262-
let ChatTranslation::Translated { model, chat, .. } = self.0;
263-
let chat = chat.blocking_lock();
264-
let usage = chat.usage();
237+
let ChatTranslation {
238+
model, exchange, ..
239+
} = self.0;
240+
let exchange = exchange.blocking_lock();
241+
let usage = exchange.usage();
265242
if let Some(usage) = usage {
266243
ui.same_line();
267244
ProgressBar::new(0.0)

0 commit comments

Comments
 (0)