Skip to content

Commit 54b0a97

Browse files
committed
Add support for volatile collection attributes that don't throw "Phantom object appeared/disappeared" exceptions
1 parent 6fb0fc3 commit 54b0a97

File tree

3 files changed

+84
-13
lines changed

3 files changed

+84
-13
lines changed

pony/orm/core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,7 @@ def _init_(attr, entity, name):
21402140
if attr.py_type == float:
21412141
if attr.is_pk: throw(TypeError, 'PrimaryKey attribute %s cannot be of type float' % attr)
21422142
elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr)
2143-
if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError,
2143+
if attr.is_volatile and attr.is_pk: throw(TypeError,
21442144
'%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr))
21452145

21462146
if attr.interleave is not None:
@@ -2150,6 +2150,8 @@ def _init_(attr, entity, name):
21502150
'`interleave` option value should be True, False or None. Got: %r' % attr.interleave)
21512151
def linked(attr):
21522152
reverse = attr.reverse
2153+
if reverse.is_volatile:
2154+
attr.is_volatile = True
21532155
if attr.cascade_delete is None:
21542156
attr.cascade_delete = attr.is_collection and reverse.is_required
21552157
elif attr.cascade_delete:
@@ -2867,7 +2869,7 @@ def prefetch_load_all(attr, objects):
28672869
else:
28682870
phantoms = setdata2 - items
28692871
if setdata2.added: phantoms -= setdata2.added
2870-
if phantoms: throw(UnrepeatableReadError,
2872+
if phantoms and not attr.is_volatile: throw(UnrepeatableReadError,
28712873
'Phantom object %s disappeared from collection %s.%s'
28722874
% (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name))
28732875
items -= setdata2
@@ -2889,7 +2891,8 @@ def load(attr, obj, items=None):
28892891
assert obj._status_ not in del_statuses
28902892
setdata = obj._vals_.get(attr)
28912893
if setdata is None: setdata = obj._vals_[attr] = SetData()
2892-
elif setdata.is_fully_loaded: return setdata
2894+
elif setdata.is_fully_loaded and not attr.is_volatile:
2895+
return setdata
28932896
entity = attr.entity
28942897
reverse = attr.reverse
28952898
rentity = reverse.entity
@@ -2968,7 +2971,7 @@ def load(attr, obj, items=None):
29682971
else:
29692972
phantoms = setdata2 - items
29702973
if setdata2.added: phantoms -= setdata2.added
2971-
if phantoms: throw(UnrepeatableReadError,
2974+
if phantoms and not attr.is_volatile: throw(UnrepeatableReadError,
29722975
'Phantom object %s disappeared from collection %s.%s'
29732976
% (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name))
29742977
items -= setdata2
@@ -3125,7 +3128,7 @@ def db_reverse_add(attr, objects, item):
31253128
for obj in objects:
31263129
setdata = obj._vals_.get(attr)
31273130
if setdata is None: setdata = obj._vals_[attr] = SetData()
3128-
elif setdata.is_fully_loaded: throw(UnrepeatableReadError,
3131+
elif setdata.is_fully_loaded and not attr.is_volatile: throw(UnrepeatableReadError,
31293132
'Phantom object %s appeared in collection %s.%s' % (safe_repr(item), safe_repr(obj), attr.name))
31303133
setdata.add(item)
31313134
def reverse_remove(attr, objects, item, undo_funcs):

pony/orm/tests/test_diagram_keys.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,6 @@ def test_volatile_pk(self):
159159
class Entity1(db.Entity):
160160
a = PrimaryKey(int, volatile=True)
161161

162-
@raises_exception(TypeError, 'Set attribute Entity1.b cannot be volatile')
163-
def test_volatile_set(self):
164-
db = self.db = Database(**db_params)
165-
class Entity1(db.Entity):
166-
a = PrimaryKey(int)
167-
b = Set('Entity2', volatile=True)
168-
169162
@raises_exception(TypeError, 'Volatile attribute Entity1.b cannot be part of primary key')
170163
def test_volatile_composite_pk(self):
171164
db = self.db = Database(**db_params)

pony/orm/tests/test_volatile.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pony.orm.tests import setup_database, teardown_database
66

77

8-
class TestVolatile(unittest.TestCase):
8+
class TestVolatile1(unittest.TestCase):
99
def setUp(self):
1010
db = self.db = Database()
1111

@@ -48,5 +48,80 @@ def test_2(self):
4848
item.flush()
4949
self.assertEqual(item.index, 1)
5050

51+
52+
class TestVolatile2(unittest.TestCase):
53+
def setUp(self):
54+
db = self.db = Database()
55+
56+
class Group(db.Entity):
57+
number = PrimaryKey(int)
58+
students = Set('Student', volatile=True)
59+
60+
class Student(db.Entity):
61+
id = PrimaryKey(int)
62+
name = Required(str)
63+
group = Required('Group')
64+
courses = Set('Course')
65+
66+
class Course(db.Entity):
67+
id = PrimaryKey(int)
68+
name = Required(str)
69+
students = Set('Student', volatile=True)
70+
71+
setup_database(db)
72+
73+
with db_session:
74+
g1 = Group(number=123)
75+
s1 = Student(id=1, name='A', group=g1)
76+
s2 = Student(id=2, name='B', group=g1)
77+
c1 = Course(id=1, name='C1', students=[s1, s2])
78+
c2 = Course(id=2, name='C1', students=[s1])
79+
80+
self.Group = Group
81+
self.Student = Student
82+
self.Course = Course
83+
84+
def tearDown(self):
85+
teardown_database(self.db)
86+
87+
def test_1(self):
88+
self.assertTrue(self.Group.students.is_volatile)
89+
self.assertTrue(self.Student.group.is_volatile)
90+
self.assertTrue(self.Student.courses.is_volatile)
91+
self.assertTrue(self.Course.students.is_volatile)
92+
93+
def test_2(self):
94+
with db_session:
95+
g1 = self.Group[123]
96+
students = set(s.id for s in g1.students)
97+
self.assertEqual(students, {1, 2})
98+
self.db.execute('''insert into student values(3, 'C', 123)''')
99+
g1.students.load()
100+
students = set(s.id for s in g1.students)
101+
self.assertEqual(students, {1, 2, 3})
102+
103+
def test_3(self):
104+
with db_session:
105+
g1 = self.Group[123]
106+
students = set(s.id for s in g1.students)
107+
self.assertEqual(students, {1, 2})
108+
self.db.execute("insert into student values(3, 'C', 123)")
109+
s3 = self.Student[3]
110+
students = set(s.id for s in g1.students)
111+
self.assertEqual(students, {1, 2, 3})
112+
113+
def test_4(self):
114+
with db_session:
115+
c1 = self.Course[1]
116+
students = set(s.id for s in c1.students)
117+
self.assertEqual(students, {1, 2})
118+
self.db.execute("insert into student values(3, 'C', 123)")
119+
attr = self.Student.courses
120+
self.db.execute("insert into %s values(1, 3)" % attr.table)
121+
c1.students.load()
122+
students = set(s.id for s in c1.students)
123+
self.assertEqual(students, {1, 2, 3})
124+
125+
51126
if __name__ == '__main__':
52127
unittest.main()

0 commit comments

Comments
 (0)