diff --git a/client/bqms_run/macros.py b/client/bqms_run/macros.py index 5d181b5cf..fa16a3921 100644 --- a/client/bqms_run/macros.py +++ b/client/bqms_run/macros.py @@ -235,8 +235,6 @@ def __init__( target_bind_generator: A function that format parameter in a way that it's detected as bind parameter in target dialect """ super().__init__(mapping) - # what if we have mapping and wrap it here. - pattern = f"([=\\(, \\[]?){pattern}" self.pattern = re.compile(pattern, re.I) self.source_bind_generator = source_bind_generator self.target_bind_generator = target_bind_generator @@ -245,18 +243,17 @@ def __init__( self.value_stripper = re.compile(value_stripper) def _substitution(self, path: Path, match: Match[str]) -> str: - prefix = match.group(1) - macro_name = match.group(2) + macro_name = match.group(1) full_match = match.group(0) if self.mapping and macro_name in self.mapping: replacement = self.mapping[macro_name] stripped_replacement = self.value_stripper.match(replacement).group(1) - if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")) and prefix: - generated = '{}{}'.format(prefix, self.source_bind_generator(self.mapping[macro_name])) - reverse_search = '{}{}'.format(prefix, self.target_bind_generator(self.mapping[macro_name])) + if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")): + generated = self.source_bind_generator(self.mapping[macro_name]) + reverse_search = self.target_bind_generator(self.mapping[macro_name]) else: - generated = '{}{}'.format(prefix,self.mapping[macro_name]) - reverse_search = '{}{}'.format(prefix,self.mapping[macro_name]) + generated = self.mapping[macro_name] + reverse_search = self.mapping[macro_name] else: self.warn_log( "Could not expand '{0}' as it is not " diff --git a/client/tests/unit/test_macros.py b/client/tests/unit/test_macros.py index a5cda703a..c112eb707 100644 --- a/client/tests/unit/test_macros.py +++ b/client/tests/unit/test_macros.py @@ -20,6 +20,7 @@ MacroExpanderRouter, PatternMacroExpander, SimpleMacroExpander, + ParameterAwareMacroExpander ) @@ -160,3 +161,71 @@ def test_basic_expand(): assert expanded == "abcdef alpha.bravo ghijkl" un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded) assert un_expanded == input_text + +def test_complex_1(): + expander = MacroExpanderRouter( + { + "*.sql": ParameterAwareMacroExpander( + pattern="(\\[\\$\\w+\\])", + value_stripper='__bq__\\d+__(.*)__bq__\\d+__', + mapping={"[$table1]": "__bq__0__ABCDEF__bq__0__"}, + source_bind_generator=lambda arg: "@"+arg, + target_bind_generator=lambda arg: arg + ) + } + ) + + input_text="SELECT * FROM WXYZ.[$table1]" + expanded = expander.expand(pathlib.Path("abc.sql"), input_text) + assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__" + un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded) + assert un_expanded == input_text + +def test_complex_2(): + expander = MacroExpanderRouter( + { + "*.sql": ParameterAwareMacroExpander( + pattern="\\[\\$(\\w+)\\]", + value_stripper='__bq__\\d+__(.*)__bq__\\d+__', + mapping={ + "table1": "__bq__0__ABCDEF__bq__0__", + "limit_val1": "__bq__1__5__bq__1__" + }, + source_bind_generator=lambda arg: "@"+arg, + target_bind_generator=lambda arg: arg + ) + } + ) + + input_text = "SELECT * FROM WXYZ.[$table1] LIMIT [$limit_val1]" + expanded = expander.expand(pathlib.Path("abc.sql"), input_text) + assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT @__bq__1__5__bq__1__" + expanded = "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT __bq__1__5__bq__1__" + un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded) + assert un_expanded == input_text + +def test_unexpand_after_database_added(): + expander = MacroExpanderRouter( + { + "*.sql": ParameterAwareMacroExpander( + pattern="(\\[\\$\\w+\\])", + value_stripper='__bq__\\d+__(.*)__bq__\\d+__', + mapping={ + "[$table]": "__bq__0__table__bq__0__", + }, + source_bind_generator=lambda arg: "@"+arg, + target_bind_generator=lambda arg: arg + ) + } + ) + + input_text = "CREATE TABLE [$table](a INT64);" + expanded = expander.expand(pathlib.Path("abc.sql"), input_text) + assert expanded == "CREATE TABLE __bq__0__table__bq__0__(a INT64);" + expanded = "CREATE TABLE db_name.__bq__0__table__bq__0__(a INT64);" + un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded) + expected_output = "CREATE TABLE db_name.[$table](a INT64);" + assert un_expanded == expected_output + + +