1616import inspect
1717import re
1818import sys
19- from typing import List , Optional , Tuple
19+ from typing import List , Optional , Tuple , Union
2020
2121import sqlparse
2222from robot .api import logger
@@ -328,9 +328,7 @@ def execute_sql_script(
328328 else :
329329 statements_to_execute = self .split_sql_script (script_path , external_parser = external_parser )
330330 for statement in statements_to_execute :
331- proc_end_pattern = re .compile ("end(?!( if;| loop;| case;| while;| repeat;)).*;()?" )
332- line_ends_with_proc_end = re .compile (r"(\s|;)" + proc_end_pattern .pattern + "$" )
333- omit_semicolon = not line_ends_with_proc_end .search (statement .lower ())
331+ omit_semicolon = self ._omit_semicolon_needed (statement )
334332 self ._execute_sql (cur , statement , omit_semicolon , replace_robot_variables = replace_robot_variables )
335333 self ._commit_if_needed (db_connection , no_transaction )
336334 except Exception as e :
@@ -350,72 +348,82 @@ def split_sql_script(
350348 """
351349 with open (script_path , encoding = "UTF-8" ) as sql_file :
352350 logger .info ("Splitting script file into statements..." )
353- statements_to_execute = []
354- if external_parser :
355- split_statements = sqlparse .split (sql_file .read ())
356- for statement in split_statements :
357- statement_without_comments = sqlparse .format (statement , strip_comments = True )
358- if statement_without_comments :
359- statements_to_execute .append (statement_without_comments )
360- else :
361- current_statement = ""
362- inside_statements_group = False
363- proc_start_pattern = re .compile ("create( or replace)? (procedure|function){1}( )?" )
364- proc_end_pattern = re .compile ("end(?!( if;| loop;| case;| while;| repeat;)).*;()?" )
365- for line in sql_file :
366- line = line .strip ()
367- if line .startswith ("#" ) or line .startswith ("--" ) or line == "/" :
368- continue
369-
370- # check if the line matches the creating procedure regexp pattern
371- if proc_start_pattern .match (line .lower ()):
372- inside_statements_group = True
373- elif line .lower ().startswith ("begin" ):
374- inside_statements_group = True
375-
376- # semicolons inside the line? use them to separate statements
377- # ... but not if they are inside a begin/end block (aka. statements group)
378- sqlFragments = line .split (";" )
379- # no semicolons
380- if len (sqlFragments ) == 1 :
381- current_statement += line + " "
382- continue
351+ return self .split_sql_string (sql_file .read (), external_parser = external_parser )
352+
353+ def split_sql_string (self , sql_string : str , external_parser : bool = False ):
354+ if external_parser :
355+ return self ._split_statements_using_external_parser (sql_string )
356+ else :
357+ return self ._parse_sql_internally (sql_string .splitlines ())
358+
359+ def _parse_sql_internally (self , sql_file : List [str ]) -> list [str ]:
360+ statements_to_execute = []
361+ current_statement = ""
362+ inside_statements_group = False
363+ proc_start_pattern = re .compile ("create( or replace)? (procedure|function){1}( )?" )
364+ proc_end_pattern = re .compile ("end(?!( if;| loop;| case;| while;| repeat;)).*;()?" )
365+ for line in sql_file :
366+ line = line .strip ()
367+ if line .startswith ("#" ) or line .startswith ("--" ) or line == "/" :
368+ continue
369+
370+ # check if the line matches the creating procedure regexp pattern
371+ if proc_start_pattern .match (line .lower ()):
372+ inside_statements_group = True
373+ elif line .lower ().startswith ("begin" ):
374+ inside_statements_group = True
375+
376+ # semicolons inside the line? use them to separate statements
377+ # ... but not if they are inside a begin/end block (aka. statements group)
378+ sqlFragments = line .split (";" )
379+ # no semicolons
380+ if len (sqlFragments ) == 1 :
381+ current_statement += line + " "
382+ continue
383+ quotes = 0
384+ # "select * from person;" -> ["select..", ""]
385+ for sqlFragment in sqlFragments :
386+ if len (sqlFragment .strip ()) == 0 :
387+ continue
388+
389+ if inside_statements_group :
390+ # if statements inside a begin/end block have semicolns,
391+ # they must persist - even with oracle
392+ sqlFragment += "; "
393+
394+ if proc_end_pattern .match (sqlFragment .lower ()):
395+ inside_statements_group = False
396+ elif proc_start_pattern .match (sqlFragment .lower ()):
397+ inside_statements_group = True
398+ elif sqlFragment .lower ().startswith ("begin" ):
399+ inside_statements_group = True
400+
401+ # check if the semicolon is a part of the value (quoted string)
402+ quotes += sqlFragment .count ("'" )
403+ quotes -= sqlFragment .count ("\\ '" )
404+ inside_quoted_string = quotes % 2 != 0
405+ if inside_quoted_string :
406+ sqlFragment += ";" # restore the semicolon
407+
408+ current_statement += sqlFragment
409+ if not inside_statements_group and not inside_quoted_string :
410+ statements_to_execute .append (current_statement .strip ())
411+ current_statement = ""
383412 quotes = 0
384- # "select * from person;" -> ["select..", ""]
385- for sqlFragment in sqlFragments :
386- if len (sqlFragment .strip ()) == 0 :
387- continue
388-
389- if inside_statements_group :
390- # if statements inside a begin/end block have semicolns,
391- # they must persist - even with oracle
392- sqlFragment += "; "
393-
394- if proc_end_pattern .match (sqlFragment .lower ()):
395- inside_statements_group = False
396- elif proc_start_pattern .match (sqlFragment .lower ()):
397- inside_statements_group = True
398- elif sqlFragment .lower ().startswith ("begin" ):
399- inside_statements_group = True
400-
401- # check if the semicolon is a part of the value (quoted string)
402- quotes += sqlFragment .count ("'" )
403- quotes -= sqlFragment .count ("\\ '" )
404- inside_quoted_string = quotes % 2 != 0
405- if inside_quoted_string :
406- sqlFragment += ";" # restore the semicolon
407-
408- current_statement += sqlFragment
409- if not inside_statements_group and not inside_quoted_string :
410- statements_to_execute .append (current_statement .strip ())
411- current_statement = ""
412- quotes = 0
413-
414- current_statement = current_statement .strip ()
415- if len (current_statement ) != 0 :
416- statements_to_execute .append (current_statement )
417-
418- return statements_to_execute
413+
414+ current_statement = current_statement .strip ()
415+ if len (current_statement ) != 0 :
416+ statements_to_execute .append (current_statement )
417+ return statements_to_execute
418+
419+ def _split_statements_using_external_parser (self , sql_file_content : str ):
420+ statements_to_execute = []
421+ split_statements = sqlparse .split (sql_file_content )
422+ for statement in split_statements :
423+ statement_without_comments = sqlparse .format (statement , strip_comments = True )
424+ if statement_without_comments :
425+ statements_to_execute .append (statement_without_comments )
426+ return statements_to_execute
419427
420428 @renamed_args (
421429 mapping = {
@@ -436,6 +444,8 @@ def execute_sql_string(
436444 sqlString : Optional [str ] = None ,
437445 sansTran : Optional [bool ] = None ,
438446 omitTrailingSemicolon : Optional [bool ] = None ,
447+ split : bool = False ,
448+ external_parser : bool = False ,
439449 ):
440450 """
441451 Executes the ``sql_string`` as a single SQL command.
@@ -473,17 +483,30 @@ def execute_sql_string(
473483 cur = db_connection .client .cursor ()
474484 if omit_trailing_semicolon is None :
475485 omit_trailing_semicolon = db_connection .omit_trailing_semicolon
476- self ._execute_sql (
477- cur ,
478- sql_string ,
479- omit_trailing_semicolon = omit_trailing_semicolon ,
480- parameters = parameters ,
481- replace_robot_variables = replace_robot_variables ,
482- )
486+ if not split :
487+ self ._execute_sql (
488+ cur ,
489+ sql_string ,
490+ omit_trailing_semicolon = omit_trailing_semicolon ,
491+ parameters = parameters ,
492+ replace_robot_variables = replace_robot_variables ,
493+ )
494+ else :
495+ statements_to_execute = self .split_sql_string (sql_string , external_parser = external_parser )
496+ for statement in statements_to_execute :
497+ omit_semicolon = self ._omit_semicolon_needed (statement )
498+ self ._execute_sql (cur , statement , omit_semicolon , replace_robot_variables = replace_robot_variables )
499+
483500 self ._commit_if_needed (db_connection , no_transaction )
484501 except Exception as e :
485502 self ._rollback_and_raise (db_connection , no_transaction , e )
486503
504+ def _omit_semicolon_needed (self , statement : str ) -> bool :
505+ proc_end_pattern = re .compile ("end(?!( if;| loop;| case;| while;| repeat;)).*;()?" )
506+ line_ends_with_proc_end = re .compile (r"(\s|;)" + proc_end_pattern .pattern + "$" )
507+ omit_semicolon = not line_ends_with_proc_end .search (statement .lower ())
508+ return omit_semicolon
509+
487510 @renamed_args (mapping = {"spName" : "procedure_name" , "spParams" : "procedure_params" , "sansTran" : "no_transaction" })
488511 def call_stored_procedure (
489512 self ,
0 commit comments