Skip to content

Commit 041e100

Browse files
author
Aman Rusia
committed
Reading file images
1 parent 3e546f9 commit 041e100

File tree

4 files changed

+140
-40
lines changed

4 files changed

+140
-40
lines changed

src/relay/serve.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Mdata(BaseModel):
4242

4343
@app.websocket("/register_serve_image/{uuid}")
4444
async def register_serve_image(websocket: WebSocket, uuid: UUID) -> None:
45+
raise Exception("Disabled")
4546
await websocket.accept()
4647
received_data = await websocket.receive_json()
4748
name = received_data["name"]

src/wcgw/basic.py

+92-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import base64
12
import json
3+
import mimetypes
24
from pathlib import Path
35
import sys
46
import traceback
@@ -8,17 +10,20 @@
810
from openai.types.chat import (
911
ChatCompletionMessageParam,
1012
ChatCompletionAssistantMessageParam,
13+
ChatCompletionUserMessageParam,
14+
ChatCompletionContentPartParam,
1115
ChatCompletionMessage,
1216
ParsedChatCompletionMessage,
1317
)
1418
import rich
19+
import petname
1520
from typer import Typer
1621
import uuid
1722

1823
from .common import Models, discard_input
1924
from .common import CostData, History
2025
from .openai_utils import get_input_cost, get_output_cost
21-
from .tools import ExecuteBash
26+
from .tools import ExecuteBash, ReadImage, ImageData
2227

