Skip to content

Commit b13a8ea

Browse files
author
Frank Stenzhorn
committed
feature: add support to split sql strings
1 parent b1f72e7 commit b13a8ea

File tree

1 file changed

+99
-76
lines changed

1 file changed

+99
-76
lines changed

src/DatabaseLibrary/query.py

Lines changed: 99 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import inspect
1717
import re
1818
import sys
19-
from typing import List, Optional, Tuple
19+
from typing import List, Optional, Tuple, Union
2020

2121
import sqlparse
2222
from robot.api import logger
@@ -328,9 +328,7 @@ def execute_sql_script(
328328
else:
329329
statements_to_execute = self.split_sql_script(script_path, external_parser=external_parser)
330330
for statement in statements_to_execute:
331-
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
332-
line_ends_with_proc_end = re.compile(r"(\s|;)" + proc_end_pattern.pattern + "$")
333-
omit_semicolon = not line_ends_with_proc_end.search(statement.lower())
331+
omit_semicolon = self._omit_semicolon_needed(statement)
334332
self._execute_sql(cur, statement, omit_semicolon, replace_robot_variables=replace_robot_variables)
335333
self._commit_if_needed(db_connection, no_transaction)
336334
except Exception as e:
@@ -350,72 +348,82 @@ def split_sql_script(
350348
"""
351349
with open(script_path, encoding="UTF-8") as sql_file:
352350
logger.info("Splitting script file into statements...")
353-
statements_to_execute = []
354-
if external_parser:
355-
split_statements = sqlparse.split(sql_file.read())
356-
for statement in split_statements:
357-
statement_without_comments = sqlparse.format(statement, strip_comments=True)
358-
if statement_without_comments:
359-
statements_to_execute.append(statement_without_comments)
360-
else:
361-
current_statement = ""
362-
inside_statements_group = False
363-
proc_start_pattern = re.compile("create( or replace)? (procedure|function){1}( )?")
364-
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
365-
for line in sql_file:
366-
line = line.strip()
367-
if line.startswith("#") or line.startswith("--") or line == "/":
368-
continue
369-
370-
# check if the line matches the creating procedure regexp pattern
371-
if proc_start_pattern.match(line.lower()):
372-
inside_statements_group = True
373-
elif line.lower().startswith("begin"):
374-
inside_statements_group = True
375-
376-
# semicolons inside the line? use them to separate statements
377-
# ... but not if they are inside a begin/end block (aka. statements group)
378-
sqlFragments = line.split(";")
379-
# no semicolons
380-
if len(sqlFragments) == 1:
381-
current_statement += line + " "
382-
continue
351+
return self.split_sql_string(sql_file.read(), external_parser=external_parser)
352+
353+
def split_sql_string(self, sql_string: str, external_parser: bool = False):
354+
if external_parser:
355+
return self._split_statements_using_external_parser(sql_string)
356+
else:
357+
return self._parse_sql_internally(sql_string.splitlines())
358+
359+
def _parse_sql_internally(self, sql_file: List[str]) -> list[str]:
360+
statements_to_execute = []
361+
current_statement = ""
362+
inside_statements_group = False
363+
proc_start_pattern = re.compile("create( or replace)? (procedure|function){1}( )?")
364+
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
365+
for line in sql_file:
366+
line = line.strip()
367+
if line.startswith("#") or line.startswith("--") or line == "/":
368+
continue
369+
370+
# check if the line matches the creating procedure regexp pattern
371+
if proc_start_pattern.match(line.lower()):
372+
inside_statements_group = True
373+
elif line.lower().startswith("begin"):
374+
inside_statements_group = True
375+
376+
# semicolons inside the line? use them to separate statements
377+
# ... but not if they are inside a begin/end block (aka. statements group)
378+
sqlFragments = line.split(";")
379+
# no semicolons
380+
if len(sqlFragments) == 1:
381+
current_statement += line + " "
382+
continue
383+
quotes = 0
384+
# "select * from person;" -> ["select..", ""]
385+
for sqlFragment in sqlFragments:
386+
if len(sqlFragment.strip()) == 0:
387+
continue
388+
389+
if inside_statements_group:
390+
# if statements inside a begin/end block have semicolns,
391+
# they must persist - even with oracle
392+
sqlFragment += "; "
393+
394+
if proc_end_pattern.match(sqlFragment.lower()):
395+
inside_statements_group = False
396+
elif proc_start_pattern.match(sqlFragment.lower()):
397+
inside_statements_group = True
398+
elif sqlFragment.lower().startswith("begin"):
399+
inside_statements_group = True
400+
401+
# check if the semicolon is a part of the value (quoted string)
402+
quotes += sqlFragment.count("'")
403+
quotes -= sqlFragment.count("\\'")
404+
inside_quoted_string = quotes % 2 != 0
405+
if inside_quoted_string:
406+
sqlFragment += ";" # restore the semicolon
407+
408+
current_statement += sqlFragment
409+
if not inside_statements_group and not inside_quoted_string:
410+
statements_to_execute.append(current_statement.strip())
411+
current_statement = ""
383412
quotes = 0
384-
# "select * from person;" -> ["select..", ""]
385-
for sqlFragment in sqlFragments:
386-
if len(sqlFragment.strip()) == 0:
387-
continue
388-
389-
if inside_statements_group:
390-
# if statements inside a begin/end block have semicolns,
391-
# they must persist - even with oracle
392-
sqlFragment += "; "
393-
394-
if proc_end_pattern.match(sqlFragment.lower()):
395-
inside_statements_group = False
396-
elif proc_start_pattern.match(sqlFragment.lower()):
397-
inside_statements_group = True
398-
elif sqlFragment.lower().startswith("begin"):
399-
inside_statements_group = True
400-
401-
# check if the semicolon is a part of the value (quoted string)
402-
quotes += sqlFragment.count("'")
403-
quotes -= sqlFragment.count("\\'")
404-
inside_quoted_string = quotes % 2 != 0
405-
if inside_quoted_string:
406-
sqlFragment += ";" # restore the semicolon
407-
408-
current_statement += sqlFragment
409-
if not inside_statements_group and not inside_quoted_string:
410-
statements_to_execute.append(current_statement.strip())
411-
current_statement = ""
412-
quotes = 0
413-
414-
current_statement = current_statement.strip()
415-
if len(current_statement) != 0:
416-
statements_to_execute.append(current_statement)
417-
418-
return statements_to_execute
413+
414+
current_statement = current_statement.strip()
415+
if len(current_statement) != 0:
416+
statements_to_execute.append(current_statement)
417+
return statements_to_execute
418+
419+
def _split_statements_using_external_parser(self, sql_file_content: str):
420+
statements_to_execute = []
421+
split_statements = sqlparse.split(sql_file_content)
422+
for statement in split_statements:
423+
statement_without_comments = sqlparse.format(statement, strip_comments=True)
424+
if statement_without_comments:
425+
statements_to_execute.append(statement_without_comments)
426+
return statements_to_execute
419427

420428
@renamed_args(
421429
mapping={
@@ -436,6 +444,8 @@ def execute_sql_string(
436444
sqlString: Optional[str] = None,
437445
sansTran: Optional[bool] = None,
438446
omitTrailingSemicolon: Optional[bool] = None,
447+
split: bool = False,
448+
external_parser: bool = False,
439449
):
440450
"""
441451
Executes the ``sql_string`` as a single SQL command.
@@ -473,17 +483,30 @@ def execute_sql_string(
473483
cur = db_connection.client.cursor()
474484
if omit_trailing_semicolon is None:
475485
omit_trailing_semicolon = db_connection.omit_trailing_semicolon
476-
self._execute_sql(
477-
cur,
478-
sql_string,
479-
omit_trailing_semicolon=omit_trailing_semicolon,
480-
parameters=parameters,
481-
replace_robot_variables=replace_robot_variables,
482-
)
486+
if not split:
487+
self._execute_sql(
488+
cur,
489+
sql_string,
490+
omit_trailing_semicolon=omit_trailing_semicolon,
491+
parameters=parameters,
492+
replace_robot_variables=replace_robot_variables,
493+
)
494+
else:
495+
statements_to_execute = self.split_sql_string(sql_string, external_parser=external_parser)
496+
for statement in statements_to_execute:
497+
omit_semicolon = self._omit_semicolon_needed(statement)
498+
self._execute_sql(cur, statement, omit_semicolon, replace_robot_variables=replace_robot_variables)
499+
483500
self._commit_if_needed(db_connection, no_transaction)
484501
except Exception as e:
485502
self._rollback_and_raise(db_connection, no_transaction, e)
486503

504+
def _omit_semicolon_needed(self, statement: str) -> bool:
505+
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
506+
line_ends_with_proc_end = re.compile(r"(\s|;)" + proc_end_pattern.pattern + "$")
507+
omit_semicolon = not line_ends_with_proc_end.search(statement.lower())
508+
return omit_semicolon
509+
487510
@renamed_args(mapping={"spName": "procedure_name", "spParams": "procedure_params", "sansTran": "no_transaction"})
488511
def call_stored_procedure(
489512
self,

0 commit comments

Comments
 (0)