Skip to content

works with openai and anthropic #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions lib/aifunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
from tenacity import retry, wait_random_exponential, stop_after_attempt
from halo import Halo
from litellm import completion, acompletion, stream_chunk_builder

# OpenAI imports
from openai import AsyncOpenAI
Expand All @@ -24,6 +25,12 @@
from lib.util import custom_style
from prompt_toolkit.formatted_text import FormattedText

from prompt_toolkit import PromptSession, print_formatted_text
from lib.util import format_response

import sys

from lib.response_printer import ResponsePrinter

# Ensure the .webwright directory exists
webwright_dir = os.path.expanduser('~/.webwright')
Expand Down Expand Up @@ -100,9 +107,15 @@ async def execute_function_by_name(function_name, **kwargs):
func_logger.error(f"Function {function_name} failed with error: {e}")
return json.dumps({"error": str(e)})

def clear_lines(num_lines):
for _ in range(num_lines):
sys.stdout.write('\033[F') # Move cursor up one line
sys.stdout.write('\033[K') # Clear the line
sys.stdout.flush()

@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
async def openai_chat_completion_request(messages=None, config=None, tools=None, tool_choice="auto"):
client = AsyncOpenAI(api_key=config.get_openai_api_key())
async def litelm_chat_completion_request(messages=None, config = None, tools=None, tool_choice="auto"):
os.environ["OPENAI_API_KEY"] = config.get_openai_api_key()

if tools:
function_names = [tool['function']['name'] for tool in tools]
Expand All @@ -113,17 +126,31 @@ async def openai_chat_completion_request(messages=None, config=None, tools=None,
"content": SYSTEM_PROMPT
})

rp = ResponsePrinter()

try:
response = await client.chat.completions.create(
model=config.get_config_value("config", "OPENAI_MODEL"),
messages=messages,
tools=tools,
tool_choice=tool_choice
)
return response
full_res = []
previous_num_lines = 0
response = await acompletion(
model=config.get_config_value("config", "OPENAI_MODEL"),
messages=messages,
tools=tools,
tool_choice=tool_choice,
stream=True
)
async for chunk in response:
full_res.append(chunk)
content = chunk.choices[0].delta.content
if content:
rp.process_chunk(content)

rp.process_final_chunk()
reconstructed_res = stream_chunk_builder(full_res, messages=messages)
return reconstructed_res

except Exception as e:
logger.error("Unable to generate OpenAI ChatCompletion response: %s", e)
raise
print(e)
return None

@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
async def anthropic_chat_completion_request(messages=None, config=None, tools=None):
Expand Down Expand Up @@ -153,6 +180,15 @@ async def anthropic_chat_completion_request(messages=None, config=None, tools=No
tools=anthropic_tools,
system=SYSTEM_PROMPT
)

res_content = ""
for content_item in response.content:
if isinstance(content_item, TextBlock):
res_content += content_item.text

formatted_text = format_response(res_content)
print_formatted_text(formatted_text, style=custom_style)

