Skip to content

Commit 13dd324

Browse files
Merge pull request #5 from analysiscenter/mask_extra_code_run_notebook
Add `mask_extra_code` option `into run_notebook`
2 parents 55c0002 + 85a1374 commit 13dd324

File tree

3 files changed

+205
-7
lines changed

3 files changed

+205
-7
lines changed

nbtools/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" Init file. """
2-
from .pylint_notebook import pylint_notebook
2+
#pylint: disable=wildcard-import
3+
from .core import *
34
from .run_notebook import run_notebook
4-
from .core import StringWithDisabledRepr, notebook_to_script
5+
from .pylint_notebook import pylint_notebook
56

67
__version__ = '0.9.7'

nbtools/core.py

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
""" !!. """
1+
""" Core utility functions to work with Jupyter Notebooks. """
22
#pylint: disable=import-outside-toplevel
33
import os
44
import re
55
import json
6+
import warnings
67

78

89

910
class StringWithDisabledRepr(str):
1011
""" String with disabled repr. Used to avoid cluttering repr from function outputs. """
1112
def __repr__(self):
12-
""" !!. """
13+
""" Shorten the repr of a string. """
1314
return f'<StringWithDisabledRepr at {hex(id(self))}. Use `str`/`print` explicitly!>'
1415

1516

@@ -70,9 +71,8 @@ def get_notebook_name():
7071
return os.path.splitext(get_notebook_path())[0].split('/')[-1]
7172

7273

