Skip to content

Commit aa118ab

Browse files
committed
Refactor highlight_tokens_cover to enforce invariants
1 parent 112805e commit aa118ab

File tree

1 file changed

+79
-36
lines changed

1 file changed

+79
-36
lines changed

src/highlighting_lexer/query.rs

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::{
33
collections::HashMap,
44
ops::{Deref, Range},
55
sync::Arc,
6-
usize,
76
};
87

98
use jni::{
@@ -14,7 +13,7 @@ use jni::{
1413
use streaming_iterator::StreamingIterator as _;
1514
use tree_sitter::{
1615
ffi::{self, TSTree},
17-
Node, Query, QueryCursor, TextProvider, Tree,
16+
Node, Query, QueryCursor, TextProvider, Tree, TreeCursor,
1817
};
1918

2019
use crate::language_registry::{LanguageId, LANGUAGE_REGISTRY};
@@ -73,19 +72,59 @@ impl<'a> TextProvider<Vec<u8>> for RecodingUtf16TextProvider<'a> {
7372
}
7473
}
7574

76-
pub fn highlight_tokens_cover(
75+
// Find start byte of minimal token cover of range
76+
// Returns (cover_start_byte, parent_stack, tree_cursor)
77+
fn find_cover_start<'tree>(
78+
tree: &'tree Tree,
79+
byte_start: usize,
80+
) -> (usize, Vec<(usize, Range<usize>)>, TreeCursor<'tree>) {
81+
let root = tree.root_node();
82+
let mut tree_cursor = root.walk();
83+
let mut parent_stack = Vec::new();
84+
loop {
85+
let node = tree_cursor.node();
86+
parent_stack.push((node.id(), node.start_byte()..node.end_byte()));
87+
if tree_cursor.goto_first_child_for_byte(byte_start).is_none() {
88+
break;
89+
}
90+
}
91+
debug_assert_eq!(
92+
parent_stack.last().map(|(node_id, _)| *node_id),
93+
Some(tree_cursor.node().id())
94+
);
95+
let mut cover_start_byte = tree_cursor.node().start_byte();
96+
while cover_start_byte > byte_start {
97+
// Need to extend cover to the left, but
98+
// there is no node between cover_start and current node
99+
if tree_cursor.goto_previous_sibling() {
100+
let node = tree_cursor.node();
101+
*parent_stack
102+
.last_mut()
103+
.expect("has stack entries if has previous sibling") =
104+
(node.id(), node.start_byte()..node.end_byte());
105+
cover_start_byte = tree_cursor.node().end_byte();
106+
} else if tree_cursor.goto_parent() {
107+
parent_stack.pop();
108+
cover_start_byte = tree_cursor.node().start_byte();
109+
} else {
110+
// start of the file, no nodes before start of range
111+
cover_start_byte = 0;
112+
}
113+
}
114+
debug_assert!(cover_start_byte <= byte_start);
115+
(cover_start_byte, parent_stack, tree_cursor)
116+
}
117+
118+
fn collect_highlights_for_range(
77119
tree: &Tree,
78120
query: &Query,
79121
text: &[u16],
80-
range: Range<usize>,
81-
) -> (usize, Vec<HighlightToken>) {
122+
byte_range: Range<usize>,
123+
) -> HashMap<Range<usize>, (u16, usize)> {
82124
let mut query_cursor = QueryCursor::new();
83-
let byte_start = range.start * 2;
84-
let byte_end = range.end * 2;
85-
query_cursor.set_byte_range(byte_start..byte_end);
86-
let root = tree.root_node();
125+
query_cursor.set_byte_range(byte_range);
87126
let text_provider = RecodingUtf16TextProvider { text };
88-
let mut captures = query_cursor.captures(query, root, text_provider);
127+
let mut captures = query_cursor.captures(query, tree.root_node(), text_provider);
89128
let mut highlights: HashMap<Range<usize>, (u16, usize)> = HashMap::new();
90129
while let Some((next_match, cidx)) = captures.next() {
91130
let capture = next_match.captures[*cidx];
@@ -98,8 +137,29 @@ pub fn highlight_tokens_cover(
98137
}
99138
highlights.insert(range, (capture_id, next_match.pattern_index));
100139
}
140+
highlights
141+
}
142+
143+
pub fn highlight_tokens_cover(
144+
tree: &Tree,
145+
query: &Query,
146+
text: &[u16],
147+
range: Range<usize>,
148+
) -> (usize, Vec<HighlightToken>) {
149+
let (byte_start, parent_stack, mut tree_cursor) = find_cover_start(&tree, range.start * 2);
150+
let byte_end = range.end * 2;
151+
152+
let highlights = collect_highlights_for_range(tree, query, text, byte_start..byte_end);
153+
154+
let mut highlight_stack: Vec<(usize, u16)> = parent_stack
155+
.into_iter()
156+
.filter_map(|(node_id, range)| {
157+
highlights
158+
.get(&range)
159+
.map(|(capture_id, _)| (node_id, *capture_id))
160+
})
161+
.collect();
101162

102-
let mut highlight_stack: Vec<(usize, u16)> = Vec::new();
103163
let mut highlight_tokens: Vec<HighlightToken> = Vec::new();
104164
let token_from_node = |node: Node<'_>, highlight_stack: &[(usize, u16)]| HighlightToken {
105165
kind_id: node.kind_id(),
@@ -110,35 +170,24 @@ pub fn highlight_tokens_cover(
110170
length: ((node.end_byte() - node.start_byte()) / 2) as u32,
111171
};
112172
let token_from_node_subrange =
113-
|node: Node<'_>, range: Range<usize>, highlight_stack: &[(usize, u16)]| HighlightToken {
114-
kind_id: node.kind_id(),
173+
|range: Range<usize>, highlight_stack: &[(usize, u16)]| HighlightToken {
174+
kind_id: u16::MAX,
115175
capture_id: highlight_stack
116176
.last()
117177
.map(|(_, capture_id)| *capture_id)
118178
.unwrap_or(u16::MAX),
119179
length: ((range.end - range.start) / 2) as u32,
120180
};
121-
let mut tree_cursor = root.walk();
122-
loop {
123-
let node_id = tree_cursor.node().id();
124-
let range = tree_cursor.node().start_byte()..tree_cursor.node().end_byte();
125-
if let Some((capture_id, _)) = highlights.get(&range).copied() {
126-
highlight_stack.push((node_id, capture_id));
127-
}
128-
if tree_cursor.goto_first_child_for_byte(byte_start).is_none() {
129-
break;
130-
}
131-
}
132-
let actual_byte_start = tree_cursor.node().start_byte();
133-
let mut byte_current = actual_byte_start;
181+
182+
let mut byte_current = byte_start;
134183
while byte_current < byte_end {
135184
let node = tree_cursor.node();
136185
let node_id = node.id();
186+
debug_assert!(byte_current >= node.start_byte());
137187
if byte_current < node.end_byte() {
138188
if tree_cursor.goto_first_child() {
139189
if tree_cursor.node().start_byte() > byte_current {
140190
highlight_tokens.push(token_from_node_subrange(
141-
node,
142191
byte_current..tree_cursor.node().start_byte(),
143192
&highlight_stack,
144193
));
@@ -155,19 +204,14 @@ pub fn highlight_tokens_cover(
155204
byte_current = node.end_byte();
156205
}
157206
} else {
158-
if let Some((highlight_id, _)) = highlight_stack.last() {
159-
if node_id == *highlight_id {
207+
if let Some((highlight_node_id, _)) = highlight_stack.last() {
208+
if node_id == *highlight_node_id {
160209
highlight_stack.pop();
161210
}
162211
}
163212
if tree_cursor.goto_next_sibling() {
164213
if tree_cursor.node().start_byte() > byte_current {
165-
let parent = tree_cursor
166-
.node()
167-
.parent()
168-
.expect("common parent with a sibling");
169214
highlight_tokens.push(token_from_node_subrange(
170-
parent,
171215
byte_current..tree_cursor.node().start_byte(),
172216
&highlight_stack,
173217
));
@@ -182,7 +226,6 @@ pub fn highlight_tokens_cover(
182226
} else if tree_cursor.goto_parent() {
183227
if tree_cursor.node().end_byte() > byte_current {
184228
highlight_tokens.push(token_from_node_subrange(
185-
tree_cursor.node(),
186229
byte_current..tree_cursor.node().end_byte(),
187230
&highlight_stack,
188231
));
@@ -193,7 +236,7 @@ pub fn highlight_tokens_cover(
193236
}
194237
}
195238
}
196-
(actual_byte_start / 2, highlight_tokens)
239+
(byte_start / 2, highlight_tokens)
197240
}
198241

199242
#[no_mangle]

0 commit comments

Comments
 (0)