Skip to content

Commit 43c4171

Browse files
committed
feat: add queries for function parameters and identifiers
1 parent 9182b4c commit 43c4171

File tree

7 files changed

+138
-5
lines changed

7 files changed

+138
-5
lines changed

scubatrace/cpp/language.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ class C(Language):
88
extensions = ["c", "h", "cc", "cpp", "cxx", "hxx", "hpp"]
99
tslanguage = TSLanguage(tscpp.language())
1010

11+
query_identifier = """
12+
(identifier)@name
13+
(field_identifier)@name
14+
"""
1115
query_call = "(call_expression)@name"
1216
query_import_identifier = """
1317
(preproc_include
@@ -17,6 +21,37 @@ class C(Language):
1721
]
1822
)
1923
"""
24+
query_function_parameter = """
25+
(parameter_declaration
26+
declarator: [
27+
(identifier)@name
28+
(pointer_declarator
29+
(identifier)@name
30+
)
31+
(pointer_declarator
32+
(pointer_declarator
33+
(identifier)@name
34+
)
35+
)
36+
(pointer_declarator
37+
(pointer_declarator
38+
(pointer_declarator
39+
(identifier)@name
40+
)
41+
)
42+
)
43+
(pointer_declarator
44+
(pointer_declarator
45+
(pointer_declarator
46+
(pointer_declarator
47+
(identifier)@name
48+
)
49+
)
50+
)
51+
)
52+
]
53+
)
54+
"""
2055

2156
query_struct = "(struct_specifier)@name"
2257
query_class = "(class_specifier)@name"

scubatrace/file.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,47 @@ def query_oneshot(self, query: str) -> Statement | None:
454454
if len(matched_statements) == 0:
455455
return None
456456
return matched_statements[0]
457+
458+
def query_identifiers(
459+
self, query: str, node: Node | None = None
460+
) -> list[Identifier]:
461+
"""
462+
Executes a tree-sitter query to find identifiers in the file.
463+
464+
Args:
465+
identifier (str): The identifier to search for.
466+
node (Node | None): The tree-sitter node to query. If None, uses the root node of the file.
467+
468+
Returns:
469+
list[Identifier]: A list of identifiers that contain the specified identifier.
470+
"""
471+
if node is None:
472+
node = self.node
473+
matched_nodes = self.parser.query_all(node, query)
474+
matched_identifiers = []
475+
for identifier in self.identifiers:
476+
for node in matched_nodes:
477+
if (
478+
identifier.node.start_byte >= node.start_byte
479+
and identifier.node.end_byte <= node.end_byte
480+
):
481+
matched_identifiers.append(identifier)
482+
break
483+
return matched_identifiers
484+
485+
def query_identifier(
486+
self, query: str, node: Node | None = None
487+
) -> Identifier | None:
488+
"""
489+
Executes a tree-sitter oneshot query to find an identifier in the file.
490+
491+
Args:
492+
query (str): The tree-sitter oneshot query to execute.
493+
494+
Returns:
495+
Identifier | None: The first identifier that matches the query, or None if no match is found.
496+
"""
497+
matched_identifiers = self.query_identifiers(query, node)
498+
if len(matched_identifiers) == 0:
499+
return None
500+
return matched_identifiers[0]

scubatrace/function.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from . import language as lang
1212
from .call import Call
13+
from .identifier import Identifier
1314
from .statement import BlockStatement, Statement
1415

1516
if TYPE_CHECKING:
@@ -163,12 +164,20 @@ def parameter_lines(self) -> list[int]:
163164
"""
164165
The lines where the parameters of the function are defined.
165166
"""
166-
parameters_node = self.node.child_by_field_name("parameters")
167-
if parameters_node is None:
167+
params = self.query_identifiers(self.language.query_function_parameter)
168+
if len(params) == 0:
168169
return [self.start_line]
169-
param_start_line = parameters_node.start_point[0] + 1
170-
param_end_line = parameters_node.end_point[0] + 1
171-
return list(range(param_start_line, param_end_line + 1))
170+
return list(range(params[0].start_line, params[-1].end_line + 1))
171+
172+
@cached_property
173+
def parameters(self) -> list[Identifier]:
174+
"""
175+
The parameter statements of the function.
176+
"""
177+
params = self.query_identifiers(self.language.query_function_parameter)
178+
if len(params) == 0:
179+
return self.block_variables
180+
return params
172181

173182
@cached_property
174183
def name_node(self) -> Node:

scubatrace/java/language.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ class JAVA(Language):
1919
(identifier)@name
2020
)
2121
"""
22+
query_function_parameter = """
23+
(formal_parameter
24+
name: (identifier)@name
25+
)
26+
"""
2227

2328
query_package = "(package_declaration)@name"
2429
query_class = "(class_declaration)@name"

scubatrace/language.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class Language:
4141
For example, in C/C++, this would match the `header.h` in `#include <header.h>`.
4242
"""
4343

44+
query_function_parameter: str
45+
"""
46+
The tree-sitter query to match function parameters.
47+
"""
48+
4449
EXIT_STATEMENTS: list[str] = []
4550
"""
4651
The tree-sitter AST types of exit statements.

scubatrace/python/language.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ class PYTHON(Language):
2727
]
2828
)
2929
"""
30+
query_function_parameter = """
31+
(parameters
32+
(identifier)@name
33+
)
34+
(typed_parameter
35+
(identifier)@name
36+
)
37+
"""
3038

3139
query_class = "(class_definition)@name"
3240

scubatrace/statement.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,30 @@ def query_oneshot(self, query: str) -> Statement | None:
10241024
if len(matched_statements) == 0:
10251025
return None
10261026
return matched_statements[0]
1027+
1028+
def query_identifiers(self, query: str) -> list[Identifier]:
1029+
"""
1030+
Executes a tree-sitter query to find identifiers in the block.
1031+
1032+
Args:
1033+
query (str): The tree-sitter query to execute.
1034+
1035+
Returns:
1036+
list[Identifier]: A list of identifiers that match the query.
1037+
"""
1038+
return self.file.query_identifiers(query, self.node)
1039+
1040+
def query_identifier(self, query: str) -> Identifier | None:
1041+
"""
1042+
Executes a tree-sitter oneshot query to find an identifier in the block.
1043+
1044+
Args:
1045+
query (str): The tree-sitter oneshot query to execute.
1046+
1047+
Returns:
1048+
Identifier | None: The identifier that matches the query, or None if not found.
1049+
"""
1050+
matched_identifiers = self.query_identifiers(query)
1051+
if len(matched_identifiers) == 0:
1052+
return None
1053+
return matched_identifiers[0]

0 commit comments

Comments
 (0)