|
1 | | -import json |
2 | 1 | from typing import Any |
3 | 2 |
|
4 | 3 | import httpx |
@@ -126,136 +125,142 @@ def set_stop_sequences(self, sequences: str): |
126 | 125 | """Set the stop sequences.""" |
127 | 126 | self.stop_sequences = sequences |
128 | 127 |
|
129 | | - @rx.event(background=True) |
| 128 | + @rx.event |
130 | 129 | async def process_question(self, form_data: dict[str, Any]): |
| 130 | + # Get the question from the form |
131 | 131 | question = form_data["question"] |
| 132 | + |
| 133 | + # Check if the question is empty |
132 | 134 | if not question: |
133 | 135 | return |
134 | 136 |
|
135 | | - # Check auth, set up initial state, and snapshot all needed values |
136 | | - async with self: |
137 | | - if not self.is_authenticated or not self.api_key: |
138 | | - return |
139 | | - |
140 | | - qa = QA(question=question, answer="") |
141 | | - self._messages.append(qa) |
142 | | - self.processing = True |
143 | | - |
144 | | - url = self.opengatellm_url |
145 | | - api_key = self.api_key |
146 | | - model = self.model |
147 | | - temperature = self.temperature |
148 | | - top_p = self.top_p |
149 | | - max_completion_tokens = self.max_completion_tokens |
150 | | - frequency_penalty = self.frequency_penalty |
151 | | - presence_penalty = self.presence_penalty |
152 | | - stream = self.stream |
153 | | - seed_str = self.seed_str |
154 | | - stop_sequences_str = self.stop_sequences |
155 | | - messages_snapshot = list(self._messages) |
156 | | - yield |
| 137 | + async for value in self.api_process_question(question): |
| 138 | + yield value |
| 139 | + |
| 140 | + @rx.event |
| 141 | + async def api_process_question(self, question: str): |
| 142 | + """Get the response from the API. |
| 143 | +
|
| 144 | + Args: |
| 145 | + question: The user's question. |
| 146 | + """ |
| 147 | + |
| 148 | + # Check if authenticated |
| 149 | + if not self.is_authenticated or not self.api_key: |
| 150 | + return |
157 | 151 |
|
158 | | - # Build messages outside the lock |
| 152 | + # Add the question to the list of questions. |
| 153 | + qa = QA(question=question, answer="") |
| 154 | + self._messages.append(qa) |
| 155 | + |
| 156 | + # Clear the input and start the processing. |
| 157 | + self.processing = True |
| 158 | + yield |
| 159 | + |
| 160 | + # Build the messages. |
159 | 161 | messages = [] |
160 | | - for qa in messages_snapshot: |
| 162 | + for qa in self._messages: |
161 | 163 | messages.append({"role": "user", "content": qa["question"]}) |
162 | 164 | if qa["answer"]: |
163 | 165 | messages.append({"role": "assistant", "content": qa["answer"]}) |
164 | 166 |
|
| 167 | + # Remove the last empty answer. |
165 | 168 | if messages and messages[-1]["role"] == "assistant" and not messages[-1]["content"]: |
166 | 169 | messages = messages[:-1] |
167 | 170 |
|
| 171 | + # Prepare the request payload |
168 | 172 | payload = { |
169 | | - "model": model, |
| 173 | + "model": self.model, |
170 | 174 | "messages": messages, |
171 | | - "temperature": temperature, |
172 | | - "top_p": top_p, |
173 | | - "max_completion_tokens": max_completion_tokens, |
174 | | - "frequency_penalty": frequency_penalty, |
175 | | - "presence_penalty": presence_penalty, |
176 | | - "stream": stream, |
| 175 | + "temperature": self.temperature, |
| 176 | + "top_p": self.top_p, |
| 177 | + "max_completion_tokens": self.max_completion_tokens, |
| 178 | + "frequency_penalty": self.frequency_penalty, |
| 179 | + "presence_penalty": self.presence_penalty, |
| 180 | + "stream": self.stream, |
177 | 181 | } |
178 | 182 |
|
179 | | - if seed_str: |
| 183 | + # Add optional parameters |
| 184 | + if self.seed_str: |
180 | 185 | try: |
181 | | - payload["seed"] = int(seed_str) |
| 186 | + payload["seed"] = int(self.seed_str) |
182 | 187 | except ValueError: |
183 | 188 | pass |
184 | 189 |
|
185 | | - if stop_sequences_str: |
186 | | - stop_list = [s.strip() for s in stop_sequences_str.split("\n") if s.strip()] |
| 190 | + if self.stop_sequences: |
| 191 | + stop_list = [s.strip() for s in self.stop_sequences.split("\n") if s.strip()] |
187 | 192 | if stop_list: |
188 | 193 | payload["stop"] = stop_list |
189 | 194 |
|
190 | 195 | try: |
191 | | - if stream: |
| 196 | + if self.stream: |
| 197 | + # Streaming response |
192 | 198 | async with httpx.AsyncClient() as client: |
193 | 199 | async with client.stream( |
194 | 200 | "POST", |
195 | | - f"{url}/v1/chat/completions", |
| 201 | + f"{self.opengatellm_url}/v1/chat/completions", |
196 | 202 | headers={ |
197 | | - "Authorization": f"Bearer {api_key}", |
| 203 | + "Authorization": f"Bearer {self.api_key}", |
198 | 204 | "Content-Type": "application/json", |
199 | 205 | }, |
200 | 206 | json=payload, |
201 | 207 | timeout=configuration.settings.playground_opengatellm_timeout, |
202 | 208 | ) as response: |
203 | 209 | if response.status_code != 200: |
204 | 210 | error_text = await response.aread() |
205 | | - async with self: |
206 | | - self._messages[-1]["answer"] = f"Error: {error_text.decode()}" |
207 | | - self.processing = False |
208 | | - yield |
| 211 | + self._messages[-1]["answer"] = f"Error: {error_text.decode()}" |
| 212 | + self.processing = False |
| 213 | + yield |
209 | 214 | return |
210 | 215 |
|
211 | 216 | async for line in response.aiter_lines(): |
212 | 217 | if line.startswith("data: "): |
213 | 218 | data = line[6:] |
214 | 219 | if data == "[DONE]": |
215 | 220 | break |
| 221 | + |
216 | 222 | try: |
| 223 | + import json |
| 224 | + |
217 | 225 | chunk = json.loads(data) |
218 | 226 | if chunk.get("choices") and len(chunk["choices"]) > 0: |
219 | 227 | delta = chunk["choices"][0].get("delta", {}) |
220 | 228 | content = delta.get("content") |
221 | 229 | if content: |
222 | | - async with self: |
223 | | - self._messages[-1]["answer"] += content |
224 | | - self._messages = self._messages |
225 | | - yield |
| 230 | + self._messages[-1]["answer"] += content |
| 231 | + self._messages = self._messages |
| 232 | + yield |
226 | 233 | except Exception: |
227 | 234 | continue |
228 | 235 | else: |
| 236 | + # Non-streaming response |
229 | 237 | async with httpx.AsyncClient() as client: |
230 | 238 | response = await client.post( |
231 | | - f"{url}/v1/chat/completions", |
| 239 | + f"{self.opengatellm_url}/v1/chat/completions", |
232 | 240 | headers={ |
233 | | - "Authorization": f"Bearer {api_key}", |
| 241 | + "Authorization": f"Bearer {self.api_key}", |
234 | 242 | "Content-Type": "application/json", |
235 | 243 | }, |
236 | 244 | json=payload, |
237 | 245 | timeout=configuration.settings.playground_opengatellm_timeout, |
238 | 246 | ) |
239 | 247 |
|
240 | | - async with self: |
241 | 248 | if response.status_code != 200: |
242 | 249 | self._messages[-1]["answer"] = f"Error: {response.text}" |
243 | 250 | else: |
244 | 251 | data = response.json() |
245 | 252 | if data.get("choices") and len(data["choices"]) > 0: |
246 | 253 | content = data["choices"][0]["message"]["content"] |
247 | 254 | self._messages[-1]["answer"] = content |
| 255 | + |
248 | 256 | yield |
249 | 257 |
|
250 | 258 | except httpx.TimeoutException: |
251 | | - async with self: |
252 | | - self._messages[-1]["answer"] = "Error: Request timeout" |
253 | | - yield |
| 259 | + self._messages[-1]["answer"] = "Error: Request timeout" |
| 260 | + yield |
254 | 261 | except Exception as e: |
255 | | - async with self: |
256 | | - self._messages[-1]["answer"] = f"Error: {str(e)}" |
257 | | - yield |
258 | | - finally: |
259 | | - async with self: |
260 | | - self.processing = False |
261 | | - yield |
| 262 | + self._messages[-1]["answer"] = f"Error: {str(e)}" |
| 263 | + yield |
| 264 | + |
| 265 | + # Toggle the processing flag. |
| 266 | + self.processing = False |
0 commit comments