2328
from .tools import (
2429
BASH_CLF_OUTPUT,
@@ -80,6 +85,38 @@ def save_history(history: History, session_id: str) -> None:
8085
json.dump(history, f, indent=3)
8186

8287

88+
def parse_user_message_special(msg: str) -> ChatCompletionUserMessageParam:
89+
# Search for lines starting with `%` and treat them as special commands
90+
parts: list[ChatCompletionContentPartParam] = []
91+
for line in msg.split("\n"):
92+
if line.startswith("%"):
93+
args = line[1:].strip().split(" ")
94+
command = args[0]
95+
assert command == 'image'
96+
image_path = args[1]
97+
with open(image_path, 'rb') as f:
98+
image_bytes = f.read()
99+
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
100+
image_type = mimetypes.guess_type(image_path)[0]
101+
dataurl=f'data:{image_type};base64,{image_b64}'
102+
parts.append({
103+
'type': 'image_url',
104+
'image_url': {
105+
'url': dataurl,
106+
'detail': 'auto'
107+
}
108+
})
109+
else:
110+
if len(parts) > 0 and parts[-1]['type'] == 'text':
111+
parts[-1]['text'] += '\n' + line
112+
else:
113+
parts.append({'type': 'text', 'text': line})
114+
return {
115+
'role': 'user',
116+
'content': parts
117+
}
118+
119+
83120
app = Typer(pretty_exceptions_show_locals=False)
84121

85122

@@ -94,6 +131,7 @@ def loop(
94131
session_id = str(uuid.uuid4())[:6]
95132

96133
history: History = []
134+
waiting_for_assistant = False
97135
if resume:
98136
if resume == "latest":
99137
resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
@@ -108,6 +146,7 @@ def loop(
108146
if history[1]["role"] != "user":
109147
raise ValueError("Invalid history file, second message should be user")
110148
first_message = ""
149+
waiting_for_assistant = history[-1]['role'] != 'assistant'
111150

112151
my_dir = os.path.dirname(__file__)
113152
config_file = os.path.join(my_dir, "..", "..", "config.toml")
@@ -164,12 +203,11 @@ def loop(
164203
- Machine: {uname_machine}
165204
"""
166205

167-
has_tool_output = False
168206
if not history:
169207
history = [{"role": "system", "content": system}]
170208
else:
171209
if history[-1]["role"] == "tool":
172-
has_tool_output = True
210+
waiting_for_assistant = True
173211

174212
client = OpenAI()
175213

@@ -188,16 +226,16 @@ def loop(
188226
)
189227
break
190228

191-
if not has_tool_output:
229+
if not waiting_for_assistant:
192230
if first_message:
193231
msg = first_message
194232
first_message = ""
195233
else:
196234
msg = text_from_editor(user_console)
197235

198-
history.append({"role": "user", "content": msg})
236+
history.append(parse_user_message_special(msg))
199237
else:
200-
has_tool_output = False
238+
waiting_for_assistant = False
201239

202240
cost_, input_toks_ = get_input_cost(
203241
config.cost_file[config.model], enc, history
@@ -222,6 +260,7 @@ def loop(
222260
_histories: History = []
223261
item: ChatCompletionMessageParam
224262
full_response: str = ""
263+
image_histories: History = []
225264
try:
226265
for chunk in stream:
227266
if chunk.choices[0].finish_reason == "tool_calls":
@@ -235,7 +274,7 @@ def loop(
235274
"type": "function",
236275
"function": {
237276
"arguments": tool_args,
238-
"name": "execute_bash",
277+
"name": type(which_tool(tool_args)).__name__,
239278
},
240279
}
241280
for tool_call_id, toolcallargs in tool_call_args_by_id.items()
@@ -251,7 +290,7 @@ def loop(
251290
)
252291
system_console.print(f"\nTotal cost: {config.cost_unit}{cost:.3f}")
253292
output_toks += output_toks_
254-
293+
255294
_histories.append(item)
256295
for tool_call_id, toolcallargs in tool_call_args_by_id.items():
257296
for toolindex, tool_args in toolcallargs.items():
@@ -283,21 +322,58 @@ def loop(
283322
f"\nTotal cost: {config.cost_unit}{cost:.3f}"
284323
)
285324
return output_or_done.task_output, cost
325+
286326
output = output_or_done
287327

288-
item = {
289-
"role": "tool",
290-
"content": str(output),
291-
"tool_call_id": tool_call_id + str(toolindex),
292-
}
328+
if isinstance(output, ImageData):
329+
randomId = petname.Generate(2, "-")
330+
if not image_histories:
331+
image_histories.extend([
332+
{
333+
'role': 'assistant',
334+
'content': f'Share images with ids: {randomId}'
335+
336+
},
337+
{
338+
'role': 'user',
339+
'content': [{
340+
'type': 'image_url',
341+
'image_url': {
342+
'url': output.dataurl,
343+
'detail': 'auto'
344+
}
345+
}]
346+
}]
347+
)
348+
else:
349+
image_histories[0]['content'] += ', ' + randomId
350+
image_histories[1]["content"].append({ # type: ignore
351+
'type': 'image_url',
352+
'image_url': {
353+
'url': output.dataurl,
354+
'detail': 'auto'
355+
}
356+
})
357+
358+
item = {
359+
"role": "tool",
360+
"content": f'Ask user for image id: {randomId}',
361+
"tool_call_id": tool_call_id + str(toolindex),
362+
}
363+
else:
364+
item = {
365+
"role": "tool",
366+
"content": str(output),
367+
"tool_call_id": tool_call_id + str(toolindex),
368+
}
293369
cost_, output_toks_ = get_output_cost(
294370
config.cost_file[config.model], enc, item
295371
)
296372
cost += cost_
297373
output_toks += output_toks_
298374

299375
_histories.append(item)
300-
has_tool_output = True
376+
waiting_for_assistant = True
301377
break
302378
elif chunk.choices[0].finish_reason:
303379
assistant_console.print("")
@@ -326,11 +402,11 @@ def loop(
326402
assistant_console.print(chunk_str, end="")
327403
full_response += chunk_str
328404
except KeyboardInterrupt:
329-
has_tool_output = False
405+
waiting_for_assistant = False
330406
input("Interrupted...enter to redo the current turn")
331407
else:
332408
history.extend(_histories)
333-
409+
history.extend(image_histories)
334410
save_history(history, session_id)
335411

336412
return "Couldn't finish the task", cost

src/wcgw/openai_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,19 @@ def get_input_cost(
2828
input_tokens = 0
2929
for msg in history:
3030
content = msg["content"]
31-
if not isinstance(content, str):
31+
refusal = msg.get("refusal")
32+
if isinstance(content, list):
33+
for part in content:
34+
if 'text' in part:
35+
input_tokens += len(enc.encode(part['text']))
36+
elif content is None:
37+
if refusal is None:
38+
raise ValueError("Expected content or refusal to be present")
39+
input_tokens += len(enc.encode(str(refusal)))
40+
elif not isinstance(content, str):
3241
raise ValueError(f"Expected content to be string, got {type(content)}")
33-
input_tokens += len(enc.encode(content))
42+
else:
43+
input_tokens += len(enc.encode(content))
3444
cost = input_tokens * cost_map.cost_per_1m_input_tokens / 1_000_000
3545
return cost, input_tokens
3646

src/wcgw/tools.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import threading
77
import traceback
8-
from typing import Callable, Literal, Optional, ParamSpec, Sequence, TypeVar, TypedDict
8+
from typing import Callable, Literal, NewType, Optional, ParamSpec, Sequence, TypeVar, TypedDict
99
import uuid
1010
from pydantic import BaseModel, TypeAdapter
1111
from websockets.sync.client import connect as syncconnect
@@ -70,7 +70,7 @@ class Writefile(BaseModel):
7070

7171
def start_shell():
7272
SHELL = pexpect.spawn(
73-
"/bin/bash",
73+
"/bin/bash --noprofile --norc",
7474
env={**os.environ, **{"PS1": "#@@"}},
7575
echo=False,
7676
encoding="utf-8",
@@ -236,6 +236,7 @@ def execute_bash(
236236

237237
class ReadImage(BaseModel):
238238
file_path: str
239+
type: Literal['ReadImage'] = 'ReadImage'
239240

240241

241242
def serve_image_in_bg(file_path: str, client_uuid: str, name: str) -> None:
@@ -257,15 +258,9 @@ def serve_image_in_bg(file_path: str, client_uuid: str, name: str) -> None:
257258
print(f"Connection closed for UUID: {client_uuid}, retrying")
258259
serve_image_in_bg(file_path, client_uuid, name)
259260

261+
class ImageData(BaseModel):
262+
dataurl: str
260263

261-
def read_image_from_shell(file_path: str) -> str:
262-
name = petname.Generate(3)
263-
client_uuid = str(uuid.uuid4())
264-
thread = threading.Thread(
265-
target=serve_image_in_bg, args=(file_path, client_uuid, name), daemon=True
266-
)
267-
thread.start()
268-
return f"https://wcgw.arcfu.com/get_image/{client_uuid}/{name}"
269264

270265

271266
Param = ParamSpec("Param")
@@ -286,6 +281,24 @@ def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> T:
286281

287282
return wrapper
288283

284+
@ensure_no_previous_output
285+
def read_image_from_shell(file_path: str) -> ImageData:
286+
if not os.path.isabs(file_path):
287+
SHELL.sendline("pwd")
288+
SHELL.expect("#@@")
289+
assert isinstance(SHELL.before, str)
290+
current_dir = SHELL.before.strip()
291+
file_path = os.path.join(current_dir, file_path)
292+
293+
if not os.path.exists(file_path):
294+
raise ValueError(f"File {file_path} does not exist")
295+
296+
with open(file_path, "rb") as image_file:
297+
image_bytes = image_file.read()
298+
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
299+
image_type = mimetypes.guess_type(file_path)[0]
300+
return ImageData(dataurl=f'data:{image_type};base64,{image_b64}')
301+
289302

290303
@ensure_no_previous_output
291304
def write_file(writefile: Writefile) -> str:
@@ -330,22 +343,22 @@ def take_help_of_ai_assistant(
330343

331344
def which_tool(args: str) -> BaseModel:
332345
adapter = TypeAdapter[
333-
Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag
334-
](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag)
346+
Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag | ReadImage
347+
](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag | ReadImage)
335348
return adapter.validate_python(json.loads(args))
336349

337350

338351
def get_tool_output(
339-
args: dict | Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag,
352+
args: dict | Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag | ReadImage,
340353
enc: tiktoken.Encoding,
341354
limit: float,
342355
loop_call: Callable[[str, float], tuple[str, float]],
343356
is_waiting_user_input: Callable[[str], tuple[BASH_CLF_OUTPUT, float]],
344-
) -> tuple[str | DoneFlag, float]:
357+
) -> tuple[str | ImageData | DoneFlag, float]:
345358
if isinstance(args, dict):
346359
adapter = TypeAdapter[
347-
Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag
348-
](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag)
360+
Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag | ReadImage
361+
](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag | ReadImage)
349362
arg = adapter.validate_python(args)
350363
else:
351364
arg = args
@@ -365,9 +378,9 @@ def get_tool_output(
365378
elif isinstance(arg, AIAssistant):
366379
console.print("Calling AI assistant tool")
367380
output = take_help_of_ai_assistant(arg, limit, loop_call)
368-
elif isinstance(arg, get_output_of_last_command):
369-
console.print("Calling get output of last program tool")
370-
output = get_output_of_last_command(enc), 0
381+
elif isinstance(arg, ReadImage):
382+
console.print("Calling read image tool")
383+
output = read_image_from_shell(arg.file_path), 0.0
371384
else:
372385
raise ValueError(f"Unknown tool: {arg}")
373386

@@ -438,7 +451,7 @@ def execute_user_input() -> None:
438451
ExecuteBash(
439452
send_ascii=[ord(x) for x in user_input] + [ord("\n")]
440453
),
441-
lambda x: ("wont_exit", 0),
454+
lambda x: ("waiting_for_input", 0),
442455
)[0]
443456
)
444457
except Exception as e:
@@ -451,10 +464,10 @@ async def register_client(server_url: str, client_uuid: str = "") -> None:
451464
# Generate a unique UUID for this client
452465
if not client_uuid:
453466
client_uuid = str(uuid.uuid4())
454-
print(f"Connecting with UUID: {client_uuid}")
455467

456468
# Create the WebSocket connection
457469
async with websockets.connect(f"{server_url}/{client_uuid}") as websocket:
470+
print(f"Connected. Share this user id with the chatbot: {client_uuid}")
458471
try:
459472
while True:
460473
# Wait to receive data from the server
@@ -481,7 +494,7 @@ async def register_client(server_url: str, client_uuid: str = "") -> None:
481494
assert not isinstance(output, DoneFlag)
482495
await websocket.send(output)
483496

484-
except websockets.ConnectionClosed:
497+
except (websockets.ConnectionClosed, ConnectionError):
485498
print(f"Connection closed for UUID: {client_uuid}, retrying")
486499
await register_client(server_url, client_uuid)
487500

0 commit comments

Comments
 (0)