Skip to content

Commit b58f1c2

Browse files
committed
Implement agent-driven profiling
1 parent bd4eeaf commit b58f1c2

File tree

6 files changed

+189
-77
lines changed

6 files changed

+189
-77
lines changed

accelerant/agent.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def run_agent(
3434
ag_tools: list[Tool] = [
3535
tools.edit_code,
3636
tools.check_codebase_for_errors,
37-
tools.get_profiler_data,
37+
tools.run_perf_profiler,
3838
tools.get_info,
3939
tools.get_references,
4040
tools.get_surrounding_code,
@@ -47,22 +47,17 @@ def run_agent(
4747
tools=ag_tools,
4848
)
4949

50-
with project.new_fs_sandbox() as fs:
51-
ag_context = AgentContext(
52-
project=project,
53-
active_fs=fs,
54-
initial_perf_data_path=ag_input["perf_data_path"],
55-
)
56-
prompt = user_prompt(
57-
lang=project.lang(), hotspot_lines=ag_input["hotspot_lines"] or []
58-
)
59-
result = Runner.run_sync(
60-
agent,
61-
prompt,
62-
context=ag_context,
63-
max_turns=100,
64-
).final_output
65-
assert result is not None
66-
final_message = str(result)
67-
fs.persist_all()
68-
return AgentResult(final_message=final_message)
50+
ag_context = AgentContext(project=project)
51+
prompt = user_prompt(
52+
lang=project.lang(), hotspot_lines=ag_input["hotspot_lines"] or []
53+
)
54+
result = Runner.run_sync(
55+
agent,
56+
prompt,
57+
context=ag_context,
58+
max_turns=100,
59+
).final_output
60+
assert result is not None
61+
final_message = str(result)
62+
project.fs_sandbox().persist_all()
63+
return AgentResult(final_message=final_message)

accelerant/fs_sandbox.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
from dataclasses import dataclass
2+
import hashlib
13
from pathlib import Path
24
from typing import Literal
35

46

7+
@dataclass(frozen=True)
8+
class FsVersion:
9+
hash: str
10+
11+
512
class FsSandbox:
613
base_dir: Path
714
old_versions: dict[Path, str]
15+
cur_hashes: dict[Path, str]
816
status: Literal["fresh"] | Literal["entered"] | Literal["done"] = "fresh"
917

1018
def __init__(self, base_dir: Path) -> None:
1119
self.base_dir = base_dir
1220
self.old_versions = {}
21+
self.cur_hashes = {}
1322

1423
def __enter__(self) -> "FsSandbox":
1524
self.status = "entered"
@@ -38,6 +47,10 @@ def write_file(self, relpath: Path, new_text: str) -> None:
3847
self.old_versions[relpath] = f.read()
3948
with open(abspath, "w") as f:
4049
f.write(new_text)
50+
self.cur_hashes[relpath] = hashlib.sha256(new_text.encode()).hexdigest()
51+
if self.old_versions[relpath] == new_text:
52+
del self.old_versions[relpath]
53+
del self.cur_hashes[relpath]
4154

4255
def persist(self, relpath: Path) -> None:
4356
assert self.status == "entered"
@@ -47,3 +60,12 @@ def persist(self, relpath: Path) -> None:
4760
def persist_all(self) -> None:
4861
assert self.status == "entered"
4962
self.old_versions = {}
63+
64+
def version(self) -> FsVersion:
65+
hasher = hashlib.sha256()
66+
for relpath in sorted(self.cur_hashes.keys()):
67+
hasher.update(relpath.as_posix().encode())
68+
hasher.update(b"\0")
69+
hasher.update(self.cur_hashes[relpath].encode())
70+
hasher.update(b"\0")
71+
return FsVersion(hash=hasher.hexdigest()[:8])

accelerant/project.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,117 @@
1+
import os
2+
import shutil
3+
import subprocess
4+
import time
15
from multilspy.lsp_protocol_handler import lsp_types
26

37

48
from pathlib import Path
59
from typing import List, Optional
610

7-
from accelerant.fs_sandbox import FsSandbox
11+
from accelerant.fs_sandbox import FsSandbox, FsVersion
812
from accelerant.lsp import LSP
913
from accelerant.perf import PerfData
1014

1115

1216
class Project:
1317
_root: Path
18+
# FIXME: this should probably not be here to allow for multiple targets
19+
_target_binary: Path
1420
_lang: str
21+
_fs: FsSandbox
1522
_lsp: Optional[LSP]
16-
_perf_data: dict[Path, PerfData]
23+
_perf_per_version: dict[FsVersion, Path]
24+
_perf_data_map: dict[Path, PerfData]
1725

18-
def __init__(self, root, lang):
26+
def __init__(self, root: Path, target_binary: Path, lang: str) -> None:
1927
self._root = root
28+
self._target_binary = target_binary
2029
self._lang = lang
30+
self._fs = FsSandbox(root)
2131
self._lsp = None
22-
self._perf_data = {}
32+
self._perf_per_version = {}
33+
self._perf_data_map = {}
34+
35+
def target_binary(self) -> Path:
36+
return self._target_binary
2337

2438
def lsp(self) -> LSP:
2539
if self._lsp is None:
2640
self._lsp = LSP(self._root, self._lang)
2741
return self._lsp
2842

29-
def perf_data(self, perf_data_path: Path) -> PerfData:
30-
if perf_data_path not in self._perf_data:
31-
self._perf_data[perf_data_path] = PerfData(perf_data_path, self._root)
32-
return self._perf_data[perf_data_path]
43+
def perf_data(self, version: Optional[FsVersion] = None) -> Optional[PerfData]:
44+
if version is None:
45+
version = self.fs_sandbox().version()
46+
if version not in self._perf_per_version:
47+
return None
48+
perf_data_path = self._perf_per_version[version]
49+
if perf_data_path not in self._perf_data_map:
50+
self._perf_data_map[perf_data_path] = PerfData(perf_data_path, self._root)
51+
return self._perf_data_map[perf_data_path]
52+
53+
def add_perf_data(self, version: FsVersion, perf_data_path: Path) -> None:
54+
self._perf_per_version[version] = perf_data_path
55+
56+
def build_for_profiling(self) -> None:
57+
if self._lang != "rust":
58+
raise NotImplementedError(
59+
f"Build for profiling not implemented for language: {self._lang}"
60+
)
61+
62+
cargo_path = shutil.which("cargo")
63+
assert cargo_path is not None, "cargo not found in PATH"
64+
65+
path_env_var = os.environ.get("PATH")
66+
assert path_env_var is not None, "PATH environment variable is not set"
67+
68+
subprocess.run(
69+
[
70+
cargo_path,
71+
"build",
72+
"--config",
73+
"profile.release.debug=true",
74+
"--release",
75+
"--all-targets",
76+
],
77+
check=True,
78+
cwd=str(self._root),
79+
env={"PATH": path_env_var},
80+
)
81+
82+
def run_profiler(
83+
self,
84+
) -> None:
85+
if self._lang != "rust":
86+
raise NotImplementedError(
87+
f"Profiler run not implemented for language: {self._lang}"
88+
)
89+
90+
perf_data_path = self._root / f"perf{time.time_ns()}.data"
91+
92+
path_env_var = os.environ.get("PATH")
93+
assert path_env_var is not None, "PATH environment variable is not set"
94+
95+
subprocess.run(
96+
[
97+
"perf",
98+
"record",
99+
"-F99",
100+
"--call-graph",
101+
"dwarf",
102+
"-o",
103+
str(perf_data_path),
104+
str(self._target_binary),
105+
],
106+
check=True,
107+
cwd=str(self._root),
108+
env={"PATH": path_env_var},
109+
)
110+
version = self.fs_sandbox().version()
111+
self.add_perf_data(version, perf_data_path)
33112

34-
def new_fs_sandbox(self) -> FsSandbox:
35-
return FsSandbox(self._root)
113+
def fs_sandbox(self) -> FsSandbox:
114+
return self._fs
36115

37116
def get_line(self, filename: str, line: int) -> str:
38117
assert line >= 0

accelerant/prompts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
"Use these tools effectively to gather the necessary information before making optimization suggestions.\n"
1212
"Always provide clear, concise, and actionable suggestions that can be directly implemented in the codebase.\n"
1313
"Take full control and apply edits to the code without needing approval from the user.\n"
14-
"Check the codebase for errors after making edits to ensure correctness."
14+
"Check the codebase for errors after making edits to ensure correctness.\n"
15+
"Rerun the performance profiler after making edits to measure improvements."
1516
),
1617
}
1718

