From 34a9687245d77f8bd024090d904cbe01b79ade9b Mon Sep 17 00:00:00 2001 From: psteinroe Date: Wed, 9 Apr 2025 11:38:21 +0200 Subject: [PATCH 01/20] refactor: parser --- crates/pgt_query_ext/src/lib.rs | 45 ++++ crates/pgt_workspace/src/workspace/server.rs | 104 ++++++++- .../src/workspace/server/TODO.md | 137 +++++++++++ .../src/workspace/server/change.rs | 98 +++----- .../src/workspace/server/document.rs | 120 +++------- .../src/workspace/server/parser.rs | 214 ++++++++++++++++++ .../src/workspace/server/pg_query.rs | 50 ++-- .../src/workspace/server/sql_function.rs | 111 +++++++++ .../workspace/server/statement_identifier.rs | 41 ++++ .../src/workspace/server/tree_sitter.rs | 36 +-- 10 files changed, 752 insertions(+), 204 deletions(-) create mode 100644 crates/pgt_workspace/src/workspace/server/TODO.md create mode 100644 crates/pgt_workspace/src/workspace/server/parser.rs create mode 100644 crates/pgt_workspace/src/workspace/server/sql_function.rs create mode 100644 crates/pgt_workspace/src/workspace/server/statement_identifier.rs diff --git a/crates/pgt_query_ext/src/lib.rs b/crates/pgt_query_ext/src/lib.rs index c1f5fb49..32eab8ef 100644 --- a/crates/pgt_query_ext/src/lib.rs +++ b/crates/pgt_query_ext/src/lib.rs @@ -30,3 +30,48 @@ pub fn parse(sql: &str) -> Result { .ok_or_else(|| Error::Parse("Unable to find root node".to_string())) })? } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_1() { + let input = "CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + println!("{:#?}", parse(input).unwrap()); + // print after 42 + println!("{:#?}", &input[42..]); + } + + #[test] + fn test_sql_2() { + let input = "CREATE FUNCTION add() RETURNS integer + AS $sql$select 1 + 2;$sql$ + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + println!("{:#?}", parse(input).unwrap()); + // print after 58 + println!("{:#?}", &input[58..]); + } + + #[test] + fn test_plpsql() { + let input = "CREATE FUNCTION add(integer, integer) RETURNS integer + AS $s$ +begin + return $1 + $2; +end +$s$ + LANGUAGE plpgsql + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + println!("{:#?}", parse(input).unwrap()); + // print after 58 + println!("{:#?}", &input[58..]); + } +} diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 8dcbfb1d..8f464b1d 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -1,4 +1,9 @@ -use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; +use std::{ + fs, + panic::RefUnwindSafe, + path::Path, + sync::{Arc, Mutex, RwLock}, +}; use analyser::AnalyserVisitorBuilder; use async_helper::run_async; @@ -15,9 +20,11 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as use pgt_fs::{ConfigName, PgTPath}; use pgt_typecheck::TypecheckParams; use schema_cache_manager::SchemaCacheManager; +use sql_function::SQLFunctionBodyStore; use sqlx::Executor; use tracing::info; use tree_sitter::TreeSitterStore; +use tree_sitter_parser::TreeSitterParserStore; use crate::{ WorkspaceError, @@ -44,8 +51,11 @@ mod change; mod db_connection; mod document; mod migration; +mod parser; mod pg_query; mod schema_cache_manager; +mod sql_function; +mod statement_identifier; mod tree_sitter; pub(super) struct WorkspaceServer { @@ -60,6 +70,8 @@ pub(super) struct WorkspaceServer { tree_sitter: TreeSitterStore, pg_query: PgQueryStore, + sql_functions: SQLFunctionBodyStore, + ts_parser: TreeSitterParserStore, connection: RwLock, } @@ -86,6 +98,8 @@ impl WorkspaceServer { pg_query: PgQueryStore::new(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), + sql_functions: SQLFunctionBodyStore::new(), + ts_parser: TreeSitterParserStore::new(), } } @@ -183,10 +197,17 @@ impl Workspace for WorkspaceServer { fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { let doc = Document::new(params.path.clone(), params.content, params.version); + let doc_parser = self.ts_parser.get_parser(params.path.clone()); + let mut parser = doc_parser.lock().expect("Error locking parser"); doc.iter_statements_with_text().for_each(|(stmt, content)| { - self.tree_sitter.add_statement(&stmt, content); + self.tree_sitter.add_statement(&mut parser, &stmt, content); self.pg_query.add_statement(&stmt, content); + if let Some(ast) = self.pg_query.get_ast(&stmt) { + self.sql_functions + .add_statement(&mut parser, &ast, &stmt, content); + } }); + drop(parser); self.documents.insert(params.path, doc); @@ -203,8 +224,11 @@ impl Workspace for WorkspaceServer { for stmt in doc.iter_statements() { self.tree_sitter.remove_statement(&stmt); self.pg_query.remove_statement(&stmt); + self.sql_functions.remove_statement(&stmt); } + self.ts_parser.remove_parser(¶ms.path); + Ok(()) } @@ -223,6 +247,8 @@ impl Workspace for WorkspaceServer { params.version, )); + let doc_parser = self.ts_parser.get_parser(params.path.clone()); + let mut parser = doc_parser.lock().expect("Error locking parser"); for c in &doc.apply_file_change(¶ms) { match c { StatementChange::Added(added) => { @@ -232,8 +258,17 @@ impl Workspace for WorkspaceServer { added.stmt.path.as_os_str().to_str(), added.text ); - self.tree_sitter.add_statement(&added.stmt, &added.text); + self.tree_sitter + .add_statement(&mut parser, &added.stmt, &added.text); self.pg_query.add_statement(&added.stmt, &added.text); + if let Some(ast) = self.pg_query.get_ast(&added.stmt) { + self.sql_functions.add_statement( + &mut parser, + &ast, + &added.stmt, + &added.text, + ); + } } StatementChange::Deleted(s) => { tracing::debug!( @@ -243,6 +278,7 @@ impl Workspace for WorkspaceServer { ); self.tree_sitter.remove_statement(s); self.pg_query.remove_statement(s); + self.sql_functions.remove_statement(s); } StatementChange::Modified(s) => { tracing::debug!( @@ -256,11 +292,15 @@ impl Workspace for WorkspaceServer { s.change_text ); - self.tree_sitter.modify_statement(s); + self.tree_sitter.modify_statement(&mut parser, s); self.pg_query.modify_statement(s); + if let Some(ast) = self.pg_query.get_ast(&s.new_stmt) { + self.sql_functions.modify_statement(&mut parser, &ast, s); + } } } } + drop(parser); Ok(()) } @@ -420,10 +460,25 @@ impl Workspace for WorkspaceServer { { let typecheck_params: Vec<_> = doc .iter_statements_with_text_and_range() - .map(|(stmt, range, text)| { + .flat_map(|(stmt, range, text)| { let ast = self.pg_query.get_ast(&stmt); let tree = self.tree_sitter.get_parse_tree(&stmt); - (text.to_string(), ast, tree, *range) + + let mut res = vec![(text.to_string(), ast, tree, *range)]; + + if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) { + // fn_body range is within the statement -> adjust it to be relative to the + // document instead (as the other ranges) + let fn_range = fn_body.range + range.start(); + res.push(( + fn_body.body.clone(), + fn_body.ast.clone().map(Arc::new), + Some(Arc::new(fn_body.cst.clone())), + fn_range, + )); + } + + res }) .collect(); @@ -479,6 +534,27 @@ impl Workspace for WorkspaceServer { ); } + if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) { + if let Some(fn_diag) = &fn_body.syntax_diagnostics { + stmt_diagnostics.push(SDiagnostic::new( + fn_diag + .clone() + .with_file_path(params.path.as_path().display().to_string()) + .with_file_span(fn_body.range + r.start()), + )); + } + + if let Some(ast) = &fn_body.ast { + stmt_diagnostics.extend( + analyser + .run(AnalyserContext { root: ast }) + .into_iter() + .map(SDiagnostic::new) + .collect::>(), + ); + } + } + stmt_diagnostics .into_iter() .map(|d| { @@ -559,13 +635,27 @@ impl Workspace for WorkspaceServer { let schema_cache = self.schema_cache.load(pool)?; - let items = pgt_completions::complete(pgt_completions::CompletionParams { + let mut items = pgt_completions::complete(pgt_completions::CompletionParams { position, schema: schema_cache.as_ref(), tree: tree.as_deref(), text: text.to_string(), }); + if let Some(f) = self.sql_functions.get_function_body(&statement) { + let fn_text = f.body.clone(); + let fn_range = f.range + stmt_range.start(); + + items.extend(pgt_completions::complete( + pgt_completions::CompletionParams { + position: position - fn_range.start(), + schema: schema_cache.as_ref(), + tree: Some(&f.cst), + text: fn_text, + }, + )); + } + Ok(CompletionsResult { items }) } } diff --git a/crates/pgt_workspace/src/workspace/server/TODO.md b/crates/pgt_workspace/src/workspace/server/TODO.md new file mode 100644 index 00000000..ff3da8ca --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/TODO.md @@ -0,0 +1,137 @@ +1. Statement Iterator + +```rust +pub struct StatementIterator<'a> { + document: &'a Document, + positions: std::slice::Iter<'a, StatementPos>, + include_text: bool, + include_range: bool, +} + +impl<'a> StatementIterator<'a> { + fn new(document: &'a Document) -> Self { + Self { + document, + positions: document.positions.iter(), + include_text: false, + include_range: false, + } + } + + pub fn with_text(mut self) -> Self { + self.include_text = true; + self + } + + pub fn with_range(mut self) -> Self { + self.include_range = true; + self + } +} + +pub enum StatementData<'a> { + Statement(Statement), + WithText(Statement, &'a str), + WithRange(Statement, &'a TextRange), + WithTextAndRange(Statement, &'a TextRange, &'a str), +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = StatementData<'a>; + + fn next(&mut self) -> Option { + self.positions.next().map(|(id, range)| { + let statement = Statement { + id: *id, + path: self.document.path.clone(), + }; + + match (self.include_text, self.include_range) { + (false, false) => StatementData::Statement(statement), + (true, false) => { + let text = &self.document.content[range.start().into()..range.end().into()]; + StatementData::WithText(statement, text) + }, + (false, true) => StatementData::WithRange(statement, range), + (true, true) => { + let text = &self.document.content[range.start().into()..range.end().into()]; + StatementData::WithTextAndRange(statement, range, text) + }, + } + }) + } +} +pub struct StatementIterator<'a> { + document: &'a Document, + positions: std::slice::Iter<'a, StatementPos>, + include_text: bool, + include_range: bool, +} + +impl<'a> StatementIterator<'a> { + fn new(document: &'a Document) -> Self { + Self { + document, + positions: document.positions.iter(), + include_text: false, + include_range: false, + } + } + + pub fn with_text(mut self) -> Self { + self.include_text = true; + self + } + + pub fn with_range(mut self) -> Self { + self.include_range = true; + self + } +} + +pub enum StatementData<'a> { + Statement(Statement), + WithText(Statement, &'a str), + WithRange(Statement, &'a TextRange), + WithTextAndRange(Statement, &'a TextRange, &'a str), + // with ast + // with cst + // include substatements +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = StatementData<'a>; + + fn next(&mut self) -> Option { + self.positions.next().map(|(id, range)| { + let statement = Statement { + id: *id, + path: self.document.path.clone(), + }; + + match (self.include_text, self.include_range) { + (false, false) => StatementData::Statement(statement), + (true, false) => { + let text = &self.document.content[range.start().into()..range.end().into()]; + StatementData::WithText(statement, text) + }, + (false, true) => StatementData::WithRange(statement, range), + (true, true) => { + let text = &self.document.content[range.start().into()..range.end().into()]; + StatementData::WithTextAndRange(statement, range, text) + }, + } + }) + } +} +``` + +2. Parser +- one instance per document +- hold ts parser +- has "inner" document +- holds parse results +- exposes unified api to find statements with data +- reason for putting this together is that we dont want to manually fiddle with sub statements client side +-> i want to do doc.statements() and get all statements WITH sub statements. + diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index e31e4178..0b19e5db 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -3,27 +3,27 @@ use std::ops::{Add, Sub}; use crate::workspace::{ChangeFileParams, ChangeParams}; -use super::{Document, Statement, document}; +use super::{Document, Statement, StatementId, document, statement_identifier::StatementId}; #[derive(Debug, PartialEq, Eq)] pub enum StatementChange { Added(AddedStatement), - Deleted(Statement), + Deleted(StatementId), Modified(ModifiedStatement), } #[derive(Debug, PartialEq, Eq)] pub struct AddedStatement { - pub stmt: Statement, + pub stmt: StatementId, pub text: String, } #[derive(Debug, PartialEq, Eq)] pub struct ModifiedStatement { - pub old_stmt: Statement, + pub old_stmt: StatementId, pub old_stmt_text: String, - pub new_stmt: Statement, + pub new_stmt: StatementId, pub new_stmt_text: String, pub change_range: TextRange, @@ -78,12 +78,7 @@ impl Document { fn drain_positions(&mut self) -> Vec { self.positions .drain(..) - .map(|(id, _)| { - StatementChange::Deleted(Statement { - id, - path: self.path.clone(), - }) - }) + .map(|(id, _)| StatementChange::Deleted(id)) .collect() } @@ -109,28 +104,22 @@ impl Document { changes.extend(ranges.into_iter().map(|range| { let id = self.id_generator.next(); let text = self.content[range].to_string(); - self.positions.push((id, range)); + self.positions.push((id.clone(), range)); - StatementChange::Added(AddedStatement { - stmt: Statement { - path: self.path.clone(), - id, - }, - text, - }) + StatementChange::Added(AddedStatement { stmt: id, text }) })); changes } - fn insert_statement(&mut self, range: TextRange) -> usize { + fn insert_statement(&mut self, range: TextRange) -> StatementId { let pos = self .positions .binary_search_by(|(_, r)| r.start().cmp(&range.start())) .unwrap_err(); let new_id = self.id_generator.next(); - self.positions.insert(pos, (new_id, range)); + self.positions.insert(pos, (new_id.clone(), range)); new_id } @@ -279,16 +268,10 @@ impl Document { self.positions[affected_idx] = (new_id, new_range); changed.push(StatementChange::Modified(ModifiedStatement { - old_stmt: Statement { - id: old_id, - path: self.path.clone(), - }, + old_stmt: old_id, old_stmt_text: self.content[old_range].to_string(), - new_stmt: Statement { - id: new_id, - path: self.path.clone(), - }, + new_stmt: new_id, new_stmt_text: changed_content[new_ranges[0]].to_string(), // change must be relative to the statement change_text: change.text.clone(), @@ -322,24 +305,15 @@ impl Document { // delete and add new ones if let Some(next_index) = next_index { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[next_index].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted(self.positions[next_index].0)); self.positions.remove(next_index); } for idx in affected_indices.iter().rev() { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[*idx].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted(self.positions[*idx].0)); self.positions.remove(*idx); } if let Some(prev_index) = prev_index { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[prev_index].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted(self.positions[prev_index].0)); self.positions.remove(prev_index); } @@ -347,10 +321,7 @@ impl Document { let actual_range = range.add(full_affected_range.start()); let new_id = self.insert_statement(actual_range); changed.push(StatementChange::Added(AddedStatement { - stmt: Statement { - id: new_id, - path: self.path.clone(), - }, + stmt: new_id, text: new_content[actual_range].to_string(), })); }); @@ -464,7 +435,7 @@ mod tests { fn open_doc_with_scan_error() { let input = "select id from users;\n\n\n\nselect 1443ddwwd33djwdkjw13331333333333;"; - let d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -475,7 +446,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect 1;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); assert!(!d.has_fatal_error()); @@ -513,7 +484,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect 1;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); assert!(!d.has_fatal_error()); @@ -551,7 +522,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select 1d;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -585,7 +556,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select 1d;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -618,7 +589,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -656,7 +627,7 @@ mod tests { fn within_statements_2() { let path = PgTPath::new("test.sql"); let input = "alter table deal alter column value drop not null;\n"; - let mut d = Document::new(path.clone(), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 1); @@ -733,7 +704,7 @@ mod tests { fn julians_sample() { let path = PgTPath::new("test.sql"); let input = "select\n *\nfrom\n test;\n\nselect\n\nalter table test\n\ndrop column id;"; - let mut d = Document::new(path.clone(), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 4); @@ -815,7 +786,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -854,7 +825,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 1); @@ -879,7 +850,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -941,7 +912,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "\n"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 0); @@ -981,7 +952,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from\nselect * from contacts;"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 2); @@ -1012,7 +983,7 @@ mod tests { fn apply_changes_replacement() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "".to_string(), 0); + let mut doc = Document::new("".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1134,7 +1105,6 @@ mod tests { let path = PgTPath::new("test.sql"); let mut doc = Document::new( - path.clone(), "-- Add new schema named \"private\"\nCREATE SCHEMA \"private\";".to_string(), 0, ); @@ -1207,12 +1177,12 @@ mod tests { let input = "select id from users;\nselect * from contacts;"; let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), input.to_string(), 0); + let mut doc = Document::new(input.to_string(), 0); assert_eq!(doc.positions.len(), 2); - let stmt_1_range = doc.positions[0]; - let stmt_2_range = doc.positions[1]; + let stmt_1_range = doc.positions[0].clone(); + let stmt_2_range = doc.positions[1].clone(); let update_text = ",test"; @@ -1259,7 +1229,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from contacts;\n\nselect * from contacts;"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 2); diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 9ef8c234..ae51d27d 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -1,22 +1,11 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; -use pgt_fs::PgTPath; use pgt_text_size::{TextRange, TextSize}; -/// Global unique identifier for a statement -#[derive(Debug, Hash, Eq, PartialEq, Clone)] -pub(crate) struct Statement { - /// Path of the document - pub(crate) path: PgTPath, - /// Unique id within the document - pub(crate) id: StatementId, -} - -pub(crate) type StatementId = usize; +use super::statement_identifier::{IdGenerator, StatementId}; type StatementPos = (StatementId, TextRange); pub(crate) struct Document { - pub(crate) path: PgTPath, pub(crate) content: String, pub(crate) version: i32, @@ -28,13 +17,12 @@ pub(crate) struct Document { } impl Document { - pub(crate) fn new(path: PgTPath, content: String, version: i32) -> Self { + pub(crate) fn new(content: String, version: i32) -> Self { let mut id_generator = IdGenerator::new(); let (ranges, diagnostics) = split_with_diagnostics(&content, None); Self { - path, positions: ranges .into_iter() .map(|range| (id_generator.next(), range)) @@ -42,7 +30,6 @@ impl Document { content, version, diagnostics, - id_generator, } } @@ -60,74 +47,8 @@ impl Document { .any(|d| d.severity() == Severity::Fatal) } - pub fn iter_statements(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, _)| Statement { - id: *id, - path: self.path.clone(), - }) - } - - pub fn iter_statements_with_text(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - let text = &self.content[range.start().into()..range.end().into()]; - (statement, text) - }) - } - - pub fn iter_statements_with_range(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - (statement, range) - }) - } - - pub fn iter_statements_with_text_and_range( - &self, - ) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - ( - statement, - range, - &self.content[range.start().into()..range.end().into()], - ) - }) - } - - pub fn get_txt(&self, stmt_id: StatementId) -> Option { - self.positions - .iter() - .find(|pos| pos.0 == stmt_id) - .map(|(_, range)| { - let stmt = &self.content[range.start().into()..range.end().into()]; - stmt.to_owned() - }) - } -} - -pub(crate) struct IdGenerator { - pub(super) next_id: usize, -} - -impl IdGenerator { - fn new() -> Self { - Self { next_id: 0 } - } - - pub(super) fn next(&mut self) -> usize { - let id = self.next_id; - self.next_id += 1; - id + pub fn iter<'a>(&'a self) -> StatementIterator<'a> { + StatementIterator::new(self) } } @@ -165,3 +86,36 @@ pub(crate) fn split_with_diagnostics( ), } } + +pub trait StatementMapper<'a> { + type Output; + + fn map(&self, doc: &'a Document, id: &'a StatementId, range: &'a TextRange) -> Self::Output; +} + +pub struct StatementIterator<'a> { + document: &'a Document, + positions: std::slice::Iter<'a, StatementPos>, +} + +impl<'a> StatementIterator<'a> { + pub fn new(document: &'a Document) -> Self { + Self { + document, + positions: document.positions.iter(), + } + } +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = (StatementId, TextRange, &'a str); + + fn next(&mut self) -> Option { + self.positions.next().map(|(id, range)| { + let range = range.clone(); + let doc = self.document; + let id = id.clone(); + (id, range, &doc.content[range.clone()]) + }) + } +} diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs new file mode 100644 index 00000000..282ef65a --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -0,0 +1,214 @@ +use std::sync::Arc; + +use pgt_fs::PgTPath; +use pgt_text_size::TextRange; + +use crate::workspace::ChangeFileParams; + +use super::{ + StatementId, + change::StatementChange, + document::{Document, StatementIterator}, + pg_query::PgQueryStore, + sql_function::SQLFunctionBodyStore, + tree_sitter::TreeSitterStore, +}; + +pub struct Parser { + path: PgTPath, + + doc: Document, + ast_db: PgQueryStore, + cst_db: TreeSitterStore, + sql_fn_db: SQLFunctionBodyStore, +} + +impl Parser { + pub fn new(path: PgTPath, content: String, version: i32) -> Parser { + let doc = Document::new(content, version); + + let cst_db = TreeSitterStore::new(); + let ast_db = PgQueryStore::new(); + let sql_fn_db = SQLFunctionBodyStore::new(); + + doc.iter().for_each(|(stmt, _, content)| { + cst_db.add_statement(&stmt, content); + }); + + Parser { + path, + doc, + ast_db, + cst_db, + sql_fn_db, + } + } + + /// Applies a change to the document and updates the CST and AST databases accordingly. + /// + /// Note that only tree-sitter cares about statement modifications vs remove + add. + /// Hence, we just clear the AST for the old statements and lazily load them when requested. + /// + /// * `params`: ChangeFileParams - The parameters for the change to be applied. + pub fn apply_change(&mut self, params: ChangeFileParams) { + for c in &self.doc.apply_file_change(¶ms) { + match c { + StatementChange::Added(added) => { + tracing::debug!( + "Adding statement: id:{:?}, text:{:?}", + added.stmt, + added.text + ); + self.cst_db.add_statement(&added.stmt, &added.text); + } + StatementChange::Deleted(s) => { + tracing::debug!("Deleting statement: id {:?}", s,); + self.cst_db.remove_statement(s); + self.ast_db.clear_statement(s); + self.sql_fn_db.clear_statement(s); + } + StatementChange::Modified(s) => { + tracing::debug!( + "Modifying statement with id {:?} (new id {:?}). Range {:?}, Changed from '{:?}' to '{:?}', changed text: {:?}", + s.old_stmt, + s.new_stmt, + s.change_range, + s.old_stmt_text, + s.new_stmt_text, + s.change_text + ); + + self.cst_db.modify_statement(s); + self.ast_db.clear_statement(&s.old_stmt); + self.sql_fn_db.clear_statement(&s.old_stmt); + } + } + } + } + + pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, DefaultFilter> + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, DefaultFilter) + } + + pub fn iter_with_filter<'a, M, F>(&'a self, mapper: M, filter: F) -> ParseIterator<'a, M, F> + where + M: StatementMapper<'a>, + F: StatementFilter<'a>, + { + ParseIterator::new(self, mapper, filter) + } +} + +pub trait StatementMapper<'a> { + type Output; + + fn map( + &self, + parser: &'a Parser, + id: StatementId, + range: &TextRange, + content: &str, + ) -> Self::Output; +} + +pub trait StatementFilter<'a> { + fn apply(&self, range: &TextRange) -> bool; +} + +pub struct ParseIterator<'a, M, F> { + parser: &'a Parser, + statements: StatementIterator<'a>, + mapper: M, + filter: F, + pending_sub_statements: Vec<(StatementId, TextRange, &'a str)>, +} + +impl<'a, M, F> ParseIterator<'a, M, F> { + pub fn new(parser: &'a Parser, mapper: M, filter: F) -> Self { + Self { + parser, + statements: parser.doc.iter(), + mapper, + filter, + pending_sub_statements: Vec::new(), + } + } +} + +impl<'a, M, F> Iterator for ParseIterator<'a, M, F> +where + M: StatementMapper<'a>, + F: StatementFilter<'a>, +{ + type Item = M::Output; + + fn next(&mut self) -> Option { + // First check if we have any pending sub-statements to process + if let Some((id, range, content)) = self.pending_sub_statements.pop() { + if self.filter.apply(&range) { + return Some(self.mapper.map(self.parser, id, &range, &content)); + } + // If the sub-statement doesn't pass the filter, continue to the next item + return self.next(); + } + + // Process the next top-level statement + let next_statement = self.statements.next(); + + if let Some((root_id, range, content)) = next_statement { + // If we should include sub-statements and this statement has an AST + if let Ok(ast) = *self.parser.ast_db.load_parse(&root_id, &content) { + // Check if this is a SQL function definition with a body + if let Some(sub_statement) = self + .parser + .sql_fn_db + .get_function_body(&root_id, &ast, content) + { + // Add sub-statements to our pending queue + self.pending_sub_statements.push(( + root_id.create_child(), + // adjust range to document + sub_statement.range + range.start(), + &sub_statement.body, + )); + } + } + + // Return the current statement if it passes the filter + if self.filter.apply(&range) { + return Some(self.mapper.map(self.parser, root_id, &range, content)); + } + + // If the current statement doesn't pass the filter, try the next one + return self.next(); + } + + None + } +} + +struct WithAst; +impl<'a> StatementMapper<'a> for WithAst { + type Output = (StatementId, TextRange, Option>); + + fn map( + &self, + parser: &'a Parser, + id: StatementId, + range: &TextRange, + _content: &str, + ) -> Self::Output { + let ast = parser.ast_db.get_ast(&id); + (id, *range, ast) + } +} + +struct DefaultFilter; +impl<'a> StatementFilter<'a> for DefaultFilter { + fn apply(&self, _range: &TextRange) -> bool { + true + } +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 3ed452fc..c95fc927 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -1,52 +1,38 @@ use std::sync::Arc; use dashmap::DashMap; -use pgt_diagnostics::serde::Diagnostic as SDiagnostic; use pgt_query_ext::diagnostics::*; -use super::{change::ModifiedStatement, document::Statement}; +use super::statement_identifier::StatementId; pub struct PgQueryStore { - ast_db: DashMap>, - diagnostics: DashMap, + db: DashMap>>, } impl PgQueryStore { pub fn new() -> PgQueryStore { - PgQueryStore { - ast_db: DashMap::new(), - diagnostics: DashMap::new(), - } - } - - pub fn get_ast(&self, statement: &Statement) -> Option> { - self.ast_db.get(statement).map(|x| x.clone()) + PgQueryStore { db: DashMap::new() } } - pub fn add_statement(&self, statement: &Statement, content: &str) { - let r = pgt_query_ext::parse(content); - if let Ok(ast) = r { - self.ast_db.insert(statement.clone(), Arc::new(ast)); - } else { - tracing::info!("invalid statement, adding diagnostics."); - self.diagnostics - .insert(statement.clone(), SyntaxDiagnostic::from(r.unwrap_err())); + pub fn load_parse( + &self, + statement: &StatementId, + content: &str, + ) -> Arc> { + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; } - } - pub fn remove_statement(&self, statement: &Statement) { - self.ast_db.remove(statement); - self.diagnostics.remove(statement); + let r = Arc::new(pgt_query_ext::parse(content).map_err(SyntaxDiagnostic::from)); + self.db.insert(statement.clone(), r.clone()); + r } - pub fn modify_statement(&self, change: &ModifiedStatement) { - self.remove_statement(&change.old_stmt); - self.add_statement(&change.new_stmt, &change.new_stmt_text); - } + pub fn clear_statement(&self, id: &StatementId) { + self.db.remove(id); - pub fn get_diagnostics(&self, stmt: &Statement) -> Vec { - self.diagnostics - .get(stmt) - .map_or_else(Vec::new, |err| vec![SDiagnostic::new(err.value().clone())]) + if let Some(child_id) = id.as_child() { + self.db.remove(&child_id); + } } } diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs new file mode 100644 index 00000000..3378c563 --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use pgt_text_size::TextRange; + +use super::StatementId; + +pub struct SQLFunctionBody { + pub range: TextRange, + pub body: String, +} + +pub struct SQLFunctionBodyStore { + db: DashMap>>, +} + +impl SQLFunctionBodyStore { + pub fn new() -> SQLFunctionBodyStore { + SQLFunctionBodyStore { db: DashMap::new() } + } + + pub fn get_function_body( + &self, + statement: &StatementId, + ast: &pgt_query_ext::NodeEnum, + content: &str, + ) -> Option> { + // First check if we already have this statement cached + if let Some(entry) = self.db.get(statement) { + return entry; + } + + // If not cached, try to extract it from the AST + let fn_body = get_sql_fn(ast, content).map(Arc::new); + + // Cache the result and return it + self.db.insert(*statement, fn_body.clone()); + fn_body + } + + pub fn clear_statement(&self, id: &StatementId) { + self.db.remove(id); + + if let Some(child_id) = id.as_child() { + self.db.remove(&child_id); + } + } +} + +/// Extracts SQL function body and its text range from a CreateFunctionStmt node. +/// Returns None if the function is not an SQL function or if the body can't be found. +fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { + let create_fn = match ast { + pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, + _ => return None, + }; + + // Extract language from function options + let language = find_option_value(create_fn, "language")?; + + // Only process SQL functions + if language != "sql" { + return None; + } + + // Extract SQL body from function options + let sql_body = find_option_value(create_fn, "as")?; + + // Find the range of the SQL body in the content + let start = content.find(&sql_body)?; + let end = start + sql_body.len(); + + let range = TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()); + + Some(SQLFunctionBody { + range, + body: sql_body.clone(), + }) +} + +/// Helper function to find a specific option value from function options +fn find_option_value( + create_fn: &pgt_query_ext::protobuf::CreateFunctionStmt, + option_name: &str, +) -> Option { + create_fn + .options + .iter() + .filter_map(|opt_wrapper| opt_wrapper.node.as_ref()) + .find_map(|opt| { + if let pgt_query_ext::NodeEnum::DefElem(def_elem) = opt { + if def_elem.defname == option_name { + def_elem + .arg + .iter() + .filter_map(|arg_wrapper| arg_wrapper.node.as_ref()) + .find_map(|arg| { + if let pgt_query_ext::NodeEnum::String(s) = arg { + Some(s.sval.clone()) + } else { + None + } + }) + } else { + None + } + } else { + None + } + }) +} diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs new file mode 100644 index 00000000..e1c47c4b --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -0,0 +1,41 @@ +pub type RootId = usize; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum StatementId { + Root(RootId), + // StatementId is the same as the root id since we can only have a single sql function body per Root + Child(RootId), +} + +/// Helper struct to generate unique statement ids +pub struct IdGenerator { + next_id: RootId, +} + +impl IdGenerator { + pub fn new() -> Self { + Self { next_id: 0 } + } + + pub fn next(&mut self) -> StatementId { + let id = self.next_id; + self.next_id += 1; + StatementId::Root(id) + } +} + +impl StatementId { + pub fn as_child(&self) -> Option { + match self { + StatementId::Root(id) => Some(StatementId::Child(*id)), + StatementId::Child(_) => None, + } + } + + pub fn create_child(&self) -> StatementId { + match self { + StatementId::Root(id) => StatementId::Child(*id), + StatementId::Child(_) => panic!("Cannot create child from a child statement id"), + } + } +} diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index 09cff74c..e1f15de8 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -1,14 +1,13 @@ -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex}; use dashmap::DashMap; use tree_sitter::InputEdit; -use super::{change::ModifiedStatement, document::Statement}; +use super::{change::ModifiedStatement, statement_identifier::StatementId}; pub struct TreeSitterStore { - db: DashMap>, - - parser: RwLock, + db: DashMap>, + parser: Mutex, } impl TreeSitterStore { @@ -20,24 +19,26 @@ impl TreeSitterStore { TreeSitterStore { db: DashMap::new(), - parser: RwLock::new(parser), + parser: Mutex::new(parser), } } - pub fn get_parse_tree(&self, statement: &Statement) -> Option> { + pub fn get_parse(&self, statement: &StatementId) -> Option> { self.db.get(statement).map(|x| x.clone()) } - pub fn add_statement(&self, statement: &Statement, content: &str) { - let mut guard = self.parser.write().expect("Error reading parser"); - // todo handle error - let tree = guard.parse(content, None).unwrap(); - drop(guard); + pub fn add_statement(&self, statement: &StatementId, content: &str) { + let mut parser = self.parser.lock().expect("Failed to lock parser"); + let tree = parser.parse(content, None).unwrap(); self.db.insert(statement.clone(), Arc::new(tree)); } - pub fn remove_statement(&self, statement: &Statement) { - self.db.remove(statement); + pub fn remove_statement(&self, id: &StatementId) { + self.db.remove(id); + + if let Some(child_id) = id.as_child() { + self.db.remove(&child_id); + } } pub fn modify_statement(&self, change: &ModifiedStatement) { @@ -61,18 +62,17 @@ impl TreeSitterStore { tree.edit(&edit); - let mut guard = self.parser.write().expect("Error reading parser"); + let mut parser = self.parser.lock().expect("Failed to lock parser"); // todo handle error self.db.insert( change.new_stmt.clone(), - Arc::new(guard.parse(&change.new_stmt_text, Some(&tree)).unwrap()), + Arc::new(parser.parse(&change.new_stmt_text, Some(&tree)).unwrap()), ); - drop(guard); } } // Converts character positions and replacement text into a tree-sitter InputEdit -fn edit_from_change( +pub(crate) fn edit_from_change( text: &str, start_char: usize, end_char: usize, From 230c03ec5cc3f47861a224bde70aa090a8f47df1 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Wed, 9 Apr 2025 11:39:45 +0200 Subject: [PATCH 02/20] cleanup --- crates/pgt_query_ext/src/lib.rs | 45 ------ .../src/workspace/server/TODO.md | 137 ------------------ 2 files changed, 182 deletions(-) delete mode 100644 crates/pgt_workspace/src/workspace/server/TODO.md diff --git a/crates/pgt_query_ext/src/lib.rs b/crates/pgt_query_ext/src/lib.rs index 32eab8ef..c1f5fb49 100644 --- a/crates/pgt_query_ext/src/lib.rs +++ b/crates/pgt_query_ext/src/lib.rs @@ -30,48 +30,3 @@ pub fn parse(sql: &str) -> Result { .ok_or_else(|| Error::Parse("Unable to find root node".to_string())) })? } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sql_1() { - let input = "CREATE FUNCTION add(integer, integer) RETURNS integer - AS 'select $1 + $2;' - LANGUAGE SQL - IMMUTABLE - RETURNS NULL ON NULL INPUT;"; - println!("{:#?}", parse(input).unwrap()); - // print after 42 - println!("{:#?}", &input[42..]); - } - - #[test] - fn test_sql_2() { - let input = "CREATE FUNCTION add() RETURNS integer - AS $sql$select 1 + 2;$sql$ - LANGUAGE SQL - IMMUTABLE - RETURNS NULL ON NULL INPUT;"; - println!("{:#?}", parse(input).unwrap()); - // print after 58 - println!("{:#?}", &input[58..]); - } - - #[test] - fn test_plpsql() { - let input = "CREATE FUNCTION add(integer, integer) RETURNS integer - AS $s$ -begin - return $1 + $2; -end -$s$ - LANGUAGE plpgsql - IMMUTABLE - RETURNS NULL ON NULL INPUT;"; - println!("{:#?}", parse(input).unwrap()); - // print after 58 - println!("{:#?}", &input[58..]); - } -} diff --git a/crates/pgt_workspace/src/workspace/server/TODO.md b/crates/pgt_workspace/src/workspace/server/TODO.md deleted file mode 100644 index ff3da8ca..00000000 --- a/crates/pgt_workspace/src/workspace/server/TODO.md +++ /dev/null @@ -1,137 +0,0 @@ -1. Statement Iterator - -```rust -pub struct StatementIterator<'a> { - document: &'a Document, - positions: std::slice::Iter<'a, StatementPos>, - include_text: bool, - include_range: bool, -} - -impl<'a> StatementIterator<'a> { - fn new(document: &'a Document) -> Self { - Self { - document, - positions: document.positions.iter(), - include_text: false, - include_range: false, - } - } - - pub fn with_text(mut self) -> Self { - self.include_text = true; - self - } - - pub fn with_range(mut self) -> Self { - self.include_range = true; - self - } -} - -pub enum StatementData<'a> { - Statement(Statement), - WithText(Statement, &'a str), - WithRange(Statement, &'a TextRange), - WithTextAndRange(Statement, &'a TextRange, &'a str), -} - -impl<'a> Iterator for StatementIterator<'a> { - type Item = StatementData<'a>; - - fn next(&mut self) -> Option { - self.positions.next().map(|(id, range)| { - let statement = Statement { - id: *id, - path: self.document.path.clone(), - }; - - match (self.include_text, self.include_range) { - (false, false) => StatementData::Statement(statement), - (true, false) => { - let text = &self.document.content[range.start().into()..range.end().into()]; - StatementData::WithText(statement, text) - }, - (false, true) => StatementData::WithRange(statement, range), - (true, true) => { - let text = &self.document.content[range.start().into()..range.end().into()]; - StatementData::WithTextAndRange(statement, range, text) - }, - } - }) - } -} -pub struct StatementIterator<'a> { - document: &'a Document, - positions: std::slice::Iter<'a, StatementPos>, - include_text: bool, - include_range: bool, -} - -impl<'a> StatementIterator<'a> { - fn new(document: &'a Document) -> Self { - Self { - document, - positions: document.positions.iter(), - include_text: false, - include_range: false, - } - } - - pub fn with_text(mut self) -> Self { - self.include_text = true; - self - } - - pub fn with_range(mut self) -> Self { - self.include_range = true; - self - } -} - -pub enum StatementData<'a> { - Statement(Statement), - WithText(Statement, &'a str), - WithRange(Statement, &'a TextRange), - WithTextAndRange(Statement, &'a TextRange, &'a str), - // with ast - // with cst - // include substatements -} - -impl<'a> Iterator for StatementIterator<'a> { - type Item = StatementData<'a>; - - fn next(&mut self) -> Option { - self.positions.next().map(|(id, range)| { - let statement = Statement { - id: *id, - path: self.document.path.clone(), - }; - - match (self.include_text, self.include_range) { - (false, false) => StatementData::Statement(statement), - (true, false) => { - let text = &self.document.content[range.start().into()..range.end().into()]; - StatementData::WithText(statement, text) - }, - (false, true) => StatementData::WithRange(statement, range), - (true, true) => { - let text = &self.document.content[range.start().into()..range.end().into()]; - StatementData::WithTextAndRange(statement, range, text) - }, - } - }) - } -} -``` - -2. Parser -- one instance per document -- hold ts parser -- has "inner" document -- holds parse results -- exposes unified api to find statements with data -- reason for putting this together is that we dont want to manually fiddle with sub statements client side --> i want to do doc.statements() and get all statements WITH sub statements. - From 773e54281e83703ae7b7de82892bb792ad9041de Mon Sep 17 00:00:00 2001 From: psteinroe Date: Thu, 10 Apr 2025 12:02:37 +0200 Subject: [PATCH 03/20] finish --- crates/pgt_completions/src/complete.rs | 2 +- crates/pgt_completions/src/context.rs | 27 +- crates/pgt_completions/src/test_helper.rs | 2 +- crates/pgt_typecheck/src/diagnostics.rs | 20 +- crates/pgt_typecheck/src/lib.rs | 2 +- crates/pgt_workspace/src/workspace/server.rs | 409 ++++++------------ .../src/workspace/server/parser.rs | 183 +++++++- .../src/workspace/server/tree_sitter.rs | 12 +- 8 files changed, 323 insertions(+), 334 deletions(-) diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index fb00aeaf..ed51c653 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -14,7 +14,7 @@ pub struct CompletionParams<'a> { pub position: TextSize, pub schema: &'a pgt_schema_cache::SchemaCache, pub text: String, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, } pub fn complete(params: CompletionParams) -> Vec { diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 8b12742d..775b8870 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -50,7 +50,7 @@ impl TryFrom for ClauseType { pub(crate) struct CompletionContext<'a> { pub ts_node: Option>, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, pub text: &'a str, pub schema_cache: &'a SchemaCache, pub position: usize, @@ -85,10 +85,7 @@ impl<'a> CompletionContext<'a> { } fn gather_info_from_ts_queries(&mut self) { - let tree = match self.tree.as_ref() { - None => return, - Some(t) => t, - }; + let tree = self.tree; let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; @@ -126,11 +123,7 @@ impl<'a> CompletionContext<'a> { } fn gather_tree_context(&mut self) { - if self.tree.is_none() { - return; - } - - let mut cursor = self.tree.as_ref().unwrap().root_node().walk(); + let mut cursor = self.tree.root_node().walk(); /* * The head node of any treesitter tree is always the "PROGRAM" node. @@ -262,7 +255,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -294,7 +287,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -328,7 +321,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -353,7 +346,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -381,7 +374,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -407,7 +400,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -432,7 +425,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index a54aacbd..58e9baf7 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -70,7 +70,7 @@ pub(crate) fn get_test_params<'a>( CompletionParams { position: (position as u32).into(), schema: schema_cache, - tree: Some(tree), + tree, text, } } diff --git a/crates/pgt_typecheck/src/diagnostics.rs b/crates/pgt_typecheck/src/diagnostics.rs index b443dcc9..8fd92da2 100644 --- a/crates/pgt_typecheck/src/diagnostics.rs +++ b/crates/pgt_typecheck/src/diagnostics.rs @@ -96,7 +96,7 @@ impl Advices for TypecheckAdvices { pub(crate) fn create_type_error( pg_err: &PgDatabaseError, - ts: Option<&tree_sitter::Tree>, + ts: &tree_sitter::Tree, ) -> TypecheckDiagnostic { let position = pg_err.position().and_then(|pos| match pos { sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1), @@ -104,16 +104,14 @@ pub(crate) fn create_type_error( }); let range = position.and_then(|pos| { - ts.and_then(|tree| { - tree.root_node() - .named_descendant_for_byte_range(pos, pos) - .map(|node| { - TextRange::new( - node.start_byte().try_into().unwrap(), - node.end_byte().try_into().unwrap(), - ) - }) - }) + ts.root_node() + .named_descendant_for_byte_range(pos, pos) + .map(|node| { + TextRange::new( + node.start_byte().try_into().unwrap(), + node.end_byte().try_into().unwrap(), + ) + }) }); let severity = match pg_err.severity() { diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 4554689c..9311bb8e 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -13,7 +13,7 @@ pub struct TypecheckParams<'a> { pub conn: &'a PgPool, pub sql: &'a str, pub ast: &'a pgt_query_ext::NodeEnum, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, } #[derive(Debug, Clone)] diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 8f464b1d..1d25a58d 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -1,30 +1,25 @@ -use std::{ - fs, - panic::RefUnwindSafe, - path::Path, - sync::{Arc, Mutex, RwLock}, -}; +use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; use analyser::AnalyserVisitorBuilder; use async_helper::run_async; -use change::StatementChange; use dashmap::DashMap; use db_connection::DbConnection; -pub(crate) use document::StatementId; -use document::{Document, Statement}; +use document::Document; use futures::{StreamExt, stream}; -use pg_query::PgQueryStore; +use parser::{ + AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, + GetCompletionsMapper, Parser, SyncDiagnosticsMapper, +}; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; -use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; +use pgt_diagnostics::{ + Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, +}; use pgt_fs::{ConfigName, PgTPath}; use pgt_typecheck::TypecheckParams; use schema_cache_manager::SchemaCacheManager; -use sql_function::SQLFunctionBodyStore; use sqlx::Executor; use tracing::info; -use tree_sitter::TreeSitterStore; -use tree_sitter_parser::TreeSitterParserStore; use crate::{ WorkspaceError, @@ -65,13 +60,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - /// Stores the document (text content + version number) associated with a URL - documents: DashMap, - - tree_sitter: TreeSitterStore, - pg_query: PgQueryStore, - sql_functions: SQLFunctionBodyStore, - ts_parser: TreeSitterParserStore, + parsers: DashMap, connection: RwLock, } @@ -93,13 +82,9 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - documents: DashMap::default(), - tree_sitter: TreeSitterStore::new(), - pg_query: PgQueryStore::new(), + parsers: DashMap::default(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), - sql_functions: SQLFunctionBodyStore::new(), - ts_parser: TreeSitterParserStore::new(), } } @@ -195,40 +180,19 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - let doc = Document::new(params.path.clone(), params.content, params.version); - - let doc_parser = self.ts_parser.get_parser(params.path.clone()); - let mut parser = doc_parser.lock().expect("Error locking parser"); - doc.iter_statements_with_text().for_each(|(stmt, content)| { - self.tree_sitter.add_statement(&mut parser, &stmt, content); - self.pg_query.add_statement(&stmt, content); - if let Some(ast) = self.pg_query.get_ast(&stmt) { - self.sql_functions - .add_statement(&mut parser, &ast, &stmt, content); - } - }); - drop(parser); - - self.documents.insert(params.path, doc); + self.parsers + .entry(params.path.clone()) + .or_insert_with(|| Parser::new(params.path.clone(), params.content, params.version)); Ok(()) } /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - let (_, doc) = self - .documents + self.parsers .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; - for stmt in doc.iter_statements() { - self.tree_sitter.remove_statement(&stmt); - self.pg_query.remove_statement(&stmt); - self.sql_functions.remove_statement(&stmt); - } - - self.ts_parser.remove_parser(¶ms.path); - Ok(()) } @@ -238,69 +202,16 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut doc = self - .documents + let mut parser = self + .parsers .entry(params.path.clone()) - .or_insert(Document::new( + .or_insert(Parser::new( params.path.clone(), "".to_string(), params.version, )); - let doc_parser = self.ts_parser.get_parser(params.path.clone()); - let mut parser = doc_parser.lock().expect("Error locking parser"); - for c in &doc.apply_file_change(¶ms) { - match c { - StatementChange::Added(added) => { - tracing::debug!( - "Adding statement: id:{:?}, path:{:?}, text:{:?}", - added.stmt.id, - added.stmt.path.as_os_str().to_str(), - added.text - ); - self.tree_sitter - .add_statement(&mut parser, &added.stmt, &added.text); - self.pg_query.add_statement(&added.stmt, &added.text); - if let Some(ast) = self.pg_query.get_ast(&added.stmt) { - self.sql_functions.add_statement( - &mut parser, - &ast, - &added.stmt, - &added.text, - ); - } - } - StatementChange::Deleted(s) => { - tracing::debug!( - "Deleting statement: id:{:?}, path:{:?}", - s.id, - s.path.as_os_str() - ); - self.tree_sitter.remove_statement(s); - self.pg_query.remove_statement(s); - self.sql_functions.remove_statement(s); - } - StatementChange::Modified(s) => { - tracing::debug!( - "Modifying statement with id {:?} (new id {:?}) in {:?}. Range {:?}, Changed from '{:?}' to '{:?}', changed text: {:?}", - s.old_stmt.id, - s.new_stmt.id, - s.old_stmt.path.as_os_str().to_str(), - s.change_range, - s.old_stmt_text, - s.new_stmt_text, - s.change_text - ); - - self.tree_sitter.modify_statement(&mut parser, s); - self.pg_query.modify_statement(s); - if let Some(ast) = self.pg_query.get_ast(&s.new_stmt) { - self.sql_functions.modify_statement(&mut parser, &ast, s); - } - } - } - } - drop(parser); + parser.apply_change(params); Ok(()) } @@ -311,10 +222,10 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .documents + .parsers .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - Ok(document.content.clone()) + Ok(document.get_document_content().to_string()) } fn is_path_ignored(&self, params: IsPathIgnoredParams) -> Result { @@ -325,17 +236,11 @@ impl Workspace for WorkspaceServer { &self, params: code_actions::CodeActionsParams, ) -> Result { - let doc = self - .documents + let parser = self + .parsers .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let eligible_statements = doc - .iter_statements_with_text_and_range() - .filter(|(_, range, _)| range.contains(params.cursor_position)); - - let mut actions: Vec = vec![]; - let settings = self .settings .read() @@ -347,20 +252,26 @@ impl Workspace for WorkspaceServer { Some("Statement execution not allowed against database.".into()) }; - for (stmt, _, txt) in eligible_statements { - let title = format!( - "Execute Statement: {}...", - txt.chars().take(50).collect::() - ); - - actions.push(CodeAction { - title, - kind: CodeActionKind::Command(CommandAction { - category: CommandActionCategory::ExecuteStatement(stmt.id), - }), - disabled_reason: disabled_reason.clone(), - }); - } + let actions = parser + .iter_with_filter( + DefaultMapper, + CursorPositionFilter::new(params.cursor_position), + ) + .map(|(stmt, _, txt)| { + let title = format!( + "Execute Statement: {}...", + txt.chars().take(50).collect::() + ); + + CodeAction { + title, + kind: CodeActionKind::Command(CommandAction { + category: CommandActionCategory::ExecuteStatement(stmt), + }), + disabled_reason: disabled_reason.clone(), + } + }) + .collect(); Ok(CodeActionsResult { actions }) } @@ -369,31 +280,25 @@ impl Workspace for WorkspaceServer { &self, params: ExecuteStatementParams, ) -> Result { - let doc = self - .documents + let parser = self + .parsers .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - if self - .pg_query - .get_ast(&Statement { - path: params.path, - id: params.statement_id, - }) - .is_none() - { + let stmt = parser.find(params.statement_id, ExecuteStatementMapper); + + if stmt.is_none() { return Ok(ExecuteStatementResult { - message: "Statement is invalid.".into(), + message: "Statement was not found in document.".into(), }); }; - let sql: String = match doc.get_txt(params.statement_id) { - Some(txt) => txt, - None => { - return Ok(ExecuteStatementResult { - message: "Statement was not found in document.".into(), - }); - } + let (id, range, content, ast) = stmt.unwrap(); + + if ast.is_none() { + return Ok(ExecuteStatementResult { + message: "Statement is invalid.".into(), + }); }; let conn = self.connection.read().unwrap(); @@ -406,7 +311,7 @@ impl Workspace for WorkspaceServer { } }; - let result = run_async(async move { pool.execute(sqlx::query(&sql)).await })??; + let result = run_async(async move { pool.execute(sqlx::query(&content)).await })??; Ok(ExecuteStatementResult { message: format!( @@ -420,13 +325,6 @@ impl Workspace for WorkspaceServer { &self, params: PullDiagnosticsParams, ) -> Result { - // get all statements form the requested document and pull diagnostics out of every - // source - let doc = self - .documents - .get(¶ms.path) - .ok_or(WorkspaceError::not_found())?; - let settings = self.settings(); // create analyser for this run @@ -450,7 +348,14 @@ impl Workspace for WorkspaceServer { filter, }); - let mut diagnostics: Vec = doc.diagnostics().to_vec(); + let parser = self + .parsers + .get(¶ms.path) + .ok_or(WorkspaceError::not_found())?; + + let mut diagnostics: Vec = parser.document_diagnostics().to_vec(); + + // TODO: run this in parallel with rayon based on rayon.count() if let Some(pool) = self .connection @@ -458,44 +363,20 @@ impl Workspace for WorkspaceServer { .expect("DbConnection RwLock panicked") .get_pool() { - let typecheck_params: Vec<_> = doc - .iter_statements_with_text_and_range() - .flat_map(|(stmt, range, text)| { - let ast = self.pg_query.get_ast(&stmt); - let tree = self.tree_sitter.get_parse_tree(&stmt); - - let mut res = vec![(text.to_string(), ast, tree, *range)]; - - if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) { - // fn_body range is within the statement -> adjust it to be relative to the - // document instead (as the other ranges) - let fn_range = fn_body.range + range.start(); - res.push(( - fn_body.body.clone(), - fn_body.ast.clone().map(Arc::new), - Some(Arc::new(fn_body.cst.clone())), - fn_range, - )); - } - - res - }) - .collect(); - - // run diagnostics for each statement in parallel if its mostly i/o work let path_clone = params.path.clone(); + let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); let async_results = run_async(async move { - stream::iter(typecheck_params) - .map(|(text, ast, tree, range)| { + stream::iter(input) + .map(|(_id, range, content, ast, cst)| { let pool = pool.clone(); let path = path_clone.clone(); async move { if let Some(ast) = ast { pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: &text, + sql: &content, ast: &ast, - tree: tree.as_deref(), + tree: &cst, }) .await .map(|d| { @@ -519,66 +400,49 @@ impl Workspace for WorkspaceServer { } } - diagnostics.extend(doc.iter_statements_with_range().flat_map(|(stmt, r)| { - let mut stmt_diagnostics = self.pg_query.get_diagnostics(&stmt); - - let ast = self.pg_query.get_ast(&stmt); - - if let Some(ast) = ast { - stmt_diagnostics.extend( - analyser - .run(AnalyserContext { root: &ast }) - .into_iter() - .map(SDiagnostic::new) - .collect::>(), - ); - } + diagnostics.extend(parser.iter(SyncDiagnosticsMapper).flat_map( + |(_id, range, ast, diag)| { + let mut errors: Vec = vec![]; - if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) { - if let Some(fn_diag) = &fn_body.syntax_diagnostics { - stmt_diagnostics.push(SDiagnostic::new( - fn_diag - .clone() - .with_file_path(params.path.as_path().display().to_string()) - .with_file_span(fn_body.range + r.start()), - )); + if let Some(diag) = diag { + errors.push(diag.into()); } - if let Some(ast) = &fn_body.ast { - stmt_diagnostics.extend( + if let Some(ast) = ast { + errors.extend( analyser - .run(AnalyserContext { root: ast }) + .run(AnalyserContext { root: &ast }) .into_iter() - .map(SDiagnostic::new) - .collect::>(), + .map(Error::from) + .collect::>(), ); } - } - stmt_diagnostics - .into_iter() - .map(|d| { - let severity = d - .category() - .filter(|category| category.name().starts_with("lint/")) - .map_or_else( - || d.severity(), - |category| { - settings - .as_ref() - .get_severity_from_rule_code(category) - .unwrap_or(Severity::Warning) - }, - ); - - SDiagnostic::new( - d.with_file_path(params.path.as_path().display().to_string()) - .with_file_span(r) - .with_severity(severity), - ) - }) - .collect::>() - })); + errors + .into_iter() + .map(|d| { + let severity = d + .category() + .filter(|category| category.name().starts_with("lint/")) + .map_or_else( + || d.severity(), + |category| { + settings + .as_ref() + .get_severity_from_rule_code(category) + .unwrap_or(Severity::Warning) + }, + ); + + SDiagnostic::new( + d.with_file_path(params.path.as_path().display().to_string()) + .with_file_span(range) + .with_severity(severity), + ) + }) + .collect::>() + }, + )); let errors = diagnostics .iter() @@ -601,60 +465,35 @@ impl Workspace for WorkspaceServer { &self, params: GetCompletionsParams, ) -> Result { - let pool = match self.connection.read().unwrap().get_pool() { - Some(pool) => pool, - None => return Ok(CompletionsResult::default()), - }; - - let doc = self - .documents + let parser = self + .parsers .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let (statement, stmt_range, text) = match doc - .iter_statements_with_text_and_range() - .find(|(_, r, _)| r.contains(params.position)) - { - Some(s) => s, + let pool = match self.connection.read().unwrap().get_pool() { + Some(pool) => pool, None => return Ok(CompletionsResult::default()), }; - // `offset` is the position in the document, - // but we need the position within the *statement*. - let position = params.position - stmt_range.start(); - - let tree = self.tree_sitter.get_parse_tree(&statement); - - tracing::debug!( - "Found the statement. We're looking for position {:?}. Statement Range {:?} to {:?}. Statement: {:?}", - position, - stmt_range.start(), - stmt_range.end(), - text - ); - let schema_cache = self.schema_cache.load(pool)?; - let mut items = pgt_completions::complete(pgt_completions::CompletionParams { - position, - schema: schema_cache.as_ref(), - tree: tree.as_deref(), - text: text.to_string(), - }); - - if let Some(f) = self.sql_functions.get_function_body(&statement) { - let fn_text = f.body.clone(); - let fn_range = f.range + stmt_range.start(); - - items.extend(pgt_completions::complete( - pgt_completions::CompletionParams { - position: position - fn_range.start(), + let items = parser + .iter_with_filter( + GetCompletionsMapper, + CursorPositionFilter::new(params.position), + ) + .flat_map(|(_id, range, content, cst)| { + // `offset` is the position in the document, + // but we need the position within the *statement*. + let position = params.position - range.start(); + pgt_completions::complete(pgt_completions::CompletionParams { + position, schema: schema_cache.as_ref(), - tree: Some(&f.cst), - text: fn_text, - }, - )); - } + tree: &cst, + text: content.to_string(), + }) + }) + .collect(); Ok(CompletionsResult { items }) } diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index 282ef65a..633dac59 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -1,16 +1,18 @@ use std::sync::Arc; +use pgt_diagnostics::serde::Diagnostic as SDiagnostic; use pgt_fs::PgTPath; -use pgt_text_size::TextRange; +use pgt_query_ext::diagnostics::SyntaxDiagnostic; +use pgt_text_size::{TextRange, TextSize}; use crate::workspace::ChangeFileParams; use super::{ - StatementId, change::StatementChange, document::{Document, StatementIterator}, pg_query::PgQueryStore, sql_function::SQLFunctionBodyStore, + statement_identifier::StatementId, tree_sitter::TreeSitterStore, }; @@ -86,6 +88,21 @@ impl Parser { } } + pub fn get_document_content(&self) -> &str { + &self.doc.content + } + + pub fn document_diagnostics(&self) -> &Vec { + &self.doc.diagnostics + } + + pub fn find<'a, M>(&'a self, id: StatementId, mapper: M) -> Option + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, IdFilter::new(id)).next() + } + pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, DefaultFilter> where M: StatementMapper<'a>, @@ -100,6 +117,10 @@ impl Parser { { ParseIterator::new(self, mapper, filter) } + + pub fn count<'a>(&'a self) -> usize { + self.iter(DefaultMapper).count() + } } pub trait StatementMapper<'a> { @@ -109,13 +130,13 @@ pub trait StatementMapper<'a> { &self, parser: &'a Parser, id: StatementId, - range: &TextRange, - content: &str, + range: &'a TextRange, + content: &'a str, ) -> Self::Output; } pub trait StatementFilter<'a> { - fn apply(&self, range: &TextRange) -> bool; + fn apply(&self, id: StatementId, range: &TextRange) -> bool; } pub struct ParseIterator<'a, M, F> { @@ -148,7 +169,7 @@ where fn next(&mut self) -> Option { // First check if we have any pending sub-statements to process if let Some((id, range, content)) = self.pending_sub_statements.pop() { - if self.filter.apply(&range) { + if self.filter.apply(id, &range) { return Some(self.mapper.map(self.parser, id, &range, &content)); } // If the sub-statement doesn't pass the filter, continue to the next item @@ -178,7 +199,7 @@ where } // Return the current statement if it passes the filter - if self.filter.apply(&range) { + if self.filter.apply(root_id, &range) { return Some(self.mapper.map(self.parser, root_id, &range, content)); } @@ -190,25 +211,155 @@ where } } -struct WithAst; -impl<'a> StatementMapper<'a> for WithAst { - type Output = (StatementId, TextRange, Option>); +pub struct DefaultMapper; +impl<'a> StatementMapper<'a> for DefaultMapper { + type Output = (StatementId, &'a TextRange, &'a str); + + fn map( + &self, + _parser: &'a Parser, + id: StatementId, + range: &'a TextRange, + content: &'a str, + ) -> Self::Output { + (id, range, content) + } +} + +pub struct ExecuteStatementMapper; +impl<'a> StatementMapper<'a> for ExecuteStatementMapper { + type Output = ( + StatementId, + &'a TextRange, + &'a str, + Option, + ); + + fn map( + &self, + parser: &'a Parser, + id: StatementId, + range: &'a TextRange, + content: &'a str, + ) -> Self::Output { + let ast_result = parser.ast_db.load_parse(&id, content); + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + (id, range, content, ast_option) + } +} + +pub struct AsyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { + type Output = ( + StatementId, + TextRange, + String, + Option, + Arc, + ); + + fn map( + &self, + parser: &'a Parser, + id: StatementId, + range: &'a TextRange, + content: &'a str, + ) -> Self::Output { + let ast_result = parser.ast_db.load_parse(&id, content); + + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + let cst_result = parser.cst_db.load_parse(&id, content); + + (id, *range, content.to_string(), ast_option, cst_result) + } +} + +pub struct SyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { + type Output = ( + StatementId, + &'a TextRange, + Option, + Option, + ); + + fn map( + &self, + parser: &'a Parser, + id: StatementId, + range: &'a TextRange, + content: &'a str, + ) -> Self::Output { + let ast_result = parser.ast_db.load_parse(&id, content); + + let (ast_option, diagnostics) = match &*ast_result { + Ok(node) => (Some(node.clone()), None), + Err(diag) => (None, Some(diag.clone())), + }; + + (id, range, ast_option, diagnostics) + } +} + +pub struct GetCompletionsMapper; +impl<'a> StatementMapper<'a> for GetCompletionsMapper { + type Output = (StatementId, &'a TextRange, &'a str, Arc); fn map( &self, parser: &'a Parser, id: StatementId, - range: &TextRange, - _content: &str, + range: &'a TextRange, + content: &'a str, ) -> Self::Output { - let ast = parser.ast_db.get_ast(&id); - (id, *range, ast) + let cst_result = parser.cst_db.load_parse(&id, content); + (id, range, content, cst_result) } } -struct DefaultFilter; +pub struct DefaultFilter; impl<'a> StatementFilter<'a> for DefaultFilter { - fn apply(&self, _range: &TextRange) -> bool { + fn apply(&self, _id: StatementId, _range: &TextRange) -> bool { true } } + +pub struct CursorPositionFilter { + pos: TextSize, +} + +impl CursorPositionFilter { + pub fn new(pos: TextSize) -> Self { + Self { pos } + } +} + +impl<'a> StatementFilter<'a> for CursorPositionFilter { + fn apply(&self, _id: StatementId, range: &TextRange) -> bool { + range.contains(self.pos) + } +} + +pub struct IdFilter { + id: StatementId, +} + +impl IdFilter { + pub fn new(id: StatementId) -> Self { + Self { id } + } +} + +impl<'a> StatementFilter<'a> for IdFilter { + fn apply(&self, id: StatementId, _range: &TextRange) -> bool { + id == self.id + } +} diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index e1f15de8..fdca5bdd 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -23,8 +23,16 @@ impl TreeSitterStore { } } - pub fn get_parse(&self, statement: &StatementId) -> Option> { - self.db.get(statement).map(|x| x.clone()) + pub fn load_parse(&self, statement: &StatementId, content: &str) -> Arc { + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; + } + + let mut parser = self.parser.lock().expect("Failed to lock parser"); + let tree = Arc::new(parser.parse(content, None).unwrap()); + self.db.insert(statement.clone(), tree.clone()); + + tree } pub fn add_statement(&self, statement: &StatementId, content: &str) { From 944198f9c0c64442581174e9607ba97f17979d22 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Thu, 10 Apr 2025 12:14:00 +0200 Subject: [PATCH 04/20] progress on maping borrow checker happy --- .../src/workspace/server/change.rs | 20 +++++---- .../src/workspace/server/parser.rs | 42 +++++++++---------- .../src/workspace/server/sql_function.rs | 8 ++-- .../workspace/server/statement_identifier.rs | 13 ++++++ 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 0b19e5db..0760cf95 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -3,7 +3,7 @@ use std::ops::{Add, Sub}; use crate::workspace::{ChangeFileParams, ChangeParams}; -use super::{Document, Statement, StatementId, document, statement_identifier::StatementId}; +use super::{Document, document, statement_identifier::StatementId}; #[derive(Debug, PartialEq, Eq)] pub enum StatementChange { @@ -32,7 +32,7 @@ pub struct ModifiedStatement { impl StatementChange { #[allow(dead_code)] - pub fn statement(&self) -> &Statement { + pub fn statement(&self) -> &StatementId { match self { StatementChange::Added(stmt) => &stmt.stmt, StatementChange::Deleted(stmt) => stmt, @@ -259,16 +259,16 @@ impl Document { if new_ranges.len() == 1 { let affected_idx = affected_indices[0]; let new_range = new_ranges[0].add(affected_range.start()); - let (old_id, old_range) = self.positions[affected_idx]; + let (old_id, old_range) = self.positions[affected_idx].clone(); // move all statements after the afffected range self.move_ranges(old_range.end(), change.diff_size(), change.is_addition()); let new_id = self.id_generator.next(); - self.positions[affected_idx] = (new_id, new_range); + self.positions[affected_idx] = (new_id.clone(), new_range); changed.push(StatementChange::Modified(ModifiedStatement { - old_stmt: old_id, + old_stmt: old_id.clone(), old_stmt_text: self.content[old_range].to_string(), new_stmt: new_id, @@ -305,15 +305,19 @@ impl Document { // delete and add new ones if let Some(next_index) = next_index { - changed.push(StatementChange::Deleted(self.positions[next_index].0)); + changed.push(StatementChange::Deleted( + self.positions[next_index].0.clone(), + )); self.positions.remove(next_index); } for idx in affected_indices.iter().rev() { - changed.push(StatementChange::Deleted(self.positions[*idx].0)); + changed.push(StatementChange::Deleted(self.positions[*idx].0.clone())); self.positions.remove(*idx); } if let Some(prev_index) = prev_index { - changed.push(StatementChange::Deleted(self.positions[prev_index].0)); + changed.push(StatementChange::Deleted( + self.positions[prev_index].0.clone(), + )); self.positions.remove(prev_index); } diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index 633dac59..39d06da6 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -130,13 +130,13 @@ pub trait StatementMapper<'a> { &self, parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output; } pub trait StatementFilter<'a> { - fn apply(&self, id: StatementId, range: &TextRange) -> bool; + fn apply(&self, id: &StatementId, range: &TextRange) -> bool; } pub struct ParseIterator<'a, M, F> { @@ -169,8 +169,8 @@ where fn next(&mut self) -> Option { // First check if we have any pending sub-statements to process if let Some((id, range, content)) = self.pending_sub_statements.pop() { - if self.filter.apply(id, &range) { - return Some(self.mapper.map(self.parser, id, &range, &content)); + if self.filter.apply(&id, &range) { + return Some(self.mapper.map(self.parser, id, range, &content)); } // If the sub-statement doesn't pass the filter, continue to the next item return self.next(); @@ -181,7 +181,7 @@ where if let Some((root_id, range, content)) = next_statement { // If we should include sub-statements and this statement has an AST - if let Ok(ast) = *self.parser.ast_db.load_parse(&root_id, &content) { + if let Ok(ast) = self.parser.ast_db.load_parse(&root_id, &content).as_ref() { // Check if this is a SQL function definition with a body if let Some(sub_statement) = self .parser @@ -199,8 +199,8 @@ where } // Return the current statement if it passes the filter - if self.filter.apply(root_id, &range) { - return Some(self.mapper.map(self.parser, root_id, &range, content)); + if self.filter.apply(&root_id, &range) { + return Some(self.mapper.map(self.parser, root_id, range, content)); } // If the current statement doesn't pass the filter, try the next one @@ -213,13 +213,13 @@ where pub struct DefaultMapper; impl<'a> StatementMapper<'a> for DefaultMapper { - type Output = (StatementId, &'a TextRange, &'a str); + type Output = (StatementId, TextRange, &'a str); fn map( &self, _parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output { (id, range, content) @@ -230,7 +230,7 @@ pub struct ExecuteStatementMapper; impl<'a> StatementMapper<'a> for ExecuteStatementMapper { type Output = ( StatementId, - &'a TextRange, + TextRange, &'a str, Option, ); @@ -239,7 +239,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { &self, parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output { let ast_result = parser.ast_db.load_parse(&id, content); @@ -266,7 +266,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { &self, parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output { let ast_result = parser.ast_db.load_parse(&id, content); @@ -278,7 +278,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { let cst_result = parser.cst_db.load_parse(&id, content); - (id, *range, content.to_string(), ast_option, cst_result) + (id, range, content.to_string(), ast_option, cst_result) } } @@ -286,7 +286,7 @@ pub struct SyncDiagnosticsMapper; impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { type Output = ( StatementId, - &'a TextRange, + TextRange, Option, Option, ); @@ -295,7 +295,7 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { &self, parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output { let ast_result = parser.ast_db.load_parse(&id, content); @@ -311,13 +311,13 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { pub struct GetCompletionsMapper; impl<'a> StatementMapper<'a> for GetCompletionsMapper { - type Output = (StatementId, &'a TextRange, &'a str, Arc); + type Output = (StatementId, TextRange, &'a str, Arc); fn map( &self, parser: &'a Parser, id: StatementId, - range: &'a TextRange, + range: TextRange, content: &'a str, ) -> Self::Output { let cst_result = parser.cst_db.load_parse(&id, content); @@ -327,7 +327,7 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { pub struct DefaultFilter; impl<'a> StatementFilter<'a> for DefaultFilter { - fn apply(&self, _id: StatementId, _range: &TextRange) -> bool { + fn apply(&self, _id: &StatementId, _range: &TextRange) -> bool { true } } @@ -343,7 +343,7 @@ impl CursorPositionFilter { } impl<'a> StatementFilter<'a> for CursorPositionFilter { - fn apply(&self, _id: StatementId, range: &TextRange) -> bool { + fn apply(&self, _id: &StatementId, range: &TextRange) -> bool { range.contains(self.pos) } } @@ -359,7 +359,7 @@ impl IdFilter { } impl<'a> StatementFilter<'a> for IdFilter { - fn apply(&self, id: StatementId, _range: &TextRange) -> bool { - id == self.id + fn apply(&self, id: &StatementId, _range: &TextRange) -> bool { + *id == self.id } } diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 3378c563..952063e3 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use dashmap::DashMap; use pgt_text_size::TextRange; -use super::StatementId; +use super::statement_identifier::StatementId; pub struct SQLFunctionBody { pub range: TextRange, @@ -26,15 +26,15 @@ impl SQLFunctionBodyStore { content: &str, ) -> Option> { // First check if we already have this statement cached - if let Some(entry) = self.db.get(statement) { - return entry; + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; } // If not cached, try to extract it from the AST let fn_body = get_sql_fn(ast, content).map(Arc::new); // Cache the result and return it - self.db.insert(*statement, fn_body.clone()); + self.db.insert(statement.clone(), fn_body.clone()); fn_body } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index e1c47c4b..1926664d 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + pub type RootId = usize; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -39,3 +41,14 @@ impl StatementId { } } } + +impl Deref for StatementId { + type Target = RootId; + + fn deref(&self) -> &Self::Target { + match self { + StatementId::Root(id) => id, + StatementId::Child(id) => id, + } + } +} From 3771c3183481399f140f2d38fe927d025880417b Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Apr 2025 12:51:56 +0200 Subject: [PATCH 05/20] make the lord of ownership happy --- crates/pgt_lsp/src/handlers/code_actions.rs | 6 +-- crates/pgt_typecheck/tests/diagnostics.rs | 4 +- crates/pgt_workspace/src/workspace.rs | 2 +- crates/pgt_workspace/src/workspace/server.rs | 6 ++- .../src/workspace/server/change.rs | 32 +++++--------- .../src/workspace/server/document.rs | 10 ----- .../src/workspace/server/parser.rs | 42 ++++++++++--------- .../workspace/server/statement_identifier.rs | 10 ++++- 8 files changed, 52 insertions(+), 60 deletions(-) diff --git a/crates/pgt_lsp/src/handlers/code_actions.rs b/crates/pgt_lsp/src/handlers/code_actions.rs index 0d124cfc..cf1ff9bd 100644 --- a/crates/pgt_lsp/src/handlers/code_actions.rs +++ b/crates/pgt_lsp/src/handlers/code_actions.rs @@ -43,7 +43,7 @@ pub fn get_actions( title: title.clone(), command: command_id, arguments: Some(vec![ - serde_json::Value::Number(stmt_id.into()), + serde_json::to_value(&stmt_id).unwrap(), serde_json::to_value(&url).unwrap(), ]), } @@ -81,7 +81,7 @@ pub async fn execute_command( match command.as_str() { "pgt.executeStatement" => { - let id: usize = serde_json::from_value(params.arguments[0].clone())?; + let statement_id = serde_json::from_value::(params.arguments[0].clone())?; let doc_url: lsp_types::Url = serde_json::from_value(params.arguments[1].clone())?; let path = session.file_path(&doc_url)?; @@ -89,7 +89,7 @@ pub async fn execute_command( let result = session .workspace .execute_statement(ExecuteStatementParams { - statement_id: id, + statement_id, path, })?; diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index d0e53b15..46daa8a1 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -21,14 +21,14 @@ async fn test(name: &str, query: &str, setup: &str) { .expect("Error loading sql language"); let root = pgt_query_ext::parse(query).unwrap(); - let tree = parser.parse(query, None); + let tree = parser.parse(query, None).unwrap(); let conn = &test_db; let result = check_sql(TypecheckParams { conn, sql: query, ast: &root, - tree: tree.as_ref(), + tree: &tree, }) .await; diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 4a503d5d..681ab95f 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -21,7 +21,7 @@ use crate::{ mod client; mod server; -pub(crate) use server::StatementId; +pub use server::StatementId; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 1d25a58d..a770d33a 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -40,6 +40,8 @@ use super::{ Workspace, }; +pub use statement_identifier::StatementId; + mod analyser; mod async_helper; mod change; @@ -293,7 +295,7 @@ impl Workspace for WorkspaceServer { }); }; - let (id, range, content, ast) = stmt.unwrap(); + let (_id, _range, content, ast) = stmt.unwrap(); if ast.is_none() { return Ok(ExecuteStatementResult { @@ -490,7 +492,7 @@ impl Workspace for WorkspaceServer { position, schema: schema_cache.as_ref(), tree: &cst, - text: content.to_string(), + text: content, }) }) .collect(); diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 0760cf95..0b09a3ca 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -808,11 +808,11 @@ mod tests { assert_eq!(changed.len(), 4); assert!(matches!( changed[0], - StatementChange::Deleted(Statement { id: 1, .. }) + StatementChange::Deleted(StatementId::Root(1)) )); assert!(matches!( changed[1], - StatementChange::Deleted(Statement { id: 0, .. }) + StatementChange::Deleted(StatementId::Root(0)) )); assert!( matches!(&changed[2], StatementChange::Added(AddedStatement { stmt: _, text }) if text == "select id,test from users;") @@ -873,35 +873,23 @@ mod tests { assert_eq!( changed[0], - StatementChange::Deleted(Statement { - path: path.clone(), - id: 1 - }) + StatementChange::Deleted(StatementId::Root(1)) ); assert_eq!( changed[1], - StatementChange::Deleted(Statement { - path: path.clone(), - id: 0 - }) + StatementChange::Deleted(StatementId::Root(0)) ); assert_eq!( changed[2], StatementChange::Added(AddedStatement { - stmt: Statement { - path: path.clone(), - id: 2 - }, + stmt: StatementId::Root(2), text: "select id,test from users".to_string() }) ); assert_eq!( changed[3], StatementChange::Added(AddedStatement { - stmt: Statement { - path: path.clone(), - id: 3 - }, + stmt: StatementId::Root(3), text: "select 1;".to_string() }) ); @@ -1132,7 +1120,7 @@ mod tests { assert_eq!(changed.len(), 3); assert!(matches!( changed[0], - StatementChange::Deleted(Statement { id: 0, .. }) + StatementChange::Deleted(StatementId::Root(0)) )); assert!(matches!( changed[1], @@ -1162,11 +1150,11 @@ mod tests { assert_eq!(changed_2.len(), 3); assert!(matches!( changed_2[0], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( changed_2[1], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( changed_2[2], @@ -1282,7 +1270,7 @@ mod tests { assert!(matches!( changes[0], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index ae51d27d..ba2c7242 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -34,10 +34,6 @@ impl Document { } } - pub fn diagnostics(&self) -> &[SDiagnostic] { - &self.diagnostics - } - /// Returns true if there is at least one fatal error in the diagnostics /// /// A fatal error is a scan error that prevents the document from being used @@ -87,12 +83,6 @@ pub(crate) fn split_with_diagnostics( } } -pub trait StatementMapper<'a> { - type Output; - - fn map(&self, doc: &'a Document, id: &'a StatementId, range: &'a TextRange) -> Self::Output; -} - pub struct StatementIterator<'a> { document: &'a Document, positions: std::slice::Iter<'a, StatementPos>, diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index 39d06da6..082f9042 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -17,6 +17,7 @@ use super::{ }; pub struct Parser { + #[allow(dead_code)] path: PgTPath, doc: Document, @@ -118,6 +119,7 @@ impl Parser { ParseIterator::new(self, mapper, filter) } + #[allow(dead_code)] pub fn count<'a>(&'a self) -> usize { self.iter(DefaultMapper).count() } @@ -131,7 +133,7 @@ pub trait StatementMapper<'a> { parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output; } @@ -144,7 +146,7 @@ pub struct ParseIterator<'a, M, F> { statements: StatementIterator<'a>, mapper: M, filter: F, - pending_sub_statements: Vec<(StatementId, TextRange, &'a str)>, + pending_sub_statements: Vec<(StatementId, TextRange, String)>, } impl<'a, M, F> ParseIterator<'a, M, F> { @@ -181,19 +183,20 @@ where if let Some((root_id, range, content)) = next_statement { // If we should include sub-statements and this statement has an AST - if let Ok(ast) = self.parser.ast_db.load_parse(&root_id, &content).as_ref() { + let content_owned = content.to_string(); + if let Ok(ast) = self.parser.ast_db.load_parse(&root_id, &content_owned).as_ref() { // Check if this is a SQL function definition with a body if let Some(sub_statement) = self .parser .sql_fn_db - .get_function_body(&root_id, &ast, content) + .get_function_body(&root_id, &ast, &content_owned) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( root_id.create_child(), // adjust range to document sub_statement.range + range.start(), - &sub_statement.body, + sub_statement.body.clone(), )); } } @@ -213,16 +216,16 @@ where pub struct DefaultMapper; impl<'a> StatementMapper<'a> for DefaultMapper { - type Output = (StatementId, TextRange, &'a str); + type Output = (StatementId, TextRange, String); fn map( &self, _parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output { - (id, range, content) + (id, range, content.to_string()) } } @@ -231,7 +234,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { type Output = ( StatementId, TextRange, - &'a str, + String, Option, ); @@ -240,7 +243,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output { let ast_result = parser.ast_db.load_parse(&id, content); let ast_option = match &*ast_result { @@ -248,7 +251,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { Err(_) => None, }; - (id, range, content, ast_option) + (id, range, content.to_string(), ast_option) } } @@ -267,18 +270,19 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output { - let ast_result = parser.ast_db.load_parse(&id, content); + let content_owned = content.to_string(); + let ast_result = parser.ast_db.load_parse(&id, &content_owned); let ast_option = match &*ast_result { Ok(node) => Some(node.clone()), Err(_) => None, }; - let cst_result = parser.cst_db.load_parse(&id, content); + let cst_result = parser.cst_db.load_parse(&id, &content_owned); - (id, range, content.to_string(), ast_option, cst_result) + (id, range, content_owned, ast_option, cst_result) } } @@ -296,7 +300,7 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output { let ast_result = parser.ast_db.load_parse(&id, content); @@ -311,17 +315,17 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { pub struct GetCompletionsMapper; impl<'a> StatementMapper<'a> for GetCompletionsMapper { - type Output = (StatementId, TextRange, &'a str, Arc); + type Output = (StatementId, TextRange, String, Arc); fn map( &self, parser: &'a Parser, id: StatementId, range: TextRange, - content: &'a str, + content: &str, ) -> Self::Output { let cst_result = parser.cst_db.load_parse(&id, content); - (id, range, content, cst_result) + (id, range, content.to_string(), cst_result) } } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 1926664d..916aea85 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,14 +1,22 @@ use std::ops::Deref; +use serde::{Deserialize, Serialize}; pub type RootId = usize; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] pub enum StatementId { Root(RootId), // StatementId is the same as the root id since we can only have a single sql function body per Root Child(RootId), } +impl Default for StatementId { + fn default() -> Self { + StatementId::Root(0) + } +} + /// Helper struct to generate unique statement ids pub struct IdGenerator { next_id: RootId, From b08b0a59067be628b300737d4ab44bad96b29379 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Apr 2025 12:52:08 +0200 Subject: [PATCH 06/20] format --- crates/pgt_lsp/src/handlers/code_actions.rs | 9 ++++----- .../pgt_workspace/src/workspace/server/change.rs | 10 ++-------- .../pgt_workspace/src/workspace/server/parser.rs | 15 ++++++++++----- .../src/workspace/server/statement_identifier.rs | 2 +- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/crates/pgt_lsp/src/handlers/code_actions.rs b/crates/pgt_lsp/src/handlers/code_actions.rs index cf1ff9bd..a10bee03 100644 --- a/crates/pgt_lsp/src/handlers/code_actions.rs +++ b/crates/pgt_lsp/src/handlers/code_actions.rs @@ -81,17 +81,16 @@ pub async fn execute_command( match command.as_str() { "pgt.executeStatement" => { - let statement_id = serde_json::from_value::(params.arguments[0].clone())?; + let statement_id = serde_json::from_value::( + params.arguments[0].clone(), + )?; let doc_url: lsp_types::Url = serde_json::from_value(params.arguments[1].clone())?; let path = session.file_path(&doc_url)?; let result = session .workspace - .execute_statement(ExecuteStatementParams { - statement_id, - path, - })?; + .execute_statement(ExecuteStatementParams { statement_id, path })?; /* * Updating all diagnostics: the changes caused by the statement execution diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 0b09a3ca..f286c4ff 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -871,14 +871,8 @@ mod tests { assert_eq!(changed.len(), 4); - assert_eq!( - changed[0], - StatementChange::Deleted(StatementId::Root(1)) - ); - assert_eq!( - changed[1], - StatementChange::Deleted(StatementId::Root(0)) - ); + assert_eq!(changed[0], StatementChange::Deleted(StatementId::Root(1))); + assert_eq!(changed[1], StatementChange::Deleted(StatementId::Root(0))); assert_eq!( changed[2], StatementChange::Added(AddedStatement { diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index 082f9042..d75210a2 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -184,12 +184,17 @@ where if let Some((root_id, range, content)) = next_statement { // If we should include sub-statements and this statement has an AST let content_owned = content.to_string(); - if let Ok(ast) = self.parser.ast_db.load_parse(&root_id, &content_owned).as_ref() { + if let Ok(ast) = self + .parser + .ast_db + .load_parse(&root_id, &content_owned) + .as_ref() + { // Check if this is a SQL function definition with a body - if let Some(sub_statement) = self - .parser - .sql_fn_db - .get_function_body(&root_id, &ast, &content_owned) + if let Some(sub_statement) = + self.parser + .sql_fn_db + .get_function_body(&root_id, &ast, &content_owned) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 916aea85..b416bfda 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,5 +1,5 @@ -use std::ops::Deref; use serde::{Deserialize, Serialize}; +use std::ops::Deref; pub type RootId = usize; From b9d6143f3359cbc3f919846d86ffcb366ee42922 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Apr 2025 12:52:57 +0200 Subject: [PATCH 07/20] fix: lint --- crates/pgt_workspace/src/workspace/server/document.rs | 4 ++-- crates/pgt_workspace/src/workspace/server/parser.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index ba2c7242..523fcad5 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -102,10 +102,10 @@ impl<'a> Iterator for StatementIterator<'a> { fn next(&mut self) -> Option { self.positions.next().map(|(id, range)| { - let range = range.clone(); + let range = *range; let doc = self.document; let id = id.clone(); - (id, range, &doc.content[range.clone()]) + (id, range, &doc.content[range]) }) } } diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index d75210a2..38169077 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -120,7 +120,7 @@ impl Parser { } #[allow(dead_code)] - pub fn count<'a>(&'a self) -> usize { + pub fn count(&self) -> usize { self.iter(DefaultMapper).count() } } @@ -194,7 +194,7 @@ where if let Some(sub_statement) = self.parser .sql_fn_db - .get_function_body(&root_id, &ast, &content_owned) + .get_function_body(&root_id, ast, &content_owned) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( From e9244ed77e0096ca001f0e75a429969342211e1c Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:00:57 +0200 Subject: [PATCH 08/20] rename get_or_cache_ast --- crates/pgt_workspace/src/workspace/server/parser.rs | 8 ++++---- crates/pgt_workspace/src/workspace/server/pg_query.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parser.rs index 38169077..83c4ab77 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parser.rs @@ -187,7 +187,7 @@ where if let Ok(ast) = self .parser .ast_db - .load_parse(&root_id, &content_owned) + .get_or_cache_ast(&root_id, &content_owned) .as_ref() { // Check if this is a SQL function definition with a body @@ -250,7 +250,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { range: TextRange, content: &str, ) -> Self::Output { - let ast_result = parser.ast_db.load_parse(&id, content); + let ast_result = parser.ast_db.get_or_cache_ast(&id, content); let ast_option = match &*ast_result { Ok(node) => Some(node.clone()), Err(_) => None, @@ -278,7 +278,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { content: &str, ) -> Self::Output { let content_owned = content.to_string(); - let ast_result = parser.ast_db.load_parse(&id, &content_owned); + let ast_result = parser.ast_db.get_or_cache_ast(&id, &content_owned); let ast_option = match &*ast_result { Ok(node) => Some(node.clone()), @@ -307,7 +307,7 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { range: TextRange, content: &str, ) -> Self::Output { - let ast_result = parser.ast_db.load_parse(&id, content); + let ast_result = parser.ast_db.get_or_cache_ast(&id, content); let (ast_option, diagnostics) = match &*ast_result { Ok(node) => (Some(node.clone()), None), diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index c95fc927..572ad716 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -14,7 +14,7 @@ impl PgQueryStore { PgQueryStore { db: DashMap::new() } } - pub fn load_parse( + pub fn get_or_cache_ast( &self, statement: &StatementId, content: &str, From 3f36ffcd9a268f9624ffc469e94a91cd2db89772 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:02:17 +0200 Subject: [PATCH 09/20] renamings --- crates/pgt_workspace/src/workspace/server.rs | 16 +++++----- .../{parser.rs => parsed_statements.rs} | 32 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) rename crates/pgt_workspace/src/workspace/server/{parser.rs => parsed_statements.rs} (94%) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index a770d33a..66a5a91c 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -6,9 +6,9 @@ use dashmap::DashMap; use db_connection::DbConnection; use document::Document; use futures::{StreamExt, stream}; -use parser::{ +use parsed_statements::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, - GetCompletionsMapper, Parser, SyncDiagnosticsMapper, + GetCompletionsMapper, ParsedStatements, SyncDiagnosticsMapper, }; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; @@ -48,7 +48,7 @@ mod change; mod db_connection; mod document; mod migration; -mod parser; +mod parsed_statements; mod pg_query; mod schema_cache_manager; mod sql_function; @@ -62,7 +62,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - parsers: DashMap, + parsers: DashMap, connection: RwLock, } @@ -182,9 +182,9 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - self.parsers - .entry(params.path.clone()) - .or_insert_with(|| Parser::new(params.path.clone(), params.content, params.version)); + self.parsers.entry(params.path.clone()).or_insert_with(|| { + ParsedStatements::new(params.path.clone(), params.content, params.version) + }); Ok(()) } @@ -207,7 +207,7 @@ impl Workspace for WorkspaceServer { let mut parser = self .parsers .entry(params.path.clone()) - .or_insert(Parser::new( + .or_insert(ParsedStatements::new( params.path.clone(), "".to_string(), params.version, diff --git a/crates/pgt_workspace/src/workspace/server/parser.rs b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs similarity index 94% rename from crates/pgt_workspace/src/workspace/server/parser.rs rename to crates/pgt_workspace/src/workspace/server/parsed_statements.rs index 83c4ab77..804c7808 100644 --- a/crates/pgt_workspace/src/workspace/server/parser.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs @@ -16,7 +16,7 @@ use super::{ tree_sitter::TreeSitterStore, }; -pub struct Parser { +pub struct ParsedStatements { #[allow(dead_code)] path: PgTPath, @@ -26,8 +26,8 @@ pub struct Parser { sql_fn_db: SQLFunctionBodyStore, } -impl Parser { - pub fn new(path: PgTPath, content: String, version: i32) -> Parser { +impl ParsedStatements { + pub fn new(path: PgTPath, content: String, version: i32) -> ParsedStatements { let doc = Document::new(content, version); let cst_db = TreeSitterStore::new(); @@ -38,7 +38,7 @@ impl Parser { cst_db.add_statement(&stmt, content); }); - Parser { + ParsedStatements { path, doc, ast_db, @@ -104,11 +104,11 @@ impl Parser { self.iter_with_filter(mapper, IdFilter::new(id)).next() } - pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, DefaultFilter> + pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, NoFilter> where M: StatementMapper<'a>, { - self.iter_with_filter(mapper, DefaultFilter) + self.iter_with_filter(mapper, NoFilter) } pub fn iter_with_filter<'a, M, F>(&'a self, mapper: M, filter: F) -> ParseIterator<'a, M, F> @@ -130,7 +130,7 @@ pub trait StatementMapper<'a> { fn map( &self, - parser: &'a Parser, + parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -142,7 +142,7 @@ pub trait StatementFilter<'a> { } pub struct ParseIterator<'a, M, F> { - parser: &'a Parser, + parser: &'a ParsedStatements, statements: StatementIterator<'a>, mapper: M, filter: F, @@ -150,7 +150,7 @@ pub struct ParseIterator<'a, M, F> { } impl<'a, M, F> ParseIterator<'a, M, F> { - pub fn new(parser: &'a Parser, mapper: M, filter: F) -> Self { + pub fn new(parser: &'a ParsedStatements, mapper: M, filter: F) -> Self { Self { parser, statements: parser.doc.iter(), @@ -225,7 +225,7 @@ impl<'a> StatementMapper<'a> for DefaultMapper { fn map( &self, - _parser: &'a Parser, + _parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -245,7 +245,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { fn map( &self, - parser: &'a Parser, + parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -272,7 +272,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { fn map( &self, - parser: &'a Parser, + parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -302,7 +302,7 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { fn map( &self, - parser: &'a Parser, + parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -324,7 +324,7 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { fn map( &self, - parser: &'a Parser, + parser: &'a ParsedStatements, id: StatementId, range: TextRange, content: &str, @@ -334,8 +334,8 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { } } -pub struct DefaultFilter; -impl<'a> StatementFilter<'a> for DefaultFilter { +pub struct NoFilter; +impl<'a> StatementFilter<'a> for NoFilter { fn apply(&self, _id: &StatementId, _range: &TextRange) -> bool { true } From dc7afd658c190c4d0d89d5ed92faafb5e36c347e Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:02:46 +0200 Subject: [PATCH 10/20] predicate, a la rust std lib --- .../src/workspace/server/parsed_statements.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs index 804c7808..f7781a11 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs @@ -138,7 +138,7 @@ pub trait StatementMapper<'a> { } pub trait StatementFilter<'a> { - fn apply(&self, id: &StatementId, range: &TextRange) -> bool; + fn predicate(&self, id: &StatementId, range: &TextRange) -> bool; } pub struct ParseIterator<'a, M, F> { @@ -171,7 +171,7 @@ where fn next(&mut self) -> Option { // First check if we have any pending sub-statements to process if let Some((id, range, content)) = self.pending_sub_statements.pop() { - if self.filter.apply(&id, &range) { + if self.filter.predicate(&id, &range) { return Some(self.mapper.map(self.parser, id, range, &content)); } // If the sub-statement doesn't pass the filter, continue to the next item @@ -207,7 +207,7 @@ where } // Return the current statement if it passes the filter - if self.filter.apply(&root_id, &range) { + if self.filter.predicate(&root_id, &range) { return Some(self.mapper.map(self.parser, root_id, range, content)); } @@ -336,7 +336,7 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { pub struct NoFilter; impl<'a> StatementFilter<'a> for NoFilter { - fn apply(&self, _id: &StatementId, _range: &TextRange) -> bool { + fn predicate(&self, _id: &StatementId, _range: &TextRange) -> bool { true } } @@ -352,7 +352,7 @@ impl CursorPositionFilter { } impl<'a> StatementFilter<'a> for CursorPositionFilter { - fn apply(&self, _id: &StatementId, range: &TextRange) -> bool { + fn predicate(&self, _id: &StatementId, range: &TextRange) -> bool { range.contains(self.pos) } } @@ -368,7 +368,7 @@ impl IdFilter { } impl<'a> StatementFilter<'a> for IdFilter { - fn apply(&self, id: &StatementId, _range: &TextRange) -> bool { + fn predicate(&self, id: &StatementId, _range: &TextRange) -> bool { *id == self.id } } From b08de29e6faf0784a0c4b255f12cdbff47a2c746 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:03:14 +0200 Subject: [PATCH 11/20] get or cache tree --- .../pgt_workspace/src/workspace/server/parsed_statements.rs | 4 ++-- crates/pgt_workspace/src/workspace/server/tree_sitter.rs | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs index f7781a11..bb923402 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_statements.rs @@ -285,7 +285,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { Err(_) => None, }; - let cst_result = parser.cst_db.load_parse(&id, &content_owned); + let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned); (id, range, content_owned, ast_option, cst_result) } @@ -329,7 +329,7 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { range: TextRange, content: &str, ) -> Self::Output { - let cst_result = parser.cst_db.load_parse(&id, content); + let cst_result = parser.cst_db.get_or_cache_tree(&id, content); (id, range, content.to_string(), cst_result) } } diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index fdca5bdd..782e8abe 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -23,7 +23,11 @@ impl TreeSitterStore { } } - pub fn load_parse(&self, statement: &StatementId, content: &str) -> Arc { + pub fn get_or_cache_tree( + &self, + statement: &StatementId, + content: &str, + ) -> Arc { if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { return existing; } From 6d772cf2b552d7f017dcbbe40b064c98af422acb Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:03:54 +0200 Subject: [PATCH 12/20] ok --- crates/pgt_workspace/src/workspace/server.rs | 40 ++++++++++---------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 66a5a91c..28a728b1 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -62,7 +62,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - parsers: DashMap, + parsed_documents: DashMap, connection: RwLock, } @@ -84,7 +84,7 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - parsers: DashMap::default(), + parsed_documents: DashMap::default(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), } @@ -182,16 +182,18 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - self.parsers.entry(params.path.clone()).or_insert_with(|| { - ParsedStatements::new(params.path.clone(), params.content, params.version) - }); + self.parsed_documents + .entry(params.path.clone()) + .or_insert_with(|| { + ParsedStatements::new(params.path.clone(), params.content, params.version) + }); Ok(()) } /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - self.parsers + self.parsed_documents .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; @@ -204,14 +206,14 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut parser = self - .parsers - .entry(params.path.clone()) - .or_insert(ParsedStatements::new( - params.path.clone(), - "".to_string(), - params.version, - )); + let mut parser = + self.parsed_documents + .entry(params.path.clone()) + .or_insert(ParsedStatements::new( + params.path.clone(), + "".to_string(), + params.version, + )); parser.apply_change(params); @@ -224,7 +226,7 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .parsers + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; Ok(document.get_document_content().to_string()) @@ -239,7 +241,7 @@ impl Workspace for WorkspaceServer { params: code_actions::CodeActionsParams, ) -> Result { let parser = self - .parsers + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -283,7 +285,7 @@ impl Workspace for WorkspaceServer { params: ExecuteStatementParams, ) -> Result { let parser = self - .parsers + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -351,7 +353,7 @@ impl Workspace for WorkspaceServer { }); let parser = self - .parsers + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -468,7 +470,7 @@ impl Workspace for WorkspaceServer { params: GetCompletionsParams, ) -> Result { let parser = self - .parsers + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; From 0cb5c90096a644278c8efab11acd111232227db4 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:05:26 +0200 Subject: [PATCH 13/20] cant be any clearer --- crates/pgt_workspace/src/workspace/server.rs | 34 ++++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 28a728b1..2343620c 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -62,7 +62,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - parsed_documents: DashMap, + parsed_stmts_by_path: DashMap, connection: RwLock, } @@ -84,7 +84,7 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - parsed_documents: DashMap::default(), + parsed_stmts_by_path: DashMap::default(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), } @@ -182,7 +182,7 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - self.parsed_documents + self.parsed_stmts_by_path .entry(params.path.clone()) .or_insert_with(|| { ParsedStatements::new(params.path.clone(), params.content, params.version) @@ -193,7 +193,7 @@ impl Workspace for WorkspaceServer { /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - self.parsed_documents + self.parsed_stmts_by_path .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; @@ -206,14 +206,14 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut parser = - self.parsed_documents - .entry(params.path.clone()) - .or_insert(ParsedStatements::new( - params.path.clone(), - "".to_string(), - params.version, - )); + let mut parser = self + .parsed_stmts_by_path + .entry(params.path.clone()) + .or_insert(ParsedStatements::new( + params.path.clone(), + "".to_string(), + params.version, + )); parser.apply_change(params); @@ -226,7 +226,7 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .parsed_documents + .parsed_stmts_by_path .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; Ok(document.get_document_content().to_string()) @@ -241,7 +241,7 @@ impl Workspace for WorkspaceServer { params: code_actions::CodeActionsParams, ) -> Result { let parser = self - .parsed_documents + .parsed_stmts_by_path .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -285,7 +285,7 @@ impl Workspace for WorkspaceServer { params: ExecuteStatementParams, ) -> Result { let parser = self - .parsed_documents + .parsed_stmts_by_path .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -353,7 +353,7 @@ impl Workspace for WorkspaceServer { }); let parser = self - .parsed_documents + .parsed_stmts_by_path .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -470,7 +470,7 @@ impl Workspace for WorkspaceServer { params: GetCompletionsParams, ) -> Result { let parser = self - .parsed_documents + .parsed_stmts_by_path .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; From 10ec65ed31919e87bcbfe30b696c31f8e5d56c93 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:12:17 +0200 Subject: [PATCH 14/20] comment stmt identifier --- .../src/workspace/server/statement_identifier.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index b416bfda..5a5df939 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -5,6 +5,22 @@ pub type RootId = usize; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +/// `StatementId` can represent IDs for nested statements. +/// +/// For example, an SQL function really consist of two statements; the function creation +/// and the body: +/// +/// ```sql +/// create or replace function get_product_name(product_id INT) -- the root statement +/// returns varchar as $$ +/// select * from … -- the child statement +/// $$ LANGUAGE plpgsql; +/// ``` +/// +/// For now, we only support SQL functions – no complex, nested statements. +/// +/// An SQL function only ever has ONE child, that's why the inner `RootId` of a `Root` +/// is the same as the one of its `Child`. pub enum StatementId { Root(RootId), // StatementId is the same as the root id since we can only have a single sql function body per Root From c8858c47c88ba3021e1bf4501e09633506212d06 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 20:19:19 +0200 Subject: [PATCH 15/20] more comments --- crates/pgt_workspace/src/workspace/server/pg_query.rs | 2 +- crates/pgt_workspace/src/workspace/server/sql_function.rs | 2 +- .../src/workspace/server/statement_identifier.rs | 8 +++++++- crates/pgt_workspace/src/workspace/server/tree_sitter.rs | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 572ad716..e5c0cac8 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -31,7 +31,7 @@ impl PgQueryStore { pub fn clear_statement(&self, id: &StatementId) { self.db.remove(id); - if let Some(child_id) = id.as_child() { + if let Some(child_id) = id.get_child_id() { self.db.remove(&child_id); } } diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 952063e3..3273466d 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -41,7 +41,7 @@ impl SQLFunctionBodyStore { pub fn clear_statement(&self, id: &StatementId) { self.db.remove(id); - if let Some(child_id) = id.as_child() { + if let Some(child_id) = id.get_child_id() { self.db.remove(&child_id); } } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 5a5df939..a5cc306c 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -51,13 +51,19 @@ impl IdGenerator { } impl StatementId { - pub fn as_child(&self) -> Option { + /// Use this to get the matching `StatementId::Child` for + /// a `StatementId::Root`. + /// If the `StatementId` was already a `Child`, this will return `None`. + /// It is not guaranteed that the `Root` actually has a `Child` statement in the workspace. + pub fn get_child_id(&self) -> Option { match self { StatementId::Root(id) => Some(StatementId::Child(*id)), StatementId::Child(_) => None, } } + /// Use this if you need to create a matching `StatementId::Child` for `Root`. + /// You cannot create a `Child` of a `Child`. pub fn create_child(&self) -> StatementId { match self { StatementId::Root(id) => StatementId::Child(*id), diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index 782e8abe..a8932535 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -48,7 +48,7 @@ impl TreeSitterStore { pub fn remove_statement(&self, id: &StatementId) { self.db.remove(id); - if let Some(child_id) = id.as_child() { + if let Some(child_id) = id.get_child_id() { self.db.remove(&child_id); } } From 682b7ee576cbd313ce06f079319ac704b41439fe Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 22:22:17 +0200 Subject: [PATCH 16/20] rootid --- .../src/workspace/server/change.rs | 38 +++++++++------- .../src/workspace/server/document.rs | 6 +-- .../workspace/server/statement_identifier.rs | 45 +++++++++++-------- 3 files changed, 52 insertions(+), 37 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index f286c4ff..31ffe769 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -402,11 +402,16 @@ fn get_affected(content: &str, range: TextRange) -> &str { #[cfg(test)] mod tests { + use std::ops::Deref; + use super::*; use pgt_diagnostics::Diagnostic; use pgt_text_size::TextRange; - use crate::workspace::{ChangeFileParams, ChangeParams}; + use crate::workspace::{ + ChangeFileParams, ChangeParams, + server::statement_identifier::{RootId, StatementIdGenerator, root_id}, + }; use pgt_fs::PgTPath; @@ -806,14 +811,13 @@ mod tests { let changed = d.apply_file_change(&change); assert_eq!(changed.len(), 4); - assert!(matches!( - changed[0], - StatementChange::Deleted(StatementId::Root(1)) - )); + assert!(matches!(changed[0], StatementChange::Deleted(_))); + assert_eq!(changed[0].statement().raw(), 1); assert!(matches!( changed[1], - StatementChange::Deleted(StatementId::Root(0)) + StatementChange::Deleted(StatementId::Root(_)) )); + assert_eq!(changed[1].statement().raw(), 0); assert!( matches!(&changed[2], StatementChange::Added(AddedStatement { stmt: _, text }) if text == "select id,test from users;") ); @@ -871,19 +875,27 @@ mod tests { assert_eq!(changed.len(), 4); - assert_eq!(changed[0], StatementChange::Deleted(StatementId::Root(1))); - assert_eq!(changed[1], StatementChange::Deleted(StatementId::Root(0))); + assert!(matches!( + changed[0], + StatementChange::Deleted(StatementId::Root(_)) + )); + assert_eq!(changed[0].statement().raw(), 1); + assert!(matches!( + changed[1], + StatementChange::Deleted(StatementId::Root(_)) + )); + assert_eq!(changed[1].statement().raw(), 0); assert_eq!( changed[2], StatementChange::Added(AddedStatement { - stmt: StatementId::Root(2), + stmt: StatementId::Root(root_id(2)), text: "select id,test from users".to_string() }) ); assert_eq!( changed[3], StatementChange::Added(AddedStatement { - stmt: StatementId::Root(3), + stmt: StatementId::Root(root_id(3)), text: "select 1;".to_string() }) ); @@ -1110,12 +1122,8 @@ mod tests { doc.content, "- Add new schema named \"private\"\nCREATE SCHEMA \"private\";" ); - assert_eq!(changed.len(), 3); - assert!(matches!( - changed[0], - StatementChange::Deleted(StatementId::Root(0)) - )); + assert!(matches!(&changed[0], StatementChange::Deleted(_))); assert!(matches!( changed[1], StatementChange::Added(AddedStatement { .. }) diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 523fcad5..f2c500cc 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -1,7 +1,7 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; use pgt_text_size::{TextRange, TextSize}; -use super::statement_identifier::{IdGenerator, StatementId}; +use super::statement_identifier::{StatementId, StatementIdGenerator}; type StatementPos = (StatementId, TextRange); @@ -13,12 +13,12 @@ pub(crate) struct Document { /// List of statements sorted by range.start() pub(super) positions: Vec, - pub(super) id_generator: IdGenerator, + pub(super) id_generator: StatementIdGenerator, } impl Document { pub(crate) fn new(content: String, version: i32) -> Self { - let mut id_generator = IdGenerator::new(); + let mut id_generator = StatementIdGenerator::new(); let (ranges, diagnostics) = split_with_diagnostics(&content, None); diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index a5cc306c..03e3683e 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,7 +1,16 @@ use serde::{Deserialize, Serialize}; use std::ops::Deref; -pub type RootId = usize; +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct RootId { + inner: usize, +} + +#[cfg(test)] +pub fn root_id(inner: usize) -> RootId { + RootId { inner } +} #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -29,16 +38,25 @@ pub enum StatementId { impl Default for StatementId { fn default() -> Self { - StatementId::Root(0) + StatementId::Root(RootId { inner: 0 }) + } +} + +impl StatementId { + pub fn raw(&self) -> usize { + match self { + StatementId::Root(s) => s.inner, + StatementId::Child(s) => s.inner, + } } } /// Helper struct to generate unique statement ids -pub struct IdGenerator { - next_id: RootId, +pub struct StatementIdGenerator { + next_id: usize, } -impl IdGenerator { +impl StatementIdGenerator { pub fn new() -> Self { Self { next_id: 0 } } @@ -46,7 +64,7 @@ impl IdGenerator { pub fn next(&mut self) -> StatementId { let id = self.next_id; self.next_id += 1; - StatementId::Root(id) + StatementId::Root(RootId { inner: id }) } } @@ -57,7 +75,7 @@ impl StatementId { /// It is not guaranteed that the `Root` actually has a `Child` statement in the workspace. pub fn get_child_id(&self) -> Option { match self { - StatementId::Root(id) => Some(StatementId::Child(*id)), + StatementId::Root(id) => Some(StatementId::Child(RootId { inner: id.inner })), StatementId::Child(_) => None, } } @@ -66,19 +84,8 @@ impl StatementId { /// You cannot create a `Child` of a `Child`. pub fn create_child(&self) -> StatementId { match self { - StatementId::Root(id) => StatementId::Child(*id), + StatementId::Root(id) => StatementId::Child(RootId { inner: id.inner }), StatementId::Child(_) => panic!("Cannot create child from a child statement id"), } } } - -impl Deref for StatementId { - type Target = RootId; - - fn deref(&self) -> &Self::Target { - match self { - StatementId::Root(id) => id, - StatementId::Child(id) => id, - } - } -} From f3f0181238d9d4df304927a3f67f93cdaf4382f8 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 22:23:51 +0200 Subject: [PATCH 17/20] parsed_document --- crates/pgt_workspace/src/workspace/server.rs | 24 +++++++++---------- ...arsed_statements.rs => parsed_document.rs} | 24 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) rename crates/pgt_workspace/src/workspace/server/{parsed_statements.rs => parsed_document.rs} (95%) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 2343620c..b180e6d9 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -6,9 +6,9 @@ use dashmap::DashMap; use db_connection::DbConnection; use document::Document; use futures::{StreamExt, stream}; -use parsed_statements::{ +use parsed_document::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, - GetCompletionsMapper, ParsedStatements, SyncDiagnosticsMapper, + GetCompletionsMapper, ParsedDocument, SyncDiagnosticsMapper, }; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; @@ -48,7 +48,7 @@ mod change; mod db_connection; mod document; mod migration; -mod parsed_statements; +mod parsed_document; mod pg_query; mod schema_cache_manager; mod sql_function; @@ -62,7 +62,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - parsed_stmts_by_path: DashMap, + parsed_documents: DashMap, connection: RwLock, } @@ -84,7 +84,7 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - parsed_stmts_by_path: DashMap::default(), + parsed_documents: DashMap::default(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), } @@ -185,7 +185,7 @@ impl Workspace for WorkspaceServer { self.parsed_stmts_by_path .entry(params.path.clone()) .or_insert_with(|| { - ParsedStatements::new(params.path.clone(), params.content, params.version) + ParsedDocument::new(params.path.clone(), params.content, params.version) }); Ok(()) @@ -193,7 +193,7 @@ impl Workspace for WorkspaceServer { /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - self.parsed_stmts_by_path + self.parsed_documents .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; @@ -209,7 +209,7 @@ impl Workspace for WorkspaceServer { let mut parser = self .parsed_stmts_by_path .entry(params.path.clone()) - .or_insert(ParsedStatements::new( + .or_insert(ParsedDocument::new( params.path.clone(), "".to_string(), params.version, @@ -226,7 +226,7 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .parsed_stmts_by_path + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; Ok(document.get_document_content().to_string()) @@ -241,7 +241,7 @@ impl Workspace for WorkspaceServer { params: code_actions::CodeActionsParams, ) -> Result { let parser = self - .parsed_stmts_by_path + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -285,7 +285,7 @@ impl Workspace for WorkspaceServer { params: ExecuteStatementParams, ) -> Result { let parser = self - .parsed_stmts_by_path + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -353,7 +353,7 @@ impl Workspace for WorkspaceServer { }); let parser = self - .parsed_stmts_by_path + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; diff --git a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs similarity index 95% rename from crates/pgt_workspace/src/workspace/server/parsed_statements.rs rename to crates/pgt_workspace/src/workspace/server/parsed_document.rs index bb923402..a110fb1f 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_statements.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -16,7 +16,7 @@ use super::{ tree_sitter::TreeSitterStore, }; -pub struct ParsedStatements { +pub struct ParsedDocument { #[allow(dead_code)] path: PgTPath, @@ -26,8 +26,8 @@ pub struct ParsedStatements { sql_fn_db: SQLFunctionBodyStore, } -impl ParsedStatements { - pub fn new(path: PgTPath, content: String, version: i32) -> ParsedStatements { +impl ParsedDocument { + pub fn new(path: PgTPath, content: String, version: i32) -> ParsedDocument { let doc = Document::new(content, version); let cst_db = TreeSitterStore::new(); @@ -38,7 +38,7 @@ impl ParsedStatements { cst_db.add_statement(&stmt, content); }); - ParsedStatements { + ParsedDocument { path, doc, ast_db, @@ -130,7 +130,7 @@ pub trait StatementMapper<'a> { fn map( &self, - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -142,7 +142,7 @@ pub trait StatementFilter<'a> { } pub struct ParseIterator<'a, M, F> { - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, statements: StatementIterator<'a>, mapper: M, filter: F, @@ -150,7 +150,7 @@ pub struct ParseIterator<'a, M, F> { } impl<'a, M, F> ParseIterator<'a, M, F> { - pub fn new(parser: &'a ParsedStatements, mapper: M, filter: F) -> Self { + pub fn new(parser: &'a ParsedDocument, mapper: M, filter: F) -> Self { Self { parser, statements: parser.doc.iter(), @@ -225,7 +225,7 @@ impl<'a> StatementMapper<'a> for DefaultMapper { fn map( &self, - _parser: &'a ParsedStatements, + _parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -245,7 +245,7 @@ impl<'a> StatementMapper<'a> for ExecuteStatementMapper { fn map( &self, - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -272,7 +272,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { fn map( &self, - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -302,7 +302,7 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { fn map( &self, - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -324,7 +324,7 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { fn map( &self, - parser: &'a ParsedStatements, + parser: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, From d77a3bb6e6f1f5a747710b8aaa7ccd60feb8657b Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 22:25:08 +0200 Subject: [PATCH 18/20] whoops --- crates/pgt_workspace/src/workspace/server.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index b180e6d9..27f5e8be 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -182,7 +182,7 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - self.parsed_stmts_by_path + self.parsed_documents .entry(params.path.clone()) .or_insert_with(|| { ParsedDocument::new(params.path.clone(), params.content, params.version) @@ -206,14 +206,14 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut parser = self - .parsed_stmts_by_path - .entry(params.path.clone()) - .or_insert(ParsedDocument::new( - params.path.clone(), - "".to_string(), - params.version, - )); + let mut parser = + self.parsed_documents + .entry(params.path.clone()) + .or_insert(ParsedDocument::new( + params.path.clone(), + "".to_string(), + params.version, + )); parser.apply_change(params); @@ -470,7 +470,7 @@ impl Workspace for WorkspaceServer { params: GetCompletionsParams, ) -> Result { let parser = self - .parsed_stmts_by_path + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; From 19937a34eebd784afc4bb2872bfd09e64b2a483c Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 22:26:38 +0200 Subject: [PATCH 19/20] ok --- crates/pgt_workspace/src/workspace/server/change.rs | 6 +----- .../src/workspace/server/statement_identifier.rs | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 31ffe769..d68538bf 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -402,16 +402,12 @@ fn get_affected(content: &str, range: TextRange) -> &str { #[cfg(test)] mod tests { - use std::ops::Deref; use super::*; use pgt_diagnostics::Diagnostic; use pgt_text_size::TextRange; - use crate::workspace::{ - ChangeFileParams, ChangeParams, - server::statement_identifier::{RootId, StatementIdGenerator, root_id}, - }; + use crate::workspace::{ChangeFileParams, ChangeParams, server::statement_identifier::root_id}; use pgt_fs::PgTPath; diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 03e3683e..0739fb2f 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::ops::Deref; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] From 9f7569f9a423278da62a568bd63e44d6f2df5488 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 22:28:26 +0200 Subject: [PATCH 20/20] fix --- crates/pgt_workspace/src/workspace/server/change.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 31086634..afe0eb64 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -1292,7 +1292,7 @@ mod tests { fn remove_trailing_whitespace() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from ".to_string(), 0); + let mut doc = Document::new("select * from ".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1338,7 +1338,7 @@ mod tests { fn remove_trailing_whitespace_and_last_char() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from ".to_string(), 0); + let mut doc = Document::new("select * from ".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1384,7 +1384,7 @@ mod tests { fn remove_inbetween_whitespace() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from users".to_string(), 0); + let mut doc = Document::new("select * from users".to_string(), 0); let change = ChangeFileParams { path: path.clone(),