Skip to content

Commit c35f03e

Browse files
authored
fix(autofix): Some fixes (#1904)
Not going ahead with DeepClaude, but discovered some bug in converting tool calls between OpenAI, Anthropic, and Gemini, so fixing that. Also adding support for reasoning-related parameters for OpenAI. Also improving the get_valid_file_paths check that the AI suggested. And fixes the exception logging in get_file_content so we don't get spammed by Sentry. And change the index at which we insert the initial prompt in the solution/coding step to prevent some hallucinations we were seeing.
1 parent 43b0c2e commit c35f03e

File tree

7 files changed

+52
-38
lines changed

7 files changed

+52
-38
lines changed

src/seer/automation/agent/agent.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class RunConfig(BaseModel):
3535
memory_storage_key: str | None = None
3636
temperature: float | None = 0.0
3737
run_name: str | None = None
38+
reasoning_effort: str | None = None
3839

3940

4041
class LlmAgent:
@@ -62,6 +63,7 @@ def get_completion(self, run_config: RunConfig):
6263
system_prompt=run_config.system_prompt if run_config.system_prompt else None,
6364
tools=(self.tools if len(self.tools) > 0 else None),
6465
temperature=run_config.temperature or 0.0,
66+
reasoning_effort=run_config.reasoning_effort,
6567
)
6668

6769
def run_iteration(self, run_config: RunConfig):

src/seer/automation/agent/client.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,14 @@ def generate_text(
109109
max_tokens: int | None = None,
110110
timeout: float | None = None,
111111
predicted_output: str | None = None,
112+
reasoning_effort: str | None = None,
112113
):
113114
message_dicts, tool_dicts = self._prep_message_and_tools(
114115
messages=messages,
115116
prompt=prompt,
116117
system_prompt=system_prompt,
117118
tools=tools,
119+
reasoning_effort=reasoning_effort,
118120
)
119121

120122
openai_client = self.get_client()
@@ -138,6 +140,7 @@ def generate_text(
138140
if predicted_output
139141
else openai.NotGiven()
140142
),
143+
reasoning_effort=reasoning_effort if reasoning_effort else openai.NotGiven(),
141144
)
142145

143146
openai_message = completion.choices[0].message
@@ -183,12 +186,14 @@ def generate_structured(
183186
response_format: Type[StructuredOutputType],
184187
max_tokens: int | None = None,
185188
timeout: float | None = None,
189+
reasoning_effort: str | None = None,
186190
) -> LlmGenerateStructuredResponse[StructuredOutputType]:
187191
message_dicts, tool_dicts = self._prep_message_and_tools(
188192
messages=messages,
189193
prompt=prompt,
190194
system_prompt=system_prompt,
191195
tools=tools,
196+
reasoning_effort=reasoning_effort,
192197
)
193198

194199
openai_client = self.get_client()
@@ -205,6 +210,7 @@ def generate_structured(
205210
response_format=response_format,
206211
max_tokens=max_tokens or openai.NotGiven(),
207212
timeout=timeout or openai.NotGiven(),
213+
reasoning_effort=reasoning_effort if reasoning_effort else openai.NotGiven(),
208214
)
209215

210216
openai_message = completion.choices[0].message
@@ -244,6 +250,7 @@ def to_message_dict(message: Message) -> ChatCompletionMessageParam:
244250
new_item["type"] = "function"
245251
parsed_tool_calls.append(new_item)
246252
message_dict["tool_calls"] = parsed_tool_calls
253+
message_dict["role"] = "assistant"
247254

