Skip to content

Commit 5bb9efe

Browse files
Basic support for #match? predicate
This can and should be done in a cleaner way.
1 parent 285a90b commit 5bb9efe

File tree

6 files changed

+106
-2
lines changed

6 files changed

+106
-2
lines changed

spec/predicate_spec.cr

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
require "./spec_helper"
2+
3+
describe TreeSitter::Predicate do
4+
it "supports `#match?`" do
5+
parser = TreeSitter::Parser.new("json")
6+
source = <<-JSON
7+
{
8+
"hello": 2
9+
"goodnight": [
10+
"moon", "sky", "earth", 1
11+
]
12+
}
13+
JSON
14+
15+
tree = parser.parse nil, source
16+
17+
query = TreeSitter::Query.new(parser.language, <<-SCM)
18+
((number) @test
19+
(#match? @test "1"))
20+
SCM
21+
22+
cursor = TreeSitter::QueryCursor.new(query)
23+
cursor.exec(tree.root_node)
24+
25+
idx = 0
26+
cursor.each_capture do |capture|
27+
if idx == 0
28+
capture.text(source).should eq("2")
29+
TreeSitter::Predicate.resolve(query, capture, source).should eq(false)
30+
elsif idx == 1
31+
capture.text(source).should eq("1")
32+
TreeSitter::Predicate.resolve(query, capture, source).should eq(true)
33+
else
34+
raise "shouldn't be here"
35+
end
36+
ensure
37+
idx += 1
38+
end
39+
end
40+
end

src/tree_sitter.cr

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require "./tree_sitter/query"
66
require "./tree_sitter/query_cursor"
77
require "./tree_sitter/editor"
88
require "./tree_sitter/range"
9+
require "./tree_sitter/predicate"
910

1011
private def calloc(n : LibC::SizeT, size : LibC::SizeT) : Pointer(Void)
1112
GC.malloc(n * size)

src/tree_sitter/capture.cr

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module TreeSitter
2-
record Capture, rule : String, node : Node do
2+
record Capture, rule : String, node : Node, capture_index : UInt32 do
33
def includes_line?(line_n : Int32) : Bool
44
node.start_point.row == line_n || node.end_point.row == line_n
55
end
@@ -13,6 +13,10 @@ module TreeSitter
1313
node.to_s(io)
1414
end
1515

16+
def text(source : String) : String
17+
node.text(source)
18+
end
19+
1620
def inspect(io : IO)
1721
io << "#<Capture "
1822
to_s(io)

src/tree_sitter/node.cr

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ module TreeSitter
122122
io.write(bytes)
123123
end
124124

125+
def text(source : String) : String
126+
start_pos = start_byte
127+
end_pos = end_byte
128+
slice = source.byte_slice(start_pos, end_pos - start_pos)
129+
@@string_pool.get(slice)
130+
end
131+
125132
# :nodoc:
126133
def to_unsafe
127134
@node

src/tree_sitter/predicate.cr

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module TreeSitter
2+
PREDICATES = {
3+
"match?" => MatchPredicate,
4+
}
5+
6+
abstract class Predicate
7+
@query : TreeSitter::Query
8+
@steps : Array(LibTreeSitter::TSQueryPredicateStep)
9+
10+
def initialize(@query, @steps)
11+
end
12+
13+
def self.resolve(
14+
query : TreeSitter::Query,
15+
capture : TreeSitter::Capture,
16+
source : String,
17+
) : Bool
18+
unsafe_steps = LibTreeSitter.ts_query_predicates_for_pattern(
19+
query,
20+
capture.capture_index,
21+
out step_count
22+
)
23+
steps = Slice.new(unsafe_steps, step_count).to_a
24+
25+
name_ptr = LibTreeSitter.ts_query_string_value_for_id(
26+
query.to_unsafe, steps[0].value_id, out name_len
27+
)
28+
29+
name = String.new(name_ptr, name_len)
30+
31+
!!PREDICATES[name]?.try(&.new(query, steps).call(capture, source))
32+
end
33+
end
34+
35+
class MatchPredicate < Predicate
36+
def call(capture : TreeSitter::Capture, source : String) : Bool
37+
# Get the regex pattern from the third step (steps[2])
38+
pattern_ptr = LibTreeSitter.ts_query_string_value_for_id(
39+
@query.to_unsafe,
40+
@steps[2].value_id,
41+
out pattern_len
42+
)
43+
pattern = String.new(pattern_ptr, pattern_len)
44+
45+
# Get the text from the captured node
46+
node_text = capture.text(source)
47+
48+
# Match the captured text against the regex pattern
49+
Regex.new(pattern).matches?(node_text)
50+
end
51+
end
52+
end

src/tree_sitter/query_cursor.cr

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ module TreeSitter
6363
capture = match.captures[capture_index]
6464
ptr = LibTreeSitter.ts_query_capture_name_for_id(@query, capture.index, out strlen)
6565
rule = TreeSitter.string_pool.get(ptr, strlen)
66-
Capture.new(rule, Node.new(capture.node))
66+
Capture.new(rule, Node.new(capture.node), capture_index)
6767
end
6868

6969
def each_capture(& : Capture -> Nil)

0 commit comments

Comments
 (0)