Skip to content

fix(completions): complete right columns right after JOIN ON #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 100 additions & 78 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ use pgt_treesitter_queries::{
use crate::sanitization::SanitizedCompletionParams;

#[derive(Debug, PartialEq, Eq)]
pub enum ClauseType {
pub enum WrappingClause<'a> {
Select,
Where,
From,
Join,
Join {
on_node: Option<tree_sitter::Node<'a>>,
},
Update,
Delete,
}
Expand All @@ -24,38 +26,6 @@ pub(crate) enum NodeText<'a> {
Original(&'a str),
}

impl TryFrom<&str> for ClauseType {
type Error = String;

fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"select" => Ok(Self::Select),
"where" => Ok(Self::Where),
"from" => Ok(Self::From),
"update" => Ok(Self::Update),
"delete" => Ok(Self::Delete),
"join" => Ok(Self::Join),
_ => {
let message = format!("Unimplemented ClauseType: {}", value);

// Err on tests, so we notice that we're lacking an implementation immediately.
if cfg!(test) {
panic!("{}", message);
}

Err(message)
}
}
}
}

impl TryFrom<String> for ClauseType {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}

/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
/// That gives us a lot of insight for completions.
/// Other nodes, such as the "relation" node, gives us less but still
Expand Down Expand Up @@ -127,7 +97,7 @@ pub(crate) struct CompletionContext<'a> {
/// on u.id = i.user_id;
/// ```
pub schema_or_alias_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub wrapping_clause_type: Option<WrappingClause<'a>>,

pub wrapping_node_kind: Option<WrappingNode>,

Expand Down Expand Up @@ -266,7 +236,9 @@ impl<'a> CompletionContext<'a> {

match parent_node_kind {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node_kind.try_into().ok();
self.wrapping_clause_type =
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);

self.wrapping_statement_range = Some(parent_node.range());
}
"invocation" => self.is_invocation = true,
Expand All @@ -277,39 +249,21 @@ impl<'a> CompletionContext<'a> {
if self.is_in_error_node {
let mut next_sibling = current_node.next_named_sibling();
while let Some(n) = next_sibling {
if n.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = txt.try_into().ok();
break;
}
_ => {}
}
};
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
self.wrapping_clause_type = Some(clause_type);
break;
} else {
next_sibling = n.next_named_sibling();
}
next_sibling = n.next_named_sibling();
}
let mut prev_sibling = current_node.prev_named_sibling();
while let Some(n) = prev_sibling {
if n.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = txt.try_into().ok();
break;
}
_ => {}
}
};
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
self.wrapping_clause_type = Some(clause_type);
break;
} else {
prev_sibling = n.prev_named_sibling();
}
prev_sibling = n.prev_named_sibling();
}
}

Expand All @@ -330,7 +284,8 @@ impl<'a> CompletionContext<'a> {
}

"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = current_node_kind.try_into().ok();
self.wrapping_clause_type =
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);
}

"relation" | "binary_expression" | "assignment" => {
Expand All @@ -353,12 +308,67 @@ impl<'a> CompletionContext<'a> {
cursor.goto_first_child_for_byte(self.position);
self.gather_context_from_node(cursor, current_node);
}

fn get_wrapping_clause_from_keyword_node(
&self,
node: tree_sitter::Node<'a>,
) -> Option<WrappingClause<'a>> {
if node.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" => return Some(WrappingClause::Where),
"update" => return Some(WrappingClause::Update),
"select" => return Some(WrappingClause::Select),
"delete" => return Some(WrappingClause::Delete),
"from" => return Some(WrappingClause::From),
"join" => {
// TODO: not sure if we can infer it here.
return Some(WrappingClause::Join { on_node: None });
}
_ => {}
}
};
}

None
}

fn get_wrapping_clause_from_current_node(
&self,
node: tree_sitter::Node<'a>,
cursor: &mut tree_sitter::TreeCursor<'a>,
) -> Option<WrappingClause<'a>> {
match node.kind() {
"where" => Some(WrappingClause::Where),
"update" => Some(WrappingClause::Update),
"select" => Some(WrappingClause::Select),
"delete" => Some(WrappingClause::Delete),
"from" => Some(WrappingClause::From),
"join" => {
// sadly, we need to manually iterate over the children –
// `node.child_by_field_id(..)` does not work as expected
Comment on lines +351 to +352
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might wanna fix this upstream at some point

let mut on_node = None;
for child in node.children(cursor) {
// 28 is the id for "keyword_on"
if child.kind_id() == 28 {
on_node = Some(child);
}
}
cursor.goto_parent();
Some(WrappingClause::Join { on_node })
}
_ => None,
}
}
}