248255
if message.tool_call_id:
249256
message_dict["tool_call_id"] = message.tool_call_id
@@ -284,11 +291,18 @@ def _prep_message_and_tools(
284291
prompt: str | None = None,
285292
system_prompt: str | None = None,
286293
tools: list[FunctionTool] | None = None,
294+
reasoning_effort: str | None = None,
287295
):
288296
message_dicts = [cls.to_message_dict(message) for message in messages] if messages else []
289297
if system_prompt:
290298
message_dicts.insert(
291-
0, cls.to_message_dict(Message(role="system", content=system_prompt))
299+
0,
300+
cls.to_message_dict(
301+
Message(
302+
role="system" if not reasoning_effort else "developer",
303+
content=system_prompt,
304+
)
305+
),
292306
)
293307
if prompt:
294308
message_dicts.append(cls.to_message_dict(Message(role="user", content=prompt)))
@@ -310,12 +324,14 @@ def generate_text_stream(
310324
temperature: float | None = None,
311325
max_tokens: int | None = None,
312326
timeout: float | None = None,
327+
reasoning_effort: str | None = None,
313328
) -> Iterator[str | ToolCall | Usage]:
314329
message_dicts, tool_dicts = self._prep_message_and_tools(
315330
messages=messages,
316331
prompt=prompt,
317332
system_prompt=system_prompt,
318333
tools=tools,
334+
reasoning_effort=reasoning_effort,
319335
)
320336

321337
openai_client = self.get_client()
@@ -333,6 +349,7 @@ def generate_text_stream(
333349
timeout=timeout or openai.NotGiven(),
334350
stream=True,
335351
stream_options={"include_usage": True},
352+
reasoning_effort=reasoning_effort if reasoning_effort else openai.NotGiven(),
336353
)
337354