return response
except Exception as e:
logger.error("Unable to generate Anthropic ChatCompletion response: %s", e)
Expand Down Expand Up @@ -181,14 +217,10 @@ async def ai(username="anonymous", query="help", config=None, upload_dir=UPLOAD_
function_call_count = 0

while function_call_count < max_function_calls:
spinner = Halo(text='Calling the model...', spinner='dots')
spinner.start()

if api_to_use == "openai":
chat_response = await openai_chat_completion_request(messages=messages, config=config, tools=tools)
chat_response = await litelm_chat_completion_request(messages=messages, config=config, tools=tools)

if not chat_response:
spinner.stop()
return False, {"error": "Failed to get a response from OpenAI"}
assistant_message = chat_response.choices[0].message

Expand All @@ -205,10 +237,12 @@ async def ai(username="anonymous", query="help", config=None, upload_dir=UPLOAD_
print(f"Failed to process tool call: {tool_call}")
print(f"Error: {e}")
else:
spinner.stop()
return True, {"response": assistant_message.content}

else: # Anthropic
spinner = Halo(text='Calling the model...', spinner='dots')
spinner.start()

chat_response = await anthropic_chat_completion_request(messages=messages, config=config, tools=tools)
if not chat_response:
spinner.stop()
Expand All @@ -234,8 +268,7 @@ async def ai(username="anonymous", query="help", config=None, upload_dir=UPLOAD_
if not function_calls:
spinner.stop()
return True, {"response": text_content.strip()}

spinner.stop()
spinner.stop()

if not function_calls:
break
Expand Down Expand Up @@ -308,7 +341,7 @@ async def execute_function(func_call):

# Formulate final response using the tool results
if api_to_use == "openai":
final_response = await openai_chat_completion_request(messages=messages, config=config, tools=tools)
final_response = await litelm_chat_completion_request(messages=messages, config=config, tools=tools)
if not final_response:
return False, {"error": "Failed to get a final response from OpenAI"}
final_message = final_response.choices[0].message.content
Expand Down
167 changes: 167 additions & 0 deletions lib/response_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import re

from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit import print_formatted_text

if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from lib.util import custom_style

class ResponsePrinter:
def __init__(self):
self.current_line = ""
self.in_code_block = False

def process_chunk(self, chunk):
self.current_line += chunk
if '\n' in self.current_line:
lines = self.current_line.split('\n')
for line in lines[:-1]:
self.print_line(line)
self.current_line = lines[-1]

def process_final_chunk(self):
if self.current_line:
self.print_line(self.current_line)

def print_line(self, line):
formatted_text = []

math_pattern = re.compile(r'\\\(.*?\\\)')
inline_code_pattern = re.compile(r'(`.*?`|``.*?``)')
tag_pattern = re.compile(r'<(\w+)>|</(\w+)>')

# handle code block start and end
if line.startswith('```'):
if self.in_code_block:
self.in_code_block = False
else:
self.in_code_block = True
return

# handle middle of code block
if self.in_code_block:
formatted_text.append(('class:code', ''.join(line)))
print_formatted_text(FormattedText(formatted_text), style=custom_style)
return

# handle header
if line.startswith('#'):
formatted_text.append(('class:header', ''.join(line[2:])))
print_formatted_text(FormattedText(formatted_text), style=custom_style)
return

# handle inline patterns
current_class = ''
parts = re.split(r'(\*\*.*?\*\*|`.*?`|``.*?``|\\\(.*?\\\))', line)
for part in parts:
if part.startswith('**') and part.endswith('**'):
formatted_text.append((f'{current_class} class:bold', part[2:-2]))
elif inline_code_pattern.match(part):
# Handle both single and double ticks
code_content = part[1:-1] if part.startswith('`') else part[2:-2]
formatted_text.append((f'{current_class} class:inline-code', code_content))
elif math_pattern.match(part):
formatted_text.append((f'{current_class} class:math', part[2:-2]))
else:
formatted_text.append((current_class, part))
print_formatted_text(FormattedText(formatted_text), style=custom_style)


if __name__ == "__main__":
import random
import time


sample_response = """
Sure, here's a response combining all of those formatting elements in Markdown:

**This is a block of bold text within a line.**
# Header Level 1
Here's an `inline code block`.
```
# This is a
# multi-line code block
print("Hello, World!")
```
"""

sample_response_tags = """
<thinking>
To generate a response with some HTML tags for testing purposes, I will use:
- Paragraph tags <p>
- Heading tags <h1> to <h3>
- A table <table> with table rows <tr> and cells <td>
- A div <div> with a class attribute
- An ordered list <ol> with list items <li>
- Emphasis <em> and bold <strong> tags
- A horizontal rule <hr>
</thinking>

<h1>Sample HTML Page</h1>

<p>This is an example <em>HTML response</em> generated for <strong>testing purposes</strong>.</p>

<h2>A Simple Table</h2>
<table>
<tr>
<td>Row 1, Cell 1</td>
<td>Row 1, Cell 2</td>
</tr>
<tr>
<td>Row 2, Cell 1</td>
<td>Row 2, Cell 2</td>
</tr>
</table>

<div class="example-class">
<h3>An Ordered List</h3>
<ol>
<li>First item</li>
<li>Second item</li>
<li>Third item</li>
</ol>
</div>

<hr>

<p>Feel free to let me know if you would like me to include any other specific HTML elements or attributes in the example!</p>
"""

def split_into_random_chunks(text, min_chunk_size=5, max_chunk_size=15):
"""
Splits the input text into chunks of random sizes.

Args:
text (str): The text to split into chunks.
min_chunk_size (int): Minimum size of each chunk.
max_chunk_size (int): Maximum size of each chunk.

Returns:
List[str]: A list of text chunks.
"""
chunks = []
index = 0
text_length = len(text)
while index < text_length:
# Determine the size of the next chunk
chunk_size = random.randint(min_chunk_size, max_chunk_size)
# Extract the chunk from the text
chunk = text[index:index + chunk_size]
chunks.append(chunk)
index += chunk_size
return chunks


rp = ResponsePrinter()
chunks = split_into_random_chunks(sample_response, min_chunk_size=5, max_chunk_size=20)

# Feed each chunk into the process_chunk method with a slight delay to simulate streaming
for chunk in chunks:
rp.process_chunk(chunk)
#print(chunk, end='')
time.sleep(0.2)

2 changes: 0 additions & 2 deletions webwright/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ async def main(config):
success, results = await process_shell_query(username, question, config, conversation_history)

if success and "explanation" in results:
formatted_response = format_response(results['explanation'])
print_formatted_text(formatted_response, style=custom_style)
conversation_history.append({"role": "assistant", "content": results["explanation"]})
elif not success and "error" in results:
# Error messages are already handled in process_shell_query
Expand Down