Skip to content

Commit f0d95d2

Browse files
committed
Add utility class to support background server
Summary: A subclass of the aepsych server with methods specifically to run the server in a background process. This will be used to ensured that even within the same main script, the server will run like an actual server and does not do anything sneaky like bypassing the async queue. Test Plan: New test
1 parent 06e6fe6 commit f0d95d2

File tree

3 files changed

+168
-13
lines changed

3 files changed

+168
-13
lines changed

aepsych/server/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from .server import AEPsychServer
8+
from .server import AEPsychBackgroundServer, AEPsychServer
99

10-
__all__ = ["AEPsychServer"]
10+
__all__ = ["AEPsychServer", "AEPsychBackgroundServer"]

aepsych/server/server.py

+60-8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
replay,
3333
)
3434
from aepsych.strategy import SequentialStrategy, Strategy
35+
from multiprocess import Process
3536

3637
logger = utils_logging.getLogger()
3738

@@ -48,11 +49,9 @@ def __init__(
4849
host: str = "0.0.0.0",
4950
port: int = 5555,
5051
database_path: str = "./databases/default.db",
51-
max_workers: Optional[int] = None,
5252
):
5353
self.host = host
5454
self.port = port
55-
self.max_workers = max_workers
5655
self.clients_connected = 0
5756
self.db: db.Database = db.Database(database_path)
5857
self.is_performing_replay = False
@@ -278,11 +277,6 @@ def start_blocking(self) -> None:
278277
process or machine."""
279278
asyncio.run(self.serve())
280279

281-
def start_background(self):
282-
"""Starts the server in a background thread. Used for scripts where the
283-
client and server are in the same process."""
284-
raise NotImplementedError
285-
286280
async def serve(self) -> None:
287281
"""Serves the server on the set IP and port. This creates a coroutine
288282
for asyncio to handle requests asyncronously.
@@ -291,7 +285,7 @@ async def serve(self) -> None:
291285
self.handle_client, self.host, self.port
292286
)
293287
self.loop = asyncio.get_running_loop()
294-
pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers)
288+
pool = concurrent.futures.ThreadPoolExecutor()
295289
self.loop.set_default_executor(pool)
296290

297291
async with self.server:
@@ -427,6 +421,64 @@ def __getstate__(self):
427421
return state
428422

429423

