@@ -5,6 +5,18 @@ use pgt_text_size::TextRange;
5
5
6
6
use super :: statement_identifier:: StatementId ;
7
7
8
+ #[ derive( Debug , Clone ) ]
9
+ pub struct SQLFunctionArgs {
10
+ pub name : Option < String > ,
11
+ pub type_ : ( Option < String > , String ) ,
12
+ }
13
+
14
+ #[ derive( Debug , Clone ) ]
15
+ pub struct SQLFunctionSignature {
16
+ pub name : ( Option < String > , String ) ,
17
+ pub args : Vec < SQLFunctionArgs > ,
18
+ }
19
+
8
20
#[ derive( Debug , Clone ) ]
9
21
pub struct SQLFunctionBody {
10
22
pub range : TextRange ,
@@ -13,11 +25,33 @@ pub struct SQLFunctionBody {
13
25
14
26
pub struct SQLFunctionBodyStore {
15
27
db : DashMap < StatementId , Option < Arc < SQLFunctionBody > > > ,
28
+ sig_db : DashMap < StatementId , Option < Arc < SQLFunctionSignature > > > ,
16
29
}
17
30
18
31
impl SQLFunctionBodyStore {
19
32
pub fn new ( ) -> SQLFunctionBodyStore {
20
- SQLFunctionBodyStore { db : DashMap :: new ( ) }
33
+ SQLFunctionBodyStore {
34
+ db : DashMap :: new ( ) ,
35
+ sig_db : DashMap :: new ( ) ,
36
+ }
37
+ }
38
+
39
+ pub fn get_function_signature (
40
+ & self ,
41
+ statement : & StatementId ,
42
+ ast : & pgt_query_ext:: NodeEnum ,
43
+ ) -> Option < Arc < SQLFunctionSignature > > {
44
+ // First check if we already have this statement cached
45
+ if let Some ( existing) = self . sig_db . get ( statement) . map ( |x| x. clone ( ) ) {
46
+ return existing;
47
+ }
48
+
49
+ // If not cached, try to extract it from the AST
50
+ let fn_sig = get_sql_fn_signature ( ast) . map ( Arc :: new) ;
51
+
52
+ // Cache the result and return it
53
+ self . sig_db . insert ( statement. clone ( ) , fn_sig. clone ( ) ) ;
54
+ fn_sig
21
55
}
22
56
23
57
pub fn get_function_body (
@@ -48,6 +82,48 @@ impl SQLFunctionBodyStore {
48
82
}
49
83
}
50
84
85
+ /// Extracts SQL function signature from a CreateFunctionStmt node.
86
+ fn get_sql_fn_signature ( ast : & pgt_query_ext:: NodeEnum ) -> Option < SQLFunctionSignature > {
87
+ let create_fn = match ast {
88
+ pgt_query_ext:: NodeEnum :: CreateFunctionStmt ( cf) => cf,
89
+ _ => return None ,
90
+ } ;
91
+
92
+ println ! ( "create_fn: {:?}" , create_fn) ;
93
+
94
+ // Extract language from function options
95
+ let language = find_option_value ( create_fn, "language" ) ?;
96
+
97
+ // Only process SQL functions
98
+ if language != "sql" {
99
+ return None ;
100
+ }
101
+
102
+ let fn_name = parse_name ( & create_fn. funcname ) ?;
103
+
104
+ // we return None if anything is not expected
105
+ let mut fn_args = Vec :: new ( ) ;
106
+ for arg in & create_fn. parameters {
107
+ if let Some ( pgt_query_ext:: NodeEnum :: FunctionParameter ( node) ) = & arg. node {
108
+ let arg_name = ( !node. name . is_empty ( ) ) . then_some ( node. name . clone ( ) ) ;
109
+
110
+ let type_name = parse_name ( & node. arg_type . as_ref ( ) . unwrap ( ) . names ) ?;
111
+
112
+ fn_args. push ( SQLFunctionArgs {
113
+ name : arg_name,
114
+ type_ : type_name,
115
+ } ) ;
116
+ } else {
117
+ return None ;
118
+ }
119
+ }
120
+
121
+ Some ( SQLFunctionSignature {
122
+ name : fn_name,
123
+ args : fn_args,
124
+ } )
125
+ }
126
+
51
127
/// Extracts SQL function body and its text range from a CreateFunctionStmt node.
52
128
/// Returns None if the function is not an SQL function or if the body can't be found.
53
129
fn get_sql_fn ( ast : & pgt_query_ext:: NodeEnum , content : & str ) -> Option < SQLFunctionBody > {
@@ -56,6 +132,8 @@ fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option<SQLFunctio
56
132
_ => return None ,
57
133
} ;
58
134
135
+ println ! ( "create_fn: {:?}" , create_fn) ;
136
+
59
137
// Extract language from function options
60
138
let language = find_option_value ( create_fn, "language" ) ?;
61
139
@@ -120,3 +198,19 @@ fn find_option_value(
120
198
}
121
199
} )
122
200
}
201
+
202
+ fn parse_name ( nodes : & Vec < pgt_query_ext:: protobuf:: Node > ) -> Option < ( Option < String > , String ) > {
203
+ let names = nodes
204
+ . iter ( )
205
+ . map ( |n| match & n. node {
206
+ Some ( pgt_query_ext:: NodeEnum :: String ( s) ) => Some ( s. sval . clone ( ) ) ,
207
+ _ => None ,
208
+ } )
209
+ . collect :: < Vec < _ > > ( ) ;
210
+
211
+ match names. as_slice ( ) {
212
+ [ Some ( schema) , Some ( name) ] => Some ( ( Some ( schema. clone ( ) ) , name. clone ( ) ) ) ,
213
+ [ Some ( name) ] => Some ( ( None , name. clone ( ) ) ) ,
214
+ _ => None ,
215
+ }
216
+ }
0 commit comments