Skip to content

Commit a703e66

Browse files
authored
Merge pull request #19 from takluyver/unparse-rework
Restore unmodified code in parameters cell
2 parents 907e37e + 344e537 commit a703e66

File tree

6 files changed

+221
-44
lines changed

6 files changed

+221
-44
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
python-version: [ 3.6, 3.7, 3.8, 3.9, ]
10+
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
1111
steps:
1212
- uses: actions/checkout@v2
1313

nbparameterise/code.py

+57-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import importlib
33
import re
4+
from warnings import warn
45

56
from nbconvert.preprocessors import ExecutePreprocessor
67

@@ -51,6 +52,19 @@ def get_driver_module(nb, override=None):
5152
assert kernel_name_re.match(module_name)
5253
return importlib.import_module('nbparameterise.code_drivers.%s' % module_name)
5354

55+
def extract_parameter_dict(nb, lang=None):
56+
"""Returns a dictionary of Parameter objects derived from the notebook.
57+
58+
This looks for assignments (like 'n = 50') in the first code cell of the
59+
notebook. The parameters may also have some metadata stored in the notebook
60+
metadata; this will be attached as the .metadata instance on each one.
61+
62+
lang may be used to override the kernel name embedded in the notebook. For
63+
now, nbparameterise only handles 'python'.
64+
"""
65+
params = extract_parameters(nb, lang)
66+
return {p.name: p for p in params}
67+
5468
def extract_parameters(nb, lang=None):
5569
"""Returns a list of Parameter instances derived from the notebook.
5670
@@ -59,7 +73,7 @@ def extract_parameters(nb, lang=None):
5973
metadata; this will be attached as the .metadata instance on each one.
6074
6175
lang may be used to override the kernel name embedded in the notebook. For
62-
now, nbparameterise only handles 'python3' and 'python2'.
76+
now, nbparameterise only handles 'python'.
6377
"""
6478
drv = get_driver_module(nb, override=lang)
6579
params = list(drv.extract_definitions(first_code_cell(nb).source))
@@ -70,8 +84,8 @@ def extract_parameters(nb, lang=None):
7084

7185
return params
7286

