Skip to content

Commit d5afd12

Browse files
add key type checking to Data get methods
Check that requested key is of appropriate arity for RDD keys, based on a call to rdd.first() to get example key.
1 parent 2639102 commit d5afd12

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

python/test/test_data.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def test_getMissing(self):
2121
def test_get(self):
2222
assert_true(array_equal(self.ary2, self.images.get(1)))
2323

24+
# keys are integers, ask for sequence
25+
assert_raises(ValueError, self.images.get, (1, 2))
26+
2427
def test_getMany(self):
2528
vals = self.images.getMany([0, -1, 1, 0])
2629
assert_equals(4, len(vals))
@@ -29,6 +32,10 @@ def test_getMany(self):
2932
assert_true(array_equal(self.ary2, vals[2]))
3033
assert_true(array_equal(self.ary1, vals[3]))
3134

35+
# keys are integers, ask for sequences:
36+
assert_raises(ValueError, self.images.get, [(0, 0)])
37+
assert_raises(ValueError, self.images.get, [0, (0, 0), 1, 0])
38+
3239
def test_getRanges(self):
3340
vals = self.images.getRange(slice(None))
3441
assert_equals(2, len(vals))
@@ -55,6 +62,9 @@ def test_getRanges(self):
5562
vals = self.images.getRange(slice(2, 3))
5663
assert_equals(0, len(vals))
5764

65+
# keys are integers, ask for sequence
66+
assert_raises(ValueError, self.images.getRange, [slice(1), slice(1)])
67+
5868
# raise exception if 'step' specified:
5969
assert_raises(ValueError, self.images.getRange, slice(1, 2, 2))
6070

@@ -108,12 +118,15 @@ def setUp(self):
108118
self.series = Data(self.sc.parallelize(self.dataLocal), dtype='float32')
109119

110120
def test_getMissing(self):
111-
assert_is_none(self.series.get(-1))
121+
assert_is_none(self.series.get((-1, -1)))
112122

113123
def test_get(self):
114124
expected = self.dataLocal[1][1]
115125
assert_true(array_equal(expected, self.series.get((0, 1))))
116126

127+
assert_raises(ValueError, self.series.get, 1) # keys are sequences, ask for integer
128+
assert_raises(ValueError, self.series.get, (1, 2, 3)) # key length mismatch
129+
117130
def test_getMany(self):
118131
vals = self.series.getMany([(0, 0), (17, 256), (1, 0), (0, 0)])
119132
assert_equals(4, len(vals))
@@ -122,6 +135,9 @@ def test_getMany(self):
122135
assert_true(array_equal(self.dataLocal[2][1], vals[2]))
123136
assert_true(array_equal(self.dataLocal[0][1], vals[3]))
124137

138+
assert_raises(ValueError, self.series.getMany, [1]) # keys are sequences, ask for integer
139+
assert_raises(ValueError, self.series.getMany, [(0, 0), 1, (1, 0), (0, 0)]) # asking for integer again
140+
125141
def test_getRanges(self):
126142
vals = self.series.getRange([slice(2), slice(2)])
127143
assert_equals(4, len(vals))
@@ -174,6 +190,12 @@ def test_getRanges(self):
174190
vals = self.series.getRange([slice(2, 3), slice(None)])
175191
assert_equals(0, len(vals))
176192

193+
# keys are sequences, ask for single slice
194+
assert_raises(ValueError, self.series.getRange, slice(2, 3))
195+
196+
# ask for wrong number of slices
197+
assert_raises(ValueError, self.series.getRange, [slice(2, 3), slice(2, 3), slice(2, 3)])
198+
177199
# raise exception if 'step' specified:
178200
assert_raises(ValueError, self.series.getRange, [slice(0, 4, 2), slice(2, 3)])
179201

python/thunder/rdds/data.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,32 @@ def take(self, *args, **kwargs):
111111
"""
112112
return self.rdd.take(*args, **kwargs)
113113

114+
@staticmethod
115+
def __getKeyTypeCheck(actualKey, keySpec):
116+
if hasattr(actualKey, "__iter__"):
117+
try:
118+
specLen = len(keySpec) if hasattr(keySpec, "__len__") else \
119+
reduce(lambda x, y: x + y, [1 for item in keySpec], initial=0)
120+
if specLen != len(actualKey):
121+
raise ValueError("Length of key specifier '%s' does not match length of first key '%s'" %
122+
(str(keySpec), str(actualKey)))
123+
except TypeError:
124+
raise ValueError("Key specifier '%s' appears not to be a sequence type, but actual keys are " %
125+
str(keySpec) + "sequences (first key: '%s')" % str(actualKey))
126+
else:
127+
if hasattr(keySpec, "__iter__"):
128+
raise ValueError("Key specifier '%s' appears to be a sequence type, " % str(keySpec) +
129+
"but actual keys are not (first key: '%s')" % str(actualKey))
130+
114131
def get(self, key):
115132
"""Returns a single value matching the passed key, or None if no matching keys found
116133
117134
If multiple records are found with keys matching the passed key, a sequence of all matching
118135
values will be returned. (This is not expected as a normal occurance, but could happen with
119136
some user-created rdds.)
120137
"""
138+
firstKey = self.first()[0]
139+
Data.__getKeyTypeCheck(firstKey, key)
121140
filteredVals = self.rdd.filter(lambda (k, v): k == key).values().collect()
122141
if len(filteredVals) == 1:
123142
return filteredVals[0]
@@ -135,6 +154,9 @@ def getMany(self, keys):
135154
If multiple values are found, the corresponding sequence element will be a sequence containing all
136155
matching values.
137156
"""
157+
firstKey = self.first()[0]
158+
for key in keys:
159+
Data.__getKeyTypeCheck(firstKey, key)
138160
keySet = frozenset(keys)
139161
filteredRecs = self.rdd.filter(lambda (k, _): k in keySet).collect()
140162
sortingDict = {}
@@ -197,7 +219,9 @@ def multiSlicesPredicate(kv):
197219
return False
198220
return True
199221

200-
if not hasattr(sliceOrSlices, '__len__'):
222+
firstKey = self.first()[0]
223+
Data.__getKeyTypeCheck(firstKey, sliceOrSlices)
224+
if not hasattr(sliceOrSlices, '__iter__'):
201225
# make my func the pFunc; http://en.wikipedia.org/wiki/P._Funk_%28Wants_to_Get_Funked_Up%29
202226
pFunc = singleSlicePredicate
203227
if hasattr(sliceOrSlices, 'step') and sliceOrSlices.step is not None:

0 commit comments

Comments
 (0)