Skip to content

Commit dd4ad27

Browse files
ok
1 parent dfd40e7 commit dd4ad27

File tree

7 files changed

+247
-23
lines changed

7 files changed

+247
-23
lines changed

crates/pgt_completions/src/context.rs

+62-16
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl TryFrom<&str> for ClauseType {
3030
match value {
3131
"select" => Ok(Self::Select),
3232
"where" => Ok(Self::Where),
33-
"from" | "keyword_from" => Ok(Self::From),
33+
"from" => Ok(Self::From),
3434
"update" => Ok(Self::Update),
3535
"delete" => Ok(Self::Delete),
3636
_ => {
@@ -49,8 +49,52 @@ impl TryFrom<&str> for ClauseType {
4949

5050
impl TryFrom<String> for ClauseType {
5151
type Error = String;
52-
fn try_from(value: String) -> Result<ClauseType, Self::Error> {
53-
ClauseType::try_from(value.as_str())
52+
fn try_from(value: String) -> Result<Self, Self::Error> {
53+
Self::try_from(value.as_str())
54+
}
55+
}
56+
57+
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
58+
/// That gives us a lot of insight for completions.
59+
/// Other nodes, such as the "relation" node, gives us less but still
60+
/// relevant information.
61+
/// `WrappingNode` maps to such nodes.
62+
///
63+
/// Note: This is not the direct parent of the `node_under_cursor`, but the closest
64+
/// *relevant* parent.
65+
#[derive(Debug, PartialEq, Eq)]
66+
pub enum WrappingNode {
67+
Relation,
68+
BinaryExpression,
69+
Assignment,
70+
}
71+
72+
impl TryFrom<&str> for WrappingNode {
73+
type Error = String;
74+
75+
fn try_from(value: &str) -> Result<Self, Self::Error> {
76+
match value {
77+
"relation" => Ok(Self::Relation),
78+
"assignment" => Ok(Self::Assignment),
79+
"binary_expression" => Ok(Self::BinaryExpression),
80+
_ => {
81+
let message = format!("Unimplemented Relation: {}", value);
82+
83+
// Err on tests, so we notice that we're lacking an implementation immediately.
84+
if cfg!(test) {
85+
panic!("{}", message);
86+
}
87+
88+
Err(message)
89+
}
90+
}
91+
}
92+
}
93+
94+
impl TryFrom<String> for WrappingNode {
95+
type Error = String;
96+
fn try_from(value: String) -> Result<Self, Self::Error> {
97+
Self::try_from(value.as_str())
5498
}
5599
}
56100

@@ -64,6 +108,9 @@ pub(crate) struct CompletionContext<'a> {
64108

65109
pub schema_name: Option<String>,
66110
pub wrapping_clause_type: Option<ClauseType>,
111+
112+
pub wrapping_node_kind: Option<WrappingNode>,
113+
67114
pub is_invocation: bool,
68115
pub wrapping_statement_range: Option<tree_sitter::Range>,
69116

@@ -80,6 +127,7 @@ impl<'a> CompletionContext<'a> {
80127
node_under_cursor: None,
81128
schema_name: None,
82129
wrapping_clause_type: None,
130+
wrapping_node_kind: None,
83131
wrapping_statement_range: None,
84132
is_invocation: false,
85133
mentioned_relations: HashMap::new(),
@@ -163,23 +211,26 @@ impl<'a> CompletionContext<'a> {
163211
) {
164212
let current_node = cursor.node();
165213

214+
let parent_node_kind = parent_node.kind();
215+
let current_node_kind = current_node.kind();
216+
166217
// prevent infinite recursion – this can happen if we only have a PROGRAM node
167-
if current_node.kind() == parent_node.kind() {
218+
if current_node_kind == parent_node_kind {
168219
self.node_under_cursor = Some(current_node);
169220
return;
170221
}
171222

172-
match parent_node.kind() {
223+
match parent_node_kind {
173224
"statement" | "subquery" => {
174-
self.wrapping_clause_type = current_node.kind().try_into().ok();
225+
self.wrapping_clause_type = current_node_kind.try_into().ok();
175226
self.wrapping_statement_range = Some(parent_node.range());
176227
}
177228
"invocation" => self.is_invocation = true,
178229

179230
_ => {}
180231
}
181232

182-
match current_node.kind() {
233+
match current_node_kind {
183234
"object_reference" => {
184235
let content = self.get_ts_node_content(current_node);
185236
if let Some(node_txt) = content {
@@ -195,13 +246,12 @@ impl<'a> CompletionContext<'a> {
195246
}
196247
}
197248

198-
// in Treesitter, the Where clause is nested inside other clauses
199-
"where" => {
200-
self.wrapping_clause_type = "where".try_into().ok();
249+
"where" | "update" | "select" | "delete" | "from" => {
250+
self.wrapping_clause_type = current_node_kind.try_into().ok();
201251
}
202252

203-
"keyword_from" => {
204-
self.wrapping_clause_type = "keyword_from".try_into().ok();
253+
"relation" | "binary_expression" | "assignment" => {
254+
self.wrapping_node_kind = current_node_kind.try_into().ok();
205255
}
206256

207257
_ => {}
@@ -406,10 +456,6 @@ mod tests {
406456
ctx.get_ts_node_content(node),
407457
Some(NodeText::Original("from"))
408458
);
409-
assert_eq!(
410-
ctx.wrapping_clause_type,
411-
Some(crate::context::ClauseType::From)
412-
);
413459
}
414460

415461
#[test]

crates/pgt_completions/src/providers/helper.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ pub(crate) fn get_completion_text_with_schema(
77
item_name: &str,
88
item_schema_name: &str,
99
) -> Option<CompletionText> {
10-
if item_schema_name == "public" {
11-
None
12-
} else if ctx.schema_name.is_some() {
10+
if item_schema_name == "public" || ctx.schema_name.is_some() {
1311
None
1412
} else {
1513
let node = ctx.node_under_cursor.unwrap();

crates/pgt_completions/src/providers/tables.rs

+96-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ mod tests {
3131

3232
use crate::{
3333
CompletionItem, CompletionItemKind, complete,
34-
test_helper::{CURSOR_POS, get_test_deps, get_test_params},
34+
test_helper::{
35+
CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results,
36+
get_test_deps, get_test_params,
37+
},
3538
};
3639

3740
#[tokio::test]
@@ -178,4 +181,96 @@ mod tests {
178181
assert_eq!(label, "coos");
179182
assert_eq!(kind, CompletionItemKind::Table);
180183
}
184+
185+
#[tokio::test]
186+
async fn suggests_tables_in_update() {
187+
let setup = r#"
188+
create table coos (
189+
id serial primary key,
190+
name text
191+
);
192+
"#;
193+
194+
assert_complete_results(
195+
format!("update {}", CURSOR_POS).as_str(),
196+
vec![CompletionAssertion::LabelAndKind(
197+
"public".into(),
198+
CompletionItemKind::Schema,
199+
)],
200+
setup,
201+
)
202+
.await;
203+
204+
assert_complete_results(
205+
format!("update public.{}", CURSOR_POS).as_str(),
206+
vec![CompletionAssertion::LabelAndKind(
207+
"coos".into(),
208+
CompletionItemKind::Table,
209+
)],
210+
setup,
211+
)
212+
.await;
213+
214+
assert_no_complete_results(format!("update public.coos {}", CURSOR_POS).as_str(), setup)
215+
.await;
216+
217+
assert_complete_results(
218+
format!("update coos set {}", CURSOR_POS).as_str(),
219+
vec![
220+
CompletionAssertion::Label("id".into()),
221+
CompletionAssertion::Label("name".into()),
222+
],
223+
setup,
224+
)
225+
.await;
226+
227+
assert_complete_results(
228+
format!("update coos set name = 'cool' where {}", CURSOR_POS).as_str(),
229+
vec![
230+
CompletionAssertion::Label("id".into()),
231+
CompletionAssertion::Label("name".into()),
232+
],
233+
setup,
234+
)
235+
.await;
236+
}
237+
238+
#[tokio::test]
239+
async fn suggests_tables_in_delete() {
240+
let setup = r#"
241+
create table coos (
242+
id serial primary key,
243+
name text
244+
);
245+
"#;
246+
247+
assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), setup).await;
248+
249+
assert_complete_results(
250+
format!("delete from {}", CURSOR_POS).as_str(),
251+
vec![
252+
CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema),
253+
CompletionAssertion::LabelAndKind("coos".into(), CompletionItemKind::Table),
254+
],
255+
setup,
256+
)
257+
.await;
258+
259+
assert_complete_results(
260+
format!("delete from public.{}", CURSOR_POS).as_str(),
261+
vec![CompletionAssertion::Label("coos".into())],
262+
setup,
263+
)
264+
.await;
265+
266+
assert_complete_results(
267+
format!("delete from public.coos where {}", CURSOR_POS).as_str(),
268+
vec![
269+
CompletionAssertion::Label("id".into()),
270+
CompletionAssertion::Label("name".into()),
271+
],
272+
setup,
273+
)
274+
.await;
275+
}
181276
}

crates/pgt_completions/src/relevance/filtering.rs

+10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ impl CompletionFilter<'_> {
3535
return None;
3636
}
3737

38+
// No autocompetions if there are two identifiers without a separator.
39+
if ctx.node_under_cursor.is_some_and(|n| {
40+
n.prev_sibling().is_some_and(|p| {
41+
(p.kind() == "identifier" || p.kind() == "object_reference")
42+
&& n.kind() == "identifier"
43+
})
44+
}) {
45+
return None;
46+
}
47+
3848
Some(())
3949
}
4050

crates/pgt_completions/src/relevance/scoring.rs

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::context::{ClauseType, CompletionContext, NodeText};
1+
use crate::context::{ClauseType, CompletionContext, NodeText, WrappingNode};
22

33
use super::CompletionRelevanceData;
44

@@ -28,6 +28,7 @@ impl CompletionScore<'_> {
2828
self.check_matches_query_input(ctx);
2929
self.check_is_invocation(ctx);
3030
self.check_matching_clause_type(ctx);
31+
self.check_matching_wrapping_node(ctx);
3132
self.check_relations_in_stmt(ctx);
3233
}
3334

@@ -96,6 +97,36 @@ impl CompletionScore<'_> {
9697
}
9798
}
9899

100+
fn check_matching_wrapping_node(&mut self, ctx: &CompletionContext) {
101+
let wrapping_node = match ctx.wrapping_node_kind.as_ref() {
102+
None => return,
103+
Some(wn) => wn,
104+
};
105+
106+
let has_schema = ctx.schema_name.is_some();
107+
108+
self.score += match self.data {
109+
CompletionRelevanceData::Table(_) => match wrapping_node {
110+
WrappingNode::Relation => 15,
111+
WrappingNode::BinaryExpression => 5,
112+
_ => -50,
113+
},
114+
CompletionRelevanceData::Function(_) => match wrapping_node {
115+
WrappingNode::Relation => 10,
116+
_ => -50,
117+
},
118+
CompletionRelevanceData::Column(_) => match wrapping_node {
119+
WrappingNode::BinaryExpression => 15,
120+
WrappingNode::Assignment => 15,
121+
_ => -15,
122+
},
123+
CompletionRelevanceData::Schema(_) => match wrapping_node {
124+
WrappingNode::Relation if !has_schema => 5,
125+
_ => -50,
126+
},
127+
}
128+
}
129+
99130
fn check_is_invocation(&mut self, ctx: &CompletionContext) {
100131
self.score += match self.data {
101132
CompletionRelevanceData::Function(_) if ctx.is_invocation => 30,

crates/pgt_completions/src/test_helper.rs

+46-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use pgt_schema_cache::SchemaCache;
22
use pgt_test_utils::test_database::get_new_test_db;
33
use sqlx::Executor;
44

5-
use crate::CompletionParams;
5+
use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete};
66

77
pub static CURSOR_POS: char = '€';
88

@@ -139,3 +139,48 @@ mod tests {
139139
}
140140
}
141141
}
142+
143+
#[derive(Debug, PartialEq, Eq)]
144+
pub(crate) enum CompletionAssertion {
145+
Label(String),
146+
LabelAndKind(String, CompletionItemKind),
147+
}
148+
149+
impl CompletionAssertion {
150+
fn assert_eq(self, item: CompletionItem) {
151+
match self {
152+
CompletionAssertion::Label(label) => {
153+
assert_eq!(item.label, label);
154+
}
155+
CompletionAssertion::LabelAndKind(label, kind) => {
156+
assert_eq!(item.label, label);
157+
assert_eq!(item.kind, kind);
158+
}
159+
}
160+
}
161+
}
162+
163+
pub(crate) async fn assert_complete_results(
164+
query: &str,
165+
assertions: Vec<CompletionAssertion>,
166+
setup: &str,
167+
) {
168+
let (tree, cache) = get_test_deps(setup, query.into()).await;
169+
let params = get_test_params(&tree, &cache, query.into());
170+
let items = complete(params);
171+
172+
assertions
173+
.into_iter()
174+
.zip(items.into_iter())
175+
.for_each(|(assertion, result)| {
176+
assertion.assert_eq(result);
177+
});
178+
}
179+
180+
pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) {
181+
let (tree, cache) = get_test_deps(setup, query.into()).await;
182+
let params = get_test_params(&tree, &cache, query.into());
183+
let items = complete(params);
184+
185+
assert_eq!(items.len(), 0)
186+
}

crates/pgt_workspace/src/features/code_actions.rs

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ pub struct CommandAction {
4646

4747
#[derive(Debug, serde::Serialize, serde::Deserialize, strum::EnumIter)]
4848
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
49-
5049
pub enum CommandActionCategory {
5150
ExecuteStatement(StatementId),
5251
}

0 commit comments

Comments
 (0)