Skip to content

Commit bf3d96c

Browse files
arairga
1 parent 121431d commit bf3d96c

File tree

8 files changed

+110
-98
lines changed

8 files changed

+110
-98
lines changed

crates/pgt_completions/src/context/context.rs

-13
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ impl<'a> NodeUnderCursor<'a> {
6767
}
6868
}
6969

70-
pub fn range(&self) -> TextRange {
71-
let start: u32 = self.start_byte().try_into().unwrap();
72-
let end: u32 = self.end_byte().try_into().unwrap();
73-
74-
TextRange::new(start.into(), end.into())
75-
}
76-
7770
pub fn kind(&self) -> &str {
7871
match self {
7972
NodeUnderCursor::TsNode(node) => node.kind(),
@@ -216,12 +209,6 @@ impl<'a> CompletionContext<'a> {
216209
ctx.gather_info_from_ts_queries();
217210
}
218211

219-
tracing::warn!("SQL: {}", ctx.text);
220-
tracing::warn!("Position: {}", ctx.position);
221-
tracing::warn!("Node: {:#?}", ctx.node_under_cursor);
222-
tracing::warn!("Relations: {:#?}", ctx.mentioned_relations);
223-
tracing::warn!("Clause: {:#?}", ctx.wrapping_clause_type);
224-
225212
ctx
226213
}
227214

crates/pgt_completions/src/context/policy_parser.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,11 @@ pub(crate) struct PolicyParser {
110110

111111
impl PolicyParser {
112112
pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext {
113+
let trimmed = sql.trim();
113114
assert!(
114-
sql.starts_with("create policy")
115-
|| sql.starts_with("drop policy")
116-
|| sql.starts_with("alter policy"),
115+
trimmed.starts_with("create policy")
116+
|| trimmed.starts_with("drop policy")
117+
|| trimmed.starts_with("alter policy"),
117118
"PolicyParser should only be used for policy statements. Developer error!"
118119
);
119120

crates/pgt_completions/src/providers/policies.rs

+57-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use super::helper::get_range_to_replace;
1010
pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) {
1111
let available_policies = &ctx.schema_cache.policies;
1212

13+
let has_quotes = ctx
14+
.get_node_under_cursor_content()
15+
.is_some_and(|c| c.starts_with('"') && c.ends_with('"'));
16+
1317
for pol in available_policies {
1418
let relevance = CompletionRelevanceData::Policy(pol);
1519

@@ -19,10 +23,14 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi
1923
filter: CompletionFilter::from(relevance),
2024
description: format!("{}", pol.table_name),
2125
kind: CompletionItemKind::Policy,
22-
completion_text: Some(CompletionText {
23-
text: format!("\"{}\"", pol.name),
24-
range: get_range_to_replace(ctx),
25-
}),
26+
completion_text: if !has_quotes {
27+
Some(CompletionText {
28+
text: format!("\"{}\"", pol.name),
29+
range: get_range_to_replace(ctx),
30+
})
31+
} else {
32+
None
33+
},
2634
};
2735

2836
builder.add_item(item);
@@ -31,30 +39,69 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi
3139

3240
#[cfg(test)]
3341
mod tests {
34-
use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results};
42+
use crate::{
43+
complete,
44+
test_helper::{
45+
CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_params,
46+
test_against_connection_string,
47+
},
48+
};
3549

3650
#[tokio::test]
3751
async fn completes_within_quotation_marks() {
3852
let setup = r#"
39-
create table users (
53+
create schema private;
54+
55+
create table private.users (
4056
id serial primary key,
4157
email text
4258
);
4359
44-
create policy "should never have access" on users
60+
create policy "read for public users disallowed" on private.users
4561
as restrictive
46-
for all
62+
for select
4763
to public
4864
using (false);
65+
66+
create policy "write for public users allowed" on private.users
67+
as restrictive
68+
for insert
69+
to public
70+
with check (true);
4971
"#;
5072

5173
assert_complete_results(
52-
format!("alter policy \"{}\" on users;", CURSOR_POS).as_str(),
74+
format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(),
75+
vec![
76+
CompletionAssertion::Label("read for public users disallowed".into()),
77+
CompletionAssertion::Label("write for public users allowed".into()),
78+
],
79+
setup,
80+
)
81+
.await;
82+
83+
assert_complete_results(
84+
format!("alter policy \"w{}\" on private.users;", CURSOR_POS).as_str(),
5385
vec![CompletionAssertion::Label(
54-
"should never have access".into(),
86+
"write for public users allowed".into(),
5587
)],
5688
setup,
5789
)
5890
.await;
5991
}
92+
93+
#[tokio::test]
94+
async fn sb_test() {
95+
let input = format!("alter policy \"u{}\" on public.fcm_tokens;", CURSOR_POS);
96+
97+
let (tree, cache) = test_against_connection_string(
98+
"postgresql://postgres:[email protected]:54322/postgres",
99+
input.as_str().into(),
100+
)
101+
.await;
102+
103+
let result = complete(get_test_params(&tree, &cache, input.as_str().into()));
104+
105+
println!("{:#?}", result);
106+
}
60107
}

crates/pgt_completions/src/relevance/filtering.rs

+5-6
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@ impl CompletionFilter<'_> {
7171

7272
match self.data {
7373
CompletionRelevanceData::Table(_) => {
74-
if in_clause(WrappingClause::Select)
75-
|| in_clause(WrappingClause::Where)
76-
|| in_clause(WrappingClause::PolicyName)
74+
if in_clause(WrappingClause::Select) || in_clause(WrappingClause::Where)
75+
// || in_clause(WrappingClause::PolicyName)
7776
{
7877
return None;
7978
};
@@ -107,9 +106,9 @@ impl CompletionFilter<'_> {
107106
}
108107
}
109108
_ => {
110-
if in_clause(WrappingClause::PolicyName) {
111-
return None;
112-
}
109+
// if in_clause(WrappingClause::PolicyName) {
110+
// return None;
111+
// }
113112
}
114113
}
115114

crates/pgt_completions/src/relevance/scoring.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,23 @@ impl CompletionScore<'_> {
3636

3737
fn check_matches_query_input(&mut self, ctx: &CompletionContext) {
3838
let content = match ctx.get_node_under_cursor_content() {
39-
Some(c) => c,
39+
Some(c) => c.replace('"', ""),
4040
None => return,
4141
};
4242

4343
let name = match self.data {
44-
CompletionRelevanceData::Function(f) => f.name.as_str(),
45-
CompletionRelevanceData::Table(t) => t.name.as_str(),
46-
CompletionRelevanceData::Column(c) => c.name.as_str(),
47-
CompletionRelevanceData::Schema(s) => s.name.as_str(),
48-
CompletionRelevanceData::Policy(p) => p.name.as_str(),
44+
CompletionRelevanceData::Function(f) => f.name.as_str().to_ascii_lowercase(),
45+
CompletionRelevanceData::Table(t) => t.name.as_str().to_ascii_lowercase(),
46+
CompletionRelevanceData::Column(c) => c.name.as_str().to_ascii_lowercase(),
47+
CompletionRelevanceData::Schema(s) => s.name.as_str().to_ascii_lowercase(),
48+
CompletionRelevanceData::Policy(p) => p.name.as_str().to_ascii_lowercase(),
4949
};
5050

5151
let fz_matcher = SkimMatcherV2::default();
5252

53-
if let Some(score) = fz_matcher.fuzzy_match(name, content.as_str()) {
53+
if let Some(score) =
54+
fz_matcher.fuzzy_match(name.as_str(), content.to_ascii_lowercase().as_str())
55+
{
5456
let scorei32: i32 = score
5557
.try_into()
5658
.expect("The length of the input exceeds i32 capacity");

crates/pgt_completions/src/sanitization.rs

+26-57
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ where
4545
'larger: 'smaller,
4646
{
4747
fn from(params: CompletionParams<'larger>) -> Self {
48-
if cursor_inbetween_nodes(params.tree, params.position)
49-
|| cursor_prepared_to_write_token_after_last_node(params.tree, params.position)
48+
if cursor_inbetween_nodes(&params.text, params.position)
49+
|| cursor_prepared_to_write_token_after_last_node(&params.text, params.position)
5050
|| cursor_before_semicolon(params.tree, params.position)
5151
|| cursor_on_a_dot(&params.text, params.position)
5252
|| cursor_between_double_quotes(&params.text, params.position)
@@ -125,37 +125,17 @@ where
125125
/// select |from users; -- cursor "touches" from node. returns false.
126126
/// select | from users; -- cursor is between select and from nodes. returns true.
127127
/// ```
128-
fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool {
129-
let mut cursor = tree.walk();
130-
let mut leaf_node = tree.root_node();
131-
132-
let byte = position.into();
133-
134-
// if the cursor escapes the root node, it can't be between nodes.
135-
if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() {
136-
return false;
137-
}
128+
fn cursor_inbetween_nodes(sql: &str, position: TextSize) -> bool {
129+
let position: usize = position.into();
130+
let mut chars = sql.chars();
138131

139-
/*
140-
* Get closer and closer to the leaf node, until
141-
* a) there is no more child *for the node* or
142-
* b) there is no more child *under the cursor*.
143-
*/
144-
loop {
145-
let child_idx = cursor.goto_first_child_for_byte(position.into());
146-
if child_idx.is_none() {
147-
break;
148-
}
149-
leaf_node = cursor.node();
150-
}
132+
let previous_whitespace = chars
133+
.nth(position - 1)
134+
.is_some_and(|c| c.is_ascii_whitespace());
151135

152-
let cursor_on_leafnode = byte >= leaf_node.start_byte() && leaf_node.end_byte() >= byte;
136+
let current_whitespace = chars.next().is_some_and(|c| c.is_ascii_whitespace());
153137

154-
/*
155-
* The cursor is inbetween nodes if it is not within the range
156-
* of a leaf node.
157-
*/
158-
!cursor_on_leafnode
138+
previous_whitespace && current_whitespace
159139
}
160140

161141
/// Checks if the cursor is positioned after the last node,
@@ -166,12 +146,9 @@ fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool
166146
/// select * from| -- user still needs to type a space
167147
/// select * from | -- too far off.
168148
/// ```
169-
fn cursor_prepared_to_write_token_after_last_node(
170-
tree: &tree_sitter::Tree,
171-
position: TextSize,
172-
) -> bool {
149+
fn cursor_prepared_to_write_token_after_last_node(sql: &str, position: TextSize) -> bool {
173150
let cursor_pos: usize = position.into();
174-
cursor_pos == tree.root_node().end_byte() + 1
151+
cursor_pos == sql.len() + 1
175152
}
176153

177154
fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool {
@@ -243,58 +220,44 @@ mod tests {
243220
// note: two spaces between select and from.
244221
let input = "select from users;";
245222

246-
let mut parser = tree_sitter::Parser::new();
247-
parser
248-
.set_language(tree_sitter_sql::language())
249-
.expect("Error loading sql language");
250-
251-
let tree = parser.parse(input, None).unwrap();
252-
253223
// select | from users; <-- just right, one space after select token, one space before from
254-
assert!(cursor_inbetween_nodes(&tree, TextSize::new(7)));
224+
assert!(cursor_inbetween_nodes(input, TextSize::new(7)));
255225

256226
// select| from users; <-- still on select token
257-
assert!(!cursor_inbetween_nodes(&tree, TextSize::new(6)));
227+
assert!(!cursor_inbetween_nodes(input, TextSize::new(6)));
258228

259229
// select |from users; <-- already on from token
260-
assert!(!cursor_inbetween_nodes(&tree, TextSize::new(8)));
230+
assert!(!cursor_inbetween_nodes(input, TextSize::new(8)));
261231

262232
// select from users;|
263-
assert!(!cursor_inbetween_nodes(&tree, TextSize::new(19)));
233+
assert!(!cursor_inbetween_nodes(input, TextSize::new(19)));
264234
}
265235

266236
#[test]
267237
fn test_cursor_after_nodes() {
268238
let input = "select * from";
269239

270-
let mut parser = tree_sitter::Parser::new();
271-
parser
272-
.set_language(tree_sitter_sql::language())
273-
.expect("Error loading sql language");
274-
275-
let tree = parser.parse(input, None).unwrap();
276-
277240
// select * from| <-- still on previous token
278241
assert!(!cursor_prepared_to_write_token_after_last_node(
279-
&tree,
242+
input,
280243
TextSize::new(13)
281244
));
282245

283246
// select * from | <-- too far off, two spaces afterward
284247
assert!(!cursor_prepared_to_write_token_after_last_node(
285-
&tree,
248+
input,
286249
TextSize::new(15)
287250
));
288251

289252
// select * |from <-- it's within
290253
assert!(!cursor_prepared_to_write_token_after_last_node(
291-
&tree,
254+
input,
292255
TextSize::new(9)
293256
));
294257

295258
// select * from | <-- just right
296259
assert!(cursor_prepared_to_write_token_after_last_node(
297-
&tree,
260+
input,
298261
TextSize::new(14)
299262
));
300263
}
@@ -353,5 +316,11 @@ mod tests {
353316

354317
// select * from "|" <-- between quotations
355318
assert!(cursor_between_double_quotes(input, TextSize::new(15)));
319+
320+
// select * from "r|" <-- between quotations, but there's
321+
// a letter inside
322+
let input = "select * from \"r\"";
323+
324+
assert!(!cursor_between_double_quotes(input, TextSize::new(16)));
356325
}
357326
}

crates/pgt_lsp/src/handlers/completions.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ pub fn get_completions(
5454
})
5555
.collect();
5656

57+
tracing::warn!("{:#?}", items);
58+
5759
Ok(lsp_types::CompletionResponse::Array(items))
5860
}
5961

@@ -65,6 +67,6 @@ fn to_lsp_types_completion_item_kind(
6567
pgt_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS,
6668
pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD,
6769
pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS,
68-
pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::VALUE,
70+
pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::CONSTANT,
6971
}
7072
}

crates/pgt_workspace/src/workspace/server.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -481,15 +481,20 @@ impl Workspace for WorkspaceServer {
481481
Some(pool) => pool,
482482
None => {
483483
tracing::debug!("No connection to database. Skipping completions.");
484+
tracing::warn!("No connection to database.");
484485
return Ok(CompletionsResult::default());
485486
}
486487
};
487488

488489
let schema_cache = self.schema_cache.load(pool)?;
489490

490491
match get_statement_for_completions(&parsed_doc, params.position) {
491-
None => Ok(CompletionsResult::default()),
492+
None => {
493+
tracing::warn!("No statement found.");
494+
Ok(CompletionsResult::default())
495+
}
492496
Some((_id, range, content, cst)) => {
497+
tracing::warn!("found matching statement, content: {}", content);
493498
let position = params.position - range.start();
494499

495500
let items = pgt_completions::complete(pgt_completions::CompletionParams {

0 commit comments

Comments
 (0)