Skip to content

Commit 1d39057

Browse files
committed
Merge pull request #98 from industrial-sloth/data_getter
Data getter
2 parents 94c3978 + 8999872 commit 1d39057

File tree

4 files changed

+401
-3
lines changed

4 files changed

+401
-3
lines changed

python/test/test_data.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true
2+
from numpy import array, array_equal
3+
4+
from thunder.rdds.data import Data
5+
from test_utils import PySparkTestCase
6+
7+
8+
class TestImagesGetters(PySparkTestCase):
9+
"""Test `get` and related methods on an Images-like Data object
10+
"""
11+
def setUp(self):
12+
super(TestImagesGetters, self).setUp()
13+
self.ary1 = array([[1, 2], [3, 4]], dtype='int16')
14+
self.ary2 = array([[5, 6], [7, 8]], dtype='int16')
15+
rdd = self.sc.parallelize([(0, self.ary1), (1, self.ary2)])
16+
self.images = Data(rdd, dtype='int16')
17+
18+
def test_getMissing(self):
19+
assert_is_none(self.images.get(-1))
20+
21+
def test_get(self):
22+
assert_true(array_equal(self.ary2, self.images.get(1)))
23+
24+
# keys are integers, ask for sequence
25+
assert_raises(ValueError, self.images.get, (1, 2))
26+
27+
def test_getMany(self):
28+
vals = self.images.getMany([0, -1, 1, 0])
29+
assert_equals(4, len(vals))
30+
assert_true(array_equal(self.ary1, vals[0]))
31+
assert_is_none(vals[1])
32+
assert_true(array_equal(self.ary2, vals[2]))
33+
assert_true(array_equal(self.ary1, vals[3]))
34+
35+
# keys are integers, ask for sequences:
36+
assert_raises(ValueError, self.images.get, [(0, 0)])
37+
assert_raises(ValueError, self.images.get, [0, (0, 0), 1, 0])
38+
39+
def test_getRanges(self):
40+
vals = self.images.getRange(slice(None))
41+
assert_equals(2, len(vals))
42+
assert_equals(0, vals[0][0])
43+
assert_equals(1, vals[1][0])
44+
assert_true(array_equal(self.ary1, vals[0][1]))
45+
assert_true(array_equal(self.ary2, vals[1][1]))
46+
47+
vals = self.images.getRange(slice(0, 1))
48+
assert_equals(1, len(vals))
49+
assert_equals(0, vals[0][0])
50+
assert_true(array_equal(self.ary1, vals[0][1]))
51+
52+
vals = self.images.getRange(slice(1))
53+
assert_equals(1, len(vals))
54+
assert_equals(0, vals[0][0])
55+
assert_true(array_equal(self.ary1, vals[0][1]))
56+
57+
vals = self.images.getRange(slice(1, 2))
58+
assert_equals(1, len(vals))
59+
assert_equals(1, vals[0][0])
60+
assert_true(array_equal(self.ary2, vals[0][1]))
61+
62+
vals = self.images.getRange(slice(2, 3))
63+
assert_equals(0, len(vals))
64+
65+
# keys are integers, ask for sequence
66+
assert_raises(ValueError, self.images.getRange, [slice(1), slice(1)])
67+
68+
# raise exception if 'step' specified:
69+
assert_raises(ValueError, self.images.getRange, slice(1, 2, 2))
70+
71+
def test_brackets(self):
72+
vals = self.images[1]
73+
assert_true(array_equal(self.ary2, vals))
74+
75+
vals = self.images[0:1]
76+
assert_equals(1, len(vals))
77+
assert_equals(0, vals[0][0])
78+
assert_true(array_equal(self.ary1, vals[0][1]))
79+
80+
vals = self.images[:]
81+
assert_equals(2, len(vals))
82+
assert_equals(0, vals[0][0])
83+
assert_equals(1, vals[1][0])
84+
assert_true(array_equal(self.ary1, vals[0][1]))
85+
assert_true(array_equal(self.ary2, vals[1][1]))
86+
87+
vals = self.images[1:4]
88+
assert_equals(1, len(vals))
89+
assert_equals(1, vals[0][0])
90+
assert_true(array_equal(self.ary2, vals[0][1]))
91+
92+
vals = self.images[1:]
93+
assert_equals(1, len(vals))
94+
assert_equals(1, vals[0][0])
95+
assert_true(array_equal(self.ary2, vals[0][1]))
96+
97+
vals = self.images[:1]
98+
assert_equals(1, len(vals))
99+
assert_equals(0, vals[0][0])
100+
assert_true(array_equal(self.ary1, vals[0][1]))
101+
102+
assert_is_none(self.images[2])
103+
104+
assert_equals([], self.images[2:3])
105+
106+
107+
class TestSeriesGetters(PySparkTestCase):
108+
"""Test `get` and related methods on a Series-like Data object
109+
"""
110+
def setUp(self):
111+
super(TestSeriesGetters, self).setUp()
112+
self.dataLocal = [
113+
((0, 0), array([1.0, 2.0, 3.0], dtype='float32')),
114+
((0, 1), array([2.0, 2.0, 4.0], dtype='float32')),
115+
((1, 0), array([4.0, 2.0, 1.0], dtype='float32')),
116+
((1, 1), array([3.0, 1.0, 1.0], dtype='float32'))
117+
]
118+
self.series = Data(self.sc.parallelize(self.dataLocal), dtype='float32')
119+
120+
def test_getMissing(self):
121+
assert_is_none(self.series.get((-1, -1)))
122+
123+
def test_get(self):
124+
expected = self.dataLocal[1][1]
125+
assert_true(array_equal(expected, self.series.get((0, 1))))
126+
127+
assert_raises(ValueError, self.series.get, 1) # keys are sequences, ask for integer
128+
assert_raises(ValueError, self.series.get, (1, 2, 3)) # key length mismatch
129+
130+
def test_getMany(self):
131+
vals = self.series.getMany([(0, 0), (17, 256), (1, 0), (0, 0)])
132+
assert_equals(4, len(vals))
133+
assert_true(array_equal(self.dataLocal[0][1], vals[0]))
134+
assert_is_none(vals[1])
135+
assert_true(array_equal(self.dataLocal[2][1], vals[2]))
136+
assert_true(array_equal(self.dataLocal[0][1], vals[3]))
137+
138+
assert_raises(ValueError, self.series.getMany, [1]) # keys are sequences, ask for integer
139+
assert_raises(ValueError, self.series.getMany, [(0, 0), 1, (1, 0), (0, 0)]) # asking for integer again
140+
141+
def test_getRanges(self):
142+
vals = self.series.getRange([slice(2), slice(2)])
143+
assert_equals(4, len(vals))
144+
assert_equals(self.dataLocal[0][0], vals[0][0])
145+
assert_equals(self.dataLocal[1][0], vals[1][0])
146+
assert_equals(self.dataLocal[2][0], vals[2][0])
147+
assert_equals(self.dataLocal[3][0], vals[3][0])
148+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
149+
assert_true(array_equal(self.dataLocal[1][1], vals[1][1]))
150+
assert_true(array_equal(self.dataLocal[2][1], vals[2][1]))
151+
assert_true(array_equal(self.dataLocal[3][1], vals[3][1]))
152+
153+
vals = self.series.getRange([slice(2), slice(1)])
154+
assert_equals(2, len(vals))
155+
assert_equals(self.dataLocal[0][0], vals[0][0])
156+
assert_equals(self.dataLocal[2][0], vals[1][0])
157+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
158+
assert_true(array_equal(self.dataLocal[2][1], vals[1][1]))
159+
160+
vals = self.series.getRange([slice(None), slice(1, 2)])
161+
assert_equals(2, len(vals))
162+
assert_equals(self.dataLocal[1][0], vals[0][0])
163+
assert_equals(self.dataLocal[3][0], vals[1][0])
164+
assert_true(array_equal(self.dataLocal[1][1], vals[0][1]))
165+
assert_true(array_equal(self.dataLocal[3][1], vals[1][1]))
166+
167+
vals = self.series.getRange([slice(None), slice(None)])
168+
assert_equals(4, len(vals))
169+
assert_equals(self.dataLocal[0][0], vals[0][0])
170+
assert_equals(self.dataLocal[1][0], vals[1][0])
171+
assert_equals(self.dataLocal[2][0], vals[2][0])
172+
assert_equals(self.dataLocal[3][0], vals[3][0])
173+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
174+
assert_true(array_equal(self.dataLocal[1][1], vals[1][1]))
175+
assert_true(array_equal(self.dataLocal[2][1], vals[2][1]))
176+
assert_true(array_equal(self.dataLocal[3][1], vals[3][1]))
177+
178+
vals = self.series.getRange([0, slice(None)])
179+
assert_equals(2, len(vals))
180+
assert_equals(self.dataLocal[0][0], vals[0][0])
181+
assert_equals(self.dataLocal[1][0], vals[1][0])
182+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
183+
assert_true(array_equal(self.dataLocal[1][1], vals[1][1]))
184+
185+
vals = self.series.getRange([0, 1])
186+
assert_equals(1, len(vals))
187+
assert_equals(self.dataLocal[1][0], vals[0][0])
188+
assert_true(array_equal(self.dataLocal[1][1], vals[0][1]))
189+
190+
vals = self.series.getRange([slice(2, 3), slice(None)])
191+
assert_equals(0, len(vals))
192+
193+
# keys are sequences, ask for single slice
194+
assert_raises(ValueError, self.series.getRange, slice(2, 3))
195+
196+
# ask for wrong number of slices
197+
assert_raises(ValueError, self.series.getRange, [slice(2, 3), slice(2, 3), slice(2, 3)])
198+
199+
# raise exception if 'step' specified:
200+
assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)])
201+
202+
def test_brackets(self):
203+
# returns just value; calls `get`
204+
vals = self.series[(1, 0)]
205+
assert_true(array_equal(self.dataLocal[2][1], vals))
206+
207+
# tuple isn't needed; returns just value, calls `get`
208+
vals = self.series[0, 1]
209+
assert_true(array_equal(self.dataLocal[1][1], vals))
210+
211+
# if slices are passed, calls `getRange`, returns keys and values
212+
vals = self.series[0:1, 1:2]
213+
assert_equals(1, len(vals))
214+
assert_equals(self.dataLocal[1][0], vals[0][0])
215+
assert_true(array_equal(self.dataLocal[1][1], vals[0][1]))
216+
217+
# if slice extends out of bounds, return only the elements that are in bounds
218+
vals = self.series[:4, :1]
219+
assert_equals(2, len(vals))
220+
assert_equals(self.dataLocal[0][0], vals[0][0])
221+
assert_equals(self.dataLocal[2][0], vals[1][0])
222+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
223+
assert_true(array_equal(self.dataLocal[2][1], vals[1][1]))
224+
225+
# empty slice works
226+
vals = self.series[:, 1:2]
227+
assert_equals(2, len(vals))
228+
assert_equals(self.dataLocal[1][0], vals[0][0])
229+
assert_equals(self.dataLocal[3][0], vals[1][0])
230+
assert_true(array_equal(self.dataLocal[1][1], vals[0][1]))
231+
assert_true(array_equal(self.dataLocal[3][1], vals[1][1]))
232+
233+
# multiple empty slices work
234+
vals = self.series[:, :]
235+
assert_equals(4, len(vals))
236+
assert_equals(self.dataLocal[0][0], vals[0][0])
237+
assert_equals(self.dataLocal[1][0], vals[1][0])
238+
assert_equals(self.dataLocal[2][0], vals[2][0])
239+
assert_equals(self.dataLocal[3][0], vals[3][0])
240+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
241+
assert_true(array_equal(self.dataLocal[1][1], vals[1][1]))
242+
assert_true(array_equal(self.dataLocal[2][1], vals[2][1]))
243+
assert_true(array_equal(self.dataLocal[3][1], vals[3][1]))
244+
245+
# mixing slices and individual indicies works:
246+
vals = self.series[0, :]
247+
assert_equals(2, len(vals))
248+
assert_equals(self.dataLocal[0][0], vals[0][0])
249+
assert_equals(self.dataLocal[1][0], vals[1][0])
250+
assert_true(array_equal(self.dataLocal[0][1], vals[0][1]))
251+
assert_true(array_equal(self.dataLocal[1][1], vals[1][1]))
252+
253+
# trying to getitem a key that doesn't exist returns None
254+
assert_is_none(self.series[(25, 17)])
255+
256+
# passing a range that is completely out of bounds returns []
257+
assert_equals([], self.series[2:3, :])

python/test/test_images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numpy import allclose, arange, array, array_equal, prod, squeeze, zeros
55
from numpy import dtype as dtypeFunc
66
import itertools
7-
from nose.tools import assert_equals, assert_raises, assert_true
7+
from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true
88
import unittest
99

1010
from thunder.rdds.fileio.imagesloader import ImagesLoader

python/test/test_series.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from numpy import allclose, amax, arange, array, array_equal
22
from numpy import dtype as dtypeFunc
3-
from nose.tools import assert_equals, assert_true, assert_raises
3+
from nose.tools import assert_equals, assert_is_none, assert_raises, assert_true
44

55
from thunder.rdds.series import Series
66
from test_utils import *
@@ -305,4 +305,3 @@ def setIndex(data, idx):
305305

306306
assert_raises(ValueError, setIndex, data, 5)
307307
assert_raises(ValueError, setIndex, data, [1, 2])
308-

0 commit comments

Comments
 (0)