Skip to content
2 changes: 2 additions & 0 deletions configs/config_app_example.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# MAKE SURE TO RENAME THIS FILE TO config.toml AND PLACE IT IN THE ROOT OF THE PROJECT

node_name = "example_node"

# Router configuration
router_name = "router1"
router_address = "xx.xx.xx.xx" # Replace with actual IP address
Expand Down
6 changes: 6 additions & 0 deletions src/pqnstack/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
import httpx
from fastapi import Depends

from pqnstack.app.core.config import NodeState
from pqnstack.app.core.config import get_state


async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(timeout=600_000) as client:
yield client


ClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]


StateDep = Annotated[NodeState, Depends(get_state)]
4 changes: 4 additions & 0 deletions src/pqnstack/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import APIRouter

from pqnstack.app.api.routes import chsh
from pqnstack.app.api.routes import coordination
from pqnstack.app.api.routes import debug
from pqnstack.app.api.routes import qkd
from pqnstack.app.api.routes import rng
from pqnstack.app.api.routes import serial
Expand All @@ -12,3 +14,5 @@
api_router.include_router(timetagger.router)
api_router.include_router(rng.router)
api_router.include_router(serial.router)
api_router.include_router(coordination.router)
api_router.include_router(debug.router)
4 changes: 2 additions & 2 deletions src/pqnstack/app/api/routes/chsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from fastapi import status

from pqnstack.app.api.deps import ClientDep
from pqnstack.app.api.deps import StateDep
from pqnstack.app.core.config import settings
from pqnstack.app.core.config import state
from pqnstack.app.core.models import calculate_chsh_expectation_error
from pqnstack.network.client import Client

Expand Down Expand Up @@ -126,7 +126,7 @@ async def chsh(


@router.post("/request-angle-by-basis")
async def request_angle_by_basis(index: int, *, perp: bool = False) -> bool:
async def request_angle_by_basis(index: int, state: StateDep, *, perp: bool = False) -> bool:
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
hwp = client.get_device(settings.chsh_settings.request_hwp[0], settings.chsh_settings.request_hwp[1])
if hwp is None:
Expand Down
177 changes: 177 additions & 0 deletions src/pqnstack/app/api/routes/coordination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import asyncio
import logging

from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from fastapi import status
from pydantic import BaseModel

from pqnstack.app.api.deps import ClientDep
from pqnstack.app.api.deps import StateDep
from pqnstack.app.core.config import ask_user_for_follow_event
from pqnstack.app.core.config import settings
from pqnstack.app.core.config import user_replied_event

logger = logging.getLogger(__name__)


class FollowRequestResponse(BaseModel):
accepted: bool


class CollectFollowerResponse(BaseModel):
success: bool


class ResetCoordinationStateResponse(BaseModel):
message: str = "Coordination state reset successfully"


router = APIRouter(prefix="/coordination", tags=["coordination"])


# TODO: Send a disconnection message if I was following someone.
@router.post("/reset_coordination_state")
async def reset_coordination_state(state: StateDep) -> ResetCoordinationStateResponse:
"""Reset the coordination state of the node."""
state.leading = False
state.followers_address = ""
state.following = False
state.following_requested = False
state.following_requested_user_response = None
state.leaders_address = ""
state.leaders_name = ""
return ResetCoordinationStateResponse()


@router.post("/collect_follower")
async def collect_follower(address: str, state: StateDep, http_client: ClientDep) -> CollectFollowerResponse:
"""
Endpoint called by a leader node (this one) to request a follower node (other node) to follow it.

Returns
-------
CollectFollowerResponse indicating if the follower accepted the request.
"""
logger.info("Requesting client at %s to follow", address)

ret = await http_client.post(f"http://{address}/coordination/follow_requested?leaders_name={settings.node_name}")
if ret.status_code != status.HTTP_200_OK:
raise HTTPException(status_code=ret.status_code, detail=ret.text)

response_data = ret.json()
if response_data.get("accepted") is True:
state.leading = True
state.followers_address = address
logger.info("Successfully collected follower")
return CollectFollowerResponse(success=True)
if response_data.get("accepted") is False:
logger.info("Follower rejected follow request")
return CollectFollowerResponse(success=False)

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not collect follower for unknown reasons"
)


@router.post("/follow_requested")
async def follow_requested(request: Request, leaders_name: str, state: StateDep) -> FollowRequestResponse:
"""
Endpoint is called by a leader node (other node) to request this node to follow it.

Returns
-------
FollowRequestResponse indicating if the follow request is accepted.
"""
logger.debug("Requesting client at %s to follow", leaders_name)

if request.client is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Request lacks the clients host")
leaders_address = request.client.host

# Check if the client is ready to accept a follower request and that node is not already following someone.
if not state.client_listening_for_follower_requests or state.following:
logger.info(
"Request rejected because %s",
(
"client is not listening for requests"
if not state.client_listening_for_follower_requests
else "this node is already following someone"
),
)
return FollowRequestResponse(accepted=False)

state.following_requested = True
state.leaders_name = leaders_name
state.leaders_address = leaders_address
# Trigger the state change to get the websocket to send question to user
ask_user_for_follow_event.set()

logger.debug("Asking user to accept follow request from %s (%s)", leaders_name, leaders_address)
await user_replied_event.wait() # Wait for a state change event to see if user accepted
user_replied_event.clear() # Reset the event for the next change
if state.following_requested_user_response:
logger.debug("Follow request from %s accepted.", leaders_address)
state.following = True
state.leaders_name = leaders_name
state.leaders_address = leaders_address
return FollowRequestResponse(accepted=True)

logger.debug("Follow request from %s rejected.", leaders_address)
# Clean up the state if user rejected
state.leaders_address = ""
state.leaders_name = ""
state.following_requested = False
return FollowRequestResponse(accepted=False)


