diff --git a/src/dependencies.py b/src/dependencies.py index 1b467df6..9588d441 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -22,6 +22,7 @@ from src.services.user_service import UserService from src.services.solver_service import SolverService from src.services.structure_service import StructureService +from project_lock_manager import ProjectQueueManager from src.config import config from src.database import get_connection_string_and_token, build_connection_url @@ -47,6 +48,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]: raise e +queue_manager = None +async def get_project_lock_manager() -> ProjectQueueManager: + global queue_manager + if queue_manager is None: + queue_manager = ProjectQueueManager() + return queue_manager + async def get_project_service() -> ProjectService: return ProjectService() diff --git a/src/project_lock_manager.py b/src/project_lock_manager.py new file mode 100644 index 00000000..1bacd6d7 --- /dev/null +++ b/src/project_lock_manager.py @@ -0,0 +1,21 @@ +import uuid +import asyncio + +class ProjectQueueManager: + project_locks: dict[uuid.UUID, asyncio.Lock] = {} + + def add_project_lock(self, scenario_id: uuid.UUID): + self.project_locks[scenario_id] = asyncio.Lock() + + def aquire_project_lock(self, project_id: uuid.UUID) -> asyncio.Lock: + lock = self.project_locks.get(project_id) + if lock is None: + self.add_project_lock(project_id) + lock = self.project_locks.get(project_id) + if lock is None: + raise Exception("Scenario lock could not be aquired") + return lock + + + + diff --git a/src/routes/edge_routes.py b/src/routes/edge_routes.py index c315968f..368a0ba8 100644 --- a/src/routes/edge_routes.py +++ b/src/routes/edge_routes.py @@ -1,6 +1,7 @@ import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query +import asyncio from sqlalchemy.ext.asyncio import AsyncSession from src.dtos.edge_dtos import ( @@ -10,22 +11,26 @@ from src.services.edge_service import EdgeService from src.dependencies import get_edge_service from src.constants import SwaggerDocumentationConstants -from src.dependencies import get_db - +from src.dependencies import get_db, get_project_lock_manager, ProjectQueueManager router = APIRouter(tags=["edges"]) +# example projectid +project_id = uuid.UUID("0c0e7dd2-e683-4e14-bd5b-92d588d72f93") @router.post("/edges") async def create_edges( dtos: list[EdgeIncomingDto], edge_service: EdgeService = Depends(get_edge_service), session: AsyncSession = Depends(get_db), + lock_manager: ProjectQueueManager = Depends(get_project_lock_manager), ) -> list[EdgeOutgoingDto]: try: - result = list(await edge_service.create(session, dtos)) - await session.commit() - return result + async with lock_manager.aquire_project_lock(project_id): + await asyncio.sleep(5) + result = list(await edge_service.create(session, dtos)) + await session.commit() + return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -66,22 +71,29 @@ async def delete_edge( id: uuid.UUID, edge_service: EdgeService = Depends(get_edge_service), session: AsyncSession = Depends(get_db), + lock_manager: ProjectQueueManager = Depends(get_project_lock_manager), ): try: - await edge_service.delete(session, [id]) - await session.commit() + async with lock_manager.aquire_project_lock(project_id): + await asyncio.sleep(5) + await edge_service.delete(session, [id]) + await session.commit() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.delete("/edges") + async def delete_edges( + ids: list[uuid.UUID] = Query([]), edge_service: EdgeService = Depends(get_edge_service), session: AsyncSession = Depends(get_db), + lock_manager: ProjectQueueManager = Depends(get_project_lock_manager), ): try: - await edge_service.delete(session, ids) - await session.commit() + async with lock_manager.aquire_project_lock(project_id): + await edge_service.delete(session, ids) + await session.commit() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -90,10 +102,12 @@ async def update_edges( dtos: list[EdgeIncomingDto], edge_service: EdgeService = Depends(get_edge_service), session: AsyncSession = Depends(get_db), + lock_manager: ProjectQueueManager = Depends(get_project_lock_manager), ) -> list[EdgeOutgoingDto]: try: - result = list(await edge_service.update(session, dtos)) - await session.commit() - return result + async with lock_manager.aquire_project_lock(project_id): + result = list(await edge_service.update(session, dtos)) + await session.commit() + return result except Exception as e: raise HTTPException(status_code=500, detail=str(e))