Skip to content

Commit ca0cb6a

Browse files
committed
Add utility class to support background server
Summary: A subclass of the aepsych server with methods specifically to run the server in a background process. This will be used to ensured that even within the same main script, the server will run like an actual server and does not do anything sneaky like bypassing the async queue. Test Plan: New test
1 parent 61f22b0 commit ca0cb6a

File tree

3 files changed

+176
-13
lines changed

3 files changed

+176
-13
lines changed

aepsych/server/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from .server import AEPsychServer
8+
from .server import AEPsychBackgroundServer, AEPsychServer
99

10-
__all__ = ["AEPsychServer"]
10+
__all__ = ["AEPsychServer", "AEPsychBackgroundServer"]

aepsych/server/server.py

+60-8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
replay,
3333
)
3434
from aepsych.strategy import SequentialStrategy, Strategy
35+
from multiprocess import Process
3536

3637
logger = utils_logging.getLogger()
3738

@@ -48,11 +49,9 @@ def __init__(
4849
host: str = "0.0.0.0",
4950
port: int = 5555,
5051
database_path: str = "./databases/default.db",
51-
max_workers: Optional[int] = None,
5252
):
5353
self.host = host
5454
self.port = port
55-
self.max_workers = max_workers
5655
self.clients_connected = 0
5756
self.db: db.Database = db.Database(database_path)
5857
self.is_performing_replay = False
@@ -278,11 +277,6 @@ def start_blocking(self) -> None:
278277
process or machine."""
279278
asyncio.run(self.serve())
280279

281-
def start_background(self):
282-
"""Starts the server in a background thread. Used for scripts where the
283-
client and server are in the same process."""
284-
raise NotImplementedError
285-
286280
async def serve(self) -> None:
287281
"""Serves the server on the set IP and port. This creates a coroutine
288282
for asyncio to handle requests asyncronously.
@@ -291,7 +285,7 @@ async def serve(self) -> None:
291285
self.handle_client, self.host, self.port
292286
)
293287
self.loop = asyncio.get_running_loop()
294-
pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers)
288+
pool = concurrent.futures.ThreadPoolExecutor()
295289
self.loop.set_default_executor(pool)
296290

297291
async with self.server:
@@ -427,6 +421,64 @@ def __getstate__(self):
427421
return state
428422

429423

