Skip to content

Commit f9d993c

Browse files
committed
Add Stream node constructor for sub-classing #442
1 parent 7453431 commit f9d993c

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

streamz/core.py

+20
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ def __str__(self):
119119

120120
class APIRegisterMixin(object):
121121

122+
def _new_node(self, cls, args, kwargs):
123+
""" Constructor for downstream nodes.
124+
125+
Examples
126+
--------
127+
To provide inheritance through nodes :
128+
129+
>>> class MyStream(Stream):
130+
>>>
131+
>>> def _new_node(self, cls, args, kwargs):
132+
>>> if not issubclass(cls, MyStream):
133+
>>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
134+
>>> return cls(*args, **kwargs)
135+
"""
136+
return cls(*args, **kwargs)
137+
122138
@classmethod
123139
def register_api(cls, modifier=identity, attribute_name=None):
124140
""" Add callable to Stream API
@@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None):
158174
def _(func):
159175
@functools.wraps(func)
160176
def wrapped(*args, **kwargs):
177+
if identity is not staticmethod and args:
178+
self = args[0]
179+
if isinstance(self, APIRegisterMixin):
180+
return self._new_node(func, args, kwargs)
161181
return func(*args, **kwargs)
162182
name = attribute_name if attribute_name else func.__name__
163183
setattr(cls, name, modifier(wrapped))

streamz/tests/test_core.py

+29
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,35 @@ class foo(NewStream):
13671367
assert not hasattr(Stream(), 'foo')
13681368

13691369

1370+
def test_subclass_node():
1371+
1372+
def add(x) : return x + 1
1373+
1374+
class MyStream(Stream):
1375+
def _new_node(self, cls, args, kwargs):
1376+
if not issubclass(cls, MyStream):
1377+
cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
1378+
return cls(*args, **kwargs)
1379+
1380+
@MyStream.register_api()
1381+
class foo(sz.sinks.sink):
1382+
pass
1383+
1384+
stream = MyStream()
1385+
lst = list()
1386+
1387+
node = stream.map(add)
1388+
assert isinstance(node, sz.core.map)
1389+
assert isinstance(node, MyStream)
1390+
1391+
node = node.foo(lst.append)
1392+
assert isinstance(node, sz.sinks.sink)
1393+
assert isinstance(node, MyStream)
1394+
1395+
stream.emit(100)
1396+
assert lst == [ 101 ]
1397+
1398+
13701399
@gen_test()
13711400
def test_latest():
13721401
source = Stream(asynchronous=True)

0 commit comments

Comments
 (0)