@router.websocket("/follow_requested_alerts")
async def follow_requested_alert(websocket: WebSocket, state: StateDep) -> None:
"""Websocket endpoint is used to alert the client when a follow request is received. It also handles the response from the client."""
await websocket.accept()
logger.info("Client connected to websocket for multiplayer coordination.")
state.client_listening_for_follower_requests = True

async def ask_user_for_follow_handler() -> None:
"""Task that waits for the ask_user_for_follow_event event and sends a message to the client if a follow request is detected."""
while True:
try:
await ask_user_for_follow_event.wait() # Wait for a state change event
if state.following_requested:
logger.debug("Websocket detected a follow request, asking user for response.")
if websocket.client_state.name == "CONNECTED":
await websocket.send_text(
f"Do you want to accept a connection from {state.leaders_name} ({state.leaders_address})?"
)
else:
logger.debug("WebSocket not connected, cannot send message")
break
ask_user_for_follow_event.clear() # Reset the event for the next change
except Exception:
logger.exception("Error in ask_user_for_follow_handler")
break

async def client_message_handler() -> None:
"""Task that waits for a message from the client and handles the response. It also handles the case where the client disconnects."""
try:
while True:
response = await websocket.receive_text()
state.following_requested_user_response = response.lower() in ["true", "yes", "y"]
state.following_requested = False
logger.debug("Websocket received a response from user: %s", state.following_requested_user_response)
user_replied_event.set()
except WebSocketDisconnect:
logger.info("Client disconnected from websocket for multiplayer coordination.")
state.client_listening_for_follower_requests = False

state_change_task = asyncio.create_task(ask_user_for_follow_handler())
client_message_task = asyncio.create_task(client_message_handler())

try:
await asyncio.gather(state_change_task, client_message_task)
finally:
state_change_task.cancel()
client_message_task.cancel()
11 changes: 11 additions & 0 deletions src/pqnstack/app/api/routes/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import APIRouter

from pqnstack.app.api.deps import StateDep
from pqnstack.app.core.config import NodeState

router = APIRouter(prefix="/debug", tags=["debug"])


@router.get("/state")
async def get_state(state: StateDep) -> NodeState:
return state
10 changes: 6 additions & 4 deletions src/pqnstack/app/api/routes/qkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from fastapi import status

from pqnstack.app.api.deps import ClientDep
from pqnstack.app.api.deps import StateDep
from pqnstack.app.core.config import settings
from pqnstack.app.core.config import state
from pqnstack.constants import BasisBool
from pqnstack.constants import QKDEncodingBasis
from pqnstack.network.client import Client
Expand All @@ -21,6 +21,7 @@
async def _qkd(
follower_node_address: str,
http_client: ClientDep,
state: StateDep,
timetagger_address: str | None = None,
) -> list[int]:
logger.debug("Starting QKD")
Expand Down Expand Up @@ -106,6 +107,7 @@ def get_outcome(state: int, basis: int, choice: int, counts: int) -> int:
async def qkd(
follower_node_address: str,
http_client: ClientDep,
state: StateDep,
timetagger_address: str | None = None,
) -> list[int]:
"""Perform a QKD protocol with the given follower node."""
Expand All @@ -116,11 +118,11 @@ async def qkd(
detail="QKD basis list is empty",
)

return await _qkd(follower_node_address, http_client, timetagger_address)
return await _qkd(follower_node_address, http_client, state, timetagger_address)


@router.post("/single_bit")
async def request_qkd_single_pass() -> bool:
async def request_qkd_single_pass(state: StateDep) -> bool:
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
hwp = client.get_device(settings.qkd_settings.request_hwp[0], settings.qkd_settings.request_hwp[1])

Expand Down Expand Up @@ -149,7 +151,7 @@ async def request_qkd_single_pass() -> bool:


@router.post("/request_basis_list")
def request_qkd_basis_list(leader_basis_list: list[str]) -> list[str]:
def request_qkd_basis_list(leader_basis_list: list[str], state: StateDep) -> list[str]:
"""Return the list of basis angles for QKD."""
# Check that lengths match
if len(leader_basis_list) != len(state.qkd_request_basis_list):
Expand Down
29 changes: 29 additions & 0 deletions src/pqnstack/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from functools import lru_cache

Expand Down Expand Up @@ -32,6 +33,7 @@ class QKDSettings(BaseModel):


class Settings(BaseSettings):
node_name: str = "node1"
router_name: str = "router1"
router_address: str = "localhost"
router_port: int = 5555
Expand Down Expand Up @@ -72,7 +74,28 @@ def get_settings() -> Settings:


class NodeState(BaseModel):
# Coordination state
# FIXME: Make sure we are checking for the client_listening_for_follower_requests state everywhere.
client_listening_for_follower_requests: bool = False

# Leader's state
leading: bool = False
followers_address: str = ""

# Follower's state
following: bool = False
# Other node requested this node to follow it.
following_requested: bool = False
# User's response to the follow request. None if no response yet, True if accepted, False if rejected.
following_requested_user_response: bool | None = None
# The address of the leader this node is following. None if not following anyone.
leaders_address: str = ""
leaders_name: str = ""

# CHSH state
chsh_request_basis: list[float] = [22.5, 67.5]

# QKD state
# FIXME: Use enums for this
qkd_basis_list: list[QKDEncodingBasis] = [
QKDEncodingBasis.DA,
Expand All @@ -94,3 +117,9 @@ class NodeState(BaseModel):


state = NodeState()
ask_user_for_follow_event = asyncio.Event()
user_replied_event = asyncio.Event()


def get_state() -> NodeState:
return state