424+
class AEPsychBackgroundServer(AEPsychServer):
425+
"""A class to handle the server in a background thread. Unlike the normal
426+
AEPsychServer, this does not create the db right away until the server is
427+
started. When starting this server, it'll be sent to another process, a db
428+
will be initialized, then the server will be served. This server should then
429+
be interacted with by the main thread via a client."""
430+
431+
def __init__(
432+
self,
433+
host: str = "0.0.0.0",
434+
port: int = 5555,
435+
database_path: str = "./databases/default.db",
436+
):
437+
self.host = host
438+
self.port = port
439+
self.database_path = database_path
440+
self.clients_connected = 0
441+
self.is_performing_replay = False
442+
self.exit_server_loop = False
443+
self._db_raw_record = None
444+
self.skip_computations = False
445+
self.background_process = None
446+
self.strat_names = None
447+
self.extensions = None
448+
self._strats = []
449+
self._parnames = []
450+
self._configs = []
451+
self._master_records = []
452+
self.strat_id = -1
453+
self.outcome_names = []
454+
455+
def _start_server(self) -> None:
456+
self.db: db.Database = db.Database(self.database_path)
457+
if self.db.is_update_required():
458+
self.db.perform_updates()
459+
460+
super().start_blocking()
461+
462+
def start(self):
463+
"""Starts the server in a background thread. Used by the client to start
464+
the server for a client in another process or machine."""
465+
self.background_process = Process(target=self._start_server, daemon=True)
466+
self.background_process.start()
467+
468+
def stop(self):
469+
"""Stops the server and closes the background process."""
470+
self.exit_server_loop = True
471+
self.background_process.terminate()
472+
self.background_process.join()
473+
self.background_process.close()
474+
self.background_process = None
475+
476+
def __getstate__(self):
477+
# Override parent's __getstate__ to not worry about the db
478+
state = self.__dict__.copy()
479+
return state
480+
481+
430482
def parse_argument():
431483
parser = argparse.ArgumentParser(description="AEPsych Server")
432484
parser.add_argument(

tests/server/test_server.py

+106-3
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
import asyncio
99
import json
10-
import logging
10+
import time
1111
import unittest
1212
import uuid
1313
from pathlib import Path
1414
from typing import Any, Dict
1515

1616
import aepsych.server as server
1717
import aepsych.utils_logging as utils_logging
18-
from aepsych.server.sockets import BAD_REQUEST
1918

2019
dummy_config = """
2120
[common]
@@ -87,7 +86,7 @@ async def asyncSetUp(self):
8786
self.port = 5555
8887

8988
# setup logger
90-
server.logger = utils_logging.getLogger("unittests")
89+
self.logger = utils_logging.getLogger("unittests")
9190

9291
# random datebase path name without dashes
9392
database_path = self.database_path
@@ -523,5 +522,109 @@ async def _mock_client2(request: Dict[str, Any]) -> Any:
523522
self.assertTrue(self.s.clients_connected == 2)
524523

525524

525+
class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase):
526+
@property
527+
def database_path(self):
528+
return "./{}_test_server.db".format(str(uuid.uuid4().hex))
529+
530+
async def asyncSetUp(self):
531+
self.ip = "127.0.0.1"
532+
self.port = 5555
533+
534+
# setup logger
535+
self.logger = utils_logging.getLogger("unittests")
536+
537+
# random datebase path name without dashes
538+
database_path = self.database_path
539+
self.s = server.AEPsychBackgroundServer(
540+
database_path=database_path, host=self.ip, port=self.port
541+
)
542+
self.db_name = database_path.split("/")[1]
543+
self.db_path = database_path
544+
545+
# Writer will be made in tests
546+
self.writer = None
547+
548+
async def asyncTearDown(self):
549+
# Stops the client
550+
if self.writer is not None:
551+
self.writer.close()
552+
553+
time.sleep(0.1)
554+
555+
# cleanup the db
556+
db_path = Path(self.db_path)
557+
try:
558+
print(db_path)
559+
db_path.unlink()
560+
except PermissionError as e:
561+
print("Failed to deleted database: ", e)
562+
563+
async def test_background_server(self):
564+
self.assertIsNone(self.s.background_process)
565+
self.s.start()
566+
self.assertTrue(self.s.background_process.is_alive())
567+
568+
# Give time for the server to start
569+
time.sleep(5)
570+
571+
# Create a client
572+
reader, self.writer = await asyncio.open_connection(self.ip, self.port)
573+
574+
async def _mock_client(request: Dict[str, Any]) -> Any:
575+
self.writer.write(json.dumps(request).encode())
576+
await self.writer.drain()
577+
578+
response = await reader.read(1024 * 512)
579+
return response.decode()
580+
581+
setup_request = {
582+
"type": "setup",
583+
"version": "0.01",
584+
"message": {"config_str": dummy_config},
585+
}
586+
ask_request = {"type": "ask", "message": ""}
587+
tell_request = {
588+
"type": "tell",
589+
"message": {"config": {"x": [0.5]}, "outcome": 1},
590+
"extra_info": {},
591+
}
592+
593+
await _mock_client(setup_request)
594+
595+
expected_x = [0, 1, 2, 3]
596+
expected_z = list(reversed(expected_x))
597+
expected_y = [x % 2 for x in expected_x]
598+
i = 0
599+
while True:
600+
response = await _mock_client(ask_request)
601+
response = json.loads(response)
602+
tell_request["message"]["config"]["x"] = [expected_x[i]]
603+
tell_request["message"]["config"]["z"] = [expected_z[i]]
604+
tell_request["message"]["outcome"] = expected_y[i]
605+
tell_request["extra_info"]["e1"] = 1
606+
tell_request["extra_info"]["e2"] = 2
607+
i = i + 1
608+
await _mock_client(tell_request)
609+
610+
if response["is_finished"]:
611+
break
612+
613+
self.s.stop()
614+
self.assertIsNone(self.s.background_process)
615+
616+
# Create a synchronous server to check db contents
617+
s = server.AEPsychServer(database_path=self.db_path)
618+
unique_id = s.db.get_master_records()[-1].unique_id
619+
out_df = s.get_dataframe_from_replay(unique_id)
620+
self.assertTrue((out_df.x == expected_x).all())
621+
self.assertTrue((out_df.z == expected_z).all())
622+
self.assertTrue((out_df.response == expected_y).all())
623+
self.assertTrue((out_df.e1 == [1] * 4).all())
624+
self.assertTrue((out_df.e2 == [2] * 4).all())
625+
self.assertTrue("post_mean" in out_df.columns)
626+
self.assertTrue("post_var" in out_df.columns)
627+
628+
526629
if __name__ == "__main__":
527630
unittest.main()

0 commit comments

Comments
 (0)