Skip to content

Commit beeb661

Browse files
committed
fix: sql fn params
1 parent a5ba9cb commit beeb661

File tree

6 files changed

+172
-6
lines changed

6 files changed

+172
-6
lines changed

crates/pgt_typecheck/src/lib.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
mod diagnostics;
2+
mod typed_identifier;
23

3-
pub use diagnostics::TypecheckDiagnostic;
44
use diagnostics::create_type_error;
5+
pub use diagnostics::TypecheckDiagnostic;
56
use pgt_text_size::TextRange;
67
use sqlx::postgres::PgDatabaseError;
78
pub use sqlx::postgres::PgSeverity;
89
use sqlx::{Executor, PgPool};
10+
use typed_identifier::{apply_identifiers, TypedIdentifier};
911

1012
#[derive(Debug)]
1113
pub struct TypecheckParams<'a> {
1214
pub conn: &'a PgPool,
1315
pub sql: &'a str,
1416
pub ast: &'a pgt_query_ext::NodeEnum,
1517
pub tree: &'a tree_sitter::Tree,
18+
pub schema_cache: &'a pgt_schema_cache::SchemaCache,
19+
pub cst: &'a tree_sitter::Node<'a>,
20+
pub identifiers: Vec<TypedIdentifier>,
1621
}
1722

1823
#[derive(Debug, Clone)]
@@ -51,7 +56,14 @@ pub async fn check_sql(
5156
// each typecheck operation.
5257
conn.close_on_drop();
5358

54-
let res = conn.prepare(params.sql).await;
59+
let prepared = apply_identifiers(
60+
params.identifiers,
61+
params.schema_cache,
62+
params.cst,
63+
params.sql,
64+
);
65+
66+
let res = conn.prepare(prepared).await;
5567

5668
match res {
5769
Ok(_) => Ok(None),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#[derive(Debug)]
2+
pub struct TypedIdentifier {
3+
pub schema: Option<String>,
4+
pub relation: String,
5+
pub name: String,
6+
pub type_: String,
7+
}
8+
9+
/// Applies the identifiers to the SQL string by replacing them with their default values.
10+
pub fn apply_identifiers<'a>(
11+
identifiers: Vec<TypedIdentifier>,
12+
schema_cache: &'a pgt_schema_cache::SchemaCache,
13+
cst: &'a tree_sitter::Node<'a>,
14+
sql: &'a str,
15+
) -> &'a str {
16+
// TODO
17+
sql
18+
}

crates/pgt_workspace/src/workspace/server/document.rs

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ impl Document {
3434
}
3535
}
3636

37+
pub fn statement_content(&self, id: &StatementId) -> Option<&str> {
38+
self.positions
39+
.iter()
40+
.find(|(statement_id, _)| statement_id == id)
41+
.map(|(_, range)| &self.content[*range])
42+
}
43+
3744
/// Returns true if there is at least one fatal error in the diagnostics
3845
///
3946
/// A fatal error is a scan error that prevents the document from being used

crates/pgt_workspace/src/workspace/server/parsed_document.rs

+23-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use super::{
1212
change::StatementChange,
1313
document::{Document, StatementIterator},
1414
pg_query::PgQueryStore,
15-
sql_function::SQLFunctionBodyStore,
15+
sql_function::{SQLFunctionBodyStore, SQLFunctionSignature},
1616
statement_identifier::StatementId,
1717
tree_sitter::TreeSitterStore,
1818
};
@@ -274,6 +274,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper {
274274
String,
275275
Option<pgt_query_ext::NodeEnum>,
276276
Arc<tree_sitter::Tree>,
277+
Option<Arc<SQLFunctionSignature>>,
277278
);
278279

279280
fn map(
@@ -293,7 +294,26 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper {
293294

294295
let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned);
295296

296-
(id, range, content_owned, ast_option, cst_result)
297+
let sql_fn_sig = id
298+
.parent()
299+
.and_then(|root| {
300+
let c = parser.doc.statement_content(&root)?;
301+
Some((root, c))
302+
})
303+
.and_then(|(root, c)| {
304+
let ast_option = parser
305+
.ast_db
306+
.get_or_cache_ast(&root, c)
307+
.as_ref()
308+
.clone()
309+
.ok();
310+
311+
let ast_option = ast_option.as_ref()?;
312+
313+
parser.sql_fn_db.get_function_signature(&root, ast_option)
314+
});
315+
316+
(id, range, content_owned, ast_option, cst_result, sql_fn_sig)
297317
}
298318
}
299319

