-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenai_model.py
More file actions
398 lines (342 loc) · 13.9 KB
/
genai_model.py
File metadata and controls
398 lines (342 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import os
import tempfile
import asyncio
import subprocess
import glob
import aiohttp
import tiktoken
import logging
from dotenv import load_dotenv
from openai import AsyncOpenAI
from .prompts import system_prompt_json, system_prompt_user_content, system_prompt_org_content
from .models import SoftwareSourceCode, GitHubOrganization, GitHubUser
from ..utils.utils import *
from .verification import Verification
load_dotenv()
OPENROUTER_API_KEY = os.environ["OPENROUTER_API_KEY"]
OPENROUTER_ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
MODEL = os.environ["MODEL"]
PROVIDER = os.environ["PROVIDER"]
# Create async OpenAI client
async_openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Setup logger
logger = logging.getLogger(__name__)
def reduce_input_size(input_text, max_tokens=800000):
"""
Reduce the size of the input text to fit within the specified token limit.
"""
limiter_encoding = tiktoken.get_encoding("cl100k_base")
tokens = limiter_encoding.encode(input_text)
logger.info(f"Original amount of tokens: {len(tokens)}")
if len(tokens) > max_tokens:
tokens = tokens[:max_tokens]
reduced_text = limiter_encoding.decode(tokens)
logger.warning(f"Token count exceeded limit, truncated to {max_tokens} tokens")
return reduced_text
return input_text
def sort_files_by_priority(file_paths):
"""
Sorts a list of file paths based on a predefined extension priority.
The order is:
1. Documentation files (.md, .txt, .html)
2. Code files (.py, .r)
3. All other files
"""
priority_order = {
# Priority 0: Documentation
".cff":0,
".md": 0,
".txt": 0,
".html": 0,
# Priority 1: Code
".py": 1,
".r": 1,
}
# Priority 2 will be the default for all other extensions
def get_sort_key(filepath):
# Get the file extension
_, ext = os.path.splitext(filepath)
# Return a tuple: (priority, original_filepath)
# The priority is looked up from the map (defaulting to 2)
# The original filepath is used as a tie-breaker to maintain a stable sort
return (priority_order.get(ext.lower(), 2), filepath)
return sorted(file_paths, key=get_sort_key)
def combine_text_files(directory):
"""
Combine all text files in the specified directory into a single string.
"""
combined_text = ""
txt_files = glob.glob(os.path.join(directory, "*.txt"))
logger.info(f"Found {len(txt_files)} text files in {directory}")
for file in txt_files:
logger.debug(f"Reading file: {file}")
with open(file, "r", encoding="utf-8") as f:
combined_text += f.read() + "\n"
return combined_text
def store_combined_text(input_text, output_file):
"""
Store the combined text into a specified output file.
"""
with open(output_file, "w", encoding="utf-8") as f:
f.write(input_text)
logger.info(f"Combined text saved to {output_file}")
return output_file
async def clone_repo(repo_url, temp_dir):
"""
Clone a GitHub repository into a temporary directory asynchronously.
"""
logger.info(f"Cloning {repo_url} into {temp_dir}...")
try:
process = await asyncio.create_subprocess_exec(
'git', 'clone', '-c', 'core.symlinks=false', repo_url, temp_dir,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
logger.info("Repository cloned successfully.")
return temp_dir
else:
logger.error(f"Failed to clone repository: {stderr.decode()}")
return None
except Exception as e:
logger.error(f"Failed to clone repository: {e}")
return None
async def run_repo_to_text(temp_dir):
"""
Run the repo-to-text command asynchronously.
"""
try:
process = await asyncio.create_subprocess_exec(
'repo-to-text',
cwd=temp_dir,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
logger.info("repo-to-text command completed successfully.")
return True
else:
logger.error(f"'repo-to-text' command failed: {stderr.decode()}")
return False
except Exception as e:
logger.error(f"'repo-to-text' command failed: {e}")
return False
def sanitize_special_tokens(text):
"""
Remove special tokens using tiktoken encoding/decoding.
"""
encoding = tiktoken.get_encoding("cl100k_base")
# Encode with disallowed_special=() to handle special tokens
# Then decode to get clean text
try:
tokens = encoding.encode(text, disallowed_special=())
clean_text = encoding.decode(tokens)
return clean_text
except Exception as e:
logger.warning(f"Failed to sanitize with tiktoken: {e}")
# Fallback to simple regex cleanup
import re
return re.sub(r'<\|[^|]*\|>', '', text)
async def llm_request_repo_infos(repo_url, output_format="json-ld", gimie_output=None, max_tokens=40000):
"""
Async version of llm_request_repo_infos
"""
# Clone the GitHub repository into a temporary folder
with tempfile.TemporaryDirectory() as temp_dir:
# Clone repository asynchronously
clone_result = await clone_repo(repo_url, temp_dir)
if not clone_result:
return None
# Run repo-to-text asynchronously
repo_to_text_success = await run_repo_to_text(temp_dir)
if not repo_to_text_success:
return None
input_text = combine_text_files(temp_dir)
input_text = sanitize_special_tokens(input_text)
input_text = reduce_input_size(input_text, max_tokens=max_tokens)
if gimie_output:
input_text += "\n\n" + str(gimie_output)
combined_file_path = os.path.join(temp_dir, "combined_repo.txt")
store_combined_text(input_text, combined_file_path)
if PROVIDER == "openrouter":
response = await get_openrouter_response_async(input_text, model=MODEL)
elif PROVIDER == "openai":
response = await get_openai_response_async(input_text, model=MODEL)
else:
logger.error("No provider provided")
return None
try:
if PROVIDER == "openrouter":
raw_result = response["choices"][0]["message"]["content"]
parsed_result = clean_json_string(raw_result)
json_data = json.loads(parsed_result)
elif PROVIDER == "openai":
json_data = response.choices[0].message.parsed
logger.info("Clean result from OpenAI response:")
json_data = json_data.model_dump(mode='json')
logger.info("Successfully JSON API response")
# Run verification before converting to JSON-LD
verifier = Verification(json_data)
verifier.run()
verifier.summary()
cleaned_json = verifier.sanitize_metadata()
context_path = "src/files/json-ld-context.json"
if output_format == "json-ld":
return json_to_jsonLD(cleaned_json, context_path)
elif output_format == "json":
return cleaned_json
else:
logger.error(f"Unsupported output format: {output_format}")
return None
except Exception as e:
logger.error(f"Error parsing response: {e}")
return None
async def get_openrouter_response_async(input_text, system_prompt=system_prompt_json, model="google/gemini-2.5-flash", temperature=0.2, schema=SoftwareSourceCode):
"""
Get structured response from openrouter asynchronously
"""
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": input_text}
],
"response_format": {
"type": "json_schema",
"json_schema": schema.model_json_schema()
},
"temperature": temperature
}
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json"
}
timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout
for attempt in range(3):
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(OPENROUTER_ENDPOINT, headers=headers, json=payload) as response:
logger.info(f"API response status: {response.status}")
if response.status == 200:
return await response.json()
else:
logger.error(f"API request failed with status {response.status}")
if attempt == 2: # Last attempt
return None
except aiohttp.ClientError as e:
logger.error(f"Request failed (attempt {attempt + 1}): {e}")
if attempt == 2: # Last attempt
return None
except asyncio.TimeoutError as e:
logger.error(f"Request timeout (attempt {attempt + 1}): {e}")
if attempt == 2: # Last attempt
return None
return None
async def get_openai_response_async(prompt, system_prompt=system_prompt_json, model="gpt-4o", temperature=0.2, schema=SoftwareSourceCode):
"""
Get structured response from OpenAI API using SoftwareSourceCode schema asynchronously.
"""
try:
# Use the async OpenAI client
if model.split("-")[0] == "o3":
response = await async_openai_client.beta.chat.completions.parse(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
response_format=convert_httpurl_to_str(schema)
)
else:
response = await async_openai_client.beta.chat.completions.parse(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=temperature,
response_format=convert_httpurl_to_str(schema)
)
return response
except Exception as e:
logger.error(f"OpenAI API error: {e}")
return None
async def llm_request_userorg_infos(metadata, item_type="user"):
"""
Async version of llm_request_userorg_infos
"""
input_text = metadata.model_dump_json()
if item_type == "user":
schema = GitHubUser
system_prompt = system_prompt_user_content
elif item_type == "org":
schema = GitHubOrganization
system_prompt = system_prompt_org_content
if PROVIDER == "openrouter":
response = await get_openrouter_response_async(input_text,
system_prompt=system_prompt,
model=MODEL,
schema=schema)
elif PROVIDER == "openai":
response = await get_openai_response_async(input_text,
system_prompt=system_prompt,
model=MODEL,
schema=schema)
else:
logger.error("No provider provided")
return None
try:
if PROVIDER == "openrouter":
raw_result = response["choices"][0]["message"]["content"]
parsed_result = clean_json_string(raw_result)
json_data = json.loads(parsed_result)
elif PROVIDER == "openai":
json_data = response.choices[0].message.parsed
json_data = json_data.model_dump(mode='json')
else:
logger.error("Unknown provider")
return None
logger.info("Successfully parsed API response")
return json_data
except Exception as e:
logger.error(f"Error parsing response: {e}")
return None
# Keep the synchronous versions for backward compatibility
def get_openrouter_response(input_text, system_prompt=system_prompt_json, model="google/gemini-2.5-flash", temperature=0.2, schema=SoftwareSourceCode):
"""
Synchronous wrapper for backward compatibility
"""
import asyncio
return asyncio.run(get_openrouter_response_async(input_text, system_prompt, model, temperature, schema))
def get_openai_response(prompt, system_prompt=system_prompt_json, model="gpt-4o", temperature=0.2, schema=SoftwareSourceCode):
"""
Synchronous wrapper for backward compatibility
"""
from openai import OpenAI
sync_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
try:
if model.split("-")[0] == "o3":
response = sync_client.beta.chat.completions.parse(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
response_format=convert_httpurl_to_str(schema)
)
else:
response = sync_client.beta.chat.completions.parse(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=temperature,
response_format=convert_httpurl_to_str(schema)
)
return response
except Exception as e:
logger.error(f"OpenAI API error: {e}")
return None