Skip to content

Commit 2e079aa

Browse files
authored
[140] Ensuring that lists and dictionaries are included in signatures (#141)
* work * adding test * removing old file * remove old comment
1 parent ab3afdb commit 2e079aa

File tree

10 files changed

+173
-57
lines changed

10 files changed

+173
-57
lines changed

dds/_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
FunctionInteractionsUtils,
3131
FunctionIndirectInteractionUtils,
3232
)
33-
from . import _options
33+
from ._config import get_option, extra_debug_option
3434

3535
_Out = TypeVar("_Out")
3636
_In = TypeVar("_In")
@@ -192,7 +192,7 @@ def _eval(
192192

193193
stages = _parse_stages(dds_stages)
194194

195-
extra_debug = dds_extra_debug or _options._dds_extra_debug
195+
extra_debug = dds_extra_debug or get_option(extra_debug_option)
196196

197197
if not _eval_ctx:
198198
# Not in an evaluation context, create one and introspect

dds/_config.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,46 @@ def validate(self, v: Any) -> None:
105105

106106
# Available options.
107107
#
108+
109+
accept_list_option = Option(
110+
key="accept_list",
111+
doc=(
112+
"Accepts lists as objects. If true, lists are then traversed and their content is included in the signature"
113+
" (default true)"
114+
),
115+
default=True,
116+
types=(bool,),
117+
check_func=(lambda v: True, "",),
118+
)
119+
120+
accept_dict_option = Option(
121+
key="accept_dict",
122+
doc=(
123+
"Accepts dictionaries as objects. If true, lists are then traversed and their content is included "
124+
"in the signature"
125+
" (default true)"
126+
),
127+
default=True,
128+
types=(bool,),
129+
check_func=(lambda v: True, "",),
130+
)
131+
132+
extra_debug_option = Option(
133+
key="extra_debug",
134+
doc=(
135+
"Prints and evaluates extra debugging information. This information requires extra roundtrips to the "
136+
"storage backend. It is disabled by default to assist with debugging, but it can be disabled if "
137+
"I/O with the storage backend is an issue."
138+
),
139+
default=True,
140+
types=(bool,),
141+
check_func=(lambda v: True, "",),
142+
)
143+
108144
_options: List[Option] = [
145+
extra_debug_option,
146+
accept_list_option,
147+
accept_dict_option,
109148
Option(
110149
key="hash.max_sequence_size",
111150
doc=(
@@ -154,7 +193,7 @@ def show_options():
154193
print(row_format.format("=" * 31, "=" * 14, "=" * 53))
155194

156195

157-
def get_option(key: str, default: Union[Any, None] = None) -> Any:
196+
def get_option(key: Union[Option, str], default: Union[Any, None] = None) -> Any:
158197
"""
159198
Retrieves the value of the specified option.
160199
@@ -173,6 +212,8 @@ def get_option(key: str, default: Union[Any, None] = None) -> Any:
173212
------
174213
DDSException : if no such option exists and the default is not provided
175214
"""
215+
if isinstance(key, Option):
216+
return get_option(key.key)
176217
_check_option(key)
177218
if default is None:
178219
default = _options_dict[key].default

dds/_options.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

dds/_retrieve_objects.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from .structures import DDSException, CanonicalPath, LocalDepPath, DDSErrorCode
2525
from .structures_utils import LocalDepPathUtils, CanonicalPathUtils
26+
from ._config import get_option, accept_list_option, accept_dict_option
2627

2728
_logger = logging.getLogger(__name__)
2829

@@ -55,6 +56,11 @@ def _is_authorized_type(tpe: Type[Any], gctx: EvalMainContext) -> bool:
5556
return True
5657
if tpe in (int, float, str, bytes, PurePosixPath, FunctionType, ModuleType):
5758
return True
59+
# Some specific structural types are more complex and can be user-controlled.
60+
if get_option(accept_list_option) and tpe in (list,):
61+
return True
62+
if get_option(accept_dict_option) and tpe in (dict,):
63+
return True
5864
if issubclass(tpe, object):
5965
mod = inspect.getmodule(tpe)
6066
if mod is None:
@@ -90,9 +96,11 @@ def retrieve_object(
9096
obj_key = (local_path, mod_path)
9197

9298
if obj_key in gctx.cached_objects:
93-
# _logger.debug(f"retrieve_object: found in cache: obj_key: {obj_key}")
99+
if debug:
100+
_logger.debug(f"retrieve_object: found in cache: obj_key: {obj_key}")
94101
return gctx.cached_objects[obj_key]
95-
# _logger.debug(f"retrieve_object: not found in cache: obj_key: {obj_key}")
102+
if debug:
103+
_logger.debug(f"retrieve_object: not found in cache: obj_key: {obj_key}")
96104

97105
fname = local_path.parts[0]
98106
sub_path = LocalDepPathUtils.tail(local_path)
@@ -112,21 +120,19 @@ def retrieve_object(
112120
# Looking into the globals (only if the scope is currently __main__ or __global__)
113121
mod_path = _mod_path(context_mod)
114122
if CanonicalPathUtils.head(mod_path) not in ("__main__", "__global__"):
115-
# _logger.debug(
116-
# f"Could not load name %s and not in global context (%s), skipping ",
117-
# fname,
118-
# mod_path,
119-
# )
123+
if debug:
124+
_logger.debug(
125+
f"Could not load name %s and not in global context (%s), skipping ",
126+
fname,
127+
mod_path,
128+
)
120129
return None
121130
else:
122-
# _logger.debug(
123-
# f"Could not load name %s, looking into the globals (mod_path: %s, %s)",
124-
# fname,
125-
# mod_path,
126-
# mod_path.get(0),
127-
# )
128131
pass
129-
# _logger.debug(f"Could not load name {fname}, looking into the globals")
132+
if debug:
133+
_logger.debug(
134+
f"Could not load name {fname}, looking into the globals"
135+
)
130136
if fname in gctx.start_globals:
131137
# _logger.debug(f"Found {fname} in start_globals")
132138
obj = gctx.start_globals[fname]
@@ -151,10 +157,11 @@ def retrieve_object(
151157
["__global__"] + [str(x) for x in local_path.parts]
152158
)
153159
if not gctx.is_authorized_path(obj_path):
154-
# _logger.debug(
155-
# f"Object[start_globals] {fname} of type {type(obj)} is not authorized (path),"
156-
# f" dropping path {obj_path}"
157-
# )
160+
if debug:
161+
_logger.debug(
162+
f"Object[start_globals] {fname} of type {type(obj)} is not authorized (path),"
163+
f" dropping path {obj_path}"
164+
)
158165
res = ExternalObject(obj_path)
159166
gctx.cached_objects[obj_key] = res
160167
return res
@@ -171,21 +178,24 @@ def retrieve_object(
171178
str,
172179
),
173180
):
174-
# _logger.debug(
175-
# f"Object[start_globals] {fname} ({type(obj)}) of path {obj_path} is authorized,"
176-
# )
181+
if debug:
182+
_logger.debug(
183+
f"Object[start_globals] {fname} ({type(obj)}) of path {obj_path} is authorized,"
184+
)
177185
res = AuthorizedObject(obj, obj_path)
178186
gctx.cached_objects[obj_key] = res
179187
return res
180188
else:
181-
# _logger.debug(
182-
# f"Object[start_globals] {fname} of type {type(obj)} is noft authorized (type), dropping path {obj_path}"
183-
# )
189+
if debug:
190+
_logger.debug(
191+
f"Object[start_globals] {fname} of type {type(obj)} is noft authorized (type), dropping path {obj_path}"
192+
)
184193
res = ExternalObject(obj_path)
185194
gctx.cached_objects[obj_key] = res
186195
return res
187196
else:
188-
# _logger.debug(f"{fname} not found in start_globals")
197+
if debug:
198+
_logger.debug(f"{fname} not found in start_globals")
189199
gctx.cached_objects[obj_key] = None
190200
return None
191201
res = cls._retrieve_object_rec(sub_path, loaded_mod, gctx)

dds/fun_args.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def check_len(x: Any) -> None:
8181
f"Object of type {type(x)} is a sequence of length {len(x)}. "
8282
f"Only sequences of length less than {max_sequence_size} are supported. "
8383
"This behaviour can be adjusted with the 'hash.max_sequence_size' option."
84-
f" Path hint: <{current_path()}>"
84+
f" Path hint: <{current_path()}>",
85+
DDSErrorCode.SEQUENCE_TOO_LONG,
8586
)
8687

8788
def _dds_hash(elt: Any, path_item: Union[int, str, None]) -> PyHash:
@@ -92,6 +93,14 @@ def _dds_hash(elt: Any, path_item: Union[int, str, None]) -> PyHash:
9293
trace.pop()
9394
return res
9495

96+
def _hash_dict_tuple(k: Any, v: Any) -> str:
97+
# Supposing for now that any dictionary key is well-behaved with respect to being converted to a string.
98+
if isinstance(k, str):
99+
n = k
100+
else:
101+
n = str(k)
102+
return _dds_hash(k, None) + "|" + _dds_hash(v, n)
103+
95104
def _dds_hash0(elt: Any) -> PyHash:
96105
if elt is None:
97106
# TODO: this is not robust to adversarial changes.
@@ -126,16 +135,16 @@ def _dds_hash0(elt: Any) -> PyHash:
126135
# majority of python interpreters out there).
127136
# Not going to check for obscure corner cases for now.
128137
check_len(elt)
129-
return _dds_hash(
130-
[name + "|" + _dds_hash(v, name) for (name, v) in elt.items()], None
131-
)
138+
return _dds_hash([_hash_dict_tuple(k, v) for (k, v) in elt.items()], None)
132139
if dataclasses.is_dataclass(elt):
133140
names: List[str] = [f.name for f in dataclasses.fields(elt)]
134141
# TODO: this is not entirely accurate. The error message will show a 'list' type, but it is actually
135142
# a dataclass.
136143
check_len(names)
137144
vals = [_dds_hash(getattr(elt, n), n) for n in names]
138-
return _dds_hash([name + "|" + h for (name, h) in zip(names, vals)], None)
145+
return _dds_hash(
146+
[_hash_dict_tuple(name, h) for (name, h) in zip(names, vals)], None
147+
)
139148
if isinstance(
140149
elt,
141150
(

dds/introspect.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -436,25 +436,29 @@ def __init__(
436436

437437
def visit_Name(self, node: ast.Name, debug: bool = False) -> Any:
438438
local_dep_path = LocalDepPath(PurePosixPath(node.id))
439-
# _logger.debug(
440-
# "ExternalVarsVisitor:visit_Name: id: %s local_dep_path:%s",
441-
# node.id,
442-
# local_dep_path,
443-
# )
439+
if debug:
440+
_logger.debug(
441+
"ExternalVarsVisitor:visit_Name: id: %s local_dep_path:%s",
442+
node.id,
443+
local_dep_path,
444+
)
444445
if not isinstance(node.ctx, ast.Load):
445-
# _logger.debug(
446-
# "ExternalVarsVisitor:visit_Name: id: %s skipping ctx: %s",
447-
# node.id,
448-
# node.ctx,
449-
# )
446+
if debug:
447+
_logger.debug(
448+
"ExternalVarsVisitor:visit_Name: id: %s skipping ctx: %s",
449+
node.id,
450+
node.ctx,
451+
)
450452
return
451453
# If it is a var that is already part of the function, do not introspect
452454
if len(local_dep_path.parts) == 1:
453455
v = str(local_dep_path)
454456
if v in self._local_vars:
455-
# _logger.debug(
456-
# "ExternalVarsVisitor:visit_Name: id: %s skipping, in vars", node.id
457-
# )
457+
if debug:
458+
_logger.debug(
459+
"ExternalVarsVisitor:visit_Name: id: %s skipping, in vars",
460+
node.id,
461+
)
458462
return
459463
if local_dep_path in self.vars or local_dep_path in self._rejected_paths:
460464
return
@@ -658,7 +662,7 @@ def inspect_fun(
658662
vdeps.visit(n)
659663
ext_deps = sorted(vdeps.vars.values(), key=lambda ed: ed.local_path)
660664
if debug:
661-
_logger.debug(f"inspect_fun: ext_deps: %s", ext_deps)
665+
_logger.debug("inspect_fun: ext_deps: %s", ext_deps)
662666

663667
# The variables that are hashable: authorized variables outside of the function
664668
sig_variables: List[Tuple[LocalDepPath, PyHash]] = [

dds/structures.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,19 @@ class DDSErrorCode(IntEnum):
5858
ARG_IN_DATA_FUNCTION = 15
5959
OVERLAPPING_PATH = 16
6060
UNKNOWN_OPTION = 17
61+
SEQUENCE_TOO_LONG = 18
6162

6263

6364
class DDSException(BaseException):
6465
"""
6566
The base exception for all the exceptions generated in DDS.
6667
"""
6768

69+
error_code: Optional[DDSErrorCode]
70+
6871
def __init__(self, message: str, error_code: Optional[DDSErrorCode] = None):
6972
super(DDSException, self).__init__(message)
70-
self.error_code: Optional[DDSErrorCode] = error_code
73+
self.error_code = error_code
7174

7275

7376
class EvalContext(NamedTuple):

dds_tests/test_refs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def fun_9():
219219

220220

221221
@pytest.mark.usefixtures("cleandir")
222-
def test_9():
222+
def test_gh140_biglists():
223223
""" Using big objects throws an error """
224224
# TODO: more comprehensive test on lists. They are still seen as external dependencies
225-
assert dds.eval(fun_9) == test_9_len
225+
with pytest.raises(dds.DDSException) as e:
226+
dds.eval(fun_9)
227+
assert e.value.error_code == dds.structures.DDSErrorCode.SEQUENCE_TOO_LONG

dds_tests/test_structures.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import dds
2+
import pytest
3+
from .utils import cleandir, Counter
4+
5+
_ = cleandir
6+
7+
_c = Counter()
8+
9+
l = [1]
10+
11+
12+
@dds.data_function("/p")
13+
def f1():
14+
_c.increment()
15+
return len(l)
16+
17+
18+
@pytest.mark.usefixtures("cleandir")
19+
def test_gh140_list():
20+
global l
21+
_c.reset()
22+
f1()
23+
assert _c.value == 1
24+
l[0] = 2
25+
f1()
26+
assert _c.value == 2
27+
l[0] = 1
28+
f1()
29+
assert _c.value == 2
30+
31+
32+
d = {0: 1}
33+
34+
35+
@dds.data_function("/p")
36+
def f2():
37+
_c.increment()
38+
return len(d)
39+
40+
41+
@pytest.mark.usefixtures("cleandir")
42+
def test_gh140_dict():
43+
_c.reset()
44+
f2()
45+
assert _c.value == 1
46+
d[0] = 2
47+
f2()
48+
assert _c.value == 2
49+
d[0] = 1
50+
f2()
51+
assert _c.value == 2

dds_tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(self):
6464
def increment(self):
6565
self.value += 1
6666

67+
def reset(self):
68+
self.value = 0
69+
6770

6871
def unreachable():
6972
# Will trigger a failure in the parsing

0 commit comments

Comments
 (0)