Skip to content

Commit 3b7971d

Browse files
feat(completions): respect table aliases, complete in JOINs (#388)
* add comment, rename * add filtering * feat(completions): recognize table aliases * non exhaustive * ok
1 parent 9cd19b0 commit 3b7971d

File tree

8 files changed

+318
-38
lines changed

8 files changed

+318
-38
lines changed

crates/pgt_completions/src/context.rs

+77-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum ClauseType {
1313
Select,
1414
Where,
1515
From,
16+
Join,
1617
Update,
1718
Delete,
1819
}
@@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType {
3334
"from" => Ok(Self::From),
3435
"update" => Ok(Self::Update),
3536
"delete" => Ok(Self::Delete),
37+
"join" => Ok(Self::Join),
3638
_ => {
3739
let message = format!("Unimplemented ClauseType: {}", value);
3840

@@ -106,14 +108,35 @@ pub(crate) struct CompletionContext<'a> {
106108
pub schema_cache: &'a SchemaCache,
107109
pub position: usize,
108110

109-
pub schema_name: Option<String>,
111+
/// If the cursor is on a node that uses dot notation
112+
/// to specify an alias or schema, this will hold the schema's or
113+
/// alias's name.
114+
///
115+
/// Here, `auth` is a schema name:
116+
/// ```sql
117+
/// select * from auth.users;
118+
/// ```
119+
///
120+
/// Here, `u` is an alias name:
121+
/// ```sql
122+
/// select
123+
/// *
124+
/// from
125+
/// auth.users u
126+
/// left join identities i
127+
/// on u.id = i.user_id;
128+
/// ```
129+
pub schema_or_alias_name: Option<String>,
110130
pub wrapping_clause_type: Option<ClauseType>,
111131

112132
pub wrapping_node_kind: Option<WrappingNode>,
113133

114134
pub is_invocation: bool,
115135
pub wrapping_statement_range: Option<tree_sitter::Range>,
116136

137+
/// Some incomplete statements can't be correctly parsed by TreeSitter.
138+
pub is_in_error_node: bool,
139+
117140
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
118141

119142
pub mentioned_table_aliases: HashMap<String, String>,
@@ -127,13 +150,14 @@ impl<'a> CompletionContext<'a> {
127150
schema_cache: params.schema,
128151
position: usize::from(params.position),
129152
node_under_cursor: None,
130-
schema_name: None,
153+
schema_or_alias_name: None,
131154
wrapping_clause_type: None,
132155
wrapping_node_kind: None,
133156
wrapping_statement_range: None,
134157
is_invocation: false,
135158
mentioned_relations: HashMap::new(),
136159
mentioned_table_aliases: HashMap::new(),
160+
is_in_error_node: false,
137161
};
138162

139163
ctx.gather_tree_context();
@@ -246,34 +270,77 @@ impl<'a> CompletionContext<'a> {
246270
self.wrapping_statement_range = Some(parent_node.range());
247271
}
248272
"invocation" => self.is_invocation = true,
249-
250273
_ => {}
251274
}
252275

276+
// try to gather context from the siblings if we're within an error node.
277+
if self.is_in_error_node {
278+
let mut next_sibling = current_node.next_named_sibling();
279+
while let Some(n) = next_sibling {
280+
if n.kind().starts_with("keyword_") {
281+
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
282+
NodeText::Original(txt) => Some(txt),
283+
NodeText::Replaced => None,
284+
}) {
285+
match txt {
286+
"where" | "update" | "select" | "delete" | "from" | "join" => {
287+
self.wrapping_clause_type = txt.try_into().ok();
288+
break;
289+
}
290+
_ => {}
291+
}
292+
};
293+
}
294+
next_sibling = n.next_named_sibling();
295+
}
296+
let mut prev_sibling = current_node.prev_named_sibling();
297+
while let Some(n) = prev_sibling {
298+
if n.kind().starts_with("keyword_") {
299+
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
300+
NodeText::Original(txt) => Some(txt),
301+
NodeText::Replaced => None,
302+
}) {
303+
match txt {
304+
"where" | "update" | "select" | "delete" | "from" | "join" => {
305+
self.wrapping_clause_type = txt.try_into().ok();
306+
break;
307+
}
308+
_ => {}
309+
}
310+
};
311+
}
312+
prev_sibling = n.prev_named_sibling();
313+
}
314+
}
315+
253316
match current_node_kind {
254-
"object_reference" => {
317+
"object_reference" | "field" => {
255318
let content = self.get_ts_node_content(current_node);
256319
if let Some(node_txt) = content {
257320
match node_txt {
258321
NodeText::Original(txt) => {
259322
let parts: Vec<&str> = txt.split('.').collect();
260323
if parts.len() == 2 {
261-
self.schema_name = Some(parts[0].to_string());
324+
self.schema_or_alias_name = Some(parts[0].to_string());
262325
}
263326
}
264327
NodeText::Replaced => {}
265328
}
266329
}
267330
}
268331

269-
"where" | "update" | "select" | "delete" | "from" => {
332+
"where" | "update" | "select" | "delete" | "from" | "join" => {
270333
self.wrapping_clause_type = current_node_kind.try_into().ok();
271334
}
272335

273336
"relation" | "binary_expression" | "assignment" => {
274337
self.wrapping_node_kind = current_node_kind.try_into().ok();
275338
}
276339

340+
"ERROR" => {
341+
self.is_in_error_node = true;
342+
}
343+
277344
_ => {}
278345
}
279346

@@ -380,7 +447,10 @@ mod tests {
380447

381448
let ctx = CompletionContext::new(&params);
382449

383-
assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string()));
450+
assert_eq!(
451+
ctx.schema_or_alias_name,
452+
expected_schema.map(|f| f.to_string())
453+
);
384454
}
385455
}
386456

