Skip to content

feat(completions): respect table aliases, complete in JOINs #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum ClauseType {
Select,
Where,
From,
Join,
Update,
Delete,
}
Expand All @@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType {
"from" => Ok(Self::From),
"update" => Ok(Self::Update),
"delete" => Ok(Self::Delete),
"join" => Ok(Self::Join),
_ => {
let message = format!("Unimplemented ClauseType: {}", value);

Expand Down Expand Up @@ -106,14 +108,35 @@ pub(crate) struct CompletionContext<'a> {
pub schema_cache: &'a SchemaCache,
pub position: usize,

pub schema_name: Option<String>,
/// If the cursor is on a node that uses dot notation
/// to specify an alias or schema, this will hold the schema's or
/// alias's name.
///
/// Here, `auth` is a schema name:
/// ```sql
/// select * from auth.users;
/// ```
///
/// Here, `u` is an alias name:
/// ```sql
/// select
/// *
/// from
/// auth.users u
/// left join identities i
/// on u.id = i.user_id;
/// ```
pub schema_or_alias_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,

pub wrapping_node_kind: Option<WrappingNode>,

pub is_invocation: bool,
pub wrapping_statement_range: Option<tree_sitter::Range>,

/// Some incomplete statements can't be correctly parsed by TreeSitter.
pub is_in_error_node: bool,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,

pub mentioned_table_aliases: HashMap<String, String>,
Expand All @@ -127,13 +150,14 @@ impl<'a> CompletionContext<'a> {
schema_cache: params.schema,
position: usize::from(params.position),
node_under_cursor: None,
schema_name: None,
schema_or_alias_name: None,
wrapping_clause_type: None,
wrapping_node_kind: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
mentioned_table_aliases: HashMap::new(),
is_in_error_node: false,
};

ctx.gather_tree_context();
Expand Down Expand Up @@ -246,34 +270,77 @@ impl<'a> CompletionContext<'a> {
self.wrapping_statement_range = Some(parent_node.range());
}
"invocation" => self.is_invocation = true,

_ => {}
}

// try to gather context from the siblings if we're within an error node.
if self.is_in_error_node {
let mut next_sibling = current_node.next_named_sibling();
while let Some(n) = next_sibling {
if n.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = txt.try_into().ok();
break;
}
_ => {}
}
};
}
next_sibling = n.next_named_sibling();
}
let mut prev_sibling = current_node.prev_named_sibling();
while let Some(n) = prev_sibling {
if n.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = txt.try_into().ok();
break;
}
_ => {}
}
};
}
prev_sibling = n.prev_named_sibling();
}
}

match current_node_kind {
"object_reference" => {
"object_reference" | "field" => {
let content = self.get_ts_node_content(current_node);
if let Some(node_txt) = content {
match node_txt {
NodeText::Original(txt) => {
let parts: Vec<&str> = txt.split('.').collect();
if parts.len() == 2 {
self.schema_name = Some(parts[0].to_string());
self.schema_or_alias_name = Some(parts[0].to_string());
}
}
NodeText::Replaced => {}
}
}
}

"where" | "update" | "select" | "delete" | "from" => {
"where" | "update" | "select" | "delete" | "from" | "join" => {
self.wrapping_clause_type = current_node_kind.try_into().ok();
}

"relation" | "binary_expression" | "assignment" => {
self.wrapping_node_kind = current_node_kind.try_into().ok();
}

"ERROR" => {
self.is_in_error_node = true;
}

_ => {}
}

Expand Down Expand Up @@ -380,7 +447,10 @@ mod tests {

let ctx = CompletionContext::new(&params);

assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string()));
assert_eq!(
ctx.schema_or_alias_name,
expected_schema.map(|f| f.to_string())
);
}
}