#[cfg(test)]
mod tests {
use crate::{
context::{ClauseType, CompletionContext, NodeText},
context::{CompletionContext, NodeText, WrappingClause},
sanitization::SanitizedCompletionParams,
test_helper::{CURSOR_POS, get_text_and_position},
};
Expand All @@ -375,29 +385,41 @@ mod tests {
#[test]
fn identifies_clauses() {
let test_cases = vec![
(format!("Select {}* from users;", CURSOR_POS), "select"),
(format!("Select * from u{};", CURSOR_POS), "from"),
(
format!("Select {}* from users;", CURSOR_POS),
WrappingClause::Select,
),
(
format!("Select * from u{};", CURSOR_POS),
WrappingClause::From,
),
(
format!("Select {}* from users where n = 1;", CURSOR_POS),
"select",
WrappingClause::Select,
),
(
format!("Select * from users where {}n = 1;", CURSOR_POS),
"where",
WrappingClause::Where,
),
(
format!("update users set u{} = 1 where n = 2;", CURSOR_POS),
"update",
WrappingClause::Update,
),
(
format!("update users set u = 1 where n{} = 2;", CURSOR_POS),
"where",
WrappingClause::Where,
),
(
format!("delete{} from users;", CURSOR_POS),
WrappingClause::Delete,
),
(
format!("delete from {}users;", CURSOR_POS),
WrappingClause::From,
),
(format!("delete{} from users;", CURSOR_POS), "delete"),
(format!("delete from {}users;", CURSOR_POS), "from"),
(
format!("select name, age, location from public.u{}sers", CURSOR_POS),
"from",
WrappingClause::From,
),
];

Expand All @@ -415,7 +437,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok());
assert_eq!(ctx.wrapping_clause_type, Some(expected_clause));
}
}

Expand Down Expand Up @@ -518,7 +540,7 @@ mod tests {

assert_eq!(
ctx.wrapping_clause_type,
Some(crate::context::ClauseType::Select)
Some(crate::context::WrappingClause::Select)
);
}
}
Expand Down Expand Up @@ -596,6 +618,6 @@ mod tests {
ctx.get_ts_node_content(node),
Some(NodeText::Original("fro"))
);
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
}
}
53 changes: 53 additions & 0 deletions crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,57 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn completes_in_join_on_clause() {
let setup = r#"
create schema auth;

create table auth.users (
uid serial primary key,
name text not null,
email text unique not null
);

create table auth.posts (
pid serial primary key,
user_id int not null references auth.users(uid),
title text not null,
content text,
created_at timestamp default now()
);
"#;

assert_complete_results(
format!(
"select u.id, auth.posts.content from auth.users u join auth.posts on u.{}",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::KindNotExists(CompletionItemKind::Table),
CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column),
CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column),
CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column),
],
setup,
)
.await;

assert_complete_results(
format!(
"select u.id, p.content from auth.users u join auth.posts p on p.user_id = u.{}",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::KindNotExists(CompletionItemKind::Table),
CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column),
CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column),
CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column),
],
setup,
)
.await;
}
}
27 changes: 16 additions & 11 deletions crates/pgt_completions/src/relevance/filtering.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::context::{ClauseType, CompletionContext, WrappingNode};
use crate::context::{CompletionContext, WrappingClause};

use super::CompletionRelevanceData;

Expand Down Expand Up @@ -50,31 +50,36 @@ impl CompletionFilter<'_> {

fn check_clause(&self, ctx: &CompletionContext) -> Option<()> {
let clause = ctx.wrapping_clause_type.as_ref();
let wrapping_node = ctx.wrapping_node_kind.as_ref();

match self.data {
CompletionRelevanceData::Table(_) => {
let in_select_clause = clause.is_some_and(|c| c == &ClauseType::Select);
let in_where_clause = clause.is_some_and(|c| c == &ClauseType::Where);
let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select);
let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where);

if in_select_clause || in_where_clause {
return None;
};
}
CompletionRelevanceData::Column(_) => {
let in_from_clause = clause.is_some_and(|c| c == &ClauseType::From);
let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From);
if in_from_clause {
return None;
}

// We can complete columns in JOIN cluases, but only if we are in the
// "ON u.id = posts.user_id" part.
let in_join_clause = clause.is_some_and(|c| c == &ClauseType::Join);
// We can complete columns in JOIN cluases, but only if we are after the
// ON node in the "ON u.id = posts.user_id" part.
let in_join_clause_before_on_node = clause.is_some_and(|c| match c {
// we are in a JOIN, but definitely not after an ON
WrappingClause::Join { on_node: None } => true,

let in_comparison_clause =
wrapping_node.is_some_and(|n| n == &WrappingNode::BinaryExpression);
WrappingClause::Join { on_node: Some(on) } => ctx
.node_under_cursor
.is_some_and(|n| n.end_byte() < on.start_byte()),

if in_join_clause && !in_comparison_clause {
_ => false,
});

if in_join_clause_before_on_node {
return None;
}
}
Expand Down
Loading
Loading