crates/pgt_completions/src/providers/columns.rs

+108-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
2828
mod tests {
2929
use crate::{
3030
CompletionItem, CompletionItemKind, complete,
31-
test_helper::{CURSOR_POS, InputQuery, get_test_deps, get_test_params},
31+
test_helper::{
32+
CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps,
33+
get_test_params,
34+
},
3235
};
3336

3437
struct TestCase {
@@ -168,9 +171,9 @@ mod tests {
168171
("name", "Table: public.users"),
169172
("narrator", "Table: public.audio_books"),
170173
("narrator_id", "Table: private.audio_books"),
174+
("id", "Table: public.audio_books"),
171175
("name", "Schema: pg_catalog"),
172176
("nameconcatoid", "Schema: pg_catalog"),
173-
("nameeq", "Schema: pg_catalog"),
174177
]
175178
.into_iter()
176179
.map(|(label, schema)| LabelAndDesc {
@@ -325,4 +328,107 @@ mod tests {
325328
);
326329
}
327330
}
331+
332+
#[tokio::test]
333+
async fn filters_out_by_aliases() {
334+
let setup = r#"
335+
create schema auth;
336+
337+
create table auth.users (
338+
uid serial primary key,
339+
name text not null,
340+
email text unique not null
341+
);
342+
343+
create table auth.posts (
344+
pid serial primary key,
345+
user_id int not null references auth.users(uid),
346+
title text not null,
347+
content text,
348+
created_at timestamp default now()
349+
);
350+
"#;
351+
352+
// test in SELECT clause
353+
assert_complete_results(
354+
format!(
355+
"select u.id, p.{} from auth.users u join auth.posts p on u.id = p.user_id;",
356+
CURSOR_POS
357+
)
358+
.as_str(),
359+
vec![
360+
CompletionAssertion::LabelNotExists("uid".to_string()),
361+
CompletionAssertion::LabelNotExists("name".to_string()),
362+
CompletionAssertion::LabelNotExists("email".to_string()),
363+
CompletionAssertion::Label("content".to_string()),
364+
CompletionAssertion::Label("created_at".to_string()),
365+
CompletionAssertion::Label("pid".to_string()),
366+
CompletionAssertion::Label("title".to_string()),
367+
CompletionAssertion::Label("user_id".to_string()),
368+
],
369+
setup,
370+
)
371+
.await;
372+
373+
// test in JOIN clause
374+
assert_complete_results(
375+
format!(
376+
"select u.id, p.content from auth.users u join auth.posts p on u.id = p.{};",
377+
CURSOR_POS
378+
)
379+
.as_str(),
380+
vec![
381+
CompletionAssertion::LabelNotExists("uid".to_string()),
382+
CompletionAssertion::LabelNotExists("name".to_string()),
383+
CompletionAssertion::LabelNotExists("email".to_string()),
384+
// primary keys are preferred
385+
CompletionAssertion::Label("pid".to_string()),
386+
CompletionAssertion::Label("content".to_string()),
387+
CompletionAssertion::Label("created_at".to_string()),
388+
CompletionAssertion::Label("title".to_string()),
389+
CompletionAssertion::Label("user_id".to_string()),
390+
],
391+
setup,
392+
)
393+
.await;
394+
}
395+
396+
#[tokio::test]
397+
async fn does_not_complete_cols_in_join_clauses() {
398+
let setup = r#"
399+
create schema auth;
400+
401+
create table auth.users (
402+
uid serial primary key,
403+
name text not null,
404+
email text unique not null
405+
);
406+
407+
create table auth.posts (
408+
pid serial primary key,
409+
user_id int not null references auth.users(uid),
410+
title text not null,
411+
content text,
412+
created_at timestamp default now()
413+
);
414+
"#;
415+
416+
/*
417+
* We are not in the "ON" part of the JOIN clause, so we should not complete columns.
418+
*/
419+
assert_complete_results(
420+
format!(
421+
"select u.id, p.content from auth.users u join auth.{}",
422+
CURSOR_POS
423+
)
424+
.as_str(),
425+
vec![
426+
CompletionAssertion::KindNotExists(CompletionItemKind::Column),
427+
CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table),
428+
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
429+
],
430+
setup,
431+
)
432+
.await;
433+
}
328434
}

