@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
33use async_trait:: async_trait;
44use enclose:: enclose;
55use openai_chat:: {
6- chat:: { self , ChatBuffer , Message , Model } ,
6+ chat:: { self , ChatBuffer , Exchange , Message , Model } ,
77 ConnectionPolicy ,
88} ;
99use tokio:: sync:: { Mutex , Semaphore } ;
@@ -53,9 +53,9 @@ impl Translator for ChatTranslator {
5353 let chatgpt = & settings. chat ;
5454
5555 let permit = self . semaphore . clone ( ) . acquire_owned ( ) . await . unwrap ( ) ;
56- let chat_request = {
56+ let mut exchange = {
5757 let mut chat = self . chat . lock ( ) . await ;
58- chat. begin_exchange (
58+ chat. start_exchange (
5959 Message {
6060 role : chat:: Role :: System ,
6161 content : Some ( chatgpt. system_prompt . clone ( ) ) ,
@@ -66,68 +66,59 @@ impl Translator for ChatTranslator {
6666 content : Some ( text. clone ( ) ) ,
6767 ..Default :: default ( )
6868 } ,
69- ) ;
70- chat. enforce_context_limit ( chatgpt. max_context_tokens ) ;
71-
72- chat:: Request {
73- model : chatgpt. model ,
74- messages : chat. prompt ( ) ,
75- temperature : chatgpt. temperature ,
76- top_p : chatgpt. top_p ,
77- max_tokens : chatgpt. max_tokens ,
78- presence_penalty : chatgpt. presence_penalty ,
79- ..Default :: default ( )
80- }
69+ )
8170 } ;
8271
83- let stream = self . client . stream ( chat_request ) . await ;
84- let mut stream = match stream {
85- Ok ( stream ) => stream ,
86- Err ( err ) => {
87- let mut chat = self . chat . lock ( ) . await ;
88- chat . cancel_exchange ( ) ;
89- return Err ( err . into ( ) ) ;
90- }
72+ let chat_request = chat :: Request {
73+ model : chatgpt . model ,
74+ messages : exchange . prompt ( ) ,
75+ temperature : chatgpt . temperature ,
76+ top_p : chatgpt . top_p ,
77+ max_tokens : chatgpt . max_tokens ,
78+ presence_penalty : chatgpt . presence_penalty ,
79+ .. Default :: default ( )
9180 } ;
9281
82+ let exchange = Arc :: new ( Mutex :: new ( exchange) ) ;
83+ let mut stream = self . client . stream ( chat_request) . await ?;
9384 let token = CancellationToken :: new ( ) ;
9485 let chat = & self . chat ;
95- tokio:: spawn ( enclose ! { ( chat, token) async move {
96- // Hold permit: We are not allowed to begin another translation
97- // request until this one is complete.
98- let _permit = permit;
99- loop {
100- tokio:: select! {
101- msg = stream. next( ) => match msg {
102- Some ( Ok ( completion) ) => {
103- let mut chat = chat. lock( ) . await ;
104- let message = & completion. choices. first( ) . unwrap( ) . delta;
105- chat. append_partial_response( message)
86+ tokio:: spawn (
87+ enclose ! { ( chat, token, exchange, chatgpt. max_context_tokens => max_context_tokens) async move {
88+ // Hold permit: We are not allowed to begin another translation
89+ // request until this one is complete.
90+ let _permit = permit;
91+ loop {
92+ tokio:: select! {
93+ msg = stream. next( ) => match msg {
94+ Some ( Ok ( completion) ) => {
95+ let mut exchange = exchange. lock( ) . await ;
96+ let message = & completion. choices. first( ) . unwrap( ) . delta;
97+ exchange. append( message)
98+ } ,
99+ Some ( Err ( err) ) => {
100+ tracing:: error!( %err, "stream" ) ;
101+ break
102+ } ,
103+ None => {
104+ let mut chat = chat. lock( ) . await ;
105+ let mut exchange = exchange. lock( ) . await ;
106+ chat. commit( & mut exchange) ;
107+ chat. enforce_context_limit( max_context_tokens) ;
108+ break
109+ }
106110 } ,
107- Some ( Err ( err) ) => {
108- tracing:: error!( %err, "stream" ) ;
109- let mut chat = chat. lock( ) . await ;
110- chat. cancel_exchange( ) ;
111- break
112- } ,
113- None => {
114- let mut chat = chat. lock( ) . await ;
115- chat. end_exchange( ) ;
111+ _ = token. cancelled( ) => {
116112 break
117113 }
118- } ,
119- _ = token. cancelled( ) => {
120- let mut chat = chat. lock( ) . await ;
121- chat. cancel_exchange( ) ;
122- break
123114 }
124115 }
125- }
126- } . instrument ( tracing :: Span :: current ( ) ) } ) ;
116+ } . instrument ( tracing :: Span :: current ( ) ) } ,
117+ ) ;
127118
128- Ok ( Box :: new ( ChatTranslation :: Translated {
119+ Ok ( Box :: new ( ChatTranslation {
129120 model : chatgpt. model ,
130- chat : chat . clone ( ) ,
121+ exchange ,
131122 _guard : token. drop_guard ( ) ,
132123 } ) )
133124 }
@@ -137,12 +128,10 @@ impl Translator for ChatTranslator {
137128 }
138129}
139130
140- pub enum ChatTranslation {
141- Translated {
142- model : Model ,
143- chat : Arc < Mutex < ChatBuffer > > ,
144- _guard : DropGuard ,
145- } ,
131+ pub struct ChatTranslation {
132+ pub model : Model ,
133+ pub exchange : Arc < Mutex < Exchange > > ,
134+ _guard : DropGuard ,
146135}
147136impl Translation for ChatTranslation {
148137 fn view ( & self ) -> Box < dyn View + ' _ > {
0 commit comments