Skip to content

Commit 5945c1b

Browse files
fix(completions): complete right columns right after JOIN ON (#390)
1 parent 8abc44d commit 5945c1b

File tree

4 files changed

+197
-110
lines changed

4 files changed

+197
-110
lines changed

crates/pgt_completions/src/context.rs

+100-78
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ use pgt_treesitter_queries::{
99
use crate::sanitization::SanitizedCompletionParams;
1010

1111
#[derive(Debug, PartialEq, Eq)]
12-
pub enum ClauseType {
12+
pub enum WrappingClause<'a> {
1313
Select,
1414
Where,
1515
From,
16-
Join,
16+
Join {
17+
on_node: Option<tree_sitter::Node<'a>>,
18+
},
1719
Update,
1820
Delete,
1921
}
@@ -24,38 +26,6 @@ pub(crate) enum NodeText<'a> {
2426
Original(&'a str),
2527
}
2628

27-
impl TryFrom<&str> for ClauseType {
28-
type Error = String;
29-
30-
fn try_from(value: &str) -> Result<Self, Self::Error> {
31-
match value {
32-
"select" => Ok(Self::Select),
33-
"where" => Ok(Self::Where),
34-
"from" => Ok(Self::From),
35-
"update" => Ok(Self::Update),
36-
"delete" => Ok(Self::Delete),
37-
"join" => Ok(Self::Join),
38-
_ => {
39-
let message = format!("Unimplemented ClauseType: {}", value);
40-
41-
// Err on tests, so we notice that we're lacking an implementation immediately.
42-
if cfg!(test) {
43-
panic!("{}", message);
44-
}
45-
46-
Err(message)
47-
}
48-
}
49-
}
50-
}
51-
52-
impl TryFrom<String> for ClauseType {
53-
type Error = String;
54-
fn try_from(value: String) -> Result<Self, Self::Error> {
55-
Self::try_from(value.as_str())
56-
}
57-
}
58-
5929
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
6030
/// That gives us a lot of insight for completions.
6131
/// Other nodes, such as the "relation" node, gives us less but still
@@ -127,7 +97,7 @@ pub(crate) struct CompletionContext<'a> {
12797
/// on u.id = i.user_id;
12898
/// ```
12999
pub schema_or_alias_name: Option<String>,
130-
pub wrapping_clause_type: Option<ClauseType>,
100+
pub wrapping_clause_type: Option<WrappingClause<'a>>,
131101

132102
pub wrapping_node_kind: Option<WrappingNode>,
133103

@@ -266,7 +236,9 @@ impl<'a> CompletionContext<'a> {
266236

267237
match parent_node_kind {
268238
"statement" | "subquery" => {
269-
self.wrapping_clause_type = current_node_kind.try_into().ok();
239+
self.wrapping_clause_type =
240+
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);
241+
270242
self.wrapping_statement_range = Some(parent_node.range());
271243
}
272244
"invocation" => self.is_invocation = true,
@@ -277,39 +249,21 @@ impl<'a> CompletionContext<'a> {
277249
if self.is_in_error_node {
278250
let mut next_sibling = current_node.next_named_sibling();
279251
while let Some(n) = next_sibling {
280-
if n.kind().starts_with("keyword_") {
281-
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
282-
NodeText::Original(txt) => Some(txt),
283-
NodeText::Replaced => None,
284-
}) {
285-
match txt {
286-
"where" | "update" | "select" | "delete" | "from" | "join" => {
287-
self.wrapping_clause_type = txt.try_into().ok();
288-
break;
289-
}
290-
_ => {}
291-
}
292-
};
252+
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
253+
self.wrapping_clause_type = Some(clause_type);
254+
break;
255+
} else {
256+
next_sibling = n.next_named_sibling();
293257
}
294-
next_sibling = n.next_named_sibling();
295258
}
296259
let mut prev_sibling = current_node.prev_named_sibling();
297260
while let Some(n) = prev_sibling {
298-
if n.kind().starts_with("keyword_") {
299-
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
300-
NodeText::Original(txt) => Some(txt),
301-
NodeText::Replaced => None,
302-
}) {
303-
match txt {
304-
"where" | "update" | "select" | "delete" | "from" | "join" => {
305-
self.wrapping_clause_type = txt.try_into().ok();
306-
break;
307-
}
308-
_ => {}
309-
}
310-
};
261+
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
262+
self.wrapping_clause_type = Some(clause_type);
263+
break;
264+
} else {
265+
prev_sibling = n.prev_named_sibling();
311266
}
312-
prev_sibling = n.prev_named_sibling();
313267
}
314268
}
315269

