Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions client/bqms_run/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ def __init__(
target_bind_generator: A function that format parameter in a way that it's detected as bind parameter in target dialect
"""
super().__init__(mapping)
# what if we have mapping and wrap it here.
pattern = f"([=\\(, \\[]?){pattern}"
self.pattern = re.compile(pattern, re.I)
self.source_bind_generator = source_bind_generator
self.target_bind_generator = target_bind_generator
Expand All @@ -245,18 +243,17 @@ def __init__(
self.value_stripper = re.compile(value_stripper)

def _substitution(self, path: Path, match: Match[str]) -> str:
prefix = match.group(1)
macro_name = match.group(2)
macro_name = match.group(1)
full_match = match.group(0)
if self.mapping and macro_name in self.mapping:
replacement = self.mapping[macro_name]
stripped_replacement = self.value_stripper.match(replacement).group(1)
if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")) and prefix:
generated = '{}{}'.format(prefix, self.source_bind_generator(self.mapping[macro_name]))
reverse_search = '{}{}'.format(prefix, self.target_bind_generator(self.mapping[macro_name]))
if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")):
generated = self.source_bind_generator(self.mapping[macro_name])
reverse_search = self.target_bind_generator(self.mapping[macro_name])
else:
generated = '{}{}'.format(prefix,self.mapping[macro_name])
reverse_search = '{}{}'.format(prefix,self.mapping[macro_name])
generated = self.mapping[macro_name]
reverse_search = self.mapping[macro_name]
else:
self.warn_log(
"Could not expand '{0}' as it is not "
Expand Down
69 changes: 69 additions & 0 deletions client/tests/unit/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MacroExpanderRouter,
PatternMacroExpander,
SimpleMacroExpander,
ParameterAwareMacroExpander
)


Expand Down Expand Up @@ -160,3 +161,71 @@ def test_basic_expand():
assert expanded == "abcdef alpha.bravo ghijkl"
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
assert un_expanded == input_text

def test_complex_1():
expander = MacroExpanderRouter(
{
"*.sql": ParameterAwareMacroExpander(
pattern="(\\[\\$\\w+\\])",
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
mapping={"[$table1]": "__bq__0__ABCDEF__bq__0__"},
source_bind_generator=lambda arg: "@"+arg,
target_bind_generator=lambda arg: arg
)
}
)

input_text="SELECT * FROM WXYZ.[$table1]"
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__"
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
assert un_expanded == input_text

def test_complex_2():
expander = MacroExpanderRouter(
{
"*.sql": ParameterAwareMacroExpander(
pattern="\\[\\$(\\w+)\\]",
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
mapping={
"table1": "__bq__0__ABCDEF__bq__0__",
"limit_val1": "__bq__1__5__bq__1__"
},
source_bind_generator=lambda arg: "@"+arg,
target_bind_generator=lambda arg: arg
)
}
)

input_text = "SELECT * FROM WXYZ.[$table1] LIMIT [$limit_val1]"
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT @__bq__1__5__bq__1__"
expanded = "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT __bq__1__5__bq__1__"
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
assert un_expanded == input_text

def test_unexpand_after_database_added():
expander = MacroExpanderRouter(
{
"*.sql": ParameterAwareMacroExpander(
pattern="(\\[\\$\\w+\\])",
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
mapping={
"[$table]": "__bq__0__table__bq__0__",
},
source_bind_generator=lambda arg: "@"+arg,
target_bind_generator=lambda arg: arg
)
}
)

input_text = "CREATE TABLE [$table](a INT64);"
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
assert expanded == "CREATE TABLE __bq__0__table__bq__0__(a INT64);"
expanded = "CREATE TABLE db_name.__bq__0__table__bq__0__(a INT64);"
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
expected_output = "CREATE TABLE db_name.[$table](a INT64);"
assert un_expanded == expected_output



Loading