-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
Copy pathhistory.py
141 lines (108 loc) · 4.4 KB
/
history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
from aider import models, prompts
from aider.dump import dump # noqa: F401
class ChatSummary:
def __init__(self, models=None, max_tokens=1024):
if not models:
raise ValueError("At least one model must be provided")
self.models = models if isinstance(models, list) else [models]
self.max_tokens = max_tokens
self.token_count = self.models[0].token_count
def too_big(self, messages):
sized = self.tokenize(messages)
total = sum(tokens for tokens, _msg in sized)
return total > self.max_tokens
def tokenize(self, messages):
sized = []
for msg in messages:
tokens = self.token_count(msg)
sized.append((tokens, msg))
return sized
def summarize(self, messages, depth=0):
messages = self.summarize_real(messages)
if messages and messages[-1]["role"] != "assistant":
messages.append(dict(role="assistant", content="Ok."))
return messages
def summarize_real(self, messages, depth=0):
if not self.models:
raise ValueError("No models available for summarization")
sized = self.tokenize(messages)
total = sum(tokens for tokens, _msg in sized)
if total <= self.max_tokens and depth == 0:
return messages
min_split = 4
if len(messages) <= min_split or depth > 3:
return self.summarize_all(messages)
tail_tokens = 0
split_index = len(messages)
half_max_tokens = self.max_tokens // 2
# Iterate over the messages in reverse order
for i in range(len(sized) - 1, -1, -1):
tokens, _msg = sized[i]
if tail_tokens + tokens < half_max_tokens:
tail_tokens += tokens
split_index = i
else:
break
# Ensure the head ends with an assistant message
while messages[split_index - 1]["role"] != "assistant" and split_index > 1:
split_index -= 1
if split_index <= min_split:
return self.summarize_all(messages)
tail = messages[split_index:]
sized_head = sized[:split_index]
sized_head.reverse()
keep = []
total = 0
# These sometimes come set with value = None
model_max_input_tokens = self.models[0].info.get("max_input_tokens") or 4096
model_max_input_tokens -= 512
for tokens, msg in sized_head:
total += tokens
if total > model_max_input_tokens:
break
keep.append(msg)
keep.reverse()
summary = self.summarize_all(keep)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = self.token_count(summary)
result = summary + tail
if summary_tokens + tail_tokens < self.max_tokens:
return result
return self.summarize_real(result, depth + 1)
def summarize_all(self, messages):
content = ""
for msg in messages:
role = msg["role"].upper()
if role not in ("USER", "ASSISTANT"):
continue
content += f"# {role}\n"
content += msg["content"]
if not content.endswith("\n"):
content += "\n"
summarize_messages = [
dict(role="system", content=prompts.summarize),
dict(role="user", content=content),
]
for model in self.models:
try:
summary = model.simple_send_with_retries(summarize_messages)
if summary is not None:
summary = prompts.summary_prefix + summary
return [dict(role="user", content=summary)]
except Exception as e:
print(f"Summarization failed for model {model.name}: {str(e)}")
raise ValueError("summarizer unexpectedly failed for all models")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args()
model_names = ["gpt-3.5-turbo", "gpt-4"] # Add more model names as needed
model_list = [models.Model(name) for name in model_names]
summarizer = ChatSummary(model_list)
with open(args.filename, "r") as f:
text = f.read()
summary = summarizer.summarize_chat_history_markdown(text)
dump(summary)
if __name__ == "__main__":
main()