Expand Down
110 changes: 108 additions & 2 deletions crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
mod tests {
use crate::{
CompletionItem, CompletionItemKind, complete,
test_helper::{CURSOR_POS, InputQuery, get_test_deps, get_test_params},
test_helper::{
CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps,
get_test_params,
},
};

struct TestCase {
Expand Down Expand Up @@ -168,9 +171,9 @@ mod tests {
("name", "Table: public.users"),
("narrator", "Table: public.audio_books"),
("narrator_id", "Table: private.audio_books"),
("id", "Table: public.audio_books"),
("name", "Schema: pg_catalog"),
("nameconcatoid", "Schema: pg_catalog"),
("nameeq", "Schema: pg_catalog"),
]
.into_iter()
.map(|(label, schema)| LabelAndDesc {
Expand Down Expand Up @@ -325,4 +328,107 @@ mod tests {
);
}
}

#[tokio::test]
async fn filters_out_by_aliases() {
let setup = r#"
create schema auth;

create table auth.users (
uid serial primary key,
name text not null,
email text unique not null
);

create table auth.posts (
pid serial primary key,
user_id int not null references auth.users(uid),
title text not null,
content text,
created_at timestamp default now()
);
"#;

// test in SELECT clause
assert_complete_results(
format!(
"select u.id, p.{} from auth.users u join auth.posts p on u.id = p.user_id;",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::LabelNotExists("uid".to_string()),
CompletionAssertion::LabelNotExists("name".to_string()),
CompletionAssertion::LabelNotExists("email".to_string()),
CompletionAssertion::Label("content".to_string()),
CompletionAssertion::Label("created_at".to_string()),
CompletionAssertion::Label("pid".to_string()),
CompletionAssertion::Label("title".to_string()),
CompletionAssertion::Label("user_id".to_string()),
],
setup,
)
.await;

// test in JOIN clause
assert_complete_results(
format!(
"select u.id, p.content from auth.users u join auth.posts p on u.id = p.{};",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::LabelNotExists("uid".to_string()),
CompletionAssertion::LabelNotExists("name".to_string()),
CompletionAssertion::LabelNotExists("email".to_string()),
// primary keys are preferred
CompletionAssertion::Label("pid".to_string()),
CompletionAssertion::Label("content".to_string()),
CompletionAssertion::Label("created_at".to_string()),
CompletionAssertion::Label("title".to_string()),
CompletionAssertion::Label("user_id".to_string()),
],
setup,
)
.await;
}

#[tokio::test]
async fn does_not_complete_cols_in_join_clauses() {
let setup = r#"
create schema auth;

create table auth.users (
uid serial primary key,
name text not null,
email text unique not null
);

create table auth.posts (
pid serial primary key,
user_id int not null references auth.users(uid),
title text not null,
content text,
created_at timestamp default now()
);
"#;

/*
* We are not in the "ON" part of the JOIN clause, so we should not complete columns.
*/
assert_complete_results(
format!(
"select u.id, p.content from auth.users u join auth.{}",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::KindNotExists(CompletionItemKind::Column),
CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table),
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
],
setup,
)
.await;
}
}
2 changes: 1 addition & 1 deletion crates/pgt_completions/src/providers/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub(crate) fn get_completion_text_with_schema(
item_name: &str,
item_schema_name: &str,
) -> Option<CompletionText> {
if item_schema_name == "public" || ctx.schema_name.is_some() {
if item_schema_name == "public" || ctx.schema_or_alias_name.is_some() {
None
} else {
let node = ctx.node_under_cursor.unwrap();
Expand Down
3 changes: 2 additions & 1 deletion crates/pgt_completions/src/providers/schemas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ mod tests {
"private".to_string(),
CompletionItemKind::Schema,
),
// users table still preferred over system schemas
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
CompletionAssertion::LabelAndKind(
"information_schema".to_string(),
CompletionItemKind::Schema,
Expand All @@ -71,7 +73,6 @@ mod tests {
"pg_toast".to_string(),
CompletionItemKind::Schema,
),
CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table),
],
setup,
)
Expand Down
33 changes: 33 additions & 0 deletions crates/pgt_completions/src/providers/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,37 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn suggests_tables_in_join() {
let setup = r#"
create schema auth;

create table auth.users (
uid serial primary key,
name text not null,
email text unique not null
);

create table auth.posts (
pid serial primary key,
user_id int not null references auth.users(uid),
title text not null,
content text,
created_at timestamp default now()
);
"#;

assert_complete_results(
format!("select * from auth.users u join {}", CURSOR_POS).as_str(),
vec![
CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema),
CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema),
CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join
CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table),
],
setup,
)
.await;
}
}
Loading