Skip to content

Commit c106d40

Browse files
committed
Fix un-expansion of macros when the spacing around its use is changed by translation
1 parent 8980efb commit c106d40

File tree

2 files changed

+75
-9
lines changed

2 files changed

+75
-9
lines changed

client/bqms_run/macros.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ def __init__(
235235
target_bind_generator: A function that format parameter in a way that it's detected as bind parameter in target dialect
236236
"""
237237
super().__init__(mapping)
238-
# what if we have mapping and wrap it here.
239-
pattern = f"([=\\(, \\[]?){pattern}"
240238
self.pattern = re.compile(pattern, re.I)
241239
self.source_bind_generator = source_bind_generator
242240
self.target_bind_generator = target_bind_generator
@@ -245,18 +243,17 @@ def __init__(
245243
self.value_stripper = re.compile(value_stripper)
246244

247245
def _substitution(self, path: Path, match: Match[str]) -> str:
248-
prefix = match.group(1)
249-
macro_name = match.group(2)
246+
macro_name = match.group(1)
250247
full_match = match.group(0)
251248
if self.mapping and macro_name in self.mapping:
252249
replacement = self.mapping[macro_name]
253250
stripped_replacement = self.value_stripper.match(replacement).group(1)
254-
if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")) and prefix:
255-
generated = '{}{}'.format(prefix, self.source_bind_generator(self.mapping[macro_name]))
256-
reverse_search = '{}{}'.format(prefix, self.target_bind_generator(self.mapping[macro_name]))
251+
if (stripped_replacement.isnumeric() or stripped_replacement.lower() in ("true", "false")):
252+
generated = self.source_bind_generator(self.mapping[macro_name])
253+
reverse_search = self.target_bind_generator(self.mapping[macro_name])
257254
else:
258-
generated = '{}{}'.format(prefix,self.mapping[macro_name])
259-
reverse_search = '{}{}'.format(prefix,self.mapping[macro_name])
255+
generated = self.mapping[macro_name]
256+
reverse_search = self.mapping[macro_name]
260257
else:
261258
self.warn_log(
262259
"Could not expand '{0}' as it is not "

client/tests/unit/test_macros.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MacroExpanderRouter,
2121
PatternMacroExpander,
2222
SimpleMacroExpander,
23+
ParameterAwareMacroExpander
2324
)
2425

2526

@@ -160,3 +161,71 @@ def test_basic_expand():
160161
assert expanded == "abcdef alpha.bravo ghijkl"
161162
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
162163
assert un_expanded == input_text
164+
165+
def test_complex_1():
166+
expander = MacroExpanderRouter(
167+
{
168+
"*.sql": ParameterAwareMacroExpander(
169+
pattern="(\\[\\$\\w+\\])",
170+
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
171+
mapping={"[$table1]": "__bq__0__ABCDEF__bq__0__"},
172+
source_bind_generator=lambda arg: "@"+arg,
173+
target_bind_generator=lambda arg: arg
174+
)
175+
}
176+
)
177+
178+
input_text="SELECT * FROM WXYZ.[$table1]"
179+
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
180+
assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__"
181+
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
182+
assert un_expanded == input_text
183+
184+
def test_complex_2():
185+
expander = MacroExpanderRouter(
186+
{
187+
"*.sql": ParameterAwareMacroExpander(
188+
pattern="\\[\\$(\\w+)\\]",
189+
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
190+
mapping={
191+
"table1": "__bq__0__ABCDEF__bq__0__",
192+
"limit_val1": "__bq__1__5__bq__1__"
193+
},
194+
source_bind_generator=lambda arg: "@"+arg,
195+
target_bind_generator=lambda arg: arg
196+
)
197+
}
198+
)
199+
200+
input_text = "SELECT * FROM WXYZ.[$table1] LIMIT [$limit_val1]"
201+
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
202+
assert expanded == "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT @__bq__1__5__bq__1__"
203+
expanded = "SELECT * FROM WXYZ.__bq__0__ABCDEF__bq__0__ LIMIT __bq__1__5__bq__1__"
204+
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
205+
assert un_expanded == input_text
206+
207+
def test_unexpand_after_database_added():
208+
expander = MacroExpanderRouter(
209+
{
210+
"*.sql": ParameterAwareMacroExpander(
211+
pattern="(\\[\\$\\w+\\])",
212+
value_stripper='__bq__\\d+__(.*)__bq__\\d+__',
213+
mapping={
214+
"[$table]": "__bq__0__table__bq__0__",
215+
},
216+
source_bind_generator=lambda arg: "@"+arg,
217+
target_bind_generator=lambda arg: arg
218+
)
219+
}
220+
)
221+
222+
input_text = "CREATE TABLE [$table](a INT64);"
223+
expanded = expander.expand(pathlib.Path("abc.sql"), input_text)
224+
assert expanded == "CREATE TABLE __bq__0__table__bq__0__(a INT64);"
225+
expanded = "CREATE TABLE db_name.__bq__0__table__bq__0__(a INT64);"
226+
un_expanded = expander.un_expand(pathlib.Path("abc.sql"), expanded)
227+
expected_output = "CREATE TABLE db_name.[$table](a INT64);"
228+
assert un_expanded == expected_output
229+
230+
231+

0 commit comments

Comments
 (0)