@@ -330,7 +284,8 @@ impl<'a> CompletionContext<'a> {
330284
}
331285

332286
"where" | "update" | "select" | "delete" | "from" | "join" => {
333-
self.wrapping_clause_type = current_node_kind.try_into().ok();
287+
self.wrapping_clause_type =
288+
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);
334289
}
335290

336291
"relation" | "binary_expression" | "assignment" => {
@@ -353,12 +308,67 @@ impl<'a> CompletionContext<'a> {
353308
cursor.goto_first_child_for_byte(self.position);
354309
self.gather_context_from_node(cursor, current_node);
355310
}
311+
312+
fn get_wrapping_clause_from_keyword_node(
313+
&self,
314+
node: tree_sitter::Node<'a>,
315+
) -> Option<WrappingClause<'a>> {
316+
if node.kind().starts_with("keyword_") {
317+
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
318+
NodeText::Original(txt) => Some(txt),
319+
NodeText::Replaced => None,
320+
}) {
321+
match txt {
322+
"where" => return Some(WrappingClause::Where),
323+
"update" => return Some(WrappingClause::Update),
324+
"select" => return Some(WrappingClause::Select),
325+
"delete" => return Some(WrappingClause::Delete),
326+
"from" => return Some(WrappingClause::From),
327+
"join" => {
328+
// TODO: not sure if we can infer it here.
329+
return Some(WrappingClause::Join { on_node: None });
330+
}
331+
_ => {}
332+
}
333+
};
334+
}
335+
336+
None
337+
}
338+
339+
fn get_wrapping_clause_from_current_node(
340+
&self,
341+
node: tree_sitter::Node<'a>,
342+
cursor: &mut tree_sitter::TreeCursor<'a>,
343+
) -> Option<WrappingClause<'a>> {
344+
match node.kind() {
345+
"where" => Some(WrappingClause::Where),
346+
"update" => Some(WrappingClause::Update),
347+
"select" => Some(WrappingClause::Select),
348+
"delete" => Some(WrappingClause::Delete),
349+
"from" => Some(WrappingClause::From),
350+
"join" => {
351+
// sadly, we need to manually iterate over the children –
352+
// `node.child_by_field_id(..)` does not work as expected
353+
let mut on_node = None;
354+
for child in node.children(cursor) {
355+
// 28 is the id for "keyword_on"
356+
if child.kind_id() == 28 {
357+
on_node = Some(child);
358+
}
359+
}
360+
cursor.goto_parent();
361+
Some(WrappingClause::Join { on_node })
362+
}
363+
_ => None,
364+
}
365+
}
356366
}
357367

358368
#[cfg(test)]
359369
mod tests {
360370
use crate::{
361-
context::{ClauseType, CompletionContext, NodeText},
371+
context::{CompletionContext, NodeText, WrappingClause},
362372
sanitization::SanitizedCompletionParams,
363373
test_helper::{CURSOR_POS, get_text_and_position},
364374
};
@@ -375,29 +385,41 @@ mod tests {
375385
#[test]
376386
fn identifies_clauses() {
377387
let test_cases = vec![
378-
(format!("Select {}* from users;", CURSOR_POS), "select"),
379-
(format!("Select * from u{};", CURSOR_POS), "from"),
388+
(
389+
format!("Select {}* from users;", CURSOR_POS),
390+
WrappingClause::Select,
391+
),
392+
(
393+
format!("Select * from u{};", CURSOR_POS),
394+
WrappingClause::From,
395+
),
380396
(
381397
format!("Select {}* from users where n = 1;", CURSOR_POS),
382-
"select",
398+
WrappingClause::Select,
383399
),
384400
(
385401
format!("Select * from users where {}n = 1;", CURSOR_POS),
386-
"where",
402+
WrappingClause::Where,
387403
),
388404
(
389405
format!("update users set u{} = 1 where n = 2;", CURSOR_POS),
390-
"update",
406+
WrappingClause::Update,
391407
),
392408
(
393409
format!("update users set u = 1 where n{} = 2;", CURSOR_POS),
394-
"where",
410+
WrappingClause::Where,
411+
),
412+
(
413+
format!("delete{} from users;", CURSOR_POS),
414+
WrappingClause::Delete,
415+
),
416+
(
417+
format!("delete from {}users;", CURSOR_POS),
418+
WrappingClause::From,
395419
),
396-
(format!("delete{} from users;", CURSOR_POS), "delete"),
397-
(format!("delete from {}users;", CURSOR_POS), "from"),
398420
(
399421
format!("select name, age, location from public.u{}sers", CURSOR_POS),
400-
"from",
422+
WrappingClause::From,
401423
),
402424
];
403425

@@ -415,7 +437,7 @@ mod tests {
415437

416438
let ctx = CompletionContext::new(&params);
417439

418-
assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok());
440+
assert_eq!(ctx.wrapping_clause_type, Some(expected_clause));
419441
}
420442
}
421443

@@ -518,7 +540,7 @@ mod tests {
518540

519541
assert_eq!(
520542
ctx.wrapping_clause_type,
521-
Some(crate::context::ClauseType::Select)
543+
Some(crate::context::WrappingClause::Select)
522544
);
523545
}
524546
}
@@ -596,6 +618,6 @@ mod tests {
596618
ctx.get_ts_node_content(node),
597619
Some(NodeText::Original("fro"))
598620
);
599-
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
621+
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
600622
}
601623
}

