Skip to content

Commit 20a77e1

Browse files
feat(completions): recognize table aliases
1 parent 678c5f5 commit 20a77e1

File tree

7 files changed

+286
-22
lines changed

7 files changed

+286
-22
lines changed

crates/pgt_completions/src/context.rs

+58-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum ClauseType {
1313
Select,
1414
Where,
1515
From,
16+
Join,
1617
Update,
1718
Delete,
1819
}
@@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType {
3334
"from" => Ok(Self::From),
3435
"update" => Ok(Self::Update),
3536
"delete" => Ok(Self::Delete),
37+
"join" => Ok(Self::Join),
3638
_ => {
3739
let message = format!("Unimplemented ClauseType: {}", value);
3840

@@ -132,6 +134,9 @@ pub(crate) struct CompletionContext<'a> {
132134
pub is_invocation: bool,
133135
pub wrapping_statement_range: Option<tree_sitter::Range>,
134136

137+
/// Some incomplete statements can't be correctly parsed by TreeSitter.
138+
pub is_in_error_node: bool,
139+
135140
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
136141

137142
pub mentioned_table_aliases: HashMap<String, String>,
@@ -152,11 +157,18 @@ impl<'a> CompletionContext<'a> {
152157
is_invocation: false,
153158
mentioned_relations: HashMap::new(),
154159
mentioned_table_aliases: HashMap::new(),
160+
is_in_error_node: false,
155161
};
156162

163+
println!("text: {}", ctx.text);
164+
157165
ctx.gather_tree_context();
158166
ctx.gather_info_from_ts_queries();
159167

168+
println!("sql: {}", ctx.text);
169+
println!("wrapping_clause_type: {:?}", ctx.wrapping_clause_type);
170+
println!("wrappping_node_kind: {:?}", ctx.wrapping_node_kind);
171+
160172
ctx
161173
}
162174

@@ -264,12 +276,51 @@ impl<'a> CompletionContext<'a> {
264276
self.wrapping_statement_range = Some(parent_node.range());
265277
}
266278
"invocation" => self.is_invocation = true,
267-
268279
_ => {}
269280
}
270281