@@ -413,7 +433,7 @@ mod tests {
413433

414434
#[test]
415435
fn sql_function_body() {
416-
let input = "CREATE FUNCTION add(integer, integer) RETURNS integer
436+
let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer
417437
AS 'select $1 + $2;'
418438
LANGUAGE SQL
419439
IMMUTABLE

crates/pgt_workspace/src/workspace/server/sql_function.rs

+95-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,18 @@ use pgt_text_size::TextRange;
55

66
use super::statement_identifier::StatementId;
77

8+
#[derive(Debug, Clone)]
9+
pub struct SQLFunctionArgs {
10+
pub name: Option<String>,
11+
pub type_: (Option<String>, String),
12+
}
13+
14+
#[derive(Debug, Clone)]
15+
pub struct SQLFunctionSignature {
16+
pub name: (Option<String>, String),
17+
pub args: Vec<SQLFunctionArgs>,
18+
}
19+
820
#[derive(Debug, Clone)]
921
pub struct SQLFunctionBody {
1022
pub range: TextRange,
@@ -13,11 +25,33 @@ pub struct SQLFunctionBody {
1325

1426
pub struct SQLFunctionBodyStore {
1527
db: DashMap<StatementId, Option<Arc<SQLFunctionBody>>>,
28+
sig_db: DashMap<StatementId, Option<Arc<SQLFunctionSignature>>>,
1629
}
1730

1831
impl SQLFunctionBodyStore {
1932
pub fn new() -> SQLFunctionBodyStore {
20-
SQLFunctionBodyStore { db: DashMap::new() }
33+
SQLFunctionBodyStore {
34+
db: DashMap::new(),
35+
sig_db: DashMap::new(),
36+
}
37+
}
38+
39+
pub fn get_function_signature(
40+
&self,
41+
statement: &StatementId,
42+
ast: &pgt_query_ext::NodeEnum,
43+
) -> Option<Arc<SQLFunctionSignature>> {
44+
// First check if we already have this statement cached
45+
if let Some(existing) = self.sig_db.get(statement).map(|x| x.clone()) {
46+
return existing;
47+
}
48+
49+
// If not cached, try to extract it from the AST
50+
let fn_sig = get_sql_fn_signature(ast).map(Arc::new);
51+
52+
// Cache the result and return it
53+
self.sig_db.insert(statement.clone(), fn_sig.clone());
54+
fn_sig
2155
}
2256

2357
pub fn get_function_body(
@@ -48,6 +82,48 @@ impl SQLFunctionBodyStore {
4882
}
4983
}
5084

85+
/// Extracts SQL function signature from a CreateFunctionStmt node.
86+
fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option<SQLFunctionSignature> {
87+
let create_fn = match ast {
88+
pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf,
89+
_ => return None,
90+
};
91+
92+
println!("create_fn: {:?}", create_fn);
93+
94+
// Extract language from function options
95+
let language = find_option_value(create_fn, "language")?;
96+
97+
// Only process SQL functions
98+
if language != "sql" {
99+
return None;
100+
}
101+
102+
let fn_name = parse_name(&create_fn.funcname)?;
103+
104+
// we return None if anything is not expected
105+
let mut fn_args = Vec::new();
106+
for arg in &create_fn.parameters {
107+
if let Some(pgt_query_ext::NodeEnum::FunctionParameter(node)) = &arg.node {
108+
let arg_name = (!node.name.is_empty()).then_some(node.name.clone());
109+
110+
let type_name = parse_name(&node.arg_type.as_ref().unwrap().names)?;
111+
112+
fn_args.push(SQLFunctionArgs {
113+
name: arg_name,
114+
type_: type_name,
115+
});
116+
} else {
117+
return None;
118+
}
119+
}
120+
121+
Some(SQLFunctionSignature {
122+
name: fn_name,
123+
args: fn_args,
124+
})
125+
}
126+
51127
/// Extracts SQL function body and its text range from a CreateFunctionStmt node.
52128
/// Returns None if the function is not an SQL function or if the body can't be found.
53129
fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option<SQLFunctionBody> {
@@ -56,6 +132,8 @@ fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option<SQLFunctio
56132
_ => return None,
57133
};
58134

135+
println!("create_fn: {:?}", create_fn);
136+
59137
// Extract language from function options
60138
let language = find_option_value(create_fn, "language")?;
61139

@@ -120,3 +198,19 @@ fn find_option_value(
120198
}
121199
})
122200
}
201+
202+
fn parse_name(nodes: &Vec<pgt_query_ext::protobuf::Node>) -> Option<(Option<String>, String)> {
203+
let names = nodes
204+
.iter()
205+
.map(|n| match &n.node {
206+
Some(pgt_query_ext::NodeEnum::String(s)) => Some(s.sval.clone()),
207+
_ => None,
208+
})
209+
.collect::<Vec<_>>();
210+
211+
match names.as_slice() {
212+
[Some(schema), Some(name)] => Some((Some(schema.clone()), name.clone())),
213+
[Some(name)] => Some((None, name.clone())),
214+
_ => None,
215+
}
216+
}

crates/pgt_workspace/src/workspace/server/statement_identifier.rs

+15
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ impl StatementId {
5757
StatementId::Child(s) => s.inner,
5858
}
5959
}
60+
61+
pub fn is_root(&self) -> bool {
62+
matches!(self, StatementId::Root(_))
63+
}
64+
65+
pub fn is_child(&self) -> bool {
66+
matches!(self, StatementId::Child(_))
67+
}
68+
69+
pub fn parent(&self) -> Option<StatementId> {
70+
match self {
71+
StatementId::Root(_) => None,
72+
StatementId::Child(id) => Some(StatementId::Root(id.clone())),
73+
}
74+
}
6075
}
6176

6277
/// Helper struct to generate unique statement ids

0 commit comments

Comments
 (0)