-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_embeddings.py
110 lines (84 loc) · 3.54 KB
/
get_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from openai import OpenAI
import json
import tiktoken
import os
client = OpenAI()
# Function to count the number of tokens using tiktoken
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
# Truncate text to a specified max token count using tiktoken
def truncate_text_to_max_tokens(text: str, max_tokens: int = 8000, encoding_name: str = "cl100k_base") -> str:
encoding = tiktoken.get_encoding(encoding_name)
tokens = encoding.encode(text)
if len(tokens) > max_tokens:
# Truncate to the first max_tokens tokens
truncated_tokens = tokens[:max_tokens]
return encoding.decode(truncated_tokens)
return text
# Load data from JSON fil
def load_data(file_path: str):
with open(file_path, 'r') as file:
data = json.load(file)
return data
# Function to generate embeddings using the OpenAI API
def generate_embedding(text: str):
truncated_text = truncate_text_to_max_tokens(text)
if not truncated_text:
truncated_text = " "
return client.embeddings.create(input=[truncated_text], model="text-embedding-3-small").data[0].embedding
# Function to process conversations and generate embeddings
def process_conversations(conversations):
processed_data = []
for i, conv in enumerate(conversations):
print(i)
message_embeddings = []
# cumulative_text = ""
# cumulative_embeddings = []
for message in conv['conversations']:
text = message['value']
embedding = generate_embedding(text)
message_embeddings.append({"embedding": embedding})
# cumulative_text += (" " + text) if cumulative_text else text
# cumulative_text = truncate_text_to_max_tokens(cumulative_text)
# cumulative_embedding = generate_embedding(cumulative_text)
# cumulative_embeddings.append(cumulative_embedding)
# whole_conversation_text = " ".join(
# [msg['value'] for msg in conv['conversations']])
# whole_conversation_embedding = generate_embedding(
# whole_conversation_text)
processed_data.append({
# "embedding": whole_conversation_embedding,
"message_embeddings": [
{
"embedding": emb['embedding'],
# "cumulative_embedding": cum_emb
}
# for emb, cum_emb in zip(message_embeddings, cumulative_embeddings)
for emb in message_embeddings
]
})
return processed_data
def gen_embedding(conversation_data, output_file_path):
# Process the loaded conversation data
processed_conversations = process_conversations(conversation_data)
with open(output_file_path, 'w') as f:
json.dump(processed_conversations, f, indent=2)
print(f"Processed data written to {output_file_path}")
if __name__ == "__main__":
# Load the conversation data from the JSON file
file_path = 'data/chats_data2023-09-27.json'
conversation_data = load_data(file_path)
ranges = [
[0, 50],
[50, 100],
[100, 150],
[150, 200],
]
for r in ranges:
sample_conversations = conversation_data[r[0]:r[1]]
# Write the processed data to a JSON file
output_file_path = f'processed_conversations-${r[0]}-${r[1]}.json'
gen_embedding(sample_conversations, output_file_path)