Skip to content

Commit 94c3978

Browse files
committed
Merge pull request #96 from broxtronix/serializable
Add Serializable decorator
2 parents 1913a7b + 3882a7c commit 94c3978

File tree

3 files changed

+407
-0
lines changed

3 files changed

+407
-0
lines changed

python/test/test_decorators.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import unittest
2+
from pyspark import SparkContext
3+
4+
class TestSerializableDecorator(unittest.TestCase):
5+
6+
def testSerializableDecorator(self):
7+
from thunder.utils.decorators import serializable
8+
from numpy import array, all
9+
from datetime import datetime
10+
11+
@serializable
12+
class Visitor(object):
13+
def __init__(self, ip_addr = None, agent = None, referrer = None):
14+
self.ip = ip_addr
15+
self.ua = agent
16+
self.referrer= referrer
17+
self.testDict = {'a': 10, 'b': "string", 'c': [1, 2, 3]}
18+
self.testVec = array([1,2,3])
19+
self.testArray = array([[1,2,3],[4,5,6.]])
20+
self.time = datetime.now()
21+
self.testComplex = complex(3,2)
22+
23+
def __str__(self):
24+
return str(self.ip) + " " + str(self.ua) + " " + str(self.referrer) + " " + str(self.time)
25+
26+
def test_method(self):
27+
return True
28+
29+
# Run the test. Build an object, serialize it, and recover it.
30+
31+
# Create a new object
32+
origVisitor = Visitor('192.168', 'UA-1', 'http://www.google.com')
33+
34+
# Serialize the object
35+
pickled_visitor = origVisitor.serialize(numpyStorage='ascii')
36+
37+
# Restore object
38+
recovVisitor = Visitor.deserialize(pickled_visitor)
39+
40+
# Check that the object was reconstructed successfully
41+
assert(origVisitor.ip == recovVisitor.ip)
42+
assert(origVisitor.ua == recovVisitor.ua)
43+
assert(origVisitor.referrer == recovVisitor.referrer)
44+
assert(origVisitor.testComplex == recovVisitor.testComplex)
45+
for key in origVisitor.testDict.keys():
46+
assert(origVisitor.testDict[key] == recovVisitor.testDict[key])
47+
48+
assert(all(origVisitor.testVec == recovVisitor.testVec))
49+
assert(all(origVisitor.testArray == recovVisitor.testArray))
50+
51+
def testSerializeWithSlots(self):
52+
'''
53+
Check to make sure that classes that use slots can be serialized / deserialized.
54+
'''
55+
56+
from thunder.utils.decorators import serializable
57+
58+
@serializable
59+
class Foo(object):
60+
__slots__ = ['bar']
61+
62+
foo = Foo()
63+
foo.bar = 'a'
64+
testJson = foo.serialize() # boom
65+
foo2 = Foo.deserialize(testJson)
66+
assert(foo.bar == foo2.bar)
67+
68+
def testNotSerializable(self):
69+
'''
70+
Unit test to make sure exceptions are thrown if the object contains an
71+
unserializable data type.
72+
'''
73+
74+
from thunder.utils.decorators import serializable
75+
from numpy import array, all
76+
from datetime import datetime
77+
78+
class SomeOtherClass(object):
79+
def __init__(self):
80+
someVariable = 3
81+
82+
@serializable
83+
class Visitor(object):
84+
def __init__(self):
85+
self.refrerenceToUnserializableClass = [ SomeOtherClass() ]
86+
87+
origVisitor = Visitor()
88+
89+
# Serialize the object
90+
try:
91+
pickled_visitor = origVisitor.serialize() # This should fail
92+
assert(False) # The @serializable wrapped class should have thrown an exception, but didn't!
93+
except(TypeError):
94+
pass # If the exception was thrown and caught, the test has passed
95+
96+
97+
98+
def testNamedTupleSerializable(self):
99+
'''
100+
Unit test to make sure exceptions are thrown if the object contains an
101+
unserializable data type.
102+
'''
103+
104+
from thunder.utils.decorators import serializable
105+
from collections import namedtuple
106+
107+
@serializable
108+
class Foo(object):
109+
def __init__(self):
110+
self.nt = namedtuple('FooTuple', 'bar')
111+
112+
foo = Foo()
113+
foo.nt.bar = "baz"
114+
115+
testJson = foo.serialize()
116+
foo2 = Foo.deserialize(testJson)
117+
assert(foo.nt.bar == foo2.nt.bar)
118+

python/test/test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ def elementwiseStdev(arys):
5151
combined = vstack([ary.ravel() for ary in arys])
5252
stdAry = std(combined, axis=0)
5353
return stdAry.reshape(arys[0].shape)
54+
55+

0 commit comments

Comments
 (0)