1
1
import re
2
2
from enum import Enum
3
- from typing import Generator , List , Tuple
3
+ from typing import Iterator , List , Tuple
4
+
5
+ SELECT_KEYWORD = "SELECT"
6
+ CASE_KEYWORD = "CASE"
7
+ END_KEYWORD = "END"
4
8
5
9
CONTROL_FLOW_KEYWORDS = [
6
10
"GO" ,
9
13
"BEGIN" ,
10
14
r"END\w+TRY" ,
11
15
r"END\w+CATCH" ,
12
- "END" ,
16
+ # This isn't strictly correct, but we assume that IF | (condition) | (block) should all be split up
17
+ # This mainly ensures that IF statements don't get tacked onto the previous statement incorrectly
18
+ "IF" ,
19
+ # For things like CASE, END does not mean the end of a statement.
20
+ # We have special handling for this.
21
+ END_KEYWORD ,
22
+ # "ELSE", # else is also valid in CASE, so we we can't use it here.
13
23
]
14
24
15
25
# There's an exception to this rule, which is when the statement
16
- # is preceeded by a CTE.
17
- FORCE_NEW_STATEMENT_KEYWORDS = [
26
+ # is preceded by a CTE. For those, we have to check if the character
27
+ # before this is a ")".
28
+ NEW_STATEMENT_KEYWORDS = [
18
29
# SELECT is used inside queries as well, so we can't include it here.
30
+ "CREATE" ,
19
31
"INSERT" ,
20
32
"UPDATE" ,
21
33
"DELETE" ,
22
34
"MERGE" ,
23
35
]
36
+ STRICT_NEW_STATEMENT_KEYWORDS = [
37
+ # For these keywords, a SELECT following it does indicate a new statement.
38
+ "DROP" ,
39
+ "TRUNCATE" ,
40
+ ]
41
+
42
+
43
+ class _AlreadyIncremented (Exception ):
44
+ # Using exceptions for control flow isn't great - but the code is clearer so it's fine.
45
+ pass
24
46
25
47
26
48
class ParserState (Enum ):
@@ -30,134 +52,199 @@ class ParserState(Enum):
30
52
MULTILINE_COMMENT = 4
31
53
32
54
33
- def _is_keyword_at_position (sql : str , pos : int , keyword : str ) -> bool :
34
- """
35
- Check if a keyword exists at the given position using regex word boundaries.
36
- """
37
- if pos + len (keyword ) > len (sql ):
38
- return False
55
+ class _StatementSplitter :
56
+ def __init__ (self , sql : str ):
57
+ self .sql = sql
39
58
40
- # If we're not at a word boundary, we can't generate a keyword.
41
- if pos > 0 and not (
42
- bool (re .match (r"\w\W" , sql [pos - 1 : pos + 1 ]))
43
- or bool (re .match (r"\W\w" , sql [pos - 1 : pos + 1 ]))
44
- ):
45
- return False
59
+ # Main parser state.
60
+ self .i = 0
61
+ self .state = ParserState .NORMAL
62
+ self .current_statement : List [str ] = []
46
63
47
- pattern = rf"^{ re .escape (keyword )} \b"
48
- match = re .match (pattern , sql [pos :], re .IGNORECASE )
49
- return bool (match )
64
+ # Additional parser state.
50
65
66
+ # If we see a SELECT, should we start a new statement?
67
+ # If we previously saw a drop/truncate/etc, a SELECT does mean a new statement.
68
+ # But if we're in a select/create/etc, a select could just be a subquery.
69
+ self .does_select_mean_new_statement = False
51
70
52
- def _look_ahead_for_keywords (
53
- sql : str , pos : int , keywords : List [str ]
54
- ) -> Tuple [bool , str , int ]:
55
- """
56
- Look ahead for SQL keywords at the current position.
57
- """
71
+ # The END keyword terminates CASE and BEGIN blocks.
72
+ # We need to match the CASE statements with END blocks to determine
73
+ # what a given END is closing.
74
+ self .current_case_statements = 0
58
75
59
- for keyword in keywords :
60
- if _is_keyword_at_position (sql , pos , keyword ):
61
- return True , keyword , len (keyword )
62
- return False , "" , 0
76
+ def _is_keyword_at_position (self , pos : int , keyword : str ) -> bool :
77
+ """
78
+ Check if a keyword exists at the given position using regex word boundaries.
79
+ """
80
+ sql = self .sql
63
81
82
+ if pos + len (keyword ) > len (sql ):
83
+ return False
64
84
65
- def split_statements (sql : str ) -> Generator [str , None , None ]:
66
- """
67
- Split T-SQL code into individual statements, handling various SQL constructs.
68
- """
69
- if not sql or not sql .strip ():
70
- return
85
+ # If we're not at a word boundary, we can't generate a keyword.
86
+ if pos > 0 and not (
87
+ bool (re .match (r"\w\W" , sql [pos - 1 : pos + 1 ]))
88
+ or bool (re .match (r"\W\w" , sql [pos - 1 : pos + 1 ]))
89
+ ):
90
+ return False
91
+
92
+ pattern = rf"^{ re .escape (keyword )} \b"
93
+ match = re .match (pattern , sql [pos :], re .IGNORECASE )
94
+ return bool (match )
71
95
72
- current_statement : List [str ] = []
73
- state = ParserState .NORMAL
74
- i = 0
96
+ def _look_ahead_for_keywords (self , keywords : List [str ]) -> Tuple [bool , str , int ]:
97
+ """
98
+ Look ahead for SQL keywords at the current position.
99
+ """
75
100
76
- def yield_if_complete () -> Generator [str , None , None ]:
77
- statement = "" .join (current_statement ).strip ()
101
+ for keyword in keywords :
102
+ if self ._is_keyword_at_position (self .i , keyword ):
103
+ return True , keyword , len (keyword )
104
+ return False , "" , 0
105
+
106
+ def _yield_if_complete (self ) -> Iterator [str ]:
107
+ statement = "" .join (self .current_statement ).strip ()
78
108
if statement :
109
+ # Subtle - to avoid losing full whitespace, they get merged into the next statement.
79
110
yield statement
80
- current_statement .clear ()
81
-
82
- prev_real_char = "\0 " # the most recent non-whitespace, non-comment character
83
- while i < len (sql ):
84
- c = sql [i ]
85
- next_char = sql [i + 1 ] if i < len (sql ) - 1 else "\0 "
86
-
87
- if state == ParserState .NORMAL :
88
- if c == "'" :
89
- state = ParserState .STRING
90
- current_statement .append (c )
91
- prev_real_char = c
92
- elif c == "-" and next_char == "-" :
93
- state = ParserState .COMMENT
94
- current_statement .append (c )
95
- current_statement .append (next_char )
96
- i += 1
97
- elif c == "/" and next_char == "*" :
98
- state = ParserState .MULTILINE_COMMENT
99
- current_statement .append (c )
100
- current_statement .append (next_char )
101
- i += 1
102
- else :
103
- most_recent_real_char = prev_real_char
104
- if not c .isspace ():
111
+ self .current_statement .clear ()
112
+
113
+ # Reset current_statement-specific state.
114
+ self .does_select_mean_new_statement = False
115
+ if self .current_case_statements != 0 :
116
+ breakpoint ()
117
+ self .current_case_statements = 0
118
+
119
+ def process (self ) -> Iterator [str ]:
120
+ if not self .sql or not self .sql .strip ():
121
+ return
122
+
123
+ prev_real_char = "\0 " # the most recent non-whitespace, non-comment character
124
+ while self .i < len (self .sql ):
125
+ c = self .sql [self .i ]
126
+ next_char = self .sql [self .i + 1 ] if self .i < len (self .sql ) - 1 else "\0 "
127
+
128
+ if self .state == ParserState .NORMAL :
129
+ if c == "'" :
130
+ self .state = ParserState .STRING
131
+ self .current_statement .append (c )
105
132
prev_real_char = c
106
-
107
- is_control_keyword , keyword , keyword_len = _look_ahead_for_keywords (
108
- sql , i , keywords = CONTROL_FLOW_KEYWORDS
109
- )
110
- if is_control_keyword :
111
- # Yield current statement if any
112
- yield from yield_if_complete ()
113
- # Yield keyword as its own statement
114
- yield keyword
115
- i += keyword_len
116
- continue
117
-
118
- (
119
- is_force_new_statement_keyword ,
120
- keyword ,
121
- keyword_len ,
122
- ) = _look_ahead_for_keywords (
123
- sql , i , keywords = FORCE_NEW_STATEMENT_KEYWORDS
124
- )
125
- if (
126
- is_force_new_statement_keyword and most_recent_real_char != ")"
127
- ): # usually we'd have a close paren that closes a CTE
128
- # Force termination of current statement
129
- yield from yield_if_complete ()
130
-
131
- current_statement .append (keyword )
132
- i += keyword_len
133
- continue
134
-
135
- elif c == ";" :
136
- yield from yield_if_complete ()
133
+ elif c == "-" and next_char == "-" :
134
+ self .state = ParserState .COMMENT
135
+ self .current_statement .append (c )
136
+ self .current_statement .append (next_char )
137
+ self .i += 1
138
+ elif c == "/" and next_char == "*" :
139
+ self .state = ParserState .MULTILINE_COMMENT
140
+ self .current_statement .append (c )
141
+ self .current_statement .append (next_char )
142
+ self .i += 1
137
143
else :
138
- current_statement .append (c )
139
-
140
- elif state == ParserState .STRING :
141
- current_statement .append (c )
142
- if c == "'" and next_char == "'" :
143
- current_statement .append (next_char )
144
- i += 1
145
- elif c == "'" :
146
- state = ParserState .NORMAL
147
-
148
- elif state == ParserState .COMMENT :
149
- current_statement .append (c )
150
- if c == "\n " :
151
- state = ParserState .NORMAL
152
-
153
- elif state == ParserState .MULTILINE_COMMENT :
154
- current_statement .append (c )
155
- if c == "*" and next_char == "/" :
156
- current_statement .append (next_char )
157
- i += 1
158
- state = ParserState .NORMAL
159
-
160
- i += 1
161
-
162
- # Handle the last statement
163
- yield from yield_if_complete ()
144
+ most_recent_real_char = prev_real_char
145
+ if not c .isspace ():
146
+ prev_real_char = c
147
+
148
+ try :
149
+ yield from self ._process_normal (
150
+ most_recent_real_char = most_recent_real_char
151
+ )
152
+ except _AlreadyIncremented :
153
+ # Skip the normal i += 1 step.
154
+ continue
155
+
156
+ elif self .state == ParserState .STRING :
157
+ self .current_statement .append (c )
158
+ if c == "'" and next_char == "'" :
159
+ self .current_statement .append (next_char )
160
+ self .i += 1
161
+ elif c == "'" :
162
+ self .state = ParserState .NORMAL
163
+
164
+ elif self .state == ParserState .COMMENT :
165
+ self .current_statement .append (c )
166
+ if c == "\n " :
167
+ self .state = ParserState .NORMAL
168
+
169
+ elif self .state == ParserState .MULTILINE_COMMENT :
170
+ self .current_statement .append (c )
171
+ if c == "*" and next_char == "/" :
172
+ self .current_statement .append (next_char )
173
+ self .i += 1
174
+ self .state = ParserState .NORMAL
175
+
176
+ self .i += 1
177
+
178
+ # Handle the last statement
179
+ yield from self ._yield_if_complete ()
180
+
181
+ def _process_normal (self , most_recent_real_char : str ) -> Iterator [str ]:
182
+ c = self .sql [self .i ]
183
+
184
+ if self ._is_keyword_at_position (self .i , CASE_KEYWORD ):
185
+ self .current_case_statements += 1
186
+
187
+ is_control_keyword , keyword , keyword_len = self ._look_ahead_for_keywords (
188
+ keywords = CONTROL_FLOW_KEYWORDS
189
+ )
190
+ if (
191
+ is_control_keyword
192
+ and keyword == END_KEYWORD
193
+ and self .current_case_statements > 0
194
+ ):
195
+ # If we're closing a CASE statement with END, we can just decrement the counter and continue.
196
+ self .current_case_statements -= 1
197
+ elif is_control_keyword :
198
+ # Yield current statement if any
199
+ yield from self ._yield_if_complete ()
200
+ # Yield keyword as its own statement
201
+ yield keyword
202
+ self .i += keyword_len
203
+ self .does_select_mean_new_statement = True
204
+ raise _AlreadyIncremented ()
205
+
206
+ (
207
+ is_strict_new_statement_keyword ,
208
+ keyword ,
209
+ keyword_len ,
210
+ ) = self ._look_ahead_for_keywords (keywords = STRICT_NEW_STATEMENT_KEYWORDS )
211
+ if is_strict_new_statement_keyword :
212
+ yield from self ._yield_if_complete ()
213
+ self .current_statement .append (keyword )
214
+ self .i += keyword_len
215
+ self .does_select_mean_new_statement = True
216
+ raise _AlreadyIncremented ()
217
+
218
+ (
219
+ is_force_new_statement_keyword ,
220
+ keyword ,
221
+ keyword_len ,
222
+ ) = self ._look_ahead_for_keywords (
223
+ keywords = (
224
+ NEW_STATEMENT_KEYWORDS
225
+ + ([SELECT_KEYWORD ] if self .does_select_mean_new_statement else [])
226
+ ),
227
+ )
228
+ if (
229
+ is_force_new_statement_keyword and most_recent_real_char != ")"
230
+ ): # usually we'd have a close paren that closes a CTE
231
+ # Force termination of current statement
232
+ yield from self ._yield_if_complete ()
233
+
234
+ self .current_statement .append (keyword )
235
+ self .i += keyword_len
236
+ raise _AlreadyIncremented ()
237
+
238
+ if c == ";" :
239
+ yield from self ._yield_if_complete ()
240
+ else :
241
+ self .current_statement .append (c )
242
+
243
+
244
+ def split_statements (sql : str ) -> Iterator [str ]:
245
+ """
246
+ Split T-SQL code into individual statements, handling various SQL constructs.
247
+ """
248
+
249
+ splitter = _StatementSplitter (sql )
250
+ yield from splitter .process ()
0 commit comments