@@ -9,11 +9,13 @@ use pgt_treesitter_queries::{
9
9
use crate :: sanitization:: SanitizedCompletionParams ;
10
10
11
11
#[ derive( Debug , PartialEq , Eq ) ]
12
- pub enum ClauseType {
12
+ pub enum WrappingClause < ' a > {
13
13
Select ,
14
14
Where ,
15
15
From ,
16
- Join ,
16
+ Join {
17
+ on_node : Option < tree_sitter:: Node < ' a > > ,
18
+ } ,
17
19
Update ,
18
20
Delete ,
19
21
}
@@ -24,38 +26,6 @@ pub(crate) enum NodeText<'a> {
24
26
Original ( & ' a str ) ,
25
27
}
26
28
27
- impl TryFrom < & str > for ClauseType {
28
- type Error = String ;
29
-
30
- fn try_from ( value : & str ) -> Result < Self , Self :: Error > {
31
- match value {
32
- "select" => Ok ( Self :: Select ) ,
33
- "where" => Ok ( Self :: Where ) ,
34
- "from" => Ok ( Self :: From ) ,
35
- "update" => Ok ( Self :: Update ) ,
36
- "delete" => Ok ( Self :: Delete ) ,
37
- "join" => Ok ( Self :: Join ) ,
38
- _ => {
39
- let message = format ! ( "Unimplemented ClauseType: {}" , value) ;
40
-
41
- // Err on tests, so we notice that we're lacking an implementation immediately.
42
- if cfg ! ( test) {
43
- panic ! ( "{}" , message) ;
44
- }
45
-
46
- Err ( message)
47
- }
48
- }
49
- }
50
- }
51
-
52
- impl TryFrom < String > for ClauseType {
53
- type Error = String ;
54
- fn try_from ( value : String ) -> Result < Self , Self :: Error > {
55
- Self :: try_from ( value. as_str ( ) )
56
- }
57
- }
58
-
59
29
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
60
30
/// That gives us a lot of insight for completions.
61
31
/// Other nodes, such as the "relation" node, gives us less but still
@@ -127,7 +97,7 @@ pub(crate) struct CompletionContext<'a> {
127
97
/// on u.id = i.user_id;
128
98
/// ```
129
99
pub schema_or_alias_name : Option < String > ,
130
- pub wrapping_clause_type : Option < ClauseType > ,
100
+ pub wrapping_clause_type : Option < WrappingClause < ' a > > ,
131
101
132
102
pub wrapping_node_kind : Option < WrappingNode > ,
133
103
@@ -266,7 +236,9 @@ impl<'a> CompletionContext<'a> {
266
236
267
237
match parent_node_kind {
268
238
"statement" | "subquery" => {
269
- self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
239
+ self . wrapping_clause_type =
240
+ self . get_wrapping_clause_from_current_node ( current_node, & mut cursor) ;
241
+
270
242
self . wrapping_statement_range = Some ( parent_node. range ( ) ) ;
271
243
}
272
244
"invocation" => self . is_invocation = true ,
@@ -277,39 +249,21 @@ impl<'a> CompletionContext<'a> {
277
249
if self . is_in_error_node {
278
250
let mut next_sibling = current_node. next_named_sibling ( ) ;
279
251
while let Some ( n) = next_sibling {
280
- if n. kind ( ) . starts_with ( "keyword_" ) {
281
- if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
282
- NodeText :: Original ( txt) => Some ( txt) ,
283
- NodeText :: Replaced => None ,
284
- } ) {
285
- match txt {
286
- "where" | "update" | "select" | "delete" | "from" | "join" => {
287
- self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
288
- break ;
289
- }
290
- _ => { }
291
- }
292
- } ;
252
+ if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
253
+ self . wrapping_clause_type = Some ( clause_type) ;
254
+ break ;
255
+ } else {
256
+ next_sibling = n. next_named_sibling ( ) ;
293
257
}
294
- next_sibling = n. next_named_sibling ( ) ;
295
258
}
296
259
let mut prev_sibling = current_node. prev_named_sibling ( ) ;
297
260
while let Some ( n) = prev_sibling {
298
- if n. kind ( ) . starts_with ( "keyword_" ) {
299
- if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
300
- NodeText :: Original ( txt) => Some ( txt) ,
301
- NodeText :: Replaced => None ,
302
- } ) {
303
- match txt {
304
- "where" | "update" | "select" | "delete" | "from" | "join" => {
305
- self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
306
- break ;
307
- }
308
- _ => { }
309
- }
310
- } ;
261
+ if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
262
+ self . wrapping_clause_type = Some ( clause_type) ;
263
+ break ;
264
+ } else {
265
+ prev_sibling = n. prev_named_sibling ( ) ;
311
266
}
312
- prev_sibling = n. prev_named_sibling ( ) ;
313
267
}
314
268
}
315
269
@@ -330,7 +284,8 @@ impl<'a> CompletionContext<'a> {
330
284
}
331
285
332
286
"where" | "update" | "select" | "delete" | "from" | "join" => {
333
- self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
287
+ self . wrapping_clause_type =
288
+ self . get_wrapping_clause_from_current_node ( current_node, & mut cursor) ;
334
289
}
335
290
336
291
"relation" | "binary_expression" | "assignment" => {
@@ -353,12 +308,67 @@ impl<'a> CompletionContext<'a> {
353
308
cursor. goto_first_child_for_byte ( self . position ) ;
354
309
self . gather_context_from_node ( cursor, current_node) ;
355
310
}
311
+
312
+ fn get_wrapping_clause_from_keyword_node (
313
+ & self ,
314
+ node : tree_sitter:: Node < ' a > ,
315
+ ) -> Option < WrappingClause < ' a > > {
316
+ if node. kind ( ) . starts_with ( "keyword_" ) {
317
+ if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
318
+ NodeText :: Original ( txt) => Some ( txt) ,
319
+ NodeText :: Replaced => None ,
320
+ } ) {
321
+ match txt {
322
+ "where" => return Some ( WrappingClause :: Where ) ,
323
+ "update" => return Some ( WrappingClause :: Update ) ,
324
+ "select" => return Some ( WrappingClause :: Select ) ,
325
+ "delete" => return Some ( WrappingClause :: Delete ) ,
326
+ "from" => return Some ( WrappingClause :: From ) ,
327
+ "join" => {
328
+ // TODO: not sure if we can infer it here.
329
+ return Some ( WrappingClause :: Join { on_node : None } ) ;
330
+ }
331
+ _ => { }
332
+ }
333
+ } ;
334
+ }
335
+
336
+ None
337
+ }
338
+
339
+ fn get_wrapping_clause_from_current_node (
340
+ & self ,
341
+ node : tree_sitter:: Node < ' a > ,
342
+ cursor : & mut tree_sitter:: TreeCursor < ' a > ,
343
+ ) -> Option < WrappingClause < ' a > > {
344
+ return match node. kind ( ) {
345
+ "where" => Some ( WrappingClause :: Where ) ,
346
+ "update" => Some ( WrappingClause :: Update ) ,
347
+ "select" => Some ( WrappingClause :: Select ) ,
348
+ "delete" => Some ( WrappingClause :: Delete ) ,
349
+ "from" => Some ( WrappingClause :: From ) ,
350
+ "join" => {
351
+ // sadly, we need to manually iterate over the children –
352
+ // `node.child_by_field_id(..)` does not work as expected
353
+ let mut on_node = None ;
354
+ for child in node. children ( cursor) {
355
+ // 28 is the id for "keyword_on"
356
+ if child. kind_id ( ) == 28 {
357
+ on_node = Some ( child) ;
358
+ }
359
+ }
360
+ cursor. goto_parent ( ) ;
361
+ Some ( WrappingClause :: Join { on_node } )
362
+ }
363
+ _ => None ,
364
+ } ;
365
+ }
356
366
}
357
367
358
368
#[ cfg( test) ]
359
369
mod tests {
360
370
use crate :: {
361
- context:: { ClauseType , CompletionContext , NodeText } ,
371
+ context:: { CompletionContext , NodeText , WrappingClause } ,
362
372
sanitization:: SanitizedCompletionParams ,
363
373
test_helper:: { CURSOR_POS , get_text_and_position} ,
364
374
} ;
@@ -375,29 +385,41 @@ mod tests {
375
385
#[ test]
376
386
fn identifies_clauses ( ) {
377
387
let test_cases = vec ! [
378
- ( format!( "Select {}* from users;" , CURSOR_POS ) , "select" ) ,
379
- ( format!( "Select * from u{};" , CURSOR_POS ) , "from" ) ,
388
+ (
389
+ format!( "Select {}* from users;" , CURSOR_POS ) ,
390
+ WrappingClause :: Select ,
391
+ ) ,
392
+ (
393
+ format!( "Select * from u{};" , CURSOR_POS ) ,
394
+ WrappingClause :: From ,
395
+ ) ,
380
396
(
381
397
format!( "Select {}* from users where n = 1;" , CURSOR_POS ) ,
382
- "select" ,
398
+ WrappingClause :: Select ,
383
399
) ,
384
400
(
385
401
format!( "Select * from users where {}n = 1;" , CURSOR_POS ) ,
386
- "where" ,
402
+ WrappingClause :: Where ,
387
403
) ,
388
404
(
389
405
format!( "update users set u{} = 1 where n = 2;" , CURSOR_POS ) ,
390
- "update" ,
406
+ WrappingClause :: Update ,
391
407
) ,
392
408
(
393
409
format!( "update users set u = 1 where n{} = 2;" , CURSOR_POS ) ,
394
- "where" ,
410
+ WrappingClause :: Where ,
411
+ ) ,
412
+ (
413
+ format!( "delete{} from users;" , CURSOR_POS ) ,
414
+ WrappingClause :: Delete ,
415
+ ) ,
416
+ (
417
+ format!( "delete from {}users;" , CURSOR_POS ) ,
418
+ WrappingClause :: From ,
395
419
) ,
396
- ( format!( "delete{} from users;" , CURSOR_POS ) , "delete" ) ,
397
- ( format!( "delete from {}users;" , CURSOR_POS ) , "from" ) ,
398
420
(
399
421
format!( "select name, age, location from public.u{}sers" , CURSOR_POS ) ,
400
- "from" ,
422
+ WrappingClause :: From ,
401
423
) ,
402
424
] ;
403
425
@@ -415,7 +437,7 @@ mod tests {
415
437
416
438
let ctx = CompletionContext :: new ( & params) ;
417
439
418
- assert_eq ! ( ctx. wrapping_clause_type, expected_clause . try_into ( ) . ok ( ) ) ;
440
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( expected_clause ) ) ;
419
441
}
420
442
}
421
443
@@ -518,7 +540,7 @@ mod tests {
518
540
519
541
assert_eq ! (
520
542
ctx. wrapping_clause_type,
521
- Some ( crate :: context:: ClauseType :: Select )
543
+ Some ( crate :: context:: WrappingClause :: Select )
522
544
) ;
523
545
}
524
546
}
@@ -596,6 +618,6 @@ mod tests {
596
618
ctx. get_ts_node_content( node) ,
597
619
Some ( NodeText :: Original ( "fro" ) )
598
620
) ;
599
- assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
621
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
600
622
}
601
623
}
0 commit comments