accelerant/tools.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from dataclasses import dataclass
22
from itertools import islice
3-
from pathlib import Path
3+
import shutil
44
import subprocess
55
from typing import Any, Optional
66
from agents import RunContextWrapper, function_tool
77
from llm_utils import number_group_of_lines
88
from perfparser import LineLoc
99

1010
from accelerant.chat_interface import CodeSuggestion
11-
from accelerant.fs_sandbox import FsSandbox
1211
from accelerant.lsp import TOP_LEVEL_SYMBOL_KINDS, uri_to_relpath
1312
from accelerant.patch import apply_simultaneous_suggestions
1413
from accelerant.util import find_symbol, truncate_for_llm
@@ -18,8 +17,6 @@
1817
@dataclass
1918
class AgentContext:
2019
project: Project
21-
active_fs: FsSandbox
22-
initial_perf_data_path: Optional[Path]
2320

2421

2522
@function_tool
@@ -32,7 +29,9 @@ def edit_code(
3229
Args:
3330
suggs: A list of code suggestions that should be applied.
3431
"""
35-
apply_simultaneous_suggestions(ctx.context.project, ctx.context.active_fs, suggs)
32+
apply_simultaneous_suggestions(
33+
ctx.context.project, ctx.context.project.fs_sandbox(), suggs
34+
)
3635

3736

3837
@function_tool
@@ -43,26 +42,34 @@ def check_codebase_for_errors(
4342
assert ctx.context.project._lang == "rust", (
4443
"Only Rust is supported for code checking"
4544
)
45+
46+
cargo_path = shutil.which("cargo")
47+
assert cargo_path is not None, "cargo not found in PATH"
4648
try:
4749
subprocess.run(
48-
["cargo", "check", "--all"], check=True, cwd=str(ctx.context.project._root)
50+
[cargo_path, "check", "--all-targets"],
51+
check=True,
52+
cwd=str(ctx.context.project._root),
4953
)
5054
except subprocess.CalledProcessError as e:
5155
return f"ERROR: Codebase has errors:\n\n{e}"
5256
return "OK: Codebase has no errors!"
5357

5458

5559
@function_tool
56-
def get_profiler_data(
60+
def run_perf_profiler(
5761
ctx: RunContextWrapper[AgentContext],
5862
) -> list[dict[str, Any]]:
59-
"""Get a summary of the objective performance data gathered by a profiler."""
63+
"""Run a performance profiler on the target binary and return the top hotspots."""
6064
try:
61-
perf_data_path = ctx.context.initial_perf_data_path
62-
if perf_data_path is None:
63-
raise ValueError("No initial performance data path provided")
6465
project = ctx.context.project
65-
perf_data = project.perf_data(perf_data_path)
66+
version = project.fs_sandbox().version()
67+
perf_data = project.perf_data(version)
68+
if perf_data is None:
69+
project.build_for_profiling()
70+
project.run_profiler()
71+
perf_data = project.perf_data(version)
72+
assert perf_data is not None, "perf data should be available after profiling"
6673
perf_tabulated = perf_data.tabulate()
6774
NUM_HOTSPOTS = 5
6875

@@ -76,18 +83,20 @@ def get_parent_region(loc: LineLoc) -> Optional[str]:
7683
return None
7784
return parent_sym["name"]
7885

79-
hotspots = islice(
80-
map(
81-
lambda x: {
82-
"parent_region": get_parent_region(x[0]) or "<unknown>",
83-
"loc": x[0],
84-
"pct_time": x[1] * 100,
85-
},
86-
filter(lambda x: x[0].line > 0, perf_tabulated),
87-
),
88-
NUM_HOTSPOTS,
86+
hotspots = list(
87+
islice(
88+
map(
89+
lambda x: {
90+
"parent_region": get_parent_region(x[0]) or "<unknown>",
91+
"loc": x[0],
92+
"pct_time": x[1] * 100,
93+
},
94+
filter(lambda x: x[0].line > 0, perf_tabulated),
95+
),
96+
NUM_HOTSPOTS,
97+
)
8998
)
90-
return list(hotspots)
99+
return hotspots
91100
except Exception as e:
92101
print("ERROR", e)
93102
raise e
@@ -240,8 +249,8 @@ def get_surrounding_code(
240249
filename, line - 1, TOP_LEVEL_SYMBOL_KINDS
241250
),
242251
)
243-
# FIXME: avoid crashing
244-
assert parent_sym is not None
252+
if parent_sym is None:
253+
raise ValueError(f"no surrounding top-level symbol found at {filename}:{line}")
245254
sline = parent_sym["range"]["start"]["line"] + 1
246255
lines = project.get_range(filename, parent_sym["range"])
247256
return {

0 commit comments

Comments
 (0)