73-
def parameter_values(params, **kwargs):
74-
"""Return a copy of the parameter list, substituting values from kwargs.
87+
def parameter_values(params, new_values=None, new='ignore', **kwargs):
88+
"""Return a new parameter list/dict, substituting values from kwargs.
7589
7690
Usage example::
7791
@@ -81,20 +95,42 @@ def parameter_values(params, **kwargs):
8195
)
8296
8397
Any parameters not supplied will keep their original value.
98+
Names not already in params are ignored by default, but can be added with
99+
``new='add'`` or cause an error with ``new='error'``.
100+
101+
This can be used with either a dict from :func:`extract_parameter_dict`
102+
or a list from :func:`extract_parameters`. It will return the corresponding
103+
container type.
84104
"""
85-
res = []
86-
for p in params:
87-
if p.name in kwargs:
88-
res.append(p.with_value(kwargs[p.name]))
89-
else:
90-
res.append(p)
105+
if new not in {'ignore', 'add', 'error'}:
106+
raise ValueError("new= must be one of 'ignore'/'add'/'error'")
107+
new_values = (new_values or {}).copy()
108+
new_values.update(kwargs)
109+
110+
if isinstance(params, dict):
111+
new_list = parameter_values(params.values(), new_values, new=new)
112+
return {p.name: p for p in new_list}
113+
114+
res = [p.with_value(new_values[p.name]) if p.name in new_values else p
115+
for p in params]
116+
117+
new_keys = set(new_values) - {p.name for p in params}
118+
if new == 'error':
119+
if new_keys:
120+
raise KeyError(f"Unexpected keys: {sorted(new_keys)}")
121+
elif new == 'add':
122+
for k in new_keys:
123+
value = new_values[k]
124+
res.append(Parameter(k, type(value), value))
125+
91126
return res
92127

93128
def replace_definitions(nb, values, execute=False, execute_resources=None,
94129
lang=None, *, comments=True):
95130
"""Return a copy of nb with the first code cell defining the given parameters.
96131
97-
values should be a list of Parameter objects (as returned by extract_parameters),
132+
values should be a dict (from :func:`extract_parameter_dict`) or a list
133+
(from :func:`extract_parameters`) of :class:`Parameter` objects,
98134
with their .value attribute set to the desired value.
99135
100136
If execute is True, the notebook is executed with the new values.
@@ -104,13 +140,20 @@ def replace_definitions(nb, values, execute=False, execute_resources=None,
104140
105141
lang may be used to override the kernel name embedded in the notebook. For
106142
now, nbparameterise only handles 'python3' and 'python2'.
107-
108-
If comment is True, comments attached to the parameters will be included
109-
in the replaced code, on the same line as the definition.
110143
"""
144+
if isinstance(values, list):
145+
values = {p.name: p for p in values}
146+
147+
if not comments:
148+
warn("comments=False is now ignored", stacklevel=2)
149+
111150
nb = copy.deepcopy(nb)
151+
params_cell = first_code_cell(nb)
152+
112153
drv = get_driver_module(nb, override=lang)
113-
first_code_cell(nb).source = drv.build_definitions(values, comments=comments)
154+
params_cell.source = drv.build_definitions(
155+
values, prev_code=params_cell.source
156+
)
114157
if execute:
115158
resources = execute_resources or {}
116159
nb, resources = ExecutePreprocessor().preprocess(nb, resources)

nbparameterise/code_drivers/python.py

+68-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import ast
2+
import sys
23

34
import astcheck
4-
import astsearch
5+
6+
try:
7+
from ast import unparse
8+
except ImportError:
9+
from astunparse import unparse
510

611
from io import StringIO
712
import tokenize
@@ -41,7 +46,7 @@ def check_fillable_node(node, path):
4146

4247
raise astcheck.ASTMismatch(path, node, 'number, string, boolean, list or dict')
4348

44-
definition_pattern = ast.Assign(targets=[ast.Name()], value=check_fillable_node)
49+
definition_pattern = astcheck.single_assign(target=ast.Name(), value=check_fillable_node)
4550

4651
def type_and_value(node, comments={}):
4752
comment = comments.get(node.lineno, None)
@@ -77,20 +82,66 @@ def extract_comments(cell: str):
7782
comments[rowcol[0]] = tstr
7883
return comments
7984

80-
def extract_definitions(cell):
85+
def find_assignments(cell):
8186
cell_ast = ast.parse(cell)
87+
88+
# We only want global assignments, so we're not walking the AST here.
89+
for stmt in cell_ast.body:
90+
if astcheck.is_ast_like(stmt, definition_pattern):
91+
if isinstance(stmt, ast.AnnAssign):
92+
name = stmt.target.id
93+
else: # ast.Assign
94+
name = stmt.targets[0].id
95+
yield name, stmt
96+
97+
def extract_definitions(cell):
8298
comments = extract_comments(cell)
83-
for assign in astsearch.ASTPatternFinder(definition_pattern).scan_ast(cell_ast):
84-
typ, val, comment = type_and_value(assign.value, comments)
85-
yield Parameter(assign.targets[0].id, typ, val, comment=comment)
86-
87-
def build_definitions(inputs, comments=True):
88-
defs = []
89-
for param in inputs:
90-
s = f"{param.name} = {param.value!r}"
91-
if comments and param.comment:
92-
comment = param.comment if param.comment.startswith('#') \
93-
else '# ' + param.comment.lstrip()
94-
s += f" {comment}"
95-
defs.append(s)
96-
return "\n".join(defs)
99+
for name, stmt in find_assignments(cell):
100+
typ, val, comment = type_and_value(stmt.value, comments)
101+
yield Parameter(name, typ, val, comment=comment)
102+
103+
104+
def build_definitions(params: dict, prev_code):
105+
"""Rebuild code with modified parameters
106+
107+
This function for Python >= 3.8 (?) preserves the existing code structure
108+
& comments, only replacing assignment values within the code.
109+
"""
110+
# [end_]col_offset count UTF-8 bytes, so we encode the code here and decode
111+
# again after slicing.
112+
# Stick None in to allow 1-based line indexing
113+
old_lines = [None] + prev_code.encode().splitlines(keepends=True)
114+
from_line, from_col = 1, 0
115+
vars_used = set()
116+
output = []
117+
for name, stmt in find_assignments(prev_code):
118+
if name not in params:
119+
continue # Leave the existing value
120+
121+
vars_used.add(name)
122+
123+
vn = stmt.value
124+
if vn.lineno == from_line: # Same line as last value we replaced
125+
output.append(old_lines[from_line][from_col:vn.col_offset].decode())
126+
else: # On a new line
127+
output.append(old_lines[from_line][from_col:].decode())
128+
output.extend([l.decode() for l in old_lines[from_line+1 : vn.lineno]])
129+
output.append(old_lines[vn.lineno][:vn.col_offset].decode())
130+
from_line, from_col = vn.end_lineno, vn.end_col_offset
131+
132+
# Substitute in the new value for the variable
133+
output.append(repr(params[name].value))
134+
135+
# Copy across any remaining code to the end of the cell
136+
output.append(old_lines[from_line][from_col:].decode())
137+
output.extend([l.decode() for l in old_lines[from_line+1 :]])
138+
139+
# Add in any variables for which we have a value but weren't in the code
140+
unused_vars = set(params) - vars_used
141+
if unused_vars:
142+
output.append('\n\n')
143+
for name in unused_vars:
144+
output.append(f"{name} = {params[name].value!r}\n")
145+
146+
return ''.join(output)
147+

pyproject.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ["flit_core >=3.2,<3.3"]
2+
requires = ["flit_core >=3.2,<4"]
33
build-backend = "flit_core.buildapi"
44

55
[project]
@@ -12,8 +12,12 @@ classifiers = [
1212
"Framework :: Jupyter",
1313
]
1414
readme = "README.rst"
15-
dependencies = ["nbconvert", "astsearch"]
16-
requires-python = ">=3.6"
15+
dependencies = [
16+
"nbconvert",
17+
"astcheck >=0.3",
18+
"astunparse; python_version < '3.9'",
19+
]
20+
requires-python = ">=3.8"
1721
dynamic = ['version', 'description']
1822

1923
[project.urls]

tests/sample.ipynb

+19-8
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414
}
1515
],
1616
"source": [
17-
"a = \"Some text\"\n",
18-
"b = 12\n",
19-
"b2 = -7\n",
20-
"c = 14.0\n",
17+
"a = \"\"\"\\\n",
18+
"Some text\"\"\"\n",
19+
"b: int = 12\n",
20+
"b2 = -7; c = 14.0\n",
2121
"d = False # comment:bool\n",
2222
"e = [0, 1.0, True, \"text\", [0, 1]]\n",
23-
"f = {0: 0, \"item\": True, \"dict\": {0: \"text\"}} # comment:dict\n",
24-
"print(\"This should be ignored\")"
23+
"f = { # comment:dict\n",
24+
" 0: 0,\n",
25+
" \"item\": True,\n",
26+
" \"dict\": {0: \"text\"}\n",
27+
"}\n",
28+
"print(\"This should be ignored\")\n",
29+
"\n",
30+
"café = \"καφές\" # Not only ASCII\n",
31+
"other_assignment = b ** 2 # Not recognised as a parameter\n",
32+
"\n",
33+
"def func():\n",
34+
" not_a_parameter = 5 # Should not be found\n",
35+
" return c"
2536
]
2637
},
2738
{
@@ -48,7 +59,7 @@
4859
"metadata": {
4960
"celltoolbar": "Edit Metadata",
5061
"kernelspec": {
51-
"display_name": "Python 3",
62+
"display_name": "Python 3 (ipykernel)",
5263
"language": "python",
5364
"name": "python3"
5465
},
@@ -62,7 +73,7 @@
6273
"name": "python",
6374
"nbconvert_exporter": "python",
6475
"pygments_lexer": "ipython3",
65-
"version": "3.9.2"
76+
"version": "3.11.1"
6677
},
6778
"parameterise": {
6879
"c": {

0 commit comments

Comments
 (0)