Skip to content

Commit e44d44b

Browse files
committed
Use structured responses; improve prompting and tool call errors
1 parent 21a953d commit e44d44b

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

optastic/chat.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,37 @@
33
from llm_utils import number_group_of_lines
44
import openai
55
from openai.types.chat import (
6-
ChatCompletion,
76
ChatCompletionDeveloperMessageParam,
87
ChatCompletionMessageParam,
98
ChatCompletionSystemMessageParam,
9+
ParsedChatCompletion,
1010
)
11+
from pydantic import BaseModel
1112
from rich import print as rprint
1213

1314
from optastic.project import Project
1415
from optastic.tools import GetCodeTool, GetInfoTool, LLMToolRunner, LookupDefinitionTool
1516

1617

18+
class OptimizationSuggestion(BaseModel):
19+
filename: str
20+
startLine: int
21+
endLine: int
22+
newCode: str
23+
24+
25+
class OptimizationSuite(BaseModel):
26+
highLevelSummary: str
27+
suggestions: List[OptimizationSuggestion]
28+
29+
1730
def run_chat(project: Project, filename: str, lineno: int, model_id=None):
1831
if model_id is None:
1932
model_id = "o3-mini"
2033
lang = project.lang()
2134

2235
try:
23-
client = openai.OpenAI(timeout=30)
36+
client = openai.OpenAI(timeout=90)
2437
except openai.OpenAIError:
2538
print("you need an OpenAI key to use this tool.")
2639
print("You can get a key here: https://platform.openai.com/api-keys.")
@@ -31,15 +44,17 @@ def run_chat(project: Project, filename: str, lineno: int, model_id=None):
3144
project, [LookupDefinitionTool(), GetCodeTool(), GetInfoTool()]
3245
)
3346

34-
prettyline = number_group_of_lines([project.get_line(filename, lineno - 1)], lineno)
47+
prettyline = number_group_of_lines(
48+
project.get_lines(filename, lineno - 1 - 5, lineno - 1 + 5), lineno - 5
49+
)
3550
messages: List[ChatCompletionMessageParam] = [
3651
_make_system_message(
3752
model_id,
38-
f"You are a {lang} performance optimization assistant. Please optimize the user's program, making use of the provided tool calls that will let you explore the program. Never make assumptions about the program; use tool calls if you are not sure.",
53+
f"You are a {lang} performance optimization assistant. You NEVER make assumptions or express hypotheticals about what the user's program does. Instead, you make ample use of the tool calls available to you to thoroughly explore the user's program. You always give CONCRETE code suggestions.",
3954
),
4055
{
4156
"role": "user",
42-
"content": f"I've identified line {filename}:{lineno} as a hotspot, reproduced below. Please help me optimize it.\n\n```{lang}\n{prettyline}\n```",
57+
"content": f"I've identified line {filename}:{lineno} as a hotspot, reproduced below. Please help me optimize it by exploring the program and giving me CONCRETE suggestions.\n\n```{lang}\n{prettyline}\n```",
4358
},
4459
]
4560
for msg in messages:
@@ -50,11 +65,12 @@ def run_chat(project: Project, filename: str, lineno: int, model_id=None):
5065
round_num = 0
5166
while (response_msg is None or response_msg.tool_calls) and round_num <= MAX_ROUNDS:
5267
tool_schemas = tool_runner.all_schemas()
53-
response = client.chat.completions.create(
68+
response = client.beta.chat.completions.parse(
5469
model=model_id,
5570
messages=messages,
5671
tools=tool_schemas,
5772
tool_choice="auto",
73+
response_format=OptimizationSuite,
5874
)
5975
_print_completion(response)
6076
response_msg = response.choices[0].message
@@ -103,10 +119,26 @@ def _print_message(msg: Any):
103119
rprint(f"[purple]{role}:[/purple] {content}")
104120

105121

106-
def _print_completion(completion: ChatCompletion):
122+
def _print_parsed_completion(parsed: OptimizationSuite):
123+
rprint("[underline]High-level Summary[/underline]")
124+
rprint(parsed.highLevelSummary)
125+
rprint()
126+
127+
for sugg in parsed.suggestions:
128+
rprint(
129+
f"[underline]In {sugg.filename}, replace lines {sugg.startLine} to {sugg.endLine}:[/underline]"
130+
)
131+
rprint()
132+
rprint(sugg.newCode)
133+
rprint("------\n")
134+
135+
136+
def _print_completion(completion: ParsedChatCompletion[OptimizationSuite]):
107137
response = completion.choices[0].message
108138
rprint(f"[orange]Choice 1/{len(completion.choices)}[/orange]:")
109139
_print_message(response)
140+
if response.parsed is not None:
141+
_print_parsed_completion(response.parsed)
110142
if response.tool_calls:
111143
rprint("[blue]Tool calls:[/blue]")
112144
for call in response.tool_calls:

optastic/tools.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def exec(self, req: dict, project: Project) -> Any:
4343
srclines = project.get_lines(r.filename)
4444
result = find_symbol(srclines, r.line - 1, r.symbol)
4545
if result is None:
46-
return {"error": f"symbol {r.symbol} not found at {r.filename}:{r.line}"}
46+
return {
47+
"error": f"symbol {r.symbol} not found at {r.filename}:{r.line} (wrong line number?)"
48+
}
4749
line, column = result
4850

4951
resp = syncexec(
@@ -101,14 +103,14 @@ def exec(self, req: dict, project: Project) -> Any:
101103
srclines = project.get_lines(r.filename)
102104
result = find_symbol(srclines, r.line - 1, r.symbol)
103105
if result is None:
104-
return {"error": "symbol {r.symbol} not found at {r.filename}:{r.line}"}
106+
return {
107+
"error": "symbol {r.symbol} not found at {r.filename}:{r.line} (wrong line number?)"
108+
}
105109
line, column = result
106110

107111
resp = project.lsp().request_hover(r.filename, line, column)
108112
if resp is None:
109-
return {
110-
"error": "no info found for that location (maybe off-by-one error?)"
111-
}
113+
return {"error": "no info found for that location (wrong line number?)"}
112114
return {"contents": resp["contents"]}
113115

114116

0 commit comments

Comments
 (0)