Skip to content

Commit 1955db9

Browse files
author
Jeny Sadadia
committed
api.db: decouple model indexes from DB engine
Implement a dataclass `Index` to store model field names and constraints to create indexes. Instead of directly creating indexes from model method for specific DB engine (at the moment, MongoDB), implement a method to get indexes from models independent of DB engine i.e. list of `DatabaseModel.Index` class instances. Create MongoDB specific indexes in the database method `Database.create_indexes`. Signed-off-by: Jeny Sadadia <[email protected]>
1 parent eb358a4 commit 1955db9

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

api/db.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ def _get_collection(self, model):
6262
async def create_indexes(self):
6363
"""Create indexes for models"""
6464
for model in self.COLLECTIONS:
65+
indexes = model.get_indexes()
66+
if not indexes:
67+
continue
6568
col = self._get_collection(model)
66-
model.create_indexes(col)
69+
for index in indexes:
70+
col.create_index(index.field, **index.attributes)
6771

6872
async def find_one(self, model, **kwargs):
6973
"""Find one object with matching attributes

api/models.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ class UserGroup(DatabaseModel):
6666
)
6767

6868
@classmethod
69-
def create_indexes(cls, collection):
70-
"""Create an index to bind unique constraint to group name"""
71-
collection.create_index("name", unique=True)
69+
def get_indexes(cls):
70+
"""Get an index to bind unique constraint to group name"""
71+
return [
72+
cls.Index('name', {'unique': True}),
73+
]
7274

7375

7476
class User(BeanieBaseUser, Document, # pylint: disable=too-many-ancestors
@@ -86,9 +88,11 @@ class Settings(BeanieBaseUser.Settings):
8688
name = "user"
8789

8890
@classmethod
89-
def create_indexes(cls, collection):
90-
"""Create an index to bind unique constraint to email"""
91-
collection.create_index("email", unique=True)
91+
def get_indexes(cls):
92+
"""Get indices"""
93+
return [
94+
cls.Index('email', {'unique': True}),
95+
]
9296

9397

9498
class UserRead(schemas.BaseUser[PydanticObjectId], ModelId):

0 commit comments

Comments
 (0)