Skip to content

Commit 7ec580d

Browse files
authored
Merge pull request #17 from caspervdw/fix-inplace-pipeline
API/FIX: Pipelines that act inplace - remove proc_func kwarg and attribute from `Slicerator`
2 parents 837e7d5 + 3dfbbed commit 7ec580d

File tree

2 files changed

+133
-26
lines changed

2 files changed

+133
-26
lines changed

slicerator.py

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import collections
66
import itertools
77
from functools import wraps
8+
from copy import copy
89

910

1011
# set version string using versioneer
@@ -24,7 +25,7 @@ def _iter_attr(obj):
2425

2526
class Slicerator(object):
2627
def __init__(self, ancestor, indices=None, length=None,
27-
propagate_attrs=None, proc_func=None):
28+
propagate_attrs=None):
2829
"""A generator that supports fancy indexing
2930
3031
When sliced using any iterable with a known length, it returns another
@@ -52,10 +53,6 @@ def __init__(self, ancestor, indices=None, length=None,
5253
that is, if `len(indices)` is invalid
5354
propagate_attrs : list of str, optional
5455
list of attributes to be propagated into Slicerator
55-
proc_func : function
56-
function that processes data returned by Slicerator. The function
57-
acts element-wise and is only evaluated when data is actually
58-
returned
5956
6057
Examples
6158
--------
@@ -111,10 +108,6 @@ def __init__(self, ancestor, indices=None, length=None,
111108
self._len = length
112109
self._ancestor = ancestor
113110
self._indices = indices
114-
if proc_func is None:
115-
self._proc_func = lambda x: x
116-
else:
117-
self._proc_func = proc_func
118111

119112
@classmethod
120113
def from_func(cls, func, length, propagate_attrs=None):
@@ -210,8 +203,7 @@ def indices(self):
210203
return indices
211204

212205
def _get(self, key):
213-
"Wrap ancestor's method in a processing function."
214-
return self._proc_func(self._ancestor[key])
206+
return self._ancestor[key]
215207

216208
def _map_index(self, key):
217209
if key < -self._len or key >= self._len:
@@ -225,7 +217,7 @@ def _map_index(self, key):
225217
return abs_key
226218

227219
def __repr__(self):
228-
msg = "Sliced and/or processed {0}. Original repr:\n".format(
220+
msg = "Sliced {0}. Original repr:\n".format(
229221
type(self._ancestor).__name__)
230222
old = '\n'.join(" " + ln for ln in repr(self._ancestor).split('\n'))
231223
return msg + old
@@ -247,7 +239,7 @@ def __getitem__(self, key):
247239
return (self[k] for k in rel_indices)
248240
indices = _index_generator(rel_indices, self.indices)
249241
return Slicerator(self._ancestor, indices, new_length,
250-
self._propagate_attrs, self._proc_func)
242+
self._propagate_attrs)
251243

252244
def __getattr__(self, name):
253245
# to avoid infinite recursion, always check if public field is there
@@ -263,12 +255,12 @@ def __getattr__(self, name):
263255
raise AttributeError
264256

265257
def __getstate__(self):
266-
# When serializing, return a list of the sliced and processed data
258+
# When serializing, return a list of the sliced data
267259
# Any exposed attrs are lost.
268-
return [self._get(key) for key in self.indices]
260+
return list(self)
269261

270262
def __setstate__(self, data_as_list):
271-
# When deserializing, restore the Slicerator
263+
# When deserializing, restore a Slicerator instance
272264
return self.__init__(data_as_list)
273265

274266

@@ -325,6 +317,10 @@ def key_to_indices(key, length):
325317
# allow negative indexing
326318
if -length < key < 0:
327319
return length + key, None
320+
elif 0 <= key < length:
321+
return key, None
322+
else:
323+
raise IndexError('index out of range')
328324

329325
# in all other case, just return the key and let user deal with the type.
330326
return key, None
@@ -364,18 +360,114 @@ def _index_generator(new_indices, old_indices):
364360
continue
365361

366362

363+
class Pipeline(object):
364+
def __init__(self, ancestor, proc_func, propagate_attrs=None):
365+
"""A class to support lazy function evaluation on an iterable.
366+
367+
When a ``Pipeline`` object is indexed, it returns an element of its
368+
ancestor modified with a process function.
369+
370+
Parameters
371+
----------
372+
ancestor : object
373+
proc_func : function
374+
function that processes data returned by Slicerator. The function
375+
acts element-wise and is only evaluated when data is actually
376+
returned
377+
378+
Example
379+
-------
380+
Construct the pipeline object that multiplies elements by two:
381+
>>> ancestor = [0, 1, 2, 3, 4]
382+
>>> times_two = Pipeline(ancestor, lambda x: 2*x)
383+
384+
Whenever the pipeline object is indexed, it takes the correct element
385+
from its ancestor, and then applies the process function.
386+
>>> times_two[3] # returns 6
387+
388+
See also
389+
--------
390+
pipeline
391+
"""
392+
# when list of propagated attributes are given explicitly,
393+
# take this list and ignore the class definition
394+
if propagate_attrs is not None:
395+
self._propagate_attrs = propagate_attrs
396+
else:
397+
# check propagated_attrs field from the ancestor definition
398+
self._propagate_attrs = []
399+
if hasattr(ancestor, '_propagate_attrs'):
400+
self._propagate_attrs += ancestor._propagate_attrs
401+
if hasattr(ancestor, 'propagate_attrs'):
402+
self._propagate_attrs += ancestor.propagate_attrs
403+
404+
# add methods having the _propagate flag
405+
for attr in _iter_attr(ancestor):
406+
if hasattr(attr, '_propagate_flag'):
407+
self._propagate_attrs.append(attr.__name__)
408+
409+
self._ancestor = ancestor
410+
self._proc_func = proc_func
411+
412+
def _get(self, key):
413+
# We need to copy here: else any _proc_func that acts inplace would
414+
# change the ancestor value.
415+
return self._proc_func(copy(self._ancestor[key]))
416+
417+
def __repr__(self):
418+
msg = "{0} processed through {1}. Original repr:\n".format(
419+
type(self._ancestor).__name__, self._proc_func.__name__)
420+
old = '\n'.join(" " + ln for ln in repr(self._ancestor).split('\n'))
421+
return msg + old
422+
423+
def __len__(self):
424+
return self._ancestor.__len__()
425+
426+
def __iter__(self):
427+
return (self._get(i) for i in range(len(self)))
428+
429+
def __getitem__(self, i):
430+
"""for data access"""
431+
indices, new_length = key_to_indices(i, len(self))
432+
if new_length is None:
433+
return self._get(indices)
434+
else:
435+
return Slicerator(self, indices, new_length, self._propagate_attrs)
436+
437+
def __getattr__(self, name):
438+
# to avoid infinite recursion, always check if public field is there
439+
if '_propagate_attrs' not in self.__dict__:
440+
self._propagate_attrs = []
441+
if name in self._propagate_attrs:
442+
return getattr(self._ancestor, name)
443+
raise AttributeError
444+
445+
def __getstate__(self):
446+
# When serializing, return a list of the processed data
447+
# Any exposed attrs are lost.
448+
return list(self)
449+
450+
def __setstate__(self, data_as_list):
451+
# When deserializing, restore the Pipeline
452+
return self.__init__(data_as_list, lambda x: x)
453+
454+
367455
def pipeline(func):
368-
"""Decorator to make function aware of Slicerator objects.
456+
"""Decorator to enable lazy evaluation of a function.
369457
370-
When the function is applied to a Slicerator, it
371-
returns another lazily-evaluated, Slicerator object.
458+
When the function is applied to a Slicerator or Pipeline object, it
459+
returns another lazily-evaluated, Pipeline object.
372460
373461
When the function is applied to any other object, it falls back on its
374-
normal behavhior.
462+
normal behavior.
375463
376464
Returns
377465
-------
378-
processed_images : Slicerator
466+
processed_images : Pipeline
467+
468+
See also
469+
--------
470+
Pipeline
379471
380472
Example
381473
-------
@@ -385,7 +477,7 @@ def pipeline(func):
385477
... return image[channel, :, :]
386478
...
387479
388-
Passing a Slicerator the function returns another Slicerator
480+
Passing a Slicerator the function returns a Pipeline
389481
that "lazily" applies the function when the images come out. Different
390482
functions can be applied to the same underlying images, creating
391483
independent objects.
@@ -406,10 +498,11 @@ def pipeline(func):
406498
"""
407499
@wraps(func)
408500
def process(obj, *args, **kwargs):
409-
if hasattr(obj, '_slicerator_flag') or isinstance(obj, Slicerator):
410-
def f(x):
501+
if hasattr(obj, '_slicerator_flag') or isinstance(obj, Slicerator) \
502+
or isinstance(obj, Pipeline):
503+
def proc_func(x):
411504
return func(x, *args, **kwargs)
412-
return Slicerator(obj, proc_func=f)
505+
return Pipeline(obj, proc_func)
413506
else:
414507
# Fall back on normal behavior of func, interpreting input
415508
# as a single image.
@@ -419,9 +512,10 @@ def f(x):
419512
process.__doc__ = ''
420513
process.__doc__ = ("This function has been made lazy. When passed\n"
421514
"a Slicerator, it will return a \n"
422-
"new Slicerator of the results. When passed \n"
515+
"Pipeline of the results. When passed \n"
423516
"any other objects, its behavior is "
424517
"unchanged.\n\n") + process.__doc__
518+
process.__name__ = func.__name__
425519
return process
426520

427521

tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ def _a_to_z(letter):
163163
else:
164164
return letter
165165

166+
@pipeline
167+
def append_zero_inplace(list_obj):
168+
list_obj.append(0)
169+
return list_obj
170+
171+
172+
def test_inplace_pipeline():
173+
n_mutable = Slicerator([list([i]) for i in range(10)])
174+
appended = append_zero_inplace(n_mutable)
175+
176+
assert_equal(appended[5], [5, 0]) # execute the function
177+
assert_equal(n_mutable[5], [5]) # check the original
178+
166179

167180
def test_pipeline_simple():
168181
capitalize = pipeline(_capitalize)

0 commit comments

Comments
 (0)