Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,9 +2566,15 @@ async def _scatter(
unpack = False
if isinstance(data, Iterator):
data = list(data)
if isinstance(data, (set, frozenset)):
if type(data) in (set, frozenset):
data = list(data)
if not isinstance(data, (dict, list, tuple, set, frozenset)):
if type(data) not in (dict, list, tuple, set, frozenset):
# Note: exact-type checks (not isinstance) so that subclasses of
# builtin collections (e.g. a namedtuple, or scikit-learn's Bunch)
# are scattered as a single opaque value rather than being unpacked
# into their items. This preserves their exact type on the worker;
# an isinstance check would silently downgrade a dict subclass to a
# plain dict (and similarly for list/set/tuple subclasses).
unpack = True
data = [data]
if isinstance(data, (list, tuple)):
Expand Down Expand Up @@ -2640,7 +2646,7 @@ async def _scatter(
n = None if broadcast is True else broadcast
await self._replicate(list(out.values()), workers=workers, n=n)

if issubclass(input_type, (list, tuple, set, frozenset)):
if input_type in (list, tuple, set, frozenset):
out = input_type(out[k] for k in names)

if unpack:
Expand Down
41 changes: 41 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,47 @@ async def test_scatter_types(c, s, a, b):
s.validate_state()


@gen_cluster(client=True)
async def test_scatter_collection_subclass(c, s, a, b):
# Subclasses of builtin collections must be scattered as a single opaque
# value (one Future) with their exact type preserved on the worker, rather
# than being unpacked into their items like the exact builtin collections
# are. Otherwise a dict subclass would silently arrive as a plain dict.
# See https://github.com/scikit-learn/scikit-learn/issues/34005
class Bunch(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)

class MyList(list):
pass

class MySet(set):
pass

Point = namedtuple("Point", ["x", "y"])

for obj in [
Bunch(a=1, b=2),
MyList([1, 2, 3]),
MySet({1, 2, 3}),
Point(1, 2),
]:
future = await c.scatter(obj)
assert isinstance(future, Future)
result = await future
assert type(result) is type(obj)
assert result == obj
s.validate_state()

# Attribute access (the scikit-learn metadata-routing failure mode) keeps
# working after a round-trip through a worker.
future = await c.scatter(Bunch(transform=10))
assert (await c.submit(lambda b: b.transform, future)) == 10


@gen_cluster(client=True)
async def test_scatter_non_list(c, s, a, b):
x = await c.scatter(1)
Expand Down
Loading