diff --git a/configs/config_app_example.toml b/configs/config_app_example.toml index 81c4a476..84c8a813 100644 --- a/configs/config_app_example.toml +++ b/configs/config_app_example.toml @@ -1,5 +1,7 @@ # MAKE SURE TO RENAME THIS FILE TO config.toml AND PLACE IT IN THE ROOT OF THE PROJECT +node_name = "example_node" + # Router configuration router_name = "router1" router_address = "xx.xx.xx.xx" # Replace with actual IP address diff --git a/src/pqnstack/app/api/deps.py b/src/pqnstack/app/api/deps.py index 610116ef..706a0f96 100644 --- a/src/pqnstack/app/api/deps.py +++ b/src/pqnstack/app/api/deps.py @@ -4,6 +4,9 @@ import httpx from fastapi import Depends +from pqnstack.app.core.config import NodeState +from pqnstack.app.core.config import get_state + async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]: async with httpx.AsyncClient(timeout=600_000) as client: @@ -11,3 +14,6 @@ async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]: ClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)] + + +StateDep = Annotated[NodeState, Depends(get_state)] diff --git a/src/pqnstack/app/api/main.py b/src/pqnstack/app/api/main.py index d50d7c12..175e0899 100644 --- a/src/pqnstack/app/api/main.py +++ b/src/pqnstack/app/api/main.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from pqnstack.app.api.routes import chsh +from pqnstack.app.api.routes import coordination +from pqnstack.app.api.routes import debug from pqnstack.app.api.routes import qkd from pqnstack.app.api.routes import rng from pqnstack.app.api.routes import serial @@ -12,3 +14,5 @@ api_router.include_router(timetagger.router) api_router.include_router(rng.router) api_router.include_router(serial.router) +api_router.include_router(coordination.router) +api_router.include_router(debug.router) diff --git a/src/pqnstack/app/api/routes/chsh.py b/src/pqnstack/app/api/routes/chsh.py index 375e84bb..3eeeb490 100644 --- a/src/pqnstack/app/api/routes/chsh.py +++ b/src/pqnstack/app/api/routes/chsh.py @@ -6,8 +6,8 @@ from fastapi import status from pqnstack.app.api.deps import ClientDep +from pqnstack.app.api.deps import StateDep from pqnstack.app.core.config import settings -from pqnstack.app.core.config import state from pqnstack.app.core.models import calculate_chsh_expectation_error from pqnstack.network.client import Client @@ -126,7 +126,7 @@ async def chsh( @router.post("/request-angle-by-basis") -async def request_angle_by_basis(index: int, *, perp: bool = False) -> bool: +async def request_angle_by_basis(index: int, state: StateDep, *, perp: bool = False) -> bool: client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000) hwp = client.get_device(settings.chsh_settings.request_hwp[0], settings.chsh_settings.request_hwp[1]) if hwp is None: diff --git a/src/pqnstack/app/api/routes/coordination.py b/src/pqnstack/app/api/routes/coordination.py new file mode 100644 index 00000000..35d77850 --- /dev/null +++ b/src/pqnstack/app/api/routes/coordination.py @@ -0,0 +1,177 @@ +import asyncio +import logging + +from fastapi import APIRouter +from fastapi import HTTPException +from fastapi import Request +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from fastapi import status +from pydantic import BaseModel + +from pqnstack.app.api.deps import ClientDep +from pqnstack.app.api.deps import StateDep +from pqnstack.app.core.config import ask_user_for_follow_event +from pqnstack.app.core.config import settings +from pqnstack.app.core.config import user_replied_event + +logger = logging.getLogger(__name__) + + +class FollowRequestResponse(BaseModel): + accepted: bool + + +class CollectFollowerResponse(BaseModel): + success: bool + + +class ResetCoordinationStateResponse(BaseModel): + message: str = "Coordination state reset successfully" + + +router = APIRouter(prefix="/coordination", tags=["coordination"]) + + +# TODO: Send a disconnection message if I was following someone. +@router.post("/reset_coordination_state") +async def reset_coordination_state(state: StateDep) -> ResetCoordinationStateResponse: + """Reset the coordination state of the node.""" + state.leading = False + state.followers_address = "" + state.following = False + state.following_requested = False + state.following_requested_user_response = None + state.leaders_address = "" + state.leaders_name = "" + return ResetCoordinationStateResponse() + + +@router.post("/collect_follower") +async def collect_follower(address: str, state: StateDep, http_client: ClientDep) -> CollectFollowerResponse: + """ + Endpoint called by a leader node (this one) to request a follower node (other node) to follow it. + + Returns + ------- + CollectFollowerResponse indicating if the follower accepted the request. + """ + logger.info("Requesting client at %s to follow", address) + + ret = await http_client.post(f"http://{address}/coordination/follow_requested?leaders_name={settings.node_name}") + if ret.status_code != status.HTTP_200_OK: + raise HTTPException(status_code=ret.status_code, detail=ret.text) + + response_data = ret.json() + if response_data.get("accepted") is True: + state.leading = True + state.followers_address = address + logger.info("Successfully collected follower") + return CollectFollowerResponse(success=True) + if response_data.get("accepted") is False: + logger.info("Follower rejected follow request") + return CollectFollowerResponse(success=False) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not collect follower for unknown reasons" + ) + + +@router.post("/follow_requested") +async def follow_requested(request: Request, leaders_name: str, state: StateDep) -> FollowRequestResponse: + """ + Endpoint is called by a leader node (other node) to request this node to follow it. + + Returns + ------- + FollowRequestResponse indicating if the follow request is accepted. + """ + logger.debug("Requesting client at %s to follow", leaders_name) + + if request.client is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Request lacks the clients host") + leaders_address = request.client.host + + # Check if the client is ready to accept a follower request and that node is not already following someone. + if not state.client_listening_for_follower_requests or state.following: + logger.info( + "Request rejected because %s", + ( + "client is not listening for requests" + if not state.client_listening_for_follower_requests + else "this node is already following someone" + ), + ) + return FollowRequestResponse(accepted=False) + + state.following_requested = True + state.leaders_name = leaders_name + state.leaders_address = leaders_address + # Trigger the state change to get the websocket to send question to user + ask_user_for_follow_event.set() + + logger.debug("Asking user to accept follow request from %s (%s)", leaders_name, leaders_address) + await user_replied_event.wait() # Wait for a state change event to see if user accepted + user_replied_event.clear() # Reset the event for the next change + if state.following_requested_user_response: + logger.debug("Follow request from %s accepted.", leaders_address) + state.following = True + state.leaders_name = leaders_name + state.leaders_address = leaders_address + return FollowRequestResponse(accepted=True) + + logger.debug("Follow request from %s rejected.", leaders_address) + # Clean up the state if user rejected + state.leaders_address = "" + state.leaders_name = "" + state.following_requested = False + return FollowRequestResponse(accepted=False) + + +@router.websocket("/follow_requested_alerts") +async def follow_requested_alert(websocket: WebSocket, state: StateDep) -> None: + """Websocket endpoint is used to alert the client when a follow request is received. It also handles the response from the client.""" + await websocket.accept() + logger.info("Client connected to websocket for multiplayer coordination.") + state.client_listening_for_follower_requests = True + + async def ask_user_for_follow_handler() -> None: + """Task that waits for the ask_user_for_follow_event event and sends a message to the client if a follow request is detected.""" + while True: + try: + await ask_user_for_follow_event.wait() # Wait for a state change event + if state.following_requested: + logger.debug("Websocket detected a follow request, asking user for response.") + if websocket.client_state.name == "CONNECTED": + await websocket.send_text( + f"Do you want to accept a connection from {state.leaders_name} ({state.leaders_address})?" + ) + else: + logger.debug("WebSocket not connected, cannot send message") + break + ask_user_for_follow_event.clear() # Reset the event for the next change + except Exception: + logger.exception("Error in ask_user_for_follow_handler") + break + + async def client_message_handler() -> None: + """Task that waits for a message from the client and handles the response. It also handles the case where the client disconnects.""" + try: + while True: + response = await websocket.receive_text() + state.following_requested_user_response = response.lower() in ["true", "yes", "y"] + state.following_requested = False + logger.debug("Websocket received a response from user: %s", state.following_requested_user_response) + user_replied_event.set() + except WebSocketDisconnect: + logger.info("Client disconnected from websocket for multiplayer coordination.") + state.client_listening_for_follower_requests = False + + state_change_task = asyncio.create_task(ask_user_for_follow_handler()) + client_message_task = asyncio.create_task(client_message_handler()) + + try: + await asyncio.gather(state_change_task, client_message_task) + finally: + state_change_task.cancel() + client_message_task.cancel() diff --git a/src/pqnstack/app/api/routes/debug.py b/src/pqnstack/app/api/routes/debug.py new file mode 100644 index 00000000..12768c32 --- /dev/null +++ b/src/pqnstack/app/api/routes/debug.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from pqnstack.app.api.deps import StateDep +from pqnstack.app.core.config import NodeState + +router = APIRouter(prefix="/debug", tags=["debug"]) + + +@router.get("/state") +async def get_state(state: StateDep) -> NodeState: + return state diff --git a/src/pqnstack/app/api/routes/qkd.py b/src/pqnstack/app/api/routes/qkd.py index f676b2ce..35a97695 100644 --- a/src/pqnstack/app/api/routes/qkd.py +++ b/src/pqnstack/app/api/routes/qkd.py @@ -7,8 +7,8 @@ from fastapi import status from pqnstack.app.api.deps import ClientDep +from pqnstack.app.api.deps import StateDep from pqnstack.app.core.config import settings -from pqnstack.app.core.config import state from pqnstack.constants import BasisBool from pqnstack.constants import QKDEncodingBasis from pqnstack.network.client import Client @@ -21,6 +21,7 @@ async def _qkd( follower_node_address: str, http_client: ClientDep, + state: StateDep, timetagger_address: str | None = None, ) -> list[int]: logger.debug("Starting QKD") @@ -106,6 +107,7 @@ def get_outcome(state: int, basis: int, choice: int, counts: int) -> int: async def qkd( follower_node_address: str, http_client: ClientDep, + state: StateDep, timetagger_address: str | None = None, ) -> list[int]: """Perform a QKD protocol with the given follower node.""" @@ -116,11 +118,11 @@ async def qkd( detail="QKD basis list is empty", ) - return await _qkd(follower_node_address, http_client, timetagger_address) + return await _qkd(follower_node_address, http_client, state, timetagger_address) @router.post("/single_bit") -async def request_qkd_single_pass() -> bool: +async def request_qkd_single_pass(state: StateDep) -> bool: client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000) hwp = client.get_device(settings.qkd_settings.request_hwp[0], settings.qkd_settings.request_hwp[1]) @@ -149,7 +151,7 @@ async def request_qkd_single_pass() -> bool: @router.post("/request_basis_list") -def request_qkd_basis_list(leader_basis_list: list[str]) -> list[str]: +def request_qkd_basis_list(leader_basis_list: list[str], state: StateDep) -> list[str]: """Return the list of basis angles for QKD.""" # Check that lengths match if len(leader_basis_list) != len(state.qkd_request_basis_list): diff --git a/src/pqnstack/app/core/config.py b/src/pqnstack/app/core/config.py index da78729f..b5aea03d 100644 --- a/src/pqnstack/app/core/config.py +++ b/src/pqnstack/app/core/config.py @@ -1,3 +1,4 @@ +import asyncio import logging from functools import lru_cache @@ -32,6 +33,7 @@ class QKDSettings(BaseModel): class Settings(BaseSettings): + node_name: str = "node1" router_name: str = "router1" router_address: str = "localhost" router_port: int = 5555 @@ -72,7 +74,28 @@ def get_settings() -> Settings: class NodeState(BaseModel): + # Coordination state + # FIXME: Make sure we are checking for the client_listening_for_follower_requests state everywhere. + client_listening_for_follower_requests: bool = False + + # Leader's state + leading: bool = False + followers_address: str = "" + + # Follower's state + following: bool = False + # Other node requested this node to follow it. + following_requested: bool = False + # User's response to the follow request. None if no response yet, True if accepted, False if rejected. + following_requested_user_response: bool | None = None + # The address of the leader this node is following. None if not following anyone. + leaders_address: str = "" + leaders_name: str = "" + + # CHSH state chsh_request_basis: list[float] = [22.5, 67.5] + + # QKD state # FIXME: Use enums for this qkd_basis_list: list[QKDEncodingBasis] = [ QKDEncodingBasis.DA, @@ -94,3 +117,9 @@ class NodeState(BaseModel): state = NodeState() +ask_user_for_follow_event = asyncio.Event() +user_replied_event = asyncio.Event() + + +def get_state() -> NodeState: + return state