424+
class AEPsychBackgroundServer(AEPsychServer):
425+
"""A class to handle the server in a background thread. Unlike the normal
426+
AEPsychServer, this does not create the db right away until the server is
427+
started. When starting this server, it'll be sent to another process, a db
428+
will be initialized, then the server will be served. This server should then
429+
be interacted with by the main thread via a client."""
430+
431+
def __init__(
432+
self,
433+
host: str = "0.0.0.0",
434+
port: int = 5555,
435+
database_path: str = "./databases/default.db",
436+
):
437+
self.host = host
438+
self.port = port
439+
self.database_path = database_path
440+
self.clients_connected = 0
441+
self.is_performing_replay = False
442+
self.exit_server_loop = False
443+
self._db_raw_record = None
444+
self.skip_computations = False
445+
self.background_process = None
446+
self.strat_names = None
447+
self.extensions = None
448+
self._strats = []
449+
self._parnames = []
450+
self._configs = []
451+
self._master_records = []
452+
self.strat_id = -1
453+
self.outcome_names = []
454+
455+
def _start_server(self) -> None:
456+
self.db: db.Database = db.Database(self.database_path)
457+
if self.db.is_update_required():
458+
self.db.perform_updates()
459+
460+
super().start_blocking()
461+
462+
def start(self):
463+
"""Starts the server in a background thread. Used by the client to start
464+
the server for a client in another process or machine."""
465+
self.background_process = Process(target=self._start_server, daemon=True)
466+
self.background_process.start()
467+
468+
def stop(self):
469+
"""Stops the server and closes the background process."""
470+
self.exit_server_loop = True
471+
self.background_process.terminate()
472+
self.background_process.join()
473+
self.background_process.close()
474+
self.background_process = None
475+
476+
def __getstate__(self):
477+
# Override parent's __getstate__ to not worry about the db
478+
state = self.__dict__.copy()
479+
return state
480+
481+
430482
def parse_argument():
431483
parser = argparse.ArgumentParser(description="AEPsych Server")
432484
parser.add_argument(

tests/server/test_server.py

+114-3
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
import asyncio
99
import json
10-
import logging
10+
import time
1111
import unittest
1212
import uuid
1313
from pathlib import Path
1414
from typing import Any, Dict
1515

1616
import aepsych.server as server
1717
import aepsych.utils_logging as utils_logging
18-
from aepsych.server.sockets import BAD_REQUEST
1918

2019
dummy_config = """
2120
[common]
@@ -87,7 +86,7 @@ async def asyncSetUp(self):
8786
self.port = 5555
8887

8988
# setup logger
90-
server.logger = utils_logging.getLogger("unittests")
89+
self.logger = utils_logging.getLogger("unittests")
9190

9291
# random datebase path name without dashes
9392
database_path = self.database_path
@@ -532,5 +531,117 @@ async def _mock_client2(request: Dict[str, Any]) -> Any:
532531
self.assertTrue(self.s.clients_connected == 2)
533532

534533

534+
class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase):
535+
@property
536+
def database_path(self):
537+
return "./{}_test_server.db".format(str(uuid.uuid4().hex))
538+
539+
async def asyncSetUp(self):
540+
self.ip = "127.0.0.1"
541+
self.port = 5555
542+
543+
# setup logger
544+
self.logger = utils_logging.getLogger("unittests")
545+
546+
# random datebase path name without dashes
547+
database_path = self.database_path
548+
self.s = server.AEPsychBackgroundServer(
549+
database_path=database_path, host=self.ip, port=self.port
550+
)
551+
self.db_name = database_path.split("/")[1]
552+
self.db_path = database_path
553+
554+
# Writer will be made in tests
555+
self.writer = None
556+
557+
async def asyncTearDown(self):
558+
# Stops the client
559+
if self.writer is not None:
560+
self.writer.close()
561+
562+
time.sleep(0.1)
563+
564+
# cleanup the db
565+
db_path = Path(self.db_path)
566+
try:
567+
print(db_path)
568+
db_path.unlink()
569+
except PermissionError as e:
570+
print("Failed to deleted database: ", e)
571+
572+
async def test_background_server(self):
573+
self.assertIsNone(self.s.background_process)
574+
self.s.start()
575+
self.assertTrue(self.s.background_process.is_alive())
576+
577+
# Make a client
578+
try_again = True
579+
attempts = 0
580+
while try_again:
581+
try_again = False
582+
attempts += 1
583+
try:
584+
reader, self.writer = await asyncio.open_connection(self.ip, self.port)
585+
except ConnectionRefusedError:
586+
if attempts > 10:
587+
raise ConnectionRefusedError
588+
try_again = True
589+
time.sleep(1)
590+
591+
async def _mock_client(request: Dict[str, Any]) -> Any:
592+
self.writer.write(json.dumps(request).encode())
593+
await self.writer.drain()
594+
595+
response = await reader.read(1024 * 512)
596+
return response.decode()
597+
598+
setup_request = {
599+
"type": "setup",
600+
"version": "0.01",
601+
"message": {"config_str": dummy_config},
602+
}
603+
ask_request = {"type": "ask", "message": ""}
604+
tell_request = {
605+
"type": "tell",
606+
"message": {"config": {"x": [0.5]}, "outcome": 1},
607+
"extra_info": {},
608+
}
609+
610+
await _mock_client(setup_request)
611+
612+
expected_x = [0, 1, 2, 3]
613+
expected_z = list(reversed(expected_x))
614+
expected_y = [x % 2 for x in expected_x]
615+
i = 0
616+
while True:
617+
response = await _mock_client(ask_request)
618+
response = json.loads(response)
619+
tell_request["message"]["config"]["x"] = [expected_x[i]]
620+
tell_request["message"]["config"]["z"] = [expected_z[i]]
621+
tell_request["message"]["outcome"] = expected_y[i]
622+
tell_request["extra_info"]["e1"] = 1
623+
tell_request["extra_info"]["e2"] = 2
624+
i = i + 1
625+
await _mock_client(tell_request)
626+
627+
if response["is_finished"]:
628+
break
629+
630+
self.s.stop()
631+
self.assertIsNone(self.s.background_process)
632+
633+
# Create a synchronous server to check db contents
634+
s = server.AEPsychServer(database_path=self.db_path)
635+
unique_id = s.db.get_master_records()[-1].unique_id
636+
out_df = s.get_dataframe_from_replay(unique_id)
637+
self.assertTrue((out_df.x == expected_x).all())
638+
self.assertTrue((out_df.z == expected_z).all())
639+
self.assertTrue((out_df.response == expected_y).all())
640+
self.assertTrue((out_df.e1 == [1] * 4).all())
641+
self.assertTrue((out_df.e2 == [2] * 4).all())
642+
self.assertTrue("post_mean" in out_df.columns)
643+
self.assertTrue("post_var" in out_df.columns)
644+
645+
535646
if __name__ == "__main__":
536647
unittest.main()

0 commit comments

Comments
 (0)