Skip to content

Commit c9682f8

Browse files
committed
pkg/codesearch: support searching for references
Extend codesearch clang tool to export info about function references (calls, takes-address-of). Add pkg/codesearch command find-references. Export find-references in pkg/aflow/tools/codesearcher to LLMs. Update #6469
1 parent 99063cc commit c9682f8

26 files changed

+621
-45
lines changed

pkg/aflow/tool/codesearcher/codesearcher.go

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Tool provides full source code for an entity with the given name.
4444
Entity can be function, struct, or global variable.
4545
Use it to understand implementation details of an entity.
4646
For example, how a function works, what precondition error checks it has, etc.
47+
`),
48+
aflow.NewFuncTool("codesearch-find-references", findReferences, `
49+
Tool finds and lists all references to (uses of) the given entity.
50+
Entity can be function, struct, or global variable.
51+
If can be used to find all calls or other uses of the given function.
4752
`),
4853
}
4954

@@ -95,8 +100,8 @@ type indexEntity struct {
95100

96101
// nolint: lll
97102
type defCommentArgs struct {
98-
SourceFile string `jsonschema:"Source file path that references the entity. It helps to restrict scope of the search, if there are different definitions with the same name in different source files."`
99-
Name string `jsonschema:"Name of the entity of interest."`
103+
ContextFile string `jsonschema:"Source file path that references the entity. It helps to restrict scope of the search, if there are different definitions with the same name in different source files."`
104+
Name string `jsonschema:"Name of the entity of interest."`
100105
}
101106

102107
type defCommentResult struct {
@@ -106,7 +111,7 @@ type defCommentResult struct {
106111

107112
// nolint: lll
108113
type defSourceArgs struct {
109-
SourceFile string `jsonschema:"Source file path that references the entity. It helps to restrict scope of the search, if there are different definitions with the same name in different source files."`
114+
ContextFile string `jsonschema:"Source file path that references the entity. It helps to restrict scope of the search, if there are different definitions with the same name in different source files."`
110115
Name string `jsonschema:"Name of the entity of interest."`
111116
IncludeLines bool `jsonschema:"Whether to include line numbers in the output or not. Line numbers may distract you, so ask for them only if you need to match lines elsewhere with the source code."`
112117
}
@@ -181,7 +186,7 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil
181186
}
182187

183188
func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) {
184-
info, err := state.Index.DefinitionComment(args.SourceFile, args.Name)
189+
info, err := state.Index.DefinitionComment(args.ContextFile, args.Name)
185190
if err != nil {
186191
return defCommentResult{}, err
187192
}
@@ -192,7 +197,7 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA
192197
}
193198

194199
func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) {
195-
info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines)
200+
info, err := state.Index.DefinitionSource(args.ContextFile, args.Name, args.IncludeLines)
196201
if err != nil {
197202
return defSourceResult{}, err
198203
}
@@ -201,3 +206,56 @@ func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArg
201206
SourceCode: info.Body,
202207
}, nil
203208
}
209+
210+
// nolint: lll
211+
type findReferencesArgs struct {
212+
ContextFile string `jsonschema:"Source file path that references the entity. It helps to restrict scope of the search, if there are different definitions with the same name in different source files." json:",omitempty"`
213+
Name string `jsonschema:"Name of the entity of interest."`
214+
SourceTreePrefix string `jsonschema:"Prefix of the sources tree where to search for references. Can be used to restrict search to e.g. net/ipv4/. Pass an empty string to find all references." json:",omitempty"`
215+
IncludeSnippetLines uint `jsonschema:"If set to non-0, output will include source code snippets with that many lines of context. If set to 0, no source snippets will be included. Snippets only show the referencing entity, so to see e.g. whole referencing functions pass a large value, e.g. 10000" json:",omitempty"`
216+
}
217+
218+
// nolint: lll
219+
type findReferencesResult struct {
220+
TruncatedOutput bool `jsonschema:"Set if there were too many references, and the output is truncated. If you get truncated output, you may try to either request w/o source code snippets by passing IncludeSnippetLines=0 (which has higher limit on the number of output references), or restrict search to some prefix of the source tree with SourceTreePrefix argument."`
221+
References []Reference `jsonschema:"List of requested references."`
222+
}
223+
224+
type Reference struct {
225+
ReferencingEntityKind string `jsonschema:"Kind of the referencing entity (function, struct, etc)."`
226+
ReferencingEntityName string `jsonschema:"Name of the referencing entity."`
227+
ReferenceKind string `jsonschema:"Kind of the reference (calls, takes-address, reads, writes-to, etc)."`
228+
SourceFile string `jsonschema:"Source file of the reference."`
229+
SourceLine int `jsonschema:"Source line of the reference."`
230+
SourceSnippet string `jsonschema:"Surrounding code snippet, if requested." json:",omitempty"`
231+
}
232+
233+
func findReferences(ctx *aflow.Context, state prepareResult, args findReferencesArgs) (findReferencesResult, error) {
234+
outputLimit := 20
235+
if args.IncludeSnippetLines == 0 {
236+
outputLimit = 1000
237+
} else if args.IncludeSnippetLines < 10 {
238+
outputLimit = 100
239+
}
240+
refs, totalCount, err := state.Index.FindReferences(
241+
args.ContextFile, args.Name, args.SourceTreePrefix,
242+
int(args.IncludeSnippetLines), outputLimit)
243+
if err != nil {
244+
return findReferencesResult{}, err
245+
}
246+
var results []Reference
247+
for _, ref := range refs {
248+
results = append(results, Reference{
249+
ReferencingEntityKind: ref.ReferencingEntityKind,
250+
ReferencingEntityName: ref.ReferencingEntityName,
251+
ReferenceKind: ref.ReferenceKind,
252+
SourceFile: ref.SourceFile,
253+
SourceLine: ref.SourceLine,
254+
SourceSnippet: ref.SourceSnippet,
255+
})
256+
}
257+
return findReferencesResult{
258+
TruncatedOutput: totalCount > len(refs),
259+
References: results,
260+
}, nil
261+
}

pkg/clangtool/tooltest/tooltest.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"flag"
99
"fmt"
10+
"io/fs"
1011
"os"
1112
"path/filepath"
1213
"testing"
@@ -82,7 +83,13 @@ func ForEachTestFile(t *testing.T, fn func(t *testing.T, cfg *clangtool.Config,
8283
}
8384

8485
func forEachTestFile(t *testing.T, fn func(t *testing.T, file string)) {
85-
files, err := filepath.Glob(filepath.Join(osutil.Abs("testdata"), "*.c"))
86+
var files []string
87+
err := filepath.WalkDir(osutil.Abs("testdata"), func(path string, d fs.DirEntry, err error) error {
88+
if d.Name()[0] != '.' && filepath.Ext(d.Name()) == ".c" {
89+
files = append(files, path)
90+
}
91+
return err
92+
})
8693
if err != nil {
8794
t.Fatal(err)
8895
}

pkg/codesearch/codesearch.go

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"os"
1111
"path/filepath"
1212
"slices"
13+
"strconv"
1314
"strings"
1415
"syscall"
1516

@@ -80,6 +81,28 @@ var Commands = []Command{
8081
}
8182
return fmt.Sprintf("%v %v is defined in %v:\n\n%v", info.Kind, args[1], info.File, info.Body), nil
8283
}},
84+
{"find-references", 5, func(index *Index, args []string) (string, error) {
85+
contextLines, err := strconv.Atoi(args[3])
86+
if err != nil {
87+
return "", fmt.Errorf("failed to parse number of context lines %q: %v", args[3], err)
88+
}
89+
outputLimit, err := strconv.Atoi(args[4])
90+
if err != nil {
91+
return "", fmt.Errorf("failed to parse output limit %q: %v", args[4], err)
92+
}
93+
refs, totalCount, err := index.FindReferences(args[0], args[1], args[2], contextLines, outputLimit)
94+
if err != nil {
95+
return "", err
96+
}
97+
b := new(strings.Builder)
98+
fmt.Fprintf(b, "%v has %v references:\n\n", args[1], totalCount)
99+
for _, ref := range refs {
100+
fmt.Fprintf(b, "%v %v %v it at %v:%v\n%v\n\n",
101+
ref.ReferencingEntityKind, ref.ReferencingEntityName, ref.ReferenceKind,
102+
ref.SourceFile, ref.SourceLine, ref.SourceSnippet)
103+
}
104+
return b.String(), nil
105+
}},
83106
}
84107

85108
func IsSourceFile(file string) bool {
@@ -225,6 +248,68 @@ func (index *Index) definitionSource(contextFile, name string, comment, includeL
225248
}, nil
226249
}
227250

251+
type ReferenceInfo struct {
252+
ReferencingEntityKind string
253+
ReferencingEntityName string
254+
ReferenceKind string
255+
SourceFile string
256+
SourceLine int
257+
SourceSnippet string
258+
}
259+
260+
func (index *Index) FindReferences(contextFile, name, srcPrefix string, contextLines, outputLimit int) ([]*ReferenceInfo, int, error) {
261+
target := index.findDefinition(contextFile, name)
262+
if target == nil {
263+
return nil, 0, aflow.BadCallError("requested entity does not exist")
264+
}
265+
if srcPrefix != "" {
266+
srcPrefix = filepath.Clean(srcPrefix)
267+
}
268+
totalCount := 0
269+
var results []*ReferenceInfo
270+
for _, def := range index.db.Definitions {
271+
if !strings.HasPrefix(def.Body.File, srcPrefix) {
272+
continue
273+
}
274+
for _, ref := range def.Refs {
275+
// TODO: this mis-handles the following case:
276+
// the target is a non-static 'foo' in some file,
277+
// the reference is in another file and refers to a static 'foo'
278+
// defined in that file (which is not the target 'foo').
279+
if !(ref.EntityKind == target.Kind && ref.Name == target.Name &&
280+
(!target.IsStatic || target.Body.File == def.Body.File)) {
281+
continue
282+
}
283+
totalCount++
284+
if totalCount > outputLimit {
285+
continue
286+
}
287+
snippet := ""
288+
if contextLines > 0 {
289+
lines := LineRange{
290+
File: def.Body.File,
291+
StartLine: max(def.Body.StartLine, ref.Line-contextLines),
292+
EndLine: min(def.Body.EndLine, ref.Line+contextLines),
293+
}
294+
var err error
295+
snippet, err = index.formatSource(lines, true)
296+
if err != nil {
297+
return nil, 0, err
298+
}
299+
}
300+
results = append(results, &ReferenceInfo{
301+
ReferencingEntityKind: def.Kind,
302+
ReferencingEntityName: def.Name,
303+
ReferenceKind: ref.Kind,
304+
SourceFile: def.Body.File,
305+
SourceLine: ref.Line,
306+
SourceSnippet: snippet,
307+
})
308+
}
309+
}
310+
return results, totalCount, nil
311+
}
312+
228313
func (index *Index) findDefinition(contextFile, name string) *Definition {
229314
var weakMatch *Definition
230315
for _, def := range index.db.Definitions {
@@ -269,7 +354,7 @@ func formatSourceFile(file string, start, end int, includeLines bool) (string, e
269354
b := new(strings.Builder)
270355
for line := start; line <= end; line++ {
271356
if includeLines {
272-
fmt.Fprintf(b, "%4v:\t%s\n", line, lines[line])
357+
fmt.Fprintf(b, "%4v:\t%s\n", line+1, lines[line])
273358
} else {
274359
fmt.Fprintf(b, "%s\n", lines[line])
275360
}

pkg/codesearch/codesearch_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,24 @@ func testCommand(t *testing.T, index *Index, covered map[string]bool, file strin
4747
t.Fatal(err)
4848
}
4949
query, _, _ := bytes.Cut(data, []byte{'\n'})
50-
args := strings.Fields(string(query))
51-
if len(args) == 0 {
50+
fields := strings.Fields(string(query))
51+
if len(fields) == 0 {
5252
t.Fatal("no command found")
5353
}
54-
result, err := index.Command(args[0], args[1:])
54+
cmd := fields[0]
55+
var args []string
56+
for _, arg := range fields[1:] {
57+
if arg == `""` {
58+
arg = ""
59+
}
60+
args = append(args, arg)
61+
}
62+
result, err := index.Command(cmd, args)
5563
if err != nil {
5664
// This is supposed to test aflow.BadCallError messages.
5765
result = err.Error() + "\n"
5866
}
59-
got := append([]byte(strings.Join(args, " ")+"\n\n"), result...)
67+
got := append([]byte(strings.Join(fields, " ")+"\n\n"), result...)
6068
tooltest.CompareGoldenData(t, file, got)
61-
covered[args[0]] = true
69+
covered[cmd] = true
6270
}

pkg/codesearch/database.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ type Database struct {
1414
}
1515

1616
type Definition struct {
17-
Kind string `json:"kind,omitempty"`
18-
Name string `json:"name,omitempty"`
19-
Type string `json:"type,omitempty"`
20-
IsStatic bool `json:"is_static,omitempty"`
21-
Body LineRange `json:"body,omitempty"`
22-
Comment LineRange `json:"comment,omitempty"`
17+
Kind string `json:"kind,omitempty"`
18+
Name string `json:"name,omitempty"`
19+
Type string `json:"type,omitempty"`
20+
IsStatic bool `json:"is_static,omitempty"`
21+
Body LineRange `json:"body,omitempty"`
22+
Comment LineRange `json:"comment,omitempty"`
23+
Refs []Reference `json:"refs,omitempty"`
24+
}
25+
26+
type Reference struct {
27+
Kind string `json:"kind,omitempty"`
28+
EntityKind string `json:"entity_kind,omitempty"`
29+
Name string `json:"name,omitempty"`
30+
Line int `json:"line,omitempty"`
2331
}
2432

2533
type LineRange struct {

pkg/codesearch/testdata/mm/refs.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
int refs2();
2+
3+
void ref_in_mm()
4+
{
5+
refs2();
6+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"definitions": [
3+
{
4+
"kind": "function",
5+
"name": "ref_in_mm",
6+
"type": "void ()",
7+
"body": {
8+
"file": "mm/refs.c",
9+
"start_line": 3,
10+
"end_line": 6
11+
},
12+
"comment": {},
13+
"refs": [
14+
{
15+
"kind": "calls",
16+
"entity_kind": "function",
17+
"name": "refs2",
18+
"line": 5
19+
}
20+
]
21+
}
22+
]
23+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

pkg/codesearch/testdata/query-def-source-header

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ def-source source0.c function_with_comment_in_header yes
22

33
function function_with_comment_in_header is defined in source0.c:
44

5-
18: void function_with_comment_in_header()
6-
19: {
7-
20: same_name_in_several_files();
8-
21: }
5+
19: void function_with_comment_in_header()
6+
20: {
7+
21: same_name_in_several_files();
8+
22: }

pkg/codesearch/testdata/query-def-source-open

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ def-source source0.c open yes
22

33
function open is defined in source0.c:
44

5-
5: /*
6-
6: * Comment about open.
7-
7: */
8-
8: int open()
9-
9: {
10-
10: return 0;
11-
11: }
5+
6: /*
6+
7: * Comment about open.
7+
8: */
8+
9: int open()
9+
10: {
10+
11: return 0;
11+
12: }

0 commit comments

Comments
 (0)