73-
7474
def notebook_to_script(path_script, path_notebook=None, ignore_markdown=True, return_info=False):
75-
""" !!. """
75+
""" Convert a notebook to a script. """
7676
import nbformat #pylint: disable=import-outside-toplevel
7777
path_notebook = path_notebook or get_notebook_path()
7878
if path_notebook is None:
@@ -113,3 +113,115 @@ def notebook_to_script(path_script, path_notebook=None, ignore_markdown=True, re
113113
return {'code': StringWithDisabledRepr(code),
114114
'cell_line_numbers': cell_line_numbers}
115115
return None
116+
117+
118+
119+
def get_available_gpus(n=1, min_free_memory=0.9, max_processes=2, verbose=False, raise_error=False):
120+
""" Select `n` gpus from available and free devices.
121+
122+
Parameters
123+
----------
124+
n : int, str
125+
If `max`, then use maximum number of available devices.
126+
If int, then number of devices to select.
127+
min_free_memory : float
128+
Minimum percentage of free memory on a device to consider it free.
129+
max_processes : int
130+
Maximum amount of computed processes on a device to consider it free.
131+
verbose : bool
132+
Whether to show individual device information.
133+
raise_error : bool
134+
Whether to raise an exception if not enough devices are available.
135+
136+
Returns
137+
-------
138+
List with indices of available GPUs
139+
"""
140+
try:
141+
import nvidia_smi
142+
except ImportError as exception:
143+
raise ImportError('Install Python interface for nvidia_smi') from exception
144+
145+
nvidia_smi.nvmlInit()
146+
n_devices = nvidia_smi.nvmlDeviceGetCount()
147+
148+
available_devices, memory_usage = [], []
149+
for i in range(n_devices):
150+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
151+
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
152+
153+
fraction_free = info.free / info.total
154+
num_processes = len(nvidia_smi.nvmlDeviceGetComputeRunningProcesses(handle))
155+
156+
consider_available = (fraction_free > min_free_memory) & (num_processes <= max_processes)
157+
if consider_available:
158+
available_devices.append(i)
159+
memory_usage.append(fraction_free)
160+
161+
if verbose:
162+
print(f'Device {i} | Free memory: {fraction_free:4.2f} | '
163+
f'Number of running processes: {num_processes:>2} | Free: {consider_available}')
164+
165+
if isinstance(n, str) and n.startswith('max'):
166+
n = len(available_devices)
167+
168+
if len(available_devices) < n:
169+
msg = f'Not enough free devices: requested {n}, found {len(available_devices)}'
170+
if raise_error:
171+
raise ValueError(msg)
172+
warnings.warn(msg, RuntimeWarning)
173+
174+
# Argsort of `memory_usage` in a descending order
175+
indices = sorted(range(len(available_devices)), key=memory_usage.__getitem__, reverse=True)
176+
available_devices = [available_devices[i] for i in indices]
177+
return sorted(available_devices[:n])
178+
179+
def get_gpu_free_memory(index):
180+
""" Get free memory of a device. """
181+
try:
182+
import nvidia_smi
183+
except ImportError as exception:
184+
raise ImportError('Install Python interface for nvidia_smi') from exception
185+
186+
nvidia_smi.nvmlInit()
187+
nvidia_smi.nvmlDeviceGetCount()
188+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(index)
189+
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
190+
191+
return info.free / info.total
192+
193+
def set_gpus(n=1, min_free_memory=0.9, max_processes=2, verbose=False, raise_error=False):
194+
""" Set the `CUDA_VISIBLE_DEVICES` variable to `n` available devices.
195+
196+
Parameters
197+
----------
198+
n : int, str
199+
If `max`, then use maximum number of available devices.
200+
If int, then number of devices to select.
201+
min_free_memory : float
202+
Minimum percentage of free memory on a device to consider it free.
203+
max_processes : int
204+
Maximum amount of computed processes on a device to consider it free.
205+
verbose : bool or int
206+
Whether to show individual device information.
207+
If 0 or False, then no information is displayed.
208+
If 1 or True, then display the value assigned to `CUDA_VISIBLE_DEVICES` variable.
209+
If 2, then display memory and process information for each device.
210+
raise_error : bool
211+
Whether to raise an exception if not enough devices are available.
212+
"""
213+
#pylint: disable=consider-iterating-dictionary
214+
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
215+
str_devices = os.environ["CUDA_VISIBLE_DEVICES"]
216+
warnings.warn(f'`CUDA_VISIBLE_DEVICES` is already set to "{str_devices}"!')
217+
return [int(d) for d in str_devices.split(',')]
218+
219+
devices = get_available_gpus(n=n, min_free_memory=min_free_memory, max_processes=max_processes,
220+
verbose=(verbose==2), raise_error=raise_error)
221+
str_devices = ','.join(str(i) for i in devices)
222+
os.environ['CUDA_VISIBLE_DEVICES'] = str_devices
223+
224+
newline = "\n" if verbose==2 else ""
225+
if verbose:
226+
print(f'{newline}`CUDA_VISIBLE_DEVICES` set to "{str_devices}"')
227+
return devices

nbtools/run_notebook.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
2929
locals().update(inputs)
3030
"""
31+
YAML_IMPORT = """import yaml"""
32+
INPUTS_DISPLAY = """print(yaml.dump(inputs))"""
33+
3134
INPUTS_CODE_CELL = dedent(INPUTS_CODE_CELL)
3235

3336
# Save notebook outputs
@@ -43,12 +46,15 @@
4346
with shelve.open(out_path_db) as notebook_db:
4447
notebook_db['outputs'] = output
4548
"""
49+
OUTPUTS_DISPLAY = """print(yaml.dump(output))"""
50+
4651
OUTPUTS_CODE_CELL = dedent(OUTPUTS_CODE_CELL)
4752

4853

4954
def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = './', execute_kwargs=None,
5055
out_path_db=None, out_path_ipynb=None, out_path_html=None, remove_db='always', add_timestamp=True,
51-
hide_code_cells=False, display_links=True, raise_exception=False, return_notebook=False):
56+
hide_code_cells=False, mask_extra_code=True, display_links=True,
57+
raise_exception=False, return_notebook=False):
5258
""" Execute a Jupyter Notebook programmatically.
5359
Heavily inspired by https://github.com/tritemio/nbrun.
5460
@@ -110,6 +116,9 @@ def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = '.
110116
Whether to add a cell with execution information at the beginning of the saved notebook.
111117
hide_code_cells : bool, optional
112118
Whether to hide the code cells in the saved notebook.
119+
mask_extra_code : bool, optional
120+
Whether to mask database reading and dumping code.
121+
For more, see :func:`~.mask_inputs_reading` and :func`~.mask_outputs_dumping` docstrings.
113122
display_links : bool, optional
114123
Whether to display links to the executed notebook and html at execution.
115124
raise_exception : bool, optional
@@ -182,6 +191,9 @@ def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = '.
182191
notebook_db.update(inputs)
183192

184193
code = CELL_INSERT_COMMENT + DB_CONNECT_CODE_CELL.format(repr(out_path_db)) + INPUTS_CODE_CELL
194+
if mask_extra_code:
195+
code += YAML_IMPORT + '\n' + INPUTS_DISPLAY
196+
185197
notebook['cells'].insert(inputs_pos, nbformat.v4.new_code_cell(code))
186198

187199
if outputs is not None:
@@ -191,6 +203,12 @@ def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = '.
191203
code = CELL_INSERT_COMMENT + \
192204
(DB_CONNECT_CODE_CELL.format(repr(out_path_db)) if not inputs else "") + \
193205
OUTPUTS_CODE_CELL.format(outputs)
206+
207+
if mask_extra_code:
208+
if inputs is None:
209+
code = YAML_IMPORT + '\n' + code
210+
code += OUTPUTS_DISPLAY
211+
194212
output_cell = nbformat.v4.new_code_cell(code)
195213
notebook['cells'].append(output_cell)
196214

@@ -240,13 +258,23 @@ def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = '.
240258
if outputs is not None:
241259
exec_res['outputs'] = outputs_values
242260

261+
# Notebook postprocessing: add timestamp, mask db reading/dumping code
243262
if add_timestamp:
244263
timestamp = (f"**Executed:** {time.ctime(start_time)}<br>"
245264
f"**Duration:** {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}<br>"
246265
f"**Autogenerated from:** [{path}]\n\n---")
247266
timestamp_cell = nbformat.v4.new_markdown_cell(timestamp)
248267
notebook['cells'].insert(0, timestamp_cell)
249268

269+
if mask_extra_code:
270+
if inputs is not None:
271+
pos = inputs_pos + 1 if add_timestamp else inputs_pos
272+
mask_inputs_reading(notebook=notebook, pos=pos)
273+
274+
if outputs is not None:
275+
pos = len(notebook['cells']) - 1
276+
mask_outputs_dumping(notebook=notebook, pos=pos)
277+
250278
# Save the executed notebook/HTML to disk
251279
if out_path_ipynb is not None:
252280
save_notebook(notebook=notebook, out_path_ipynb=out_path_ipynb, display_link=display_links)
@@ -257,6 +285,63 @@ def run_notebook(path, inputs=None, outputs=None, inputs_pos=1, working_dir = '.
257285
exec_res['notebook'] = notebook
258286
return exec_res
259287

288+
# Mask functions for database operations cells
289+
def mask_inputs_reading(notebook, pos):
290+
""" Replace database reading by variables initialization.
291+
292+
Result is a code cell with the following view:
293+
.. code-block:: python
294+
295+
varible_name_1 = varible_value_1
296+
varible_name_2 = varible_value_2
297+
...
298+
"""
299+
import nbformat
300+
301+
execution_count = notebook['cells'][pos]['execution_count']
302+
303+
code_mask = str(notebook['cells'][pos]['outputs'][0]['text']).split('\n')
304+
code_mask = [variable.replace(':', ' =', 1) for variable in code_mask]
305+
code_mask = '\n'.join(code_mask)[:-2]
306+
307+
cell_mask = nbformat.v4.new_code_cell(source=code_mask, execution_count=execution_count)
308+
notebook['cells'][pos] = cell_mask
309+
310+
def mask_outputs_dumping(notebook, pos):
311+
"""Replace database dumping by printing outputs.
312+
313+
Result is a code cell with the following view and corresponding output:
314+
.. code-block:: python
315+
316+
print(varible_name_1)
317+
print(varible_name_2)
318+
...
319+
"""
320+
import nbformat
321+
322+
execution_count = notebook['cells'][pos]['execution_count']
323+
outputs_variables = str(notebook['cells'][pos]['outputs'][0]['text']).split('\n')
324+
325+
code_mask = ''
326+
text_mask = ''
327+
328+
for variable in outputs_variables:
329+
separator_pos = int(variable.find(':'))
330+
331+
if separator_pos != -1:
332+
variable_name = variable[:separator_pos]
333+
variable_value = variable[separator_pos+2:] + '\n'
334+
335+
code_mask += f'print({variable_name})\n'
336+
text_mask += variable_value
337+
338+
code_mask = code_mask[:-1]
339+
outputs_mask = [nbformat.v4.new_output(text=text_mask, name='stdout', output_type='stream')]
340+
341+
cell_mask = nbformat.v4.new_code_cell(source=code_mask, execution_count=execution_count, outputs=outputs_mask)
342+
notebook['cells'][pos] = cell_mask
343+
344+
260345
# Save notebook functions
261346
def save_notebook(notebook, out_path_ipynb, display_link):
262347
""" Save an instance of :class:`nbformat.notebooknode.NotebookNode` as ipynb file."""

0 commit comments

Comments
 (0)