-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathgpt.py
More file actions
322 lines (272 loc) · 15 KB
/
Copy pathgpt.py
File metadata and controls
322 lines (272 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Type, Deque, Dict, Generator
from mautrix.client import Client
from collections import deque, defaultdict
from maubot.handlers import command, event
from maubot import Plugin, MessageEvent
from mautrix.errors import MNotFound, MatrixRequestError, MUnknown
from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, EncryptedEvent
from mautrix.util import markdown
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("api_endpoint")
helper.copy("gpt_api_key")
helper.copy("model")
helper.copy("max_tokens")
helper.copy("enable_multi_user")
helper.copy("system_prompt")
helper.copy("name")
helper.copy("allowed_users")
helper.copy("addl_context")
helper.copy("max_words")
helper.copy("max_context_messages")
helper.copy("reply_in_thread")
helper.copy("allow_summarize")
helper.copy("allow_responses")
helper.copy("temperature")
helper.copy("respond_to_replies")
class GPTPlugin(Plugin):
name: str # name of the bot
async def start(self) -> None:
await super().start()
self.config.load_and_update()
self.name = self.config['name'] or \
await self.client.get_displayname(self.client.mxid) or \
self.client.parse_user_id(self.client.mxid)[0]
self.api_endpoint = self.config['api_endpoint']
self.log.debug(f"DEBUG gpt plugin started with bot name: {self.name}")
self.log.debug(f"DEBUG gpt endpoint set: {self.api_endpoint}")
def user_allowed(self, mxid) -> bool:
for u in self.config['allowed_users']:
self.log.debug(f"DEBUG {mxid} vs. {u}")
if re.match(u, mxid):
return True
else:
self.log.debug(f"DEBUG {mxid} doesn't match {u}")
pass
async def should_respond(self, event: MessageEvent) -> bool:
""" Determine if we should respond to an event """
if (
not self.config['allow_responses'] or # Ignore if allow_responses is false
event.sender == self.client.mxid or # Ignore ourselves
event.content.body.startswith('!') or # Ignore commands
event.content['msgtype'] != MessageType.TEXT or # Don't respond to media or notices
(event.content.relates_to and hasattr(event.content.relates_to, 'rel_type') and event.content.relates_to.rel_type == RelationType.REPLACE) # Ignore edits
):
return False
# Check if the message contains the bot's ID
if re.search(f"(^|[\s\>])(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
if len(self.config['allowed_users']) > 0 and not self.user_allowed(event.sender):
await event.respond("sorry, you're not allowed to use this functionality.")
return False
else:
return True
# check if there is an intentional mention that missed the above regex somehow
# According to mautrix, mentions are held in event.content.get("m.mentions", {}).get("user_ids", [])
if self.client.mxid in (event.content.get("m.mentions", {}).get("user_ids", [])):
if len(self.config['allowed_users']) > 0 and not self.user_allowed(event.sender):
await event.respond("sorry, you're not allowed to use this functionality.")
return False
else:
return True
# Reply to all DMs as long as the person is allowed
if len(await self.client.get_joined_members(event.room_id)) == 2:
if len(self.config['allowed_users']) > 0 and not self.user_allowed(event.sender):
await event.respond("sorry, you're not allowed to use this functionality.")
return False
else:
return True
# Reply to threads if the thread's parent should be replied to
if self.config['reply_in_thread'] and event.content.relates_to and event.content.relates_to.rel_type == RelationType.THREAD:
try:
parent_event = await self.client.get_event(room_id=event.room_id, event_id=event.content.get_thread_parent())
if parent_event:
return await self.should_respond(parent_event)
except (MNotFound, MatrixRequestError, AttributeError, TypeError) as e:
# Parent message was deleted or inaccessible, don't respond
self.log.debug(f"Could not retrieve thread parent message: {e}")
return False
# Reply to messages replying to the bot by checking if the parent message as the `org.jobmachine.chatgpt` key
if event.content.relates_to and event.content.relates_to.in_reply_to:
try:
parent_event = await self.client.get_event(room_id=event.room_id, event_id=event.content.get_reply_to())
if parent_event and parent_event.sender == self.client.mxid and "org.jobmachine.chatgpt" in parent_event.content:
return True
except (MNotFound, MatrixRequestError, AttributeError, TypeError) as e:
# Parent message was deleted or inaccessible, don't respond
self.log.debug(f"Could not retrieve reply parent message: {e}")
return False
return False
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
if not await self.should_respond(event):
return
try:
context = await self.get_context(event)
await event.mark_read()
# Call the chatGPT API to get a response
await self.client.set_typing(event.room_id, timeout=99999)
response = await self._call_gpt(context)
# Send the response back to the chat room
await self.client.set_typing(event.room_id, timeout=0)
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=response, format=Format.HTML,
formatted_body=markdown.render(response))
content["org.jobmachine.chatgpt"] = True
await event.respond(content, in_thread=self.config['reply_in_thread'])
except Exception as e:
self.log.exception(f"Something went wrong: {e}")
await event.respond(f"Something went wrong: {e}")
pass
async def _call_gpt(self, prompt):
full_context = []
full_context.extend(list(prompt))
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config['gpt_api_key']}"
}
data = {
"model": self.config['model'],
"messages": full_context,
}
if 'max_tokens' in self.config and self.config['max_tokens']:
data["max_tokens"] = self.config['max_tokens']
if 'temperature' in self.config and self.config['temperature']:
data["temperature"] = self.config['temperature']
self.log.debug("CONTEXT:\n" + "\n".join([f'{m["role"]}: {m["content"]}' for m in full_context]))
async with self.http.post(
self.api_endpoint, headers=headers, data=json.dumps(data)
) as response:
if response.status != 200:
return f"Error: {await response.text()}"
response_json = await response.json()
content = response_json["choices"][0]["message"]["content"]
self.log.debug(f'GPT tokens used: {response_json["usage"]}')
# strip off extra colons which the model seems to keep adding no matter how
# much you tell it not to
content = re.sub('^\w*\:+\s+', '', content)
return str(content)
@command.new(name='gpt', help='control chatGPT functionality', require_subcommand=True)
async def gpt(self, evt: MessageEvent) -> None:
pass
@command.new(name='summarize', help='generate a summary of room or thread messages')
async def summarize(self, evt: MessageEvent) -> None:
await evt.mark_read()
# check if the user has permission to use this bot,
# and additionally if allow_summarize is true
if len(self.config['allowed_users']) > 0 and not self.user_allowed(evt.sender):
await evt.respond("sorry, you're not allowed to use this functionality.")
return
if not self.config["allow_summarize"]:
await evt.respond("sorry, this functionality is not enabled.")
return
# normally we only are restricted to getting context from the thread,
# but when explicitly asked to summarize the room messages, we can use the room context
# and ignore the addl_context messages when building our context
is_thread = (evt.content.relates_to and
evt.content.relates_to.rel_type == RelationType.THREAD)
context = await self.get_context(evt, is_summary=True, is_thread=is_thread)
context.append({
"role": "user",
"content": "Summarize the following messages concisely into four or five bullet points and key quotes: " + "\n".join([m["content"] for m in context])
})
response = await self._call_gpt(context)
await evt.respond(response)
async def get_context(self, event: MessageEvent, is_summary: bool = False, is_thread: bool = False):
system_context = deque()
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
system_prompt = {"role": "system",
"content": self.config['system_prompt'].format(name=self.name, timestamp=timestamp)}
if self.config['enable_multi_user']:
system_prompt["content"] += """
User messages are in the context of multiperson chatrooms.
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
"username: hello world."
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
"""
system_context.append(system_prompt)
if is_summary and not is_thread:
# set ignore_thread and skip processing addl_context
ignore_thread = True
elif is_summary and is_thread:
ignore_thread = False
else:
ignore_thread = False
addl_context = json.loads(json.dumps(self.config['addl_context']))
if addl_context:
for item in addl_context:
system_context.append(item)
if len(addl_context) > self.config["max_context_messages"] - 1:
raise ValueError(f"sorry, my configuration has too many additional prompts "
f"({self.config['max_context_messages']}) and i'll never see your message. "
f"Update my config to have fewer messages and i'll be able to answer your questions!")
chat_context = deque()
word_count = sum([len(m["content"].split()) for m in system_context])
message_count = len(system_context) - 1
async for next_event in self.generate_context_messages(event, ignore_thread):
# Ignore events that aren't text messages
try:
if not next_event.content.msgtype.is_text:
continue
except (KeyError, AttributeError):
continue
role = 'assistant' if next_event.sender == self.client.mxid else 'user'
message = next_event['content']['body']
user = ''
if self.config['enable_multi_user']:
try:
user = (await self.client.get_displayname(next_event.sender) or \
self.client.parse_user_id(next_event.sender)[0]) + ": "
except (MUnknown, MatrixRequestError) as e:
# Profile fetch failed (user left, privacy settings, etc.), fall back to user ID
self.log.debug(f"Could not fetch displayname for {next_event.sender}: {e}")
user = self.client.parse_user_id(next_event.sender)[0] + ": "
word_count += len(message.split())
message_count += 1
if word_count >= self.config['max_words'] or message_count >= self.config['max_context_messages']:
break
chat_context.appendleft({"role": role, "content": user + message})
return system_context + chat_context
async def generate_context_messages(self, evt: MessageEvent, ignore_thread: bool = False) -> Generator[MessageEvent, None, None]:
yield evt
if self.config['reply_in_thread'] and not ignore_thread:
while evt.content.relates_to and evt.content.relates_to.in_reply_to:
try:
reply_to_id = evt.content.get_reply_to()
if not reply_to_id:
break
next_evt = await self.client.get_event(room_id=evt.room_id, event_id=reply_to_id)
if not next_evt:
# Message was deleted or doesn't exist, stop following the chain
break
evt = next_evt
yield evt
except (MNotFound, MatrixRequestError, AttributeError, TypeError) as e:
# Message was deleted or inaccessible, stop following the chain
self.log.debug(f"Could not retrieve parent message in thread: {e}")
break
else:
event_context = await self.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id, limit=self.config["max_context_messages"]*2)
previous_messages = iter(event_context.events_before)
for evt in previous_messages:
# We already have the event, but currently, get_event_context doesn't automatically decrypt events
if isinstance(evt, EncryptedEvent) and self.client.crypto:
try:
decrypted_evt = await self.client.get_event(event_id=evt.event_id, room_id=evt.room_id)
if not decrypted_evt:
# Skip if decryption failed or message doesn't exist
continue
evt = decrypted_evt
except (MNotFound, MatrixRequestError, AttributeError, TypeError) as e:
# Skip if decryption failed or message doesn't exist
self.log.debug(f"Could not decrypt event: {e}")
continue
yield evt
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config