Skip to content

Commit 710656f

Browse files
committed
Refactor version detection logic
1 parent 8e55e13 commit 710656f

6 files changed

Lines changed: 69 additions & 54 deletions

File tree

sqlalchemy_mptt/events.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sqlalchemy.sql import func
2020
from sqlalchemy.orm.base import NO_VALUE
2121

22-
from sqlalchemy_mptt.sqlalchemy_compat import case, select
22+
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
2323

2424

2525
def _insert_subtree(
@@ -62,7 +62,7 @@ def _insert_subtree(
6262
.where(table.c.tree_id == parent_tree_id)
6363
.values(
6464
rgt=table.c.rgt + node_size,
65-
lft=case(
65+
lft=compat_layer.case(
6666
(table.c.lft > left_sibling['lft'], table.c.lft + node_size),
6767
else_=table.c.lft
6868
)
@@ -89,7 +89,7 @@ def mptt_before_insert(mapper, connection, instance):
8989
instance.right = 2
9090
instance.level = instance.get_default_level()
9191
tree_id = connection.scalar(
92-
select(
92+
compat_layer.select(
9393
func.max(table.c.tree_id) + 1
9494
)
9595
) or 1
@@ -99,7 +99,7 @@ def mptt_before_insert(mapper, connection, instance):
9999
parent_pos_right,
100100
parent_tree_id,
101101
parent_level) = connection.execute(
102-
select(
102+
compat_layer.select(
103103
table.c.lft,
104104
table.c.rgt,
105105
table.c.tree_id,
@@ -115,11 +115,11 @@ def mptt_before_insert(mapper, connection, instance):
115115
.where(table.c.rgt >= parent_pos_right)
116116
.where(table.c.tree_id == parent_tree_id)
117117
.values(
118-
lft=case(
118+
lft=compat_layer.case(
119119
(table.c.lft > parent_pos_right, table.c.lft + 2),
120120
else_=table.c.lft
121121
),
122-
rgt=case(
122+
rgt=compat_layer.case(
123123
(table.c.rgt >= parent_pos_right, table.c.rgt + 2),
124124
else_=table.c.rgt
125125
)
@@ -139,7 +139,7 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
139139
db_pk = instance.get_pk_column()
140140
table_pk = getattr(table.c, db_pk.name)
141141
lft, rgt = connection.execute(
142-
select(
142+
compat_layer.select(
143143
table.c.lft,
144144
table.c.rgt
145145
).where(
@@ -174,11 +174,11 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
174174
.where(table.c.rgt > rgt)
175175
.where(table.c.tree_id == tree_id)
176176
.values(
177-
lft=case(
177+
lft=compat_layer.case(
178178
(table.c.lft > lft, table.c.lft - delta),
179179
else_=table.c.lft
180180
),
181-
rgt=case(
181+
rgt=compat_layer.case(
182182
(table.c.rgt >= rgt, table.c.rgt - delta),
183183
else_=table.c.rgt
184184
)
@@ -210,7 +210,7 @@ def mptt_before_update(mapper, connection, instance):
210210
right_sibling_level,
211211
right_sibling_tree_id
212212
) = connection.execute(
213-
select(
213+
compat_layer.select(
214214
table.c.lft,
215215
table.c.rgt,
216216
table.c.parent_id,
@@ -221,7 +221,7 @@ def mptt_before_update(mapper, connection, instance):
221221
)
222222
).fetchone()
223223
current_lvl_nodes = connection.execute(
224-
select(
224+
compat_layer.select(
225225
table.c.lft,
226226
table.c.rgt,
227227
table.c.parent_id,
@@ -259,7 +259,7 @@ def mptt_before_update(mapper, connection, instance):
259259
left_sibling_parent,
260260
left_sibling_tree_id
261261
) = connection.execute(
262-
select(
262+
compat_layer.select(
263263
table.c.lft,
264264
table.c.rgt,
265265
table.c.parent_id,
@@ -282,7 +282,7 @@ def mptt_before_update(mapper, connection, instance):
282282
ORDER BY left_key
283283
"""
284284
subtree = connection.execute(
285-
select(table_pk)
285+
compat_layer.select(table_pk)
286286
.where(
287287
and_(
288288
table.c.lft >= instance.left,
@@ -306,7 +306,7 @@ def mptt_before_update(mapper, connection, instance):
306306
node_parent_id,
307307
node_level
308308
) = connection.execute(
309-
select(
309+
compat_layer.select(
310310
table.c.lft,
311311
table.c.rgt,
312312
table.c.tree_id,
@@ -334,7 +334,7 @@ def mptt_before_update(mapper, connection, instance):
334334
parent_tree_id,
335335
parent_level
336336
) = connection.execute(
337-
select(
337+
compat_layer.select(
338338
table_pk,
339339
table.c.rgt,
340340
table.c.lft,
@@ -362,7 +362,7 @@ def mptt_before_update(mapper, connection, instance):
362362
parent_tree_id,
363363
parent_level
364364
) = connection.execute(
365-
select(
365+
compat_layer.select(
366366
table_pk,
367367
table.c.rgt,
368368
table.c.lft,
@@ -414,7 +414,7 @@ def mptt_before_update(mapper, connection, instance):
414414
# if just insert
415415
else:
416416
tree_id = connection.scalar(
417-
select(
417+
compat_layer.select(
418418
func.max(table.c.tree_id) + 1
419419
)
420420
)

sqlalchemy_mptt/sqlalchemy_compat.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,49 @@
77
import sqlalchemy as sa
88

99

10-
if sa.__version__ < '1.4':
11-
from sqlalchemy.ext.declarative import declarative_base
12-
else:
13-
from sqlalchemy.orm import declarative_base
10+
class LegacySQLAlchemyAPI:
11+
"""A class to provide compatibility for legacy SQLAlchemy versions (1.0 - 1.3)."""
1412

13+
@staticmethod
14+
def declarative_base(*args, **kwargs):
15+
from sqlalchemy.ext.declarative import declarative_base
16+
return declarative_base(*args, **kwargs)
1517

16-
def select(*args, **kwargs):
17-
"""Compatibility function for select."""
18-
if sa.__version__ < '1.4':
18+
@staticmethod
19+
def select(*args, **kwargs):
1920
return sa.select(args, **kwargs)
20-
else:
21-
return sa.select(*args, **kwargs)
22-
2321

24-
def case(*args, **kwargs):
25-
"""Compatibility function for case."""
26-
if sa.__version__ < '1.4':
22+
@staticmethod
23+
def case(*args, **kwargs):
2724
return sa.case(args, **kwargs)
28-
else:
29-
return sa.case(*args, **kwargs)
30-
3125

32-
def get(session, model, id):
33-
"""Compatibility function for getting an object by ID."""
34-
if sa.__version__ < '1.4':
26+
@staticmethod
27+
def get(session, model, id):
3528
return session.query(model).get(id)
36-
else:
29+
30+
31+
class ModernSQLAlchemyAPI:
32+
"""A class to provide compatibility for modern SQLAlchemy versions (1.4+)."""
33+
34+
@staticmethod
35+
def declarative_base(*args, **kwargs):
36+
from sqlalchemy.orm import declarative_base
37+
return declarative_base(*args, **kwargs)
38+
39+
@staticmethod
40+
def select(*args, **kwargs):
41+
return sa.select(*args, **kwargs)
42+
43+
@staticmethod
44+
def case(*args, **kwargs):
45+
return sa.case(*args, **kwargs)
46+
47+
@staticmethod
48+
def get(session, model, id):
3749
return session.get(model, id)
3850

3951

40-
__all__ = ["case", "declarative_base", "select"]
52+
if sa.__version__ < '1.4':
53+
compat_layer = LegacySQLAlchemyAPI()
54+
else:
55+
compat_layer = ModernSQLAlchemyAPI()

sqlalchemy_mptt/tests/test_events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from sqlalchemy_mptt import mptt_sessionmaker
2020

2121
from sqlalchemy_mptt.mixins import BaseNestedSets
22-
from sqlalchemy_mptt.sqlalchemy_compat import declarative_base
22+
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
2323
from sqlalchemy_mptt.tests import TreeTestingMixin
2424

2525

26-
Base = declarative_base()
26+
Base = compat_layer.declarative_base()
2727

2828

2929
class Tree(Base, BaseNestedSets):

sqlalchemy_mptt/tests/test_inheritance.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from sqlalchemy.orm import sessionmaker
55

66
from sqlalchemy_mptt.mixins import BaseNestedSets
7-
from sqlalchemy_mptt.sqlalchemy_compat import declarative_base, get
7+
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
88
from sqlalchemy_mptt.tests import TreeTestingMixin, failures_expected_on
99

1010

11-
Base = declarative_base()
11+
Base = compat_layer.declarative_base()
1212

1313

1414
class GenericTree(Base, BaseNestedSets):
@@ -60,15 +60,15 @@ def test_create_generic(self):
6060
self.session.add(GenericTree(ppk=1))
6161
self.session.commit()
6262

63-
tree = get(self.session, GenericTree, 1)
63+
tree = compat_layer.get(self.session, GenericTree, 1)
6464
self.assertEqual(tree.ppk, 1)
6565
self.assertEqual(tree.tree_id, 1)
6666

6767
def test_create_spec(self):
6868
self.session.add(SpecializedTree(ppk=1))
6969
self.session.commit()
7070

71-
tree = get(self.session, SpecializedTree, 1)
71+
tree = compat_layer.get(self.session, SpecializedTree, 1)
7272
self.assertEqual(tree.ppk, 1)
7373
self.assertEqual(tree.tree_id, 1)
7474

@@ -84,21 +84,21 @@ def test_create_delete(self):
8484
self.session.add(parent)
8585
self.session.commit()
8686

87-
tree = get(self.session, SpecializedTree, 1)
87+
tree = compat_layer.get(self.session, SpecializedTree, 1)
8888
self.assertEqual(tree.ppk, 1)
8989
self.assertEqual(tree.tree_id, 1)
9090

9191
self.session.delete(child1)
9292
self.session.commit()
9393

94-
self.assertEqual(None, get(self.session, SpecializedTree, 2))
94+
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 2))
9595

9696
self.session.delete(child2)
9797
self.session.commit()
9898

99-
self.assertEqual(None, get(self.session, SpecializedTree, 3))
100-
self.assertEqual(None, get(self.session, SpecializedTree, 4))
101-
self.assertEqual(None, get(self.session, SpecializedTree, 5))
99+
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 3))
100+
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 4))
101+
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 5))
102102

103103

104104
class TestGenericTree(TreeTestingMixin, unittest.TestCase):
@@ -116,7 +116,7 @@ def test_rebuild(self):
116116
super().test_rebuild()
117117

118118

119-
Base2 = declarative_base()
119+
Base2 = compat_layer.declarative_base()
120120

121121

122122
class BaseInheritance(Base2):

sqlalchemy_mptt/tests/test_mixins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from sqlalchemy import Column, Integer
1515

1616
from sqlalchemy_mptt.mixins import BaseNestedSets
17-
from sqlalchemy_mptt.sqlalchemy_compat import declarative_base
17+
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
1818

1919

20-
Base = declarative_base()
20+
Base = compat_layer.declarative_base()
2121

2222

2323
class Tree2(Base, BaseNestedSets):

sqlalchemy_mptt/tests/test_stateful.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from sqlalchemy.orm import joinedload, sessionmaker
1212

1313
from sqlalchemy_mptt import BaseNestedSets, mptt_sessionmaker
14-
from sqlalchemy_mptt.sqlalchemy_compat import declarative_base
14+
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
1515

1616

17-
Base = declarative_base()
17+
Base = compat_layer.declarative_base()
1818

1919

2020
class Tree(Base, BaseNestedSets):

0 commit comments

Comments
 (0)