338355
try:
@@ -515,7 +532,7 @@ def to_message_param(message: Message) -> MessageParam:
515532
)
516533
],
517534
)
518-
elif message.role == "tool_use":
535+
elif message.role == "tool_use" or (message.role == "assistant" and message.tool_calls):
519536
if not message.tool_calls:
520537
return MessageParam(role="assistant", content=[])
521538
tool_call = message.tool_calls[0] # Assuming only one tool call per message
@@ -679,14 +696,6 @@ def construct_message_from_stream(
679696

680697
@dataclass
681698
class GeminiProvider:
682-
# !!! NOTE THE FOLLOWING LIMITATIONS FOR GEMINI:
683-
# - super strict rate limits making it unusable for evals or prod
684-
# - no multi-turn tool use
685-
# - no nested Pydantic models for structured outputs
686-
# - no nullable fields for structured outputs
687-
# - no dynamic retrieval for google search
688-
# These will likely be changed as the SDK matures. Make sure to keep an eye on updates and update these notes/our implementation as needed.
689-
690699
model_name: str
691700
provider_name = LlmProviderType.GEMINI
692701
defaults: LlmProviderDefaults | None = None
@@ -985,7 +994,7 @@ def _prep_message_and_tools(
985994

986995
@staticmethod
987996
def to_content(message: Message) -> Content:
988-
if message.role == "tool_use":
997+
if message.role == "tool_use" or (message.role == "assistant" and message.tool_calls):
989998
if not message.tool_calls:
990999
return Content(
9911000
role="model",
@@ -1117,6 +1126,7 @@ def generate_text(
11171126
run_name: str | None = None,
11181127
timeout: float | None = None,
11191128
predicted_output: str | None = None,
1129+
reasoning_effort: str | None = None,
11201130
) -> LlmGenerateTextResponse:
11211131
try:
11221132
if run_name:
@@ -1140,6 +1150,7 @@ def generate_text(
11401150
tools=tools,
11411151
timeout=timeout,
11421152
predicted_output=predicted_output,
1153+
reasoning_effort=reasoning_effort,
11431154
)
11441155
elif model.provider_name == LlmProviderType.ANTHROPIC:
11451156
model = cast(AnthropicProvider, model)
@@ -1182,6 +1193,7 @@ def generate_structured(
11821193
max_tokens: int | None = None,
11831194
run_name: str | None = None,
11841195
timeout: float | None = None,
1196+
reasoning_effort: str | None = None,
11851197
) -> LlmGenerateStructuredResponse[StructuredOutputType]:
11861198
try:
11871199
if run_name:
@@ -1203,6 +1215,7 @@ def generate_structured(
12031215
temperature=temperature,
12041216
tools=tools,
12051217
timeout=timeout,
1218+
reasoning_effort=reasoning_effort,
12061219
)
12071220
elif model.provider_name == LlmProviderType.ANTHROPIC:
12081221
raise NotImplementedError("Anthropic structured outputs are not yet supported")
@@ -1236,6 +1249,7 @@ def generate_text_stream(
12361249
max_tokens: int | None = None,
12371250
run_name: str | None = None,
12381251
timeout: float | None = None,
1252+
reasoning_effort: str | None = None,
12391253
) -> Iterator[str | ToolCall | Usage]:
12401254
try:
12411255
if run_name:
@@ -1260,6 +1274,7 @@ def generate_text_stream(
12601274
temperature=temperature or default_temperature,
12611275
tools=tools,
12621276
timeout=timeout,
1277+
reasoning_effort=reasoning_effort,
12631278
)
12641279
elif model.provider_name == LlmProviderType.ANTHROPIC:
12651280
model = cast(AnthropicProvider, model)

src/seer/automation/autofix/autofix_agent.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _get_completion(self, run_config: RunConfig):
8484
system_prompt=run_config.system_prompt if run_config.system_prompt else None,
8585
tools=(self.tools if len(self.tools) > 0 else None),
8686
temperature=run_config.temperature or 0.0,
87+
reasoning_effort=run_config.reasoning_effort,
8788
)
8889

8990
cleared = False

src/seer/automation/autofix/components/coding/component.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,11 @@ def invoke(self, request: CodingRequest) -> CodingOutput | None:
223223
custom_solution = request.solution if isinstance(request.solution, str) else None
224224

225225
if not request.initial_memory:
226-
agent.memory.insert(
227-
0,
228-
Message(
229-
role="user",
230-
content=CodingPrompts.format_fix_msg(
231-
has_tools=not is_obvious,
232-
custom_solution=custom_solution,
233-
mode=request.mode,
234-
),
226+
agent.add_user_message(
227+
CodingPrompts.format_fix_msg(
228+
has_tools=not is_obvious,
229+
custom_solution=custom_solution,
230+
mode=request.mode,
235231
),
236232
)
237233

src/seer/automation/autofix/components/solution/component.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,15 @@ def invoke(
142142

143143
state = self.context.state.get()
144144
if not request.initial_memory:
145-
agent.memory.insert(
146-
0,
147-
Message(
148-
role="user",
149-
content=SolutionPrompts.format_default_msg(
150-
event=request.event_details.format_event(),
151-
root_cause=request.root_cause_and_fix,
152-
summary=request.summary,
153-
repo_names=[repo.full_name for repo in state.request.repos],
154-
original_instruction=request.original_instruction,
155-
code_map=request.profile,
156-
has_tools=not is_obvious,
157-
),
145+
agent.add_user_message(
146+
SolutionPrompts.format_default_msg(
147+
event=request.event_details.format_event(),
148+
root_cause=request.root_cause_and_fix,
149+
summary=request.summary,
150+
repo_names=[repo.full_name for repo in state.request.repos],
151+
original_instruction=request.original_instruction,
152+
code_map=request.profile,
153+
has_tools=not is_obvious,
158154
),
159155
)
160156

src/seer/automation/codebase/repo_client.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,9 @@ def get_file_content(
333333
)
334334
path = closest_match
335335
else:
336-
logger.error(f"No matching file found for path: {path}")
336+
logger.exception(
337+
"No matching file found for provided file path", extra={"path": path}
338+
)
337339
return None, "utf-8"
338340

339341
try:
@@ -366,9 +368,8 @@ def get_valid_file_paths(self, sha: str | None = None, files_only=False) -> set[
366368
valid_file_paths: set[str] = set()
367369

368370
for file in tree.tree:
369-
if files_only and "." not in file.path:
370-
continue
371-
valid_file_paths.add(file.path)
371+
if file.type == "blob":
372+
valid_file_paths.add(file.path)
372373

373374
return valid_file_paths
374375

tests/automation/codebase/test_repo_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def test_fail_get_file_content(self, mock_requests, repo_client, mock_github):
137137

138138
def test_get_valid_file_paths(self, repo_client, mock_github):
139139
mock_tree = MagicMock()
140-
mock_tree.tree = [MagicMock(path="file1.py"), MagicMock(path="file2.py")]
140+
mock_tree.tree = [
141+
MagicMock(path="file1.py", type="blob"),
142+
MagicMock(path="file2.py", type="blob"),
143+
]
141144
mock_tree.raw_data = {"truncated": False}
142145
mock_github.get_repo.return_value.get_git_tree.return_value = mock_tree
143146

0 commit comments

Comments
 (0)