Skip to content

Commit d25a07a

Browse files
committed
Allow replacing notebooks variables
1 parent e4b29e2 commit d25a07a

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

tests/utils_for_testbook.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,21 @@
6060
def wrap_testbook(
6161
notebook_name: str,
6262
timeout_seconds: float = 10,
63-
replacements: list[tuple[str, str]] | None = None,
63+
replacements_regex: list[tuple[str, str]] | None = None,
64+
replacements_variables: list[tuple[str, str]] | None = None,
6465
) -> Callable:
6566
def inner_decorator(func: Callable) -> Any:
6667
_patch_testbook()
6768

6869
notebook_path = resolve_notebook_path(notebook_name)
6970

70-
with NotebookReplace(notebook_path, replacements):
71+
with NotebookReplace(
72+
notebook_path, replacements_regex, replacements_variables
73+
) as nr:
7174
for decorator in [
7275
_build_patch_testbook_client_decorator(notebook_name),
7376
testbook(notebook_path, execute=True, timeout=timeout_seconds),
74-
(lambda x: x) if replacements else qmod_compare_decorator,
77+
(lambda x: x) if nr.replacements else qmod_compare_decorator,
7578
_build_cd_decorator(notebook_path),
7679
_build_skip_decorator(notebook_path),
7780
]:
@@ -101,15 +104,34 @@ def inner(*args: Any, **kwargs: Any) -> Any:
101104
@dataclass
102105
class NotebookReplace:
103106
file_path: str
104-
replacements: list[tuple[str, str]] | None
107+
replacements_regex: list[tuple[str, str]] | None
108+
replacements_variables: list[tuple[str, str]] | None
105109

106110
def __post_init__(self):
107111
self.was_file_copied = False
108112

113+
self.replacements = self._group_replacements()
114+
109115
@property
110116
def file_path_copied(self):
111117
return self.file_path + FILE_COPY_SUFFIX
112118

119+
def _group_replacements(self) -> list[tuple[str, str]]:
120+
replacements = []
121+
122+
if self.replacements_regex:
123+
replacements.extend(self.replacements_regex)
124+
125+
if self.replacements_variables:
126+
replacements.extend(
127+
[
128+
(f"({variable}\\s*=\\s*)", f"\\1 {new_value} # ")
129+
for variable, new_value in self.replacements_variables
130+
]
131+
)
132+
133+
return replacements
134+
113135
def __enter__(self):
114136
if not self.replacements:
115137
return
@@ -118,11 +140,11 @@ def __enter__(self):
118140
self.was_file_copied = True
119141

120142
used_replacements = self._replace_notebook_content()
121-
122-
# verify all replacements were used
123143
assert (
124144
self.replacements == used_replacements
125-
), f"Not all replacements given were used. The onces used are: {used_replacements}. The unused are {[r for r in replacements if r not in used_replacements]}"
145+
), f"Not all replacements given were used. The onces used are: {used_replacements}. The unused are {[r for r in self.replacements if r not in used_replacements]}"
146+
147+
return self
126148

127149
def __exit__(self, *args, **kwargs):
128150
if not self.replacements:

0 commit comments

Comments
 (0)