Skip to content

Commit 32c62e5

Browse files
authored
feat(ingest/mssql): improve stored procedure splitting (datahub-project#12563)
1 parent 03bce47 commit 32c62e5

File tree

2 files changed

+277
-122
lines changed

2 files changed

+277
-122
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import re
22
from enum import Enum
3-
from typing import Generator, List, Tuple
3+
from typing import Iterator, List, Tuple
4+
5+
SELECT_KEYWORD = "SELECT"
6+
CASE_KEYWORD = "CASE"
7+
END_KEYWORD = "END"
48

59
CONTROL_FLOW_KEYWORDS = [
610
"GO",
@@ -9,18 +13,36 @@
913
"BEGIN",
1014
r"END\w+TRY",
1115
r"END\w+CATCH",
12-
"END",
16+
# This isn't strictly correct, but we assume that IF | (condition) | (block) should all be split up
17+
# This mainly ensures that IF statements don't get tacked onto the previous statement incorrectly
18+
"IF",
19+
# For things like CASE, END does not mean the end of a statement.
20+
# We have special handling for this.
21+
END_KEYWORD,
22+
# "ELSE", # else is also valid in CASE, so we we can't use it here.
1323
]
1424

1525
# There's an exception to this rule, which is when the statement
16-
# is preceeded by a CTE.
17-
FORCE_NEW_STATEMENT_KEYWORDS = [
26+
# is preceded by a CTE. For those, we have to check if the character
27+
# before this is a ")".
28+
NEW_STATEMENT_KEYWORDS = [
1829
# SELECT is used inside queries as well, so we can't include it here.
30+
"CREATE",
1931
"INSERT",
2032
"UPDATE",
2133
"DELETE",
2234
"MERGE",
2335
]
36+
STRICT_NEW_STATEMENT_KEYWORDS = [
37+
# For these keywords, a SELECT following it does indicate a new statement.
38+
"DROP",
39+
"TRUNCATE",
40+
]
41+
42+
43+
class _AlreadyIncremented(Exception):
44+
# Using exceptions for control flow isn't great - but the code is clearer so it's fine.
45+
pass
2446

2547

2648
class ParserState(Enum):
@@ -30,134 +52,199 @@ class ParserState(Enum):
3052
MULTILINE_COMMENT = 4
3153

3254

33-
def _is_keyword_at_position(sql: str, pos: int, keyword: str) -> bool:
34-
"""
35-
Check if a keyword exists at the given position using regex word boundaries.
36-
"""
37-
if pos + len(keyword) > len(sql):
38-
return False
55+
class _StatementSplitter:
56+
def __init__(self, sql: str):
57+
self.sql = sql
3958

40-
# If we're not at a word boundary, we can't generate a keyword.
41-
if pos > 0 and not (
42-
bool(re.match(r"\w\W", sql[pos - 1 : pos + 1]))
43-
or bool(re.match(r"\W\w", sql[pos - 1 : pos + 1]))
44-
):
45-
return False
59+
# Main parser state.
60+
self.i = 0
61+
self.state = ParserState.NORMAL
62+
self.current_statement: List[str] = []
4663

47-
pattern = rf"^{re.escape(keyword)}\b"
48-
match = re.match(pattern, sql[pos:], re.IGNORECASE)
49-
return bool(match)
64+
# Additional parser state.
5065

66+
# If we see a SELECT, should we start a new statement?
67+
# If we previously saw a drop/truncate/etc, a SELECT does mean a new statement.
68+
# But if we're in a select/create/etc, a select could just be a subquery.
69+
self.does_select_mean_new_statement = False
5170

52-
def _look_ahead_for_keywords(
53-
sql: str, pos: int, keywords: List[str]
54-
) -> Tuple[bool, str, int]:
55-
"""
56-
Look ahead for SQL keywords at the current position.
57-
"""
71+
# The END keyword terminates CASE and BEGIN blocks.
72+
# We need to match the CASE statements with END blocks to determine
73+
# what a given END is closing.
74+
self.current_case_statements = 0
5875

59-
for keyword in keywords:
60-
if _is_keyword_at_position(sql, pos, keyword):
61-
return True, keyword, len(keyword)
62-
return False, "", 0
76+
def _is_keyword_at_position(self, pos: int, keyword: str) -> bool:
77+
"""
78+
Check if a keyword exists at the given position using regex word boundaries.
79+
"""
80+
sql = self.sql
6381

82+
if pos + len(keyword) > len(sql):
83+
return False
6484

65-
def split_statements(sql: str) -> Generator[str, None, None]:
66-
"""
67-
Split T-SQL code into individual statements, handling various SQL constructs.
68-
"""
69-
if not sql or not sql.strip():
70-
return
85+
# If we're not at a word boundary, we can't generate a keyword.
86+
if pos > 0 and not (
87+
bool(re.match(r"\w\W", sql[pos - 1 : pos + 1]))
88+
or bool(re.match(r"\W\w", sql[pos - 1 : pos + 1]))
89+
):
90+
return False
91+
92+
pattern = rf"^{re.escape(keyword)}\b"
93+
match = re.match(pattern, sql[pos:], re.IGNORECASE)
94+
return bool(match)
7195

72-
current_statement: List[str] = []
73-
state = ParserState.NORMAL
74-
i = 0
96+
def _look_ahead_for_keywords(self, keywords: List[str]) -> Tuple[bool, str, int]:
97+
"""
98+
Look ahead for SQL keywords at the current position.
99+
"""
75100

76-
def yield_if_complete() -> Generator[str, None, None]:
77-
statement = "".join(current_statement).strip()
101+
for keyword in keywords:
102+
if self._is_keyword_at_position(self.i, keyword):
103+
return True, keyword, len(keyword)
104+
return False, "", 0
105+
106+
def _yield_if_complete(self) -> Iterator[str]:
107+
statement = "".join(self.current_statement).strip()
78108
if statement:
109+
# Subtle - to avoid losing full whitespace, they get merged into the next statement.
79110
yield statement
80-
current_statement.clear()
81-
82-
prev_real_char = "\0" # the most recent non-whitespace, non-comment character
83-
while i < len(sql):
84-
c = sql[i]
85-
next_char = sql[i + 1] if i < len(sql) - 1 else "\0"
86-
87-
if state == ParserState.NORMAL:
88-
if c == "'":
89-
state = ParserState.STRING
90-
current_statement.append(c)
91-
prev_real_char = c
92-
elif c == "-" and next_char == "-":
93-
state = ParserState.COMMENT
94-
current_statement.append(c)
95-
current_statement.append(next_char)
96-
i += 1
97-
elif c == "/" and next_char == "*":
98-
state = ParserState.MULTILINE_COMMENT
99-
current_statement.append(c)
100-
current_statement.append(next_char)
101-
i += 1
102-
else:
103-
most_recent_real_char = prev_real_char
104-
if not c.isspace():
111+
self.current_statement.clear()
112+
113+
# Reset current_statement-specific state.
114+
self.does_select_mean_new_statement = False
115+
if self.current_case_statements != 0:
116+
breakpoint()
117+
self.current_case_statements = 0
118+
119+
def process(self) -> Iterator[str]:
120+
if not self.sql or not self.sql.strip():
121+
return
122+
123+
prev_real_char = "\0" # the most recent non-whitespace, non-comment character
124+
while self.i < len(self.sql):
125+
c = self.sql[self.i]
126+
next_char = self.sql[self.i + 1] if self.i < len(self.sql) - 1 else "\0"
127+
128+
if self.state == ParserState.NORMAL:
129+
if c == "'":
130+
self.state = ParserState.STRING
131+
self.current_statement.append(c)
105132
prev_real_char = c
106-
107-
is_control_keyword, keyword, keyword_len = _look_ahead_for_keywords(
108-
sql, i, keywords=CONTROL_FLOW_KEYWORDS
109-
)
110-
if is_control_keyword:
111-
# Yield current statement if any
112-
yield from yield_if_complete()
113-
# Yield keyword as its own statement
114-
yield keyword
115-
i += keyword_len
116-
continue
117-
118-
(
119-
is_force_new_statement_keyword,
120-
keyword,
121-
keyword_len,
122-
) = _look_ahead_for_keywords(
123-
sql, i, keywords=FORCE_NEW_STATEMENT_KEYWORDS
124-
)
125-
if (
126-
is_force_new_statement_keyword and most_recent_real_char != ")"
127-
): # usually we'd have a close paren that closes a CTE
128-
# Force termination of current statement
129-
yield from yield_if_complete()
130-
131-
current_statement.append(keyword)
132-
i += keyword_len
133-
continue
134-
135-
elif c == ";":
136-
yield from yield_if_complete()
133+
elif c == "-" and next_char == "-":
134+
self.state = ParserState.COMMENT
135+
self.current_statement.append(c)
136+
self.current_statement.append(next_char)
137+
self.i += 1
138+
elif c == "/" and next_char == "*":
139+
self.state = ParserState.MULTILINE_COMMENT
140+
self.current_statement.append(c)
141+
self.current_statement.append(next_char)
142+
self.i += 1
137143
else:
138-
current_statement.append(c)
139-
140-
elif state == ParserState.STRING:
141-
current_statement.append(c)
142-
if c == "'" and next_char == "'":
143-
current_statement.append(next_char)
144-
i += 1
145-
elif c == "'":
146-
state = ParserState.NORMAL
147-
148-
elif state == ParserState.COMMENT:
149-
current_statement.append(c)
150-
if c == "\n":
151-
state = ParserState.NORMAL
152-
153-
elif state == ParserState.MULTILINE_COMMENT:
154-
current_statement.append(c)
155-
if c == "*" and next_char == "/":
156-
current_statement.append(next_char)
157-
i += 1
158-
state = ParserState.NORMAL
159-
160-
i += 1
161-
162-
# Handle the last statement
163-
yield from yield_if_complete()
144+
most_recent_real_char = prev_real_char
145+
if not c.isspace():
146+
prev_real_char = c
147+
148+
try:
149+
yield from self._process_normal(
150+
most_recent_real_char=most_recent_real_char
151+
)
152+
except _AlreadyIncremented:
153+
# Skip the normal i += 1 step.
154+
continue
155+
156+
elif self.state == ParserState.STRING:
157+
self.current_statement.append(c)
158+
if c == "'" and next_char == "'":
159+
self.current_statement.append(next_char)
160+
self.i += 1
161+
elif c == "'":
162+
self.state = ParserState.NORMAL
163+
164+
elif self.state == ParserState.COMMENT:
165+
self.current_statement.append(c)
166+
if c == "\n":
167+
self.state = ParserState.NORMAL
168+
169+
elif self.state == ParserState.MULTILINE_COMMENT:
170+
self.current_statement.append(c)
171+
if c == "*" and next_char == "/":
172+
self.current_statement.append(next_char)
173+
self.i += 1
174+
self.state = ParserState.NORMAL
175+
176+
self.i += 1
177+
178+
# Handle the last statement
179+
yield from self._yield_if_complete()
180+
181+
def _process_normal(self, most_recent_real_char: str) -> Iterator[str]:
182+
c = self.sql[self.i]
183+
184+
if self._is_keyword_at_position(self.i, CASE_KEYWORD):
185+
self.current_case_statements += 1
186+
187+
is_control_keyword, keyword, keyword_len = self._look_ahead_for_keywords(
188+
keywords=CONTROL_FLOW_KEYWORDS
189+
)
190+
if (
191+
is_control_keyword
192+
and keyword == END_KEYWORD
193+
and self.current_case_statements > 0
194+
):
195+
# If we're closing a CASE statement with END, we can just decrement the counter and continue.
196+
self.current_case_statements -= 1
197+
elif is_control_keyword:
198+
# Yield current statement if any
199+
yield from self._yield_if_complete()
200+
# Yield keyword as its own statement
201+
yield keyword
202+
self.i += keyword_len
203+
self.does_select_mean_new_statement = True
204+
raise _AlreadyIncremented()
205+
206+
(
207+
is_strict_new_statement_keyword,
208+
keyword,
209+
keyword_len,
210+
) = self._look_ahead_for_keywords(keywords=STRICT_NEW_STATEMENT_KEYWORDS)
211+
if is_strict_new_statement_keyword:
212+
yield from self._yield_if_complete()
213+
self.current_statement.append(keyword)
214+
self.i += keyword_len
215+
self.does_select_mean_new_statement = True
216+
raise _AlreadyIncremented()
217+
218+
(
219+
is_force_new_statement_keyword,
220+
keyword,
221+
keyword_len,
222+
) = self._look_ahead_for_keywords(
223+
keywords=(
224+
NEW_STATEMENT_KEYWORDS
225+
+ ([SELECT_KEYWORD] if self.does_select_mean_new_statement else [])
226+
),
227+
)
228+
if (
229+
is_force_new_statement_keyword and most_recent_real_char != ")"
230+
): # usually we'd have a close paren that closes a CTE
231+
# Force termination of current statement
232+
yield from self._yield_if_complete()
233+
234+
self.current_statement.append(keyword)
235+
self.i += keyword_len
236+
raise _AlreadyIncremented()
237+
238+
if c == ";":
239+
yield from self._yield_if_complete()
240+
else:
241+
self.current_statement.append(c)
242+
243+
244+
def split_statements(sql: str) -> Iterator[str]:
245+
"""
246+
Split T-SQL code into individual statements, handling various SQL constructs.
247+
"""
248+
249+
splitter = _StatementSplitter(sql)
250+
yield from splitter.process()

0 commit comments

Comments
 (0)