@@ -19,7 +19,7 @@ struct WordWithIndex {
19
19
}
20
20
21
21
impl WordWithIndex {
22
- fn under_cursor ( & self , cursor_pos : usize ) -> bool {
22
+ fn is_under_cursor ( & self , cursor_pos : usize ) -> bool {
23
23
self . start <= cursor_pos && self . end > cursor_pos
24
24
}
25
25
@@ -30,54 +30,58 @@ impl WordWithIndex {
30
30
}
31
31
}
32
32
33
+ /// Note: A policy name within quotation marks will be considered a single word.
33
34
fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
34
35
let mut words = vec ! [ ] ;
35
36
36
- let mut start : Option < usize > = None ;
37
+ let mut start_of_word : Option < usize > = None ;
37
38
let mut current_word = String :: new ( ) ;
38
39
let mut in_quotation_marks = false ;
39
40
40
- for ( pos , c ) in sql. char_indices ( ) {
41
- if ( c . is_ascii_whitespace ( ) || c == ';' )
41
+ for ( current_position , current_char ) in sql. char_indices ( ) {
42
+ if ( current_char . is_ascii_whitespace ( ) || current_char == ';' )
42
43
&& !current_word. is_empty ( )
43
- && start . is_some ( )
44
+ && start_of_word . is_some ( )
44
45
&& !in_quotation_marks
45
46
{
46
47
words. push ( WordWithIndex {
47
48
word : current_word,
48
- start : start . unwrap ( ) ,
49
- end : pos ,
49
+ start : start_of_word . unwrap ( ) ,
50
+ end : current_position ,
50
51
} ) ;
52
+
51
53
current_word = String :: new ( ) ;
52
- start = None ;
53
- } else if ( c. is_ascii_whitespace ( ) || c == ';' ) && current_word. is_empty ( ) {
54
+ start_of_word = None ;
55
+ } else if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
56
+ && current_word. is_empty ( )
57
+ {
54
58
// do nothing
55
- } else if c == '"' && start . is_none ( ) {
59
+ } else if current_char == '"' && start_of_word . is_none ( ) {
56
60
in_quotation_marks = true ;
57
- start = Some ( pos ) ;
58
- current_word . push ( c ) ;
59
- } else if c == '"' && start . is_some ( ) {
60
- current_word. push ( c ) ;
61
+ current_word . push ( current_char ) ;
62
+ start_of_word = Some ( current_position ) ;
63
+ } else if current_char == '"' && start_of_word . is_some ( ) {
64
+ current_word. push ( current_char ) ;
61
65
words. push ( WordWithIndex {
62
66
word : current_word,
63
- start : start . unwrap ( ) ,
64
- end : pos + 1 ,
67
+ start : start_of_word . unwrap ( ) ,
68
+ end : current_position + 1 ,
65
69
} ) ;
66
70
in_quotation_marks = false ;
67
- start = None ;
71
+ start_of_word = None ;
68
72
current_word = String :: new ( )
69
- } else if start . is_some ( ) {
70
- current_word. push ( c )
73
+ } else if start_of_word . is_some ( ) {
74
+ current_word. push ( current_char )
71
75
} else {
72
- start = Some ( pos ) ;
73
- current_word. push ( c ) ;
76
+ start_of_word = Some ( current_position ) ;
77
+ current_word. push ( current_char ) ;
74
78
}
75
79
}
76
80
77
- if !current_word. is_empty ( ) && start . is_some ( ) {
81
+ if !current_word. is_empty ( ) && start_of_word . is_some ( ) {
78
82
words. push ( WordWithIndex {
79
83
word : current_word,
80
- start : start . unwrap ( ) ,
84
+ start : start_of_word . unwrap ( ) ,
81
85
end : sql. len ( ) ,
82
86
} ) ;
83
87
}
@@ -100,6 +104,10 @@ pub(crate) struct PolicyContext {
100
104
pub node_kind : String ,
101
105
}
102
106
107
+ /// Simple parser that'll turn a policy-related statement into a context object required for
108
+ /// completions.
109
+ /// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`.
110
+ /// It can only parse policy statements.
103
111
pub ( crate ) struct PolicyParser {
104
112
tokens : Peekable < std:: vec:: IntoIter < WordWithIndex > > ,
105
113
previous_token : Option < WordWithIndex > ,
@@ -136,7 +144,7 @@ impl PolicyParser {
136
144
137
145
fn parse ( mut self ) -> PolicyContext {
138
146
while let Some ( token) = self . advance ( ) {
139
- if token. under_cursor ( self . cursor_position ) {
147
+ if token. is_under_cursor ( self . cursor_position ) {
140
148
self . handle_token_under_cursor ( token) ;
141
149
} else {
142
150
self . handle_token ( token) ;
@@ -161,9 +169,8 @@ impl PolicyParser {
161
169
}
162
170
"on" => {
163
171
if token. word . contains ( '.' ) {
164
- let mut parts = token . word . split ( '.' ) ;
172
+ let ( schema_name , table_name ) = self . schema_and_table_name ( & token ) ;
165
173
166
- let schema_name: String = parts. next ( ) . unwrap ( ) . into ( ) ;
167
174
let schema_name_len = schema_name. len ( ) ;
168
175
self . context . schema_name = Some ( schema_name) ;
169
176
@@ -176,8 +183,16 @@ impl PolicyParser {
176
183
. expect ( "Text too long" ) ;
177
184
178
185
self . context . node_range = range_without_schema;
179
- self . context . node_text = parts. next ( ) . unwrap ( ) . into ( ) ;
180
186
self . context . node_kind = "policy_table" . into ( ) ;
187
+
188
+ self . context . node_text = match table_name {
189
+ Some ( node_text) => node_text,
190
+
191
+ // In practice, this should never happen.
192
+ // The completion sanitization will add a word after a `.` if nothing follows it;
193
+ // the token_text will then look like `schema.REPLACED_TOKEN`.
194
+ None => String :: new ( ) ,
195
+ } ;
181
196
} else {
182
197
self . context . node_range = token. get_range ( ) ;
183
198
self . context . node_text = token. word ;
@@ -209,7 +224,7 @@ impl PolicyParser {
209
224
}
210
225
"on" => self . table_with_schema ( ) ,
211
226
212
- // skip the "to" so we don't parse it as the TO rolename
227
+ // skip the "to" so we don't parse it as the TO rolename when it's under the cursor
213
228
"rename" if self . next_matches ( "to" ) => {
214
229
self . advance ( ) ;
215
230
}
@@ -231,17 +246,18 @@ impl PolicyParser {
231
246
}
232
247
233
248
fn advance ( & mut self ) -> Option < WordWithIndex > {
249
+ // we can't peek back n an iterator, so we'll have to keep track manually.
234
250
self . previous_token = self . current_token . take ( ) ;
235
251
self . current_token = self . tokens . next ( ) ;
236
252
self . current_token . clone ( )
237
253
}
238
254
239
255
fn table_with_schema ( & mut self ) {
240
256
self . advance ( ) . map ( |token| {
241
- if token. under_cursor ( self . cursor_position ) {
257
+ if token. is_under_cursor ( self . cursor_position ) {
242
258
self . handle_token_under_cursor ( token) ;
243
259
} else if token. word . contains ( '.' ) {
244
- let ( schema, maybe_table) = self . schema_and_table_name ( token) ;
260
+ let ( schema, maybe_table) = self . schema_and_table_name ( & token) ;
245
261
self . context . schema_name = Some ( schema) ;
246
262
self . context . table_name = maybe_table;
247
263
} else {
@@ -250,7 +266,7 @@ impl PolicyParser {
250
266
} ) ;
251
267
}
252
268
253
- fn schema_and_table_name ( & self , token : WordWithIndex ) -> ( String , Option < String > ) {
269
+ fn schema_and_table_name ( & self , token : & WordWithIndex ) -> ( String , Option < String > ) {
254
270
let mut parts = token. word . split ( '.' ) ;
255
271
256
272
(
0 commit comments