crates/pgt_completions/src/providers/columns.rs

+53
Original file line numberDiff line numberDiff line change
@@ -431,4 +431,57 @@ mod tests {
431431
)
432432
.await;
433433
}
434+
435+
#[tokio::test]
436+
async fn completes_in_join_on_clause() {
437+
let setup = r#"
438+
create schema auth;
439+
440+
create table auth.users (
441+
uid serial primary key,
442+
name text not null,
443+
email text unique not null
444+
);
445+
446+
create table auth.posts (
447+
pid serial primary key,
448+
user_id int not null references auth.users(uid),
449+
title text not null,
450+
content text,
451+
created_at timestamp default now()
452+
);
453+
"#;
454+
455+
assert_complete_results(
456+
format!(
457+
"select u.id, auth.posts.content from auth.users u join auth.posts on u.{}",
458+
CURSOR_POS
459+
)
460+
.as_str(),
461+
vec![
462+
CompletionAssertion::KindNotExists(CompletionItemKind::Table),
463+
CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column),
464+
CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column),
465+
CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column),
466+
],
467+
setup,
468+
)
469+
.await;
470+
471+
assert_complete_results(
472+
format!(
473+
"select u.id, p.content from auth.users u join auth.posts p on p.user_id = u.{}",
474+
CURSOR_POS
475+
)
476+
.as_str(),
477+
vec![
478+
CompletionAssertion::KindNotExists(CompletionItemKind::Table),
479+
CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column),
480+
CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column),
481+
CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column),
482+
],
483+
setup,
484+
)
485+
.await;
486+
}
434487
}

crates/pgt_completions/src/relevance/filtering.rs

+16-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::context::{ClauseType, CompletionContext, WrappingNode};
1+
use crate::context::{CompletionContext, WrappingClause};
22

33
use super::CompletionRelevanceData;
44

@@ -50,31 +50,36 @@ impl CompletionFilter<'_> {
5050

5151
fn check_clause(&self, ctx: &CompletionContext) -> Option<()> {
5252
let clause = ctx.wrapping_clause_type.as_ref();
53-
let wrapping_node = ctx.wrapping_node_kind.as_ref();
5453

5554
match self.data {
5655
CompletionRelevanceData::Table(_) => {
57-
let in_select_clause = clause.is_some_and(|c| c == &ClauseType::Select);
58-
let in_where_clause = clause.is_some_and(|c| c == &ClauseType::Where);
56+
let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select);
57+
let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where);
5958

6059
if in_select_clause || in_where_clause {
6160
return None;
6261
};
6362
}
6463
CompletionRelevanceData::Column(_) => {
65-
let in_from_clause = clause.is_some_and(|c| c == &ClauseType::From);
64+
let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From);
6665
if in_from_clause {
6766
return None;
6867
}
6968

70-
// We can complete columns in JOIN cluases, but only if we are in the
71-
// "ON u.id = posts.user_id" part.
72-
let in_join_clause = clause.is_some_and(|c| c == &ClauseType::Join);
69+
// We can complete columns in JOIN cluases, but only if we are after the
70+
// ON node in the "ON u.id = posts.user_id" part.
71+
let in_join_clause_before_on_node = clause.is_some_and(|c| match c {
72+
// we are in a JOIN, but definitely not after an ON
73+
WrappingClause::Join { on_node: None } => true,
7374

74-
let in_comparison_clause =
75-
wrapping_node.is_some_and(|n| n == &WrappingNode::BinaryExpression);
75+
WrappingClause::Join { on_node: Some(on) } => ctx
76+
.node_under_cursor
77+
.is_some_and(|n| n.end_byte() < on.start_byte()),
7678

77-
if in_join_clause && !in_comparison_clause {
79+
_ => false,
80+
});
81+
82+
if in_join_clause_before_on_node {
7883
return None;
7984
}
8085
}

0 commit comments

Comments
 (0)