@@ -13,6 +13,7 @@ pub enum ClauseType {
13
13
Select ,
14
14
Where ,
15
15
From ,
16
+ Join ,
16
17
Update ,
17
18
Delete ,
18
19
}
@@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType {
33
34
"from" => Ok ( Self :: From ) ,
34
35
"update" => Ok ( Self :: Update ) ,
35
36
"delete" => Ok ( Self :: Delete ) ,
37
+ "join" => Ok ( Self :: Join ) ,
36
38
_ => {
37
39
let message = format ! ( "Unimplemented ClauseType: {}" , value) ;
38
40
@@ -106,14 +108,35 @@ pub(crate) struct CompletionContext<'a> {
106
108
pub schema_cache : & ' a SchemaCache ,
107
109
pub position : usize ,
108
110
109
- pub schema_name : Option < String > ,
111
+ /// If the cursor is on a node that uses dot notation
112
+ /// to specify an alias or schema, this will hold the schema's or
113
+ /// alias's name.
114
+ ///
115
+ /// Here, `auth` is a schema name:
116
+ /// ```sql
117
+ /// select * from auth.users;
118
+ /// ```
119
+ ///
120
+ /// Here, `u` is an alias name:
121
+ /// ```sql
122
+ /// select
123
+ /// *
124
+ /// from
125
+ /// auth.users u
126
+ /// left join identities i
127
+ /// on u.id = i.user_id;
128
+ /// ```
129
+ pub schema_or_alias_name : Option < String > ,
110
130
pub wrapping_clause_type : Option < ClauseType > ,
111
131
112
132
pub wrapping_node_kind : Option < WrappingNode > ,
113
133
114
134
pub is_invocation : bool ,
115
135
pub wrapping_statement_range : Option < tree_sitter:: Range > ,
116
136
137
+ /// Some incomplete statements can't be correctly parsed by TreeSitter.
138
+ pub is_in_error_node : bool ,
139
+
117
140
pub mentioned_relations : HashMap < Option < String > , HashSet < String > > ,
118
141
119
142
pub mentioned_table_aliases : HashMap < String , String > ,
@@ -127,13 +150,14 @@ impl<'a> CompletionContext<'a> {
127
150
schema_cache : params. schema ,
128
151
position : usize:: from ( params. position ) ,
129
152
node_under_cursor : None ,
130
- schema_name : None ,
153
+ schema_or_alias_name : None ,
131
154
wrapping_clause_type : None ,
132
155
wrapping_node_kind : None ,
133
156
wrapping_statement_range : None ,
134
157
is_invocation : false ,
135
158
mentioned_relations : HashMap :: new ( ) ,
136
159
mentioned_table_aliases : HashMap :: new ( ) ,
160
+ is_in_error_node : false ,
137
161
} ;
138
162
139
163
ctx. gather_tree_context ( ) ;
@@ -246,34 +270,77 @@ impl<'a> CompletionContext<'a> {
246
270
self . wrapping_statement_range = Some ( parent_node. range ( ) ) ;
247
271
}
248
272
"invocation" => self . is_invocation = true ,
249
-
250
273
_ => { }
251
274
}
252
275
276
+ // try to gather context from the siblings if we're within an error node.
277
+ if self . is_in_error_node {
278
+ let mut next_sibling = current_node. next_named_sibling ( ) ;
279
+ while let Some ( n) = next_sibling {
280
+ if n. kind ( ) . starts_with ( "keyword_" ) {
281
+ if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
282
+ NodeText :: Original ( txt) => Some ( txt) ,
283
+ NodeText :: Replaced => None ,
284
+ } ) {
285
+ match txt {
286
+ "where" | "update" | "select" | "delete" | "from" | "join" => {
287
+ self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
288
+ break ;
289
+ }
290
+ _ => { }
291
+ }
292
+ } ;
293
+ }
294
+ next_sibling = n. next_named_sibling ( ) ;
295
+ }
296
+ let mut prev_sibling = current_node. prev_named_sibling ( ) ;
297
+ while let Some ( n) = prev_sibling {
298
+ if n. kind ( ) . starts_with ( "keyword_" ) {
299
+ if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
300
+ NodeText :: Original ( txt) => Some ( txt) ,
301
+ NodeText :: Replaced => None ,
302
+ } ) {
303
+ match txt {
304
+ "where" | "update" | "select" | "delete" | "from" | "join" => {
305
+ self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
306
+ break ;
307
+ }
308
+ _ => { }
309
+ }
310
+ } ;
311
+ }
312
+ prev_sibling = n. prev_named_sibling ( ) ;
313
+ }
314
+ }
315
+
253
316
match current_node_kind {
254
- "object_reference" => {
317
+ "object_reference" | "field" => {
255
318
let content = self . get_ts_node_content ( current_node) ;
256
319
if let Some ( node_txt) = content {
257
320
match node_txt {
258
321
NodeText :: Original ( txt) => {
259
322
let parts: Vec < & str > = txt. split ( '.' ) . collect ( ) ;
260
323
if parts. len ( ) == 2 {
261
- self . schema_name = Some ( parts[ 0 ] . to_string ( ) ) ;
324
+ self . schema_or_alias_name = Some ( parts[ 0 ] . to_string ( ) ) ;
262
325
}
263
326
}
264
327
NodeText :: Replaced => { }
265
328
}
266
329
}
267
330
}
268
331
269
- "where" | "update" | "select" | "delete" | "from" => {
332
+ "where" | "update" | "select" | "delete" | "from" | "join" => {
270
333
self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
271
334
}
272
335
273
336
"relation" | "binary_expression" | "assignment" => {
274
337
self . wrapping_node_kind = current_node_kind. try_into ( ) . ok ( ) ;
275
338
}
276
339
340
+ "ERROR" => {
341
+ self . is_in_error_node = true ;
342
+ }
343
+
277
344
_ => { }
278
345
}
279
346
@@ -380,7 +447,10 @@ mod tests {
380
447
381
448
let ctx = CompletionContext :: new ( & params) ;
382
449
383
- assert_eq ! ( ctx. schema_name, expected_schema. map( |f| f. to_string( ) ) ) ;
450
+ assert_eq ! (
451
+ ctx. schema_or_alias_name,
452
+ expected_schema. map( |f| f. to_string( ) )
453
+ ) ;
384
454
}
385
455
}
386
456
0 commit comments