crates/pgt_completions/src/providers/helper.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub(crate) fn get_completion_text_with_schema(
77
item_name: &str,
88
item_schema_name: &str,
99
) -> Option<CompletionText> {
10-
if item_schema_name == "public" || ctx.schema_name.is_some() {
10+
if item_schema_name == "public" || ctx.schema_or_alias_name.is_some() {
1111
None
1212
} else {
1313
let node = ctx.node_under_cursor.unwrap();

crates/pgt_completions/src/providers/schemas.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ mod tests {
5959
"private".to_string(),
6060
CompletionItemKind::Schema,
6161
),
62+
// users table still preferred over system schemas
63+
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
6264
CompletionAssertion::LabelAndKind(
6365
"information_schema".to_string(),
6466
CompletionItemKind::Schema,
@@ -71,7 +73,6 @@ mod tests {
7173
"pg_toast".to_string(),
7274
CompletionItemKind::Schema,
7375
),
74-
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
7576
],
7677
setup,
7778
)

crates/pgt_completions/src/providers/tables.rs

+33
Original file line numberDiff line numberDiff line change
@@ -273,4 +273,37 @@ mod tests {
273273
)
274274
.await;
275275
}
276+
277+
#[tokio::test]
278+
async fn suggests_tables_in_join() {
279+
let setup = r#"
280+
create schema auth;
281+
282+
create table auth.users (
283+
uid serial primary key,
284+
name text not null,
285+
email text unique not null
286+
);
287+
288+
create table auth.posts (
289+
pid serial primary key,
290+
user_id int not null references auth.users(uid),
291+
title text not null,
292+
content text,
293+
created_at timestamp default now()
294+
);
295+
"#;
296+
297+
assert_complete_results(
298+
format!("select * from auth.users u join {}", CURSOR_POS).as_str(),
299+
vec![
300+
CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema),
301+
CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema),
302+
CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join
303+
CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table),
304+
],
305+
setup,
306+
)
307+
.await;
308+
}
276309
}

0 commit comments

Comments
 (0)