diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index fb00aeaf..ed51c653 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -14,7 +14,7 @@ pub struct CompletionParams<'a> { pub position: TextSize, pub schema: &'a pgt_schema_cache::SchemaCache, pub text: String, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, } pub fn complete(params: CompletionParams) -> Vec { diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 8b12742d..775b8870 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -50,7 +50,7 @@ impl TryFrom for ClauseType { pub(crate) struct CompletionContext<'a> { pub ts_node: Option>, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, pub text: &'a str, pub schema_cache: &'a SchemaCache, pub position: usize, @@ -85,10 +85,7 @@ impl<'a> CompletionContext<'a> { } fn gather_info_from_ts_queries(&mut self) { - let tree = match self.tree.as_ref() { - None => return, - Some(t) => t, - }; + let tree = self.tree; let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; @@ -126,11 +123,7 @@ impl<'a> CompletionContext<'a> { } fn gather_tree_context(&mut self) { - if self.tree.is_none() { - return; - } - - let mut cursor = self.tree.as_ref().unwrap().root_node().walk(); + let mut cursor = self.tree.root_node().walk(); /* * The head node of any treesitter tree is always the "PROGRAM" node. @@ -262,7 +255,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -294,7 +287,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -328,7 +321,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -353,7 +346,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -381,7 +374,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -407,7 +400,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -432,7 +425,7 @@ mod tests { let params = crate::CompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: &tree, schema: &pgt_schema_cache::SchemaCache::default(), }; diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index a54aacbd..58e9baf7 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -70,7 +70,7 @@ pub(crate) fn get_test_params<'a>( CompletionParams { position: (position as u32).into(), schema: schema_cache, - tree: Some(tree), + tree, text, } } diff --git a/crates/pgt_lsp/src/handlers/code_actions.rs b/crates/pgt_lsp/src/handlers/code_actions.rs index 0d124cfc..a10bee03 100644 --- a/crates/pgt_lsp/src/handlers/code_actions.rs +++ b/crates/pgt_lsp/src/handlers/code_actions.rs @@ -43,7 +43,7 @@ pub fn get_actions( title: title.clone(), command: command_id, arguments: Some(vec![ - serde_json::Value::Number(stmt_id.into()), + serde_json::to_value(&stmt_id).unwrap(), serde_json::to_value(&url).unwrap(), ]), } @@ -81,17 +81,16 @@ pub async fn execute_command( match command.as_str() { "pgt.executeStatement" => { - let id: usize = serde_json::from_value(params.arguments[0].clone())?; + let statement_id = serde_json::from_value::( + params.arguments[0].clone(), + )?; let doc_url: lsp_types::Url = serde_json::from_value(params.arguments[1].clone())?; let path = session.file_path(&doc_url)?; let result = session .workspace - .execute_statement(ExecuteStatementParams { - statement_id: id, - path, - })?; + .execute_statement(ExecuteStatementParams { statement_id, path })?; /* * Updating all diagnostics: the changes caused by the statement execution diff --git a/crates/pgt_typecheck/src/diagnostics.rs b/crates/pgt_typecheck/src/diagnostics.rs index b443dcc9..8fd92da2 100644 --- a/crates/pgt_typecheck/src/diagnostics.rs +++ b/crates/pgt_typecheck/src/diagnostics.rs @@ -96,7 +96,7 @@ impl Advices for TypecheckAdvices { pub(crate) fn create_type_error( pg_err: &PgDatabaseError, - ts: Option<&tree_sitter::Tree>, + ts: &tree_sitter::Tree, ) -> TypecheckDiagnostic { let position = pg_err.position().and_then(|pos| match pos { sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1), @@ -104,16 +104,14 @@ pub(crate) fn create_type_error( }); let range = position.and_then(|pos| { - ts.and_then(|tree| { - tree.root_node() - .named_descendant_for_byte_range(pos, pos) - .map(|node| { - TextRange::new( - node.start_byte().try_into().unwrap(), - node.end_byte().try_into().unwrap(), - ) - }) - }) + ts.root_node() + .named_descendant_for_byte_range(pos, pos) + .map(|node| { + TextRange::new( + node.start_byte().try_into().unwrap(), + node.end_byte().try_into().unwrap(), + ) + }) }); let severity = match pg_err.severity() { diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 4554689c..9311bb8e 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -13,7 +13,7 @@ pub struct TypecheckParams<'a> { pub conn: &'a PgPool, pub sql: &'a str, pub ast: &'a pgt_query_ext::NodeEnum, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, } #[derive(Debug, Clone)] diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index d0e53b15..46daa8a1 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -21,14 +21,14 @@ async fn test(name: &str, query: &str, setup: &str) { .expect("Error loading sql language"); let root = pgt_query_ext::parse(query).unwrap(); - let tree = parser.parse(query, None); + let tree = parser.parse(query, None).unwrap(); let conn = &test_db; let result = check_sql(TypecheckParams { conn, sql: query, ast: &root, - tree: tree.as_ref(), + tree: &tree, }) .await; diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 4a503d5d..681ab95f 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -21,7 +21,7 @@ use crate::{ mod client; mod server; -pub(crate) use server::StatementId; +pub use server::StatementId; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 8dcbfb1d..27f5e8be 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -2,22 +2,24 @@ use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; use analyser::AnalyserVisitorBuilder; use async_helper::run_async; -use change::StatementChange; use dashmap::DashMap; use db_connection::DbConnection; -pub(crate) use document::StatementId; -use document::{Document, Statement}; +use document::Document; use futures::{StreamExt, stream}; -use pg_query::PgQueryStore; +use parsed_document::{ + AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, + GetCompletionsMapper, ParsedDocument, SyncDiagnosticsMapper, +}; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; -use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; +use pgt_diagnostics::{ + Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, +}; use pgt_fs::{ConfigName, PgTPath}; use pgt_typecheck::TypecheckParams; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; use tracing::info; -use tree_sitter::TreeSitterStore; use crate::{ WorkspaceError, @@ -38,14 +40,19 @@ use super::{ Workspace, }; +pub use statement_identifier::StatementId; + mod analyser; mod async_helper; mod change; mod db_connection; mod document; mod migration; +mod parsed_document; mod pg_query; mod schema_cache_manager; +mod sql_function; +mod statement_identifier; mod tree_sitter; pub(super) struct WorkspaceServer { @@ -55,11 +62,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - /// Stores the document (text content + version number) associated with a URL - documents: DashMap, - - tree_sitter: TreeSitterStore, - pg_query: PgQueryStore, + parsed_documents: DashMap, connection: RwLock, } @@ -81,9 +84,7 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - documents: DashMap::default(), - tree_sitter: TreeSitterStore::new(), - pg_query: PgQueryStore::new(), + parsed_documents: DashMap::default(), schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), } @@ -181,30 +182,21 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - let doc = Document::new(params.path.clone(), params.content, params.version); - - doc.iter_statements_with_text().for_each(|(stmt, content)| { - self.tree_sitter.add_statement(&stmt, content); - self.pg_query.add_statement(&stmt, content); - }); - - self.documents.insert(params.path, doc); + self.parsed_documents + .entry(params.path.clone()) + .or_insert_with(|| { + ParsedDocument::new(params.path.clone(), params.content, params.version) + }); Ok(()) } /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - let (_, doc) = self - .documents + self.parsed_documents .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; - for stmt in doc.iter_statements() { - self.tree_sitter.remove_statement(&stmt); - self.pg_query.remove_statement(&stmt); - } - Ok(()) } @@ -214,53 +206,16 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut doc = self - .documents - .entry(params.path.clone()) - .or_insert(Document::new( - params.path.clone(), - "".to_string(), - params.version, - )); - - for c in &doc.apply_file_change(¶ms) { - match c { - StatementChange::Added(added) => { - tracing::debug!( - "Adding statement: id:{:?}, path:{:?}, text:{:?}", - added.stmt.id, - added.stmt.path.as_os_str().to_str(), - added.text - ); - self.tree_sitter.add_statement(&added.stmt, &added.text); - self.pg_query.add_statement(&added.stmt, &added.text); - } - StatementChange::Deleted(s) => { - tracing::debug!( - "Deleting statement: id:{:?}, path:{:?}", - s.id, - s.path.as_os_str() - ); - self.tree_sitter.remove_statement(s); - self.pg_query.remove_statement(s); - } - StatementChange::Modified(s) => { - tracing::debug!( - "Modifying statement with id {:?} (new id {:?}) in {:?}. Range {:?}, Changed from '{:?}' to '{:?}', changed text: {:?}", - s.old_stmt.id, - s.new_stmt.id, - s.old_stmt.path.as_os_str().to_str(), - s.change_range, - s.old_stmt_text, - s.new_stmt_text, - s.change_text - ); + let mut parser = + self.parsed_documents + .entry(params.path.clone()) + .or_insert(ParsedDocument::new( + params.path.clone(), + "".to_string(), + params.version, + )); - self.tree_sitter.modify_statement(s); - self.pg_query.modify_statement(s); - } - } - } + parser.apply_change(params); Ok(()) } @@ -271,10 +226,10 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .documents + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - Ok(document.content.clone()) + Ok(document.get_document_content().to_string()) } fn is_path_ignored(&self, params: IsPathIgnoredParams) -> Result { @@ -285,17 +240,11 @@ impl Workspace for WorkspaceServer { &self, params: code_actions::CodeActionsParams, ) -> Result { - let doc = self - .documents + let parser = self + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let eligible_statements = doc - .iter_statements_with_text_and_range() - .filter(|(_, range, _)| range.contains(params.cursor_position)); - - let mut actions: Vec = vec![]; - let settings = self .settings .read() @@ -307,20 +256,26 @@ impl Workspace for WorkspaceServer { Some("Statement execution not allowed against database.".into()) }; - for (stmt, _, txt) in eligible_statements { - let title = format!( - "Execute Statement: {}...", - txt.chars().take(50).collect::() - ); - - actions.push(CodeAction { - title, - kind: CodeActionKind::Command(CommandAction { - category: CommandActionCategory::ExecuteStatement(stmt.id), - }), - disabled_reason: disabled_reason.clone(), - }); - } + let actions = parser + .iter_with_filter( + DefaultMapper, + CursorPositionFilter::new(params.cursor_position), + ) + .map(|(stmt, _, txt)| { + let title = format!( + "Execute Statement: {}...", + txt.chars().take(50).collect::() + ); + + CodeAction { + title, + kind: CodeActionKind::Command(CommandAction { + category: CommandActionCategory::ExecuteStatement(stmt), + }), + disabled_reason: disabled_reason.clone(), + } + }) + .collect(); Ok(CodeActionsResult { actions }) } @@ -329,31 +284,25 @@ impl Workspace for WorkspaceServer { &self, params: ExecuteStatementParams, ) -> Result { - let doc = self - .documents + let parser = self + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - if self - .pg_query - .get_ast(&Statement { - path: params.path, - id: params.statement_id, - }) - .is_none() - { + let stmt = parser.find(params.statement_id, ExecuteStatementMapper); + + if stmt.is_none() { return Ok(ExecuteStatementResult { - message: "Statement is invalid.".into(), + message: "Statement was not found in document.".into(), }); }; - let sql: String = match doc.get_txt(params.statement_id) { - Some(txt) => txt, - None => { - return Ok(ExecuteStatementResult { - message: "Statement was not found in document.".into(), - }); - } + let (_id, _range, content, ast) = stmt.unwrap(); + + if ast.is_none() { + return Ok(ExecuteStatementResult { + message: "Statement is invalid.".into(), + }); }; let conn = self.connection.read().unwrap(); @@ -366,7 +315,7 @@ impl Workspace for WorkspaceServer { } }; - let result = run_async(async move { pool.execute(sqlx::query(&sql)).await })??; + let result = run_async(async move { pool.execute(sqlx::query(&content)).await })??; Ok(ExecuteStatementResult { message: format!( @@ -380,13 +329,6 @@ impl Workspace for WorkspaceServer { &self, params: PullDiagnosticsParams, ) -> Result { - // get all statements form the requested document and pull diagnostics out of every - // source - let doc = self - .documents - .get(¶ms.path) - .ok_or(WorkspaceError::not_found())?; - let settings = self.settings(); // create analyser for this run @@ -410,7 +352,14 @@ impl Workspace for WorkspaceServer { filter, }); - let mut diagnostics: Vec = doc.diagnostics().to_vec(); + let parser = self + .parsed_documents + .get(¶ms.path) + .ok_or(WorkspaceError::not_found())?; + + let mut diagnostics: Vec = parser.document_diagnostics().to_vec(); + + // TODO: run this in parallel with rayon based on rayon.count() if let Some(pool) = self .connection @@ -418,29 +367,20 @@ impl Workspace for WorkspaceServer { .expect("DbConnection RwLock panicked") .get_pool() { - let typecheck_params: Vec<_> = doc - .iter_statements_with_text_and_range() - .map(|(stmt, range, text)| { - let ast = self.pg_query.get_ast(&stmt); - let tree = self.tree_sitter.get_parse_tree(&stmt); - (text.to_string(), ast, tree, *range) - }) - .collect(); - - // run diagnostics for each statement in parallel if its mostly i/o work let path_clone = params.path.clone(); + let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); let async_results = run_async(async move { - stream::iter(typecheck_params) - .map(|(text, ast, tree, range)| { + stream::iter(input) + .map(|(_id, range, content, ast, cst)| { let pool = pool.clone(); let path = path_clone.clone(); async move { if let Some(ast) = ast { pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: &text, + sql: &content, ast: &ast, - tree: tree.as_deref(), + tree: &cst, }) .await .map(|d| { @@ -464,45 +404,49 @@ impl Workspace for WorkspaceServer { } } - diagnostics.extend(doc.iter_statements_with_range().flat_map(|(stmt, r)| { - let mut stmt_diagnostics = self.pg_query.get_diagnostics(&stmt); + diagnostics.extend(parser.iter(SyncDiagnosticsMapper).flat_map( + |(_id, range, ast, diag)| { + let mut errors: Vec = vec![]; - let ast = self.pg_query.get_ast(&stmt); + if let Some(diag) = diag { + errors.push(diag.into()); + } - if let Some(ast) = ast { - stmt_diagnostics.extend( - analyser - .run(AnalyserContext { root: &ast }) - .into_iter() - .map(SDiagnostic::new) - .collect::>(), - ); - } + if let Some(ast) = ast { + errors.extend( + analyser + .run(AnalyserContext { root: &ast }) + .into_iter() + .map(Error::from) + .collect::>(), + ); + } - stmt_diagnostics - .into_iter() - .map(|d| { - let severity = d - .category() - .filter(|category| category.name().starts_with("lint/")) - .map_or_else( - || d.severity(), - |category| { - settings - .as_ref() - .get_severity_from_rule_code(category) - .unwrap_or(Severity::Warning) - }, - ); - - SDiagnostic::new( - d.with_file_path(params.path.as_path().display().to_string()) - .with_file_span(r) - .with_severity(severity), - ) - }) - .collect::>() - })); + errors + .into_iter() + .map(|d| { + let severity = d + .category() + .filter(|category| category.name().starts_with("lint/")) + .map_or_else( + || d.severity(), + |category| { + settings + .as_ref() + .get_severity_from_rule_code(category) + .unwrap_or(Severity::Warning) + }, + ); + + SDiagnostic::new( + d.with_file_path(params.path.as_path().display().to_string()) + .with_file_span(range) + .with_severity(severity), + ) + }) + .collect::>() + }, + )); let errors = diagnostics .iter() @@ -525,46 +469,35 @@ impl Workspace for WorkspaceServer { &self, params: GetCompletionsParams, ) -> Result { - let pool = match self.connection.read().unwrap().get_pool() { - Some(pool) => pool, - None => return Ok(CompletionsResult::default()), - }; - - let doc = self - .documents + let parser = self + .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let (statement, stmt_range, text) = match doc - .iter_statements_with_text_and_range() - .find(|(_, r, _)| r.contains(params.position)) - { - Some(s) => s, + let pool = match self.connection.read().unwrap().get_pool() { + Some(pool) => pool, None => return Ok(CompletionsResult::default()), }; - // `offset` is the position in the document, - // but we need the position within the *statement*. - let position = params.position - stmt_range.start(); - - let tree = self.tree_sitter.get_parse_tree(&statement); - - tracing::debug!( - "Found the statement. We're looking for position {:?}. Statement Range {:?} to {:?}. Statement: {:?}", - position, - stmt_range.start(), - stmt_range.end(), - text - ); - let schema_cache = self.schema_cache.load(pool)?; - let items = pgt_completions::complete(pgt_completions::CompletionParams { - position, - schema: schema_cache.as_ref(), - tree: tree.as_deref(), - text: text.to_string(), - }); + let items = parser + .iter_with_filter( + GetCompletionsMapper, + CursorPositionFilter::new(params.position), + ) + .flat_map(|(_id, range, content, cst)| { + // `offset` is the position in the document, + // but we need the position within the *statement*. + let position = params.position - range.start(); + pgt_completions::complete(pgt_completions::CompletionParams { + position, + schema: schema_cache.as_ref(), + tree: &cst, + text: content, + }) + }) + .collect(); Ok(CompletionsResult { items }) } diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 38769e67..afe0eb64 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -3,27 +3,27 @@ use std::ops::{Add, Sub}; use crate::workspace::{ChangeFileParams, ChangeParams}; -use super::{Document, Statement, document}; +use super::{Document, document, statement_identifier::StatementId}; #[derive(Debug, PartialEq, Eq)] pub enum StatementChange { Added(AddedStatement), - Deleted(Statement), + Deleted(StatementId), Modified(ModifiedStatement), } #[derive(Debug, PartialEq, Eq)] pub struct AddedStatement { - pub stmt: Statement, + pub stmt: StatementId, pub text: String, } #[derive(Debug, PartialEq, Eq)] pub struct ModifiedStatement { - pub old_stmt: Statement, + pub old_stmt: StatementId, pub old_stmt_text: String, - pub new_stmt: Statement, + pub new_stmt: StatementId, pub new_stmt_text: String, pub change_range: TextRange, @@ -32,7 +32,7 @@ pub struct ModifiedStatement { impl StatementChange { #[allow(dead_code)] - pub fn statement(&self) -> &Statement { + pub fn statement(&self) -> &StatementId { match self { StatementChange::Added(stmt) => &stmt.stmt, StatementChange::Deleted(stmt) => stmt, @@ -78,12 +78,7 @@ impl Document { fn drain_positions(&mut self) -> Vec { self.positions .drain(..) - .map(|(id, _)| { - StatementChange::Deleted(Statement { - id, - path: self.path.clone(), - }) - }) + .map(|(id, _)| StatementChange::Deleted(id)) .collect() } @@ -109,28 +104,22 @@ impl Document { changes.extend(ranges.into_iter().map(|range| { let id = self.id_generator.next(); let text = self.content[range].to_string(); - self.positions.push((id, range)); + self.positions.push((id.clone(), range)); - StatementChange::Added(AddedStatement { - stmt: Statement { - path: self.path.clone(), - id, - }, - text, - }) + StatementChange::Added(AddedStatement { stmt: id, text }) })); changes } - fn insert_statement(&mut self, range: TextRange) -> usize { + fn insert_statement(&mut self, range: TextRange) -> StatementId { let pos = self .positions .binary_search_by(|(_, r)| r.start().cmp(&range.start())) .unwrap_err(); let new_id = self.id_generator.next(); - self.positions.insert(pos, (new_id, range)); + self.positions.insert(pos, (new_id.clone(), range)); new_id } @@ -272,25 +261,19 @@ impl Document { if new_ranges.len() == 1 { let affected_idx = affected_indices[0]; let new_range = new_ranges[0].add(affected_range.start()); - let (old_id, old_range) = self.positions[affected_idx]; + let (old_id, old_range) = self.positions[affected_idx].clone(); // move all statements after the affected range self.move_ranges(old_range.end(), change.diff_size(), change.is_addition()); let new_id = self.id_generator.next(); - self.positions[affected_idx] = (new_id, new_range); + self.positions[affected_idx] = (new_id.clone(), new_range); changed.push(StatementChange::Modified(ModifiedStatement { - old_stmt: Statement { - id: old_id, - path: self.path.clone(), - }, + old_stmt: old_id.clone(), old_stmt_text: previous_content[old_range].to_string(), - new_stmt: Statement { - id: new_id, - path: self.path.clone(), - }, + new_stmt: new_id, new_stmt_text: changed_content[new_ranges[0]].to_string(), // change must be relative to the statement change_text: change.text.clone(), @@ -324,24 +307,19 @@ impl Document { // delete and add new ones if let Some(next_index) = next_index { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[next_index].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted( + self.positions[next_index].0.clone(), + )); self.positions.remove(next_index); } for idx in affected_indices.iter().rev() { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[*idx].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted(self.positions[*idx].0.clone())); self.positions.remove(*idx); } if let Some(prev_index) = prev_index { - changed.push(StatementChange::Deleted(Statement { - id: self.positions[prev_index].0, - path: self.path.clone(), - })); + changed.push(StatementChange::Deleted( + self.positions[prev_index].0.clone(), + )); self.positions.remove(prev_index); } @@ -349,10 +327,7 @@ impl Document { let actual_range = range.add(full_affected_range.start()); let new_id = self.insert_statement(actual_range); changed.push(StatementChange::Added(AddedStatement { - stmt: Statement { - id: new_id, - path: self.path.clone(), - }, + stmt: new_id, text: new_content[actual_range].to_string(), })); }); @@ -429,11 +404,12 @@ fn get_affected(content: &str, range: TextRange) -> &str { #[cfg(test)] mod tests { + use super::*; use pgt_diagnostics::Diagnostic; use pgt_text_size::TextRange; - use crate::workspace::{ChangeFileParams, ChangeParams}; + use crate::workspace::{ChangeFileParams, ChangeParams, server::statement_identifier::root_id}; use pgt_fs::PgTPath; @@ -466,7 +442,7 @@ mod tests { fn open_doc_with_scan_error() { let input = "select id from users;\n\n\n\nselect 1443ddwwd33djwdkjw13331333333333;"; - let d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -477,7 +453,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect 1;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); assert!(!d.has_fatal_error()); @@ -515,7 +491,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect 1;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); assert!(!d.has_fatal_error()); @@ -553,7 +529,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select 1d;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -587,7 +563,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select 1d;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 0); assert!(d.has_fatal_error()); @@ -620,7 +596,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\n\n\n\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -658,7 +634,7 @@ mod tests { fn within_statements_2() { let path = PgTPath::new("test.sql"); let input = "alter table deal alter column value drop not null;\n"; - let mut d = Document::new(path.clone(), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 1); @@ -735,7 +711,7 @@ mod tests { fn julians_sample() { let path = PgTPath::new("test.sql"); let input = "select\n *\nfrom\n test;\n\nselect\n\nalter table test\n\ndrop column id;"; - let mut d = Document::new(path.clone(), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 4); @@ -817,7 +793,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -833,14 +809,13 @@ mod tests { let changed = d.apply_file_change(&change); assert_eq!(changed.len(), 4); - assert!(matches!( - changed[0], - StatementChange::Deleted(Statement { id: 1, .. }) - )); + assert!(matches!(changed[0], StatementChange::Deleted(_))); + assert_eq!(changed[0].statement().raw(), 1); assert!(matches!( changed[1], - StatementChange::Deleted(Statement { id: 0, .. }) + StatementChange::Deleted(StatementId::Root(_)) )); + assert_eq!(changed[1].statement().raw(), 0); assert!( matches!(&changed[2], StatementChange::Added(AddedStatement { stmt: _, text }) if text == "select id,test from users;") ); @@ -856,7 +831,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 1); @@ -881,7 +856,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from users;\nselect * from contacts;"; - let mut d = Document::new(PgTPath::new("test.sql"), input.to_string(), 0); + let mut d = Document::new(input.to_string(), 0); assert_eq!(d.positions.len(), 2); @@ -898,37 +873,27 @@ mod tests { assert_eq!(changed.len(), 4); - assert_eq!( + assert!(matches!( changed[0], - StatementChange::Deleted(Statement { - path: path.clone(), - id: 1 - }) - ); - assert_eq!( + StatementChange::Deleted(StatementId::Root(_)) + )); + assert_eq!(changed[0].statement().raw(), 1); + assert!(matches!( changed[1], - StatementChange::Deleted(Statement { - path: path.clone(), - id: 0 - }) - ); + StatementChange::Deleted(StatementId::Root(_)) + )); + assert_eq!(changed[1].statement().raw(), 0); assert_eq!( changed[2], StatementChange::Added(AddedStatement { - stmt: Statement { - path: path.clone(), - id: 2 - }, + stmt: StatementId::Root(root_id(2)), text: "select id,test from users".to_string() }) ); assert_eq!( changed[3], StatementChange::Added(AddedStatement { - stmt: Statement { - path: path.clone(), - id: 3 - }, + stmt: StatementId::Root(root_id(3)), text: "select 1;".to_string() }) ); @@ -943,7 +908,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "\n"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 0); @@ -983,7 +948,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from\nselect * from contacts;"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 2); @@ -1014,7 +979,7 @@ mod tests { fn apply_changes_replacement() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "".to_string(), 0); + let mut doc = Document::new("".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1136,7 +1101,6 @@ mod tests { let path = PgTPath::new("test.sql"); let mut doc = Document::new( - path.clone(), "-- Add new schema named \"private\"\nCREATE SCHEMA \"private\";".to_string(), 0, ); @@ -1156,12 +1120,8 @@ mod tests { doc.content, "- Add new schema named \"private\"\nCREATE SCHEMA \"private\";" ); - assert_eq!(changed.len(), 3); - assert!(matches!( - changed[0], - StatementChange::Deleted(Statement { id: 0, .. }) - )); + assert!(matches!(&changed[0], StatementChange::Deleted(_))); assert!(matches!( changed[1], StatementChange::Added(AddedStatement { .. }) @@ -1190,11 +1150,11 @@ mod tests { assert_eq!(changed_2.len(), 3); assert!(matches!( changed_2[0], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( changed_2[1], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( changed_2[2], @@ -1209,12 +1169,12 @@ mod tests { let input = "select id from users;\nselect * from contacts;"; let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), input.to_string(), 0); + let mut doc = Document::new(input.to_string(), 0); assert_eq!(doc.positions.len(), 2); - let stmt_1_range = doc.positions[0]; - let stmt_2_range = doc.positions[1]; + let stmt_1_range = doc.positions[0].clone(); + let stmt_2_range = doc.positions[1].clone(); let update_text = ",test"; @@ -1261,7 +1221,7 @@ mod tests { let path = PgTPath::new("test.sql"); let input = "select id from contacts;\n\nselect * from contacts;"; - let mut d = Document::new(path.clone(), input.to_string(), 1); + let mut d = Document::new(input.to_string(), 1); assert_eq!(d.positions.len(), 2); @@ -1310,7 +1270,7 @@ mod tests { assert!(matches!( changes[0], - StatementChange::Deleted(Statement { .. }) + StatementChange::Deleted(StatementId::Root(_)) )); assert!(matches!( @@ -1332,7 +1292,7 @@ mod tests { fn remove_trailing_whitespace() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from ".to_string(), 0); + let mut doc = Document::new("select * from ".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1378,7 +1338,7 @@ mod tests { fn remove_trailing_whitespace_and_last_char() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from ".to_string(), 0); + let mut doc = Document::new("select * from ".to_string(), 0); let change = ChangeFileParams { path: path.clone(), @@ -1424,7 +1384,7 @@ mod tests { fn remove_inbetween_whitespace() { let path = PgTPath::new("test.sql"); - let mut doc = Document::new(path.clone(), "select * from users".to_string(), 0); + let mut doc = Document::new("select * from users".to_string(), 0); let change = ChangeFileParams { path: path.clone(), diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 9ef8c234..f2c500cc 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -1,22 +1,11 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; -use pgt_fs::PgTPath; use pgt_text_size::{TextRange, TextSize}; -/// Global unique identifier for a statement -#[derive(Debug, Hash, Eq, PartialEq, Clone)] -pub(crate) struct Statement { - /// Path of the document - pub(crate) path: PgTPath, - /// Unique id within the document - pub(crate) id: StatementId, -} - -pub(crate) type StatementId = usize; +use super::statement_identifier::{StatementId, StatementIdGenerator}; type StatementPos = (StatementId, TextRange); pub(crate) struct Document { - pub(crate) path: PgTPath, pub(crate) content: String, pub(crate) version: i32, @@ -24,17 +13,16 @@ pub(crate) struct Document { /// List of statements sorted by range.start() pub(super) positions: Vec, - pub(super) id_generator: IdGenerator, + pub(super) id_generator: StatementIdGenerator, } impl Document { - pub(crate) fn new(path: PgTPath, content: String, version: i32) -> Self { - let mut id_generator = IdGenerator::new(); + pub(crate) fn new(content: String, version: i32) -> Self { + let mut id_generator = StatementIdGenerator::new(); let (ranges, diagnostics) = split_with_diagnostics(&content, None); Self { - path, positions: ranges .into_iter() .map(|range| (id_generator.next(), range)) @@ -42,15 +30,10 @@ impl Document { content, version, diagnostics, - id_generator, } } - pub fn diagnostics(&self) -> &[SDiagnostic] { - &self.diagnostics - } - /// Returns true if there is at least one fatal error in the diagnostics /// /// A fatal error is a scan error that prevents the document from being used @@ -60,74 +43,8 @@ impl Document { .any(|d| d.severity() == Severity::Fatal) } - pub fn iter_statements(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, _)| Statement { - id: *id, - path: self.path.clone(), - }) - } - - pub fn iter_statements_with_text(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - let text = &self.content[range.start().into()..range.end().into()]; - (statement, text) - }) - } - - pub fn iter_statements_with_range(&self) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - (statement, range) - }) - } - - pub fn iter_statements_with_text_and_range( - &self, - ) -> impl Iterator + '_ { - self.positions.iter().map(move |(id, range)| { - let statement = Statement { - id: *id, - path: self.path.clone(), - }; - ( - statement, - range, - &self.content[range.start().into()..range.end().into()], - ) - }) - } - - pub fn get_txt(&self, stmt_id: StatementId) -> Option { - self.positions - .iter() - .find(|pos| pos.0 == stmt_id) - .map(|(_, range)| { - let stmt = &self.content[range.start().into()..range.end().into()]; - stmt.to_owned() - }) - } -} - -pub(crate) struct IdGenerator { - pub(super) next_id: usize, -} - -impl IdGenerator { - fn new() -> Self { - Self { next_id: 0 } - } - - pub(super) fn next(&mut self) -> usize { - let id = self.next_id; - self.next_id += 1; - id + pub fn iter<'a>(&'a self) -> StatementIterator<'a> { + StatementIterator::new(self) } } @@ -165,3 +82,30 @@ pub(crate) fn split_with_diagnostics( ), } } + +pub struct StatementIterator<'a> { + document: &'a Document, + positions: std::slice::Iter<'a, StatementPos>, +} + +impl<'a> StatementIterator<'a> { + pub fn new(document: &'a Document) -> Self { + Self { + document, + positions: document.positions.iter(), + } + } +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = (StatementId, TextRange, &'a str); + + fn next(&mut self) -> Option { + self.positions.next().map(|(id, range)| { + let range = *range; + let doc = self.document; + let id = id.clone(); + (id, range, &doc.content[range]) + }) + } +} diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs new file mode 100644 index 00000000..a110fb1f --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -0,0 +1,374 @@ +use std::sync::Arc; + +use pgt_diagnostics::serde::Diagnostic as SDiagnostic; +use pgt_fs::PgTPath; +use pgt_query_ext::diagnostics::SyntaxDiagnostic; +use pgt_text_size::{TextRange, TextSize}; + +use crate::workspace::ChangeFileParams; + +use super::{ + change::StatementChange, + document::{Document, StatementIterator}, + pg_query::PgQueryStore, + sql_function::SQLFunctionBodyStore, + statement_identifier::StatementId, + tree_sitter::TreeSitterStore, +}; + +pub struct ParsedDocument { + #[allow(dead_code)] + path: PgTPath, + + doc: Document, + ast_db: PgQueryStore, + cst_db: TreeSitterStore, + sql_fn_db: SQLFunctionBodyStore, +} + +impl ParsedDocument { + pub fn new(path: PgTPath, content: String, version: i32) -> ParsedDocument { + let doc = Document::new(content, version); + + let cst_db = TreeSitterStore::new(); + let ast_db = PgQueryStore::new(); + let sql_fn_db = SQLFunctionBodyStore::new(); + + doc.iter().for_each(|(stmt, _, content)| { + cst_db.add_statement(&stmt, content); + }); + + ParsedDocument { + path, + doc, + ast_db, + cst_db, + sql_fn_db, + } + } + + /// Applies a change to the document and updates the CST and AST databases accordingly. + /// + /// Note that only tree-sitter cares about statement modifications vs remove + add. + /// Hence, we just clear the AST for the old statements and lazily load them when requested. + /// + /// * `params`: ChangeFileParams - The parameters for the change to be applied. + pub fn apply_change(&mut self, params: ChangeFileParams) { + for c in &self.doc.apply_file_change(¶ms) { + match c { + StatementChange::Added(added) => { + tracing::debug!( + "Adding statement: id:{:?}, text:{:?}", + added.stmt, + added.text + ); + self.cst_db.add_statement(&added.stmt, &added.text); + } + StatementChange::Deleted(s) => { + tracing::debug!("Deleting statement: id {:?}", s,); + self.cst_db.remove_statement(s); + self.ast_db.clear_statement(s); + self.sql_fn_db.clear_statement(s); + } + StatementChange::Modified(s) => { + tracing::debug!( + "Modifying statement with id {:?} (new id {:?}). Range {:?}, Changed from '{:?}' to '{:?}', changed text: {:?}", + s.old_stmt, + s.new_stmt, + s.change_range, + s.old_stmt_text, + s.new_stmt_text, + s.change_text + ); + + self.cst_db.modify_statement(s); + self.ast_db.clear_statement(&s.old_stmt); + self.sql_fn_db.clear_statement(&s.old_stmt); + } + } + } + } + + pub fn get_document_content(&self) -> &str { + &self.doc.content + } + + pub fn document_diagnostics(&self) -> &Vec { + &self.doc.diagnostics + } + + pub fn find<'a, M>(&'a self, id: StatementId, mapper: M) -> Option + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, IdFilter::new(id)).next() + } + + pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, NoFilter> + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, NoFilter) + } + + pub fn iter_with_filter<'a, M, F>(&'a self, mapper: M, filter: F) -> ParseIterator<'a, M, F> + where + M: StatementMapper<'a>, + F: StatementFilter<'a>, + { + ParseIterator::new(self, mapper, filter) + } + + #[allow(dead_code)] + pub fn count(&self) -> usize { + self.iter(DefaultMapper).count() + } +} + +pub trait StatementMapper<'a> { + type Output; + + fn map( + &self, + parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output; +} + +pub trait StatementFilter<'a> { + fn predicate(&self, id: &StatementId, range: &TextRange) -> bool; +} + +pub struct ParseIterator<'a, M, F> { + parser: &'a ParsedDocument, + statements: StatementIterator<'a>, + mapper: M, + filter: F, + pending_sub_statements: Vec<(StatementId, TextRange, String)>, +} + +impl<'a, M, F> ParseIterator<'a, M, F> { + pub fn new(parser: &'a ParsedDocument, mapper: M, filter: F) -> Self { + Self { + parser, + statements: parser.doc.iter(), + mapper, + filter, + pending_sub_statements: Vec::new(), + } + } +} + +impl<'a, M, F> Iterator for ParseIterator<'a, M, F> +where + M: StatementMapper<'a>, + F: StatementFilter<'a>, +{ + type Item = M::Output; + + fn next(&mut self) -> Option { + // First check if we have any pending sub-statements to process + if let Some((id, range, content)) = self.pending_sub_statements.pop() { + if self.filter.predicate(&id, &range) { + return Some(self.mapper.map(self.parser, id, range, &content)); + } + // If the sub-statement doesn't pass the filter, continue to the next item + return self.next(); + } + + // Process the next top-level statement + let next_statement = self.statements.next(); + + if let Some((root_id, range, content)) = next_statement { + // If we should include sub-statements and this statement has an AST + let content_owned = content.to_string(); + if let Ok(ast) = self + .parser + .ast_db + .get_or_cache_ast(&root_id, &content_owned) + .as_ref() + { + // Check if this is a SQL function definition with a body + if let Some(sub_statement) = + self.parser + .sql_fn_db + .get_function_body(&root_id, ast, &content_owned) + { + // Add sub-statements to our pending queue + self.pending_sub_statements.push(( + root_id.create_child(), + // adjust range to document + sub_statement.range + range.start(), + sub_statement.body.clone(), + )); + } + } + + // Return the current statement if it passes the filter + if self.filter.predicate(&root_id, &range) { + return Some(self.mapper.map(self.parser, root_id, range, content)); + } + + // If the current statement doesn't pass the filter, try the next one + return self.next(); + } + + None + } +} + +pub struct DefaultMapper; +impl<'a> StatementMapper<'a> for DefaultMapper { + type Output = (StatementId, TextRange, String); + + fn map( + &self, + _parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output { + (id, range, content.to_string()) + } +} + +pub struct ExecuteStatementMapper; +impl<'a> StatementMapper<'a> for ExecuteStatementMapper { + type Output = ( + StatementId, + TextRange, + String, + Option, + ); + + fn map( + &self, + parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output { + let ast_result = parser.ast_db.get_or_cache_ast(&id, content); + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + (id, range, content.to_string(), ast_option) + } +} + +pub struct AsyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { + type Output = ( + StatementId, + TextRange, + String, + Option, + Arc, + ); + + fn map( + &self, + parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output { + let content_owned = content.to_string(); + let ast_result = parser.ast_db.get_or_cache_ast(&id, &content_owned); + + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned); + + (id, range, content_owned, ast_option, cst_result) + } +} + +pub struct SyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { + type Output = ( + StatementId, + TextRange, + Option, + Option, + ); + + fn map( + &self, + parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output { + let ast_result = parser.ast_db.get_or_cache_ast(&id, content); + + let (ast_option, diagnostics) = match &*ast_result { + Ok(node) => (Some(node.clone()), None), + Err(diag) => (None, Some(diag.clone())), + }; + + (id, range, ast_option, diagnostics) + } +} + +pub struct GetCompletionsMapper; +impl<'a> StatementMapper<'a> for GetCompletionsMapper { + type Output = (StatementId, TextRange, String, Arc); + + fn map( + &self, + parser: &'a ParsedDocument, + id: StatementId, + range: TextRange, + content: &str, + ) -> Self::Output { + let cst_result = parser.cst_db.get_or_cache_tree(&id, content); + (id, range, content.to_string(), cst_result) + } +} + +pub struct NoFilter; +impl<'a> StatementFilter<'a> for NoFilter { + fn predicate(&self, _id: &StatementId, _range: &TextRange) -> bool { + true + } +} + +pub struct CursorPositionFilter { + pos: TextSize, +} + +impl CursorPositionFilter { + pub fn new(pos: TextSize) -> Self { + Self { pos } + } +} + +impl<'a> StatementFilter<'a> for CursorPositionFilter { + fn predicate(&self, _id: &StatementId, range: &TextRange) -> bool { + range.contains(self.pos) + } +} + +pub struct IdFilter { + id: StatementId, +} + +impl IdFilter { + pub fn new(id: StatementId) -> Self { + Self { id } + } +} + +impl<'a> StatementFilter<'a> for IdFilter { + fn predicate(&self, id: &StatementId, _range: &TextRange) -> bool { + *id == self.id + } +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 3ed452fc..e5c0cac8 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -1,52 +1,38 @@ use std::sync::Arc; use dashmap::DashMap; -use pgt_diagnostics::serde::Diagnostic as SDiagnostic; use pgt_query_ext::diagnostics::*; -use super::{change::ModifiedStatement, document::Statement}; +use super::statement_identifier::StatementId; pub struct PgQueryStore { - ast_db: DashMap>, - diagnostics: DashMap, + db: DashMap>>, } impl PgQueryStore { pub fn new() -> PgQueryStore { - PgQueryStore { - ast_db: DashMap::new(), - diagnostics: DashMap::new(), - } - } - - pub fn get_ast(&self, statement: &Statement) -> Option> { - self.ast_db.get(statement).map(|x| x.clone()) + PgQueryStore { db: DashMap::new() } } - pub fn add_statement(&self, statement: &Statement, content: &str) { - let r = pgt_query_ext::parse(content); - if let Ok(ast) = r { - self.ast_db.insert(statement.clone(), Arc::new(ast)); - } else { - tracing::info!("invalid statement, adding diagnostics."); - self.diagnostics - .insert(statement.clone(), SyntaxDiagnostic::from(r.unwrap_err())); + pub fn get_or_cache_ast( + &self, + statement: &StatementId, + content: &str, + ) -> Arc> { + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; } - } - pub fn remove_statement(&self, statement: &Statement) { - self.ast_db.remove(statement); - self.diagnostics.remove(statement); + let r = Arc::new(pgt_query_ext::parse(content).map_err(SyntaxDiagnostic::from)); + self.db.insert(statement.clone(), r.clone()); + r } - pub fn modify_statement(&self, change: &ModifiedStatement) { - self.remove_statement(&change.old_stmt); - self.add_statement(&change.new_stmt, &change.new_stmt_text); - } + pub fn clear_statement(&self, id: &StatementId) { + self.db.remove(id); - pub fn get_diagnostics(&self, stmt: &Statement) -> Vec { - self.diagnostics - .get(stmt) - .map_or_else(Vec::new, |err| vec![SDiagnostic::new(err.value().clone())]) + if let Some(child_id) = id.get_child_id() { + self.db.remove(&child_id); + } } } diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs new file mode 100644 index 00000000..3273466d --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use pgt_text_size::TextRange; + +use super::statement_identifier::StatementId; + +pub struct SQLFunctionBody { + pub range: TextRange, + pub body: String, +} + +pub struct SQLFunctionBodyStore { + db: DashMap>>, +} + +impl SQLFunctionBodyStore { + pub fn new() -> SQLFunctionBodyStore { + SQLFunctionBodyStore { db: DashMap::new() } + } + + pub fn get_function_body( + &self, + statement: &StatementId, + ast: &pgt_query_ext::NodeEnum, + content: &str, + ) -> Option> { + // First check if we already have this statement cached + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; + } + + // If not cached, try to extract it from the AST + let fn_body = get_sql_fn(ast, content).map(Arc::new); + + // Cache the result and return it + self.db.insert(statement.clone(), fn_body.clone()); + fn_body + } + + pub fn clear_statement(&self, id: &StatementId) { + self.db.remove(id); + + if let Some(child_id) = id.get_child_id() { + self.db.remove(&child_id); + } + } +} + +/// Extracts SQL function body and its text range from a CreateFunctionStmt node. +/// Returns None if the function is not an SQL function or if the body can't be found. +fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { + let create_fn = match ast { + pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, + _ => return None, + }; + + // Extract language from function options + let language = find_option_value(create_fn, "language")?; + + // Only process SQL functions + if language != "sql" { + return None; + } + + // Extract SQL body from function options + let sql_body = find_option_value(create_fn, "as")?; + + // Find the range of the SQL body in the content + let start = content.find(&sql_body)?; + let end = start + sql_body.len(); + + let range = TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()); + + Some(SQLFunctionBody { + range, + body: sql_body.clone(), + }) +} + +/// Helper function to find a specific option value from function options +fn find_option_value( + create_fn: &pgt_query_ext::protobuf::CreateFunctionStmt, + option_name: &str, +) -> Option { + create_fn + .options + .iter() + .filter_map(|opt_wrapper| opt_wrapper.node.as_ref()) + .find_map(|opt| { + if let pgt_query_ext::NodeEnum::DefElem(def_elem) = opt { + if def_elem.defname == option_name { + def_elem + .arg + .iter() + .filter_map(|arg_wrapper| arg_wrapper.node.as_ref()) + .find_map(|arg| { + if let pgt_query_ext::NodeEnum::String(s) = arg { + Some(s.sval.clone()) + } else { + None + } + }) + } else { + None + } + } else { + None + } + }) +} diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs new file mode 100644 index 00000000..0739fb2f --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct RootId { + inner: usize, +} + +#[cfg(test)] +pub fn root_id(inner: usize) -> RootId { + RootId { inner } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +/// `StatementId` can represent IDs for nested statements. +/// +/// For example, an SQL function really consist of two statements; the function creation +/// and the body: +/// +/// ```sql +/// create or replace function get_product_name(product_id INT) -- the root statement +/// returns varchar as $$ +/// select * from … -- the child statement +/// $$ LANGUAGE plpgsql; +/// ``` +/// +/// For now, we only support SQL functions – no complex, nested statements. +/// +/// An SQL function only ever has ONE child, that's why the inner `RootId` of a `Root` +/// is the same as the one of its `Child`. +pub enum StatementId { + Root(RootId), + // StatementId is the same as the root id since we can only have a single sql function body per Root + Child(RootId), +} + +impl Default for StatementId { + fn default() -> Self { + StatementId::Root(RootId { inner: 0 }) + } +} + +impl StatementId { + pub fn raw(&self) -> usize { + match self { + StatementId::Root(s) => s.inner, + StatementId::Child(s) => s.inner, + } + } +} + +/// Helper struct to generate unique statement ids +pub struct StatementIdGenerator { + next_id: usize, +} + +impl StatementIdGenerator { + pub fn new() -> Self { + Self { next_id: 0 } + } + + pub fn next(&mut self) -> StatementId { + let id = self.next_id; + self.next_id += 1; + StatementId::Root(RootId { inner: id }) + } +} + +impl StatementId { + /// Use this to get the matching `StatementId::Child` for + /// a `StatementId::Root`. + /// If the `StatementId` was already a `Child`, this will return `None`. + /// It is not guaranteed that the `Root` actually has a `Child` statement in the workspace. + pub fn get_child_id(&self) -> Option { + match self { + StatementId::Root(id) => Some(StatementId::Child(RootId { inner: id.inner })), + StatementId::Child(_) => None, + } + } + + /// Use this if you need to create a matching `StatementId::Child` for `Root`. + /// You cannot create a `Child` of a `Child`. + pub fn create_child(&self) -> StatementId { + match self { + StatementId::Root(id) => StatementId::Child(RootId { inner: id.inner }), + StatementId::Child(_) => panic!("Cannot create child from a child statement id"), + } + } +} diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index 09cff74c..a8932535 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -1,14 +1,13 @@ -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex}; use dashmap::DashMap; use tree_sitter::InputEdit; -use super::{change::ModifiedStatement, document::Statement}; +use super::{change::ModifiedStatement, statement_identifier::StatementId}; pub struct TreeSitterStore { - db: DashMap>, - - parser: RwLock, + db: DashMap>, + parser: Mutex, } impl TreeSitterStore { @@ -20,24 +19,38 @@ impl TreeSitterStore { TreeSitterStore { db: DashMap::new(), - parser: RwLock::new(parser), + parser: Mutex::new(parser), } } - pub fn get_parse_tree(&self, statement: &Statement) -> Option> { - self.db.get(statement).map(|x| x.clone()) + pub fn get_or_cache_tree( + &self, + statement: &StatementId, + content: &str, + ) -> Arc { + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; + } + + let mut parser = self.parser.lock().expect("Failed to lock parser"); + let tree = Arc::new(parser.parse(content, None).unwrap()); + self.db.insert(statement.clone(), tree.clone()); + + tree } - pub fn add_statement(&self, statement: &Statement, content: &str) { - let mut guard = self.parser.write().expect("Error reading parser"); - // todo handle error - let tree = guard.parse(content, None).unwrap(); - drop(guard); + pub fn add_statement(&self, statement: &StatementId, content: &str) { + let mut parser = self.parser.lock().expect("Failed to lock parser"); + let tree = parser.parse(content, None).unwrap(); self.db.insert(statement.clone(), Arc::new(tree)); } - pub fn remove_statement(&self, statement: &Statement) { - self.db.remove(statement); + pub fn remove_statement(&self, id: &StatementId) { + self.db.remove(id); + + if let Some(child_id) = id.get_child_id() { + self.db.remove(&child_id); + } } pub fn modify_statement(&self, change: &ModifiedStatement) { @@ -61,18 +74,17 @@ impl TreeSitterStore { tree.edit(&edit); - let mut guard = self.parser.write().expect("Error reading parser"); + let mut parser = self.parser.lock().expect("Failed to lock parser"); // todo handle error self.db.insert( change.new_stmt.clone(), - Arc::new(guard.parse(&change.new_stmt_text, Some(&tree)).unwrap()), + Arc::new(parser.parse(&change.new_stmt_text, Some(&tree)).unwrap()), ); - drop(guard); } } // Converts character positions and replacement text into a tree-sitter InputEdit -fn edit_from_change( +pub(crate) fn edit_from_change( text: &str, start_char: usize, end_char: usize,