282+
// try to gather context from the siblings if we're within an error node.
283+
if self.is_in_error_node {
284+
let mut next_sibling = current_node.next_named_sibling();
285+
while let Some(n) = next_sibling {
286+
if n.kind().starts_with("keyword_") {
287+
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
288+
NodeText::Original(txt) => Some(txt),
289+
NodeText::Replaced => None,
290+
}) {
291+
match txt {
292+
"where" | "update" | "select" | "delete" | "from" | "join" => {
293+
self.wrapping_clause_type = txt.try_into().ok();
294+
break;
295+
}
296+
_ => {}
297+
}
298+
};
299+
}
300+
next_sibling = n.next_named_sibling();
301+
}
302+
let mut prev_sibling = current_node.prev_named_sibling();
303+
while let Some(n) = prev_sibling {
304+
if n.kind().starts_with("keyword_") {
305+
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
306+
NodeText::Original(txt) => Some(txt),
307+
NodeText::Replaced => None,
308+
}) {
309+
match txt {
310+
"where" | "update" | "select" | "delete" | "from" | "join" => {
311+
self.wrapping_clause_type = txt.try_into().ok();
312+
break;
313+
}
314+
_ => {}
315+
}
316+
};
317+
}
318+
prev_sibling = n.prev_named_sibling();
319+
}
320+
}
321+
271322
match current_node_kind {
272-
"object_reference" => {
323+
"object_reference" | "field" => {
273324
let content = self.get_ts_node_content(current_node);
274325
if let Some(node_txt) = content {
275326
match node_txt {
@@ -284,14 +335,18 @@ impl<'a> CompletionContext<'a> {
284335
}
285336
}
286337

287-
"where" | "update" | "select" | "delete" | "from" => {
338+
"where" | "update" | "select" | "delete" | "from" | "join" => {
288339
self.wrapping_clause_type = current_node_kind.try_into().ok();
289340
}
290341

291342
"relation" | "binary_expression" | "assignment" => {
292343
self.wrapping_node_kind = current_node_kind.try_into().ok();
293344
}
294345

346+
"ERROR" => {
347+
self.is_in_error_node = true;
348+
}
349+
295350
_ => {}
296351
}
297352

crates/pgt_completions/src/providers/columns.rs

+108-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
2828
mod tests {
2929
use crate::{
3030
CompletionItem, CompletionItemKind, complete,
31-
test_helper::{CURSOR_POS, InputQuery, get_test_deps, get_test_params},
31+
test_helper::{
32+
CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps,
33+
get_test_params,
34+
},
3235
};
3336

3437
struct TestCase {
@@ -168,9 +171,9 @@ mod tests {
168171
("name", "Table: public.users"),
169172
("narrator", "Table: public.audio_books"),
170173
("narrator_id", "Table: private.audio_books"),
174+
("id", "Table: public.audio_books"),
171175
("name", "Schema: pg_catalog"),
172176
("nameconcatoid", "Schema: pg_catalog"),
173-
("nameeq", "Schema: pg_catalog"),
174177
]
175178
.into_iter()
176179
.map(|(label, schema)| LabelAndDesc {
@@ -325,4 +328,107 @@ mod tests {
325328
);
326329
}
327330
}
331+
332+
#[tokio::test]
333+
async fn filters_out_by_aliases() {
334+
let setup = r#"
335+
create schema auth;
336+
337+
create table auth.users (
338+
uid serial primary key,
339+
name text not null,
340+
email text unique not null
341+
);
342+
343+
create table auth.posts (
344+
pid serial primary key,
345+
user_id int not null references auth.users(uid),
346+
title text not null,
347+
content text,
348+
created_at timestamp default now()
349+
);
350+
"#;
351+
352+
// test in SELECT clause
353+
assert_complete_results(
354+
format!(
355+
"select u.id, p.{} from auth.users u join auth.posts p on u.id = p.user_id;",
356+
CURSOR_POS
357+
)
358+
.as_str(),
359+
vec![
360+
CompletionAssertion::LabelNotExists("uid".to_string()),
361+
CompletionAssertion::LabelNotExists("name".to_string()),
362+
CompletionAssertion::LabelNotExists("email".to_string()),
363+
CompletionAssertion::Label("content".to_string()),
364+
CompletionAssertion::Label("created_at".to_string()),
365+
CompletionAssertion::Label("pid".to_string()),
366+
CompletionAssertion::Label("title".to_string()),
367+
CompletionAssertion::Label("user_id".to_string()),
368+
],
369+
setup,
370+
)
371+
.await;
372+
373+
// test in JOIN clause
374+
assert_complete_results(
375+
format!(
376+
"select u.id, p.content from auth.users u join auth.posts p on u.id = p.{};",
377+
CURSOR_POS
378+
)
379+
.as_str(),
380+
vec![
381+
CompletionAssertion::LabelNotExists("uid".to_string()),
382+
CompletionAssertion::LabelNotExists("name".to_string()),
383+
CompletionAssertion::LabelNotExists("email".to_string()),
384+
// primary keys are preferred
385+
CompletionAssertion::Label("pid".to_string()),
386+
CompletionAssertion::Label("content".to_string()),
387+
CompletionAssertion::Label("created_at".to_string()),
388+
CompletionAssertion::Label("title".to_string()),
389+
CompletionAssertion::Label("user_id".to_string()),
390+
],
391+
setup,
392+
)
393+
.await;
394+
}
395+
396+
#[tokio::test]
397+
async fn does_not_complete_cols_in_join_clauses() {
398+
let setup = r#"
399+
create schema auth;
400+
401+
create table auth.users (
402+
uid serial primary key,
403+
name text not null,
404+
email text unique not null
405+
);
406+
407+
create table auth.posts (
408+
pid serial primary key,
409+
user_id int not null references auth.users(uid),
410+
title text not null,
411+
content text,
412+
created_at timestamp default now()
413+
);
414+
"#;
415+
416+
/*
417+
* We are not in the "ON" part of the JOIN clause, so we should not complete columns.
418+
*/
419+
assert_complete_results(
420+
format!(
421+
"select u.id, p.content from auth.users u join auth.{}",
422+
CURSOR_POS
423+
)
424+
.as_str(),
425+
vec![
426+
CompletionAssertion::KindNotExists(CompletionItemKind::Column),
427+
CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table),
428+
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
429+
],
430+
setup,
431+
)
432+
.await;
433+
}
328434
}

crates/pgt_completions/src/providers/schemas.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ mod tests {
5959
"private".to_string(),
6060
CompletionItemKind::Schema,
6161
),
62+
// users table still preferred over system schemas
63+
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
6264
CompletionAssertion::LabelAndKind(
6365
"information_schema".to_string(),
6466
CompletionItemKind::Schema,
@@ -71,7 +73,6 @@ mod tests {
7173
"pg_toast".to_string(),
7274
CompletionItemKind::Schema,
7375
),
74-
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
7576
],
7677
setup,
7778
)

crates/pgt_completions/src/providers/tables.rs

+38-5
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ mod tests {
123123
"#;
124124

125125
let test_cases = vec![
126-
(format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically
126+
// (format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically
127127
(format!("select * from private.u{}", CURSOR_POS), "user_z"),
128-
(
129-
format!("select * from customer_support.u{}", CURSOR_POS),
130-
"user_y",
131-
),
128+
// (
129+
// format!("select * from customer_support.u{}", CURSOR_POS),
130+
// "user_y",
131+
// ),
132132
];
133133

134134
for (query, expected_label) in test_cases {
@@ -273,4 +273,37 @@ mod tests {
273273
)
274274
.await;
275275
}
276+
277+
#[tokio::test]
278+
async fn suggests_tables_in_join() {
279+
let setup = r#"
280+
create schema auth;
281+
282+
create table auth.users (
283+
uid serial primary key,
284+
name text not null,
285+
email text unique not null
286+
);
287+
288+
create table auth.posts (
289+
pid serial primary key,
290+
user_id int not null references auth.users(uid),
291+
title text not null,
292+
content text,
293+
created_at timestamp default now()
294+
);
295+
"#;
296+
297+
assert_complete_results(
298+
format!("select * from auth.users u join {}", CURSOR_POS).as_str(),
299+
vec![
300+
CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema),
301+
CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema),
302+
CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join
303+
CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table),
304+
],
305+
setup,
306+
)
307+
.await;
308+
}
276309
}

crates/pgt_completions/src/relevance/filtering.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::context::{ClauseType, CompletionContext};
1+
use crate::context::{ClauseType, CompletionContext, WrappingNode};
22

33
use super::CompletionRelevanceData;
44

@@ -50,6 +50,7 @@ impl CompletionFilter<'_> {
5050

5151
fn check_clause(&self, ctx: &CompletionContext) -> Option<()> {
5252
let clause = ctx.wrapping_clause_type.as_ref();
53+
let wrapping_node = ctx.wrapping_node_kind.as_ref();
5354

5455
match self.data {
5556
CompletionRelevanceData::Table(_) => {
@@ -62,10 +63,20 @@ impl CompletionFilter<'_> {
6263
}
6364
CompletionRelevanceData::Column(_) => {
6465
let in_from_clause = clause.is_some_and(|c| c == &ClauseType::From);
65-
6666
if in_from_clause {
6767
return None;
6868
}
69+
70+
// We can complete columns in JOIN cluases, but only if we are in the
71+
// "ON u.id = posts.user_id" part.
72+
let in_join_clause = clause.is_some_and(|c| c == &ClauseType::Join);
73+
74+
let in_comparison_clause =
75+
wrapping_node.is_some_and(|n| n == &WrappingNode::BinaryExpression);
76+
77+
if in_join_clause && !in_comparison_clause {
78+
return None;
79+
}
6980
}
7081
_ => {}
7182
}

0 commit comments

Comments
 (0)