Skip to content

Commit 9f6cdb3

Browse files
committed
FIX keep inherited from builtin containers' types
1 parent db3cb43 commit 9f6cdb3

2 files changed

Lines changed: 50 additions & 3 deletions

File tree

distributed/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,9 +2566,15 @@ async def _scatter(
25662566
unpack = False
25672567
if isinstance(data, Iterator):
25682568
data = list(data)
2569-
if isinstance(data, (set, frozenset)):
2569+
if type(data) in (set, frozenset):
25702570
data = list(data)
2571-
if not isinstance(data, (dict, list, tuple, set, frozenset)):
2571+
if type(data) not in (dict, list, tuple, set, frozenset):
2572+
# Note: exact-type checks (not isinstance) so that subclasses of
2573+
# builtin collections (e.g. a namedtuple, or scikit-learn's Bunch)
2574+
# are scattered as a single opaque value rather than being unpacked
2575+
# into their items. This preserves their exact type on the worker;
2576+
# an isinstance check would silently downgrade a dict subclass to a
2577+
# plain dict (and similarly for list/set/tuple subclasses).
25722578
unpack = True
25732579
data = [data]
25742580
if isinstance(data, (list, tuple)):
@@ -2640,7 +2646,7 @@ async def _scatter(
26402646
n = None if broadcast is True else broadcast
26412647
await self._replicate(list(out.values()), workers=workers, n=n)
26422648

2643-
if issubclass(input_type, (list, tuple, set, frozenset)):
2649+
if input_type in (list, tuple, set, frozenset):
26442650
out = input_type(out[k] for k in names)
26452651

26462652
if unpack:

distributed/tests/test_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,47 @@ async def test_scatter_types(c, s, a, b):
11241124
s.validate_state()
11251125

11261126

1127+
@gen_cluster(client=True)
1128+
async def test_scatter_collection_subclass(c, s, a, b):
1129+
# Subclasses of builtin collections must be scattered as a single opaque
1130+
# value (one Future) with their exact type preserved on the worker, rather
1131+
# than being unpacked into their items like the exact builtin collections
1132+
# are. Otherwise a dict subclass would silently arrive as a plain dict.
1133+
# See https://github.com/scikit-learn/scikit-learn/issues/34005
1134+
class Bunch(dict):
1135+
def __getattr__(self, key):
1136+
try:
1137+
return self[key]
1138+
except KeyError:
1139+
raise AttributeError(key)
1140+
1141+
class MyList(list):
1142+
pass
1143+
1144+
class MySet(set):
1145+
pass
1146+
1147+
Point = namedtuple("Point", ["x", "y"])
1148+
1149+
for obj in [
1150+
Bunch(a=1, b=2),
1151+
MyList([1, 2, 3]),
1152+
MySet({1, 2, 3}),
1153+
Point(1, 2),
1154+
]:
1155+
future = await c.scatter(obj)
1156+
assert isinstance(future, Future)
1157+
result = await future
1158+
assert type(result) is type(obj)
1159+
assert result == obj
1160+
s.validate_state()
1161+
1162+
# Attribute access (the scikit-learn metadata-routing failure mode) keeps
1163+
# working after a round-trip through a worker.
1164+
future = await c.scatter(Bunch(transform=10))
1165+
assert (await c.submit(lambda b: b.transform, future)) == 10
1166+
1167+
11271168
@gen_cluster(client=True)
11281169
async def test_scatter_non_list(c, s, a, b):
11291170
x = await c.scatter(1)

0 commit comments

Comments
 (0)