From e24192d690703e1ba29ba84f927f02f0c892c38d Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 1/8] rewrite server with asyncio Summary: Rewrite the server to use asyncio and be multithreaded. Allows handling multiple clients. Test Plan: Tests updated. Later test also tests multiple db writes simultaneously. --- aepsych/database/db.py | 7 +- aepsych/server/server.py | 568 +++++++++--------- clients/python/tests/test_client.py | 2 - tests/models/test_pairwise_probit.py | 2 - .../message_handlers/test_ask_handlers.py | 4 +- .../server/message_handlers/test_can_model.py | 4 +- .../message_handlers/test_handle_exit.py | 26 +- .../test_handle_finish_strategy.py | 4 +- .../test_handle_get_config.py | 4 +- .../message_handlers/test_query_handlers.py | 4 +- .../message_handlers/test_tell_handlers.py | 4 +- tests/server/test_server.py | 252 ++++---- tests/test_datafetcher.py | 7 +- tests/test_db.py | 5 - 14 files changed, 460 insertions(+), 433 deletions(-) diff --git a/aepsych/database/db.py b/aepsych/database/db.py index c9c9cc65d..421b859b2 100644 --- a/aepsych/database/db.py +++ b/aepsych/database/db.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import datetime +import io import json import logging import os @@ -440,12 +441,14 @@ def record_outcome( self._session.add(outcome_entry) self._session.commit() - def record_strat(self, master_table: tables.DBMasterTable, strat: Strategy) -> None: + def record_strat( + self, master_table: tables.DBMasterTable, strat: io.BytesIO + ) -> None: """Record a strategy in the database. Args: master_table (tables.DBMasterTable): The master table. - strat (Strategy): The strategy. + strat (BytesIO): The strategy in buffer form. """ strat_entry = tables.DbStratTable() strat_entry.strat = strat diff --git a/aepsych/server/server.py b/aepsych/server/server.py index d0a16ba92..30ca22335 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import argparse +import asyncio +import concurrent import io +import json import logging import os -import sys -import threading import traceback import warnings -from typing import Dict, Union +from typing import Any, Dict, List, Optional, Union -import aepsych.database.db as db -import aepsych.utils_logging as utils_logging import dill import numpy as np +import pandas as pd import torch -from aepsych import version +from aepsych import utils_logging, version +from aepsych.config import Config +from aepsych.database import db +from aepsych.database.tables import DBMasterTable from aepsych.server.message_handlers import MESSAGE_MAP -from aepsych.server.message_handlers.handle_ask import ask from aepsych.server.message_handlers.handle_setup import configure from aepsych.server.replay import ( get_dataframe_from_replay, @@ -30,12 +31,9 @@ get_strats_from_replay, replay, ) -from aepsych.server.sockets import BAD_REQUEST, DummySocket, PySocket -from aepsych.utils import promote_0d +from aepsych.strategy import SequentialStrategy, Strategy logger = utils_logging.getLogger(logging.INFO) -DEFAULT_DESC = "default description" -DEFAULT_NAME = "default name" def get_next_filename(folder, fname, ext): @@ -44,191 +42,84 @@ def get_next_filename(folder, fname, ext): return f"{folder}/{fname}_{n + 1}.{ext}" -class AEPsychServer(object): - def __init__(self, socket=None, database_path=None): - """Server for doing black box optimization using gaussian processes. - Keyword Arguments: - socket -- socket object that implements `send` and `receive` for json - messages (default: DummySocket()). - TODO actually make an abstract interface to subclass from here - """ - if socket is None: - self.socket = DummySocket() - else: - self.socket = socket - self.db = None +class AEPsychServer: + def __init__( + self, + host: str = "0.0.0.0", + port: int = 5555, + database_path: str = "./databases/default.db", + max_workers: Optional[int] = None, + ): + self.host = host + self.port = port + self.max_workers = max_workers + self.db: db.Database = db.Database(database_path) self.is_performing_replay = False self.exit_server_loop = False self._db_raw_record = None - self.db: db.Database = db.Database(database_path) self.skip_computations = False self.strat_names = None self.extensions = None + self._strats: List[SequentialStrategy] = [] + self._parnames: List[List[str]] = [] + self._configs: List[Config] = [] + self._master_records: List[DBMasterTable] = [] + self.strat_id = -1 + self.outcome_names: List[str] = [] if self.db.is_update_required(): self.db.perform_updates() - self._strats = [] - self._parnames = [] - self._configs = [] - self._master_records = [] - self.strat_id = -1 - self._pregen_asks = [] - self.enable_pregen = False - self.outcome_names = [] - - self.debug = False - self.receive_thread = threading.Thread( - target=self._receive_send, args=(self.exit_server_loop,), daemon=True - ) - - self.queue = [] - - def cleanup(self): - """Close the socket and terminate connection to the server. - - Returns: - None - """ - self.socket.close() - - def _receive_send(self, is_exiting: bool) -> None: - """Receive messages from the client. - - Args: - is_exiting (bool): True to terminate reception of new messages from the client, False otherwise. - - Returns: - None - """ - while True: - request = self.socket.receive(is_exiting) - if request != BAD_REQUEST: - self.queue.append(request) - if self.exit_server_loop: - break - logger.info("Terminated input thread") - - def _handle_queue(self) -> None: - """Handles the queue of messages received by the server. - - Returns: - None - """ - if self.queue: - request = self.queue.pop(0) - try: - result = self.handle_request(request) - except Exception as e: - error_message = f"Request '{request}' raised error '{e}'!" - result = f"server_error, {error_message}" - logger.error(f"{error_message}! Full traceback follows:") - logger.error(traceback.format_exc()) - self.socket.send(result) - else: - if self.can_pregen_ask and (len(self._pregen_asks) == 0): - self._pregen_asks.append(ask(self)) - - def serve(self) -> None: - """Run the server. Note that all configuration outside of socket type and port - happens via messages from the client. The server simply forwards messages from - the client to its `setup`, `ask` and `tell` methods, and responds with either - acknowledgment or other response as needed. To understand the server API, see - the docs on the methods in this class. - - Returns: - None - - Raises: - RuntimeError: if a request from a client has no request type - RuntimeError: if a request from a client has no known request type - TODO make things a little more robust to bad messages from client; this - requires resetting the req/rep queue status. - - """ - logger.info("Server up, waiting for connections!") - logger.info("Ctrl-C to quit!") - # yeah we're not sanitizing input at all - - # Start the method to accept a client connection - self.socket.accept_client() - self.receive_thread.start() - while True: - self._handle_queue() - if self.exit_server_loop: - break - # Close the socket and terminate with code 0 - self.cleanup() - sys.exit(0) - - def _unpack_strat_buffer(self, strat_buffer): - if isinstance(strat_buffer, io.BytesIO): - strat = torch.load(strat_buffer, pickle_module=dill) - strat_buffer.seek(0) - elif isinstance(strat_buffer, bytes): - warnings.warn( - "Strat buffer is not in bytes format!" - + " This is a deprecated format, loading using dill.loads.", - DeprecationWarning, - ) - strat = dill.loads(strat_buffer) - else: - raise RuntimeError("Trying to load strat in unknown format!") - return strat - - ### Properties that are set on a per-strat basis + #### Properties #### @property - def strat(self): + def strat(self) -> Optional[SequentialStrategy]: if self.strat_id == -1: return None else: return self._strats[self.strat_id] @strat.setter - def strat(self, s): + def strat(self, s: SequentialStrategy): self._strats.append(s) @property - def config(self): + def config(self) -> Optional[Config]: if self.strat_id == -1: return None else: return self._configs[self.strat_id] @config.setter - def config(self, s): + def config(self, s: Config): self._configs.append(s) @property - def parnames(self): + def parnames(self) -> List[str]: if self.strat_id == -1: return [] else: return self._parnames[self.strat_id] @parnames.setter - def parnames(self, s): + def parnames(self, s: List[str]): self._parnames.append(s) @property - def _db_master_record(self): + def _db_master_record(self) -> Optional[DBMasterTable]: if self.strat_id == -1: return None else: return self._master_records[self.strat_id] @_db_master_record.setter - def _db_master_record(self, s): + def _db_master_record(self, s: DBMasterTable): self._master_records.append(s) @property - def n_strats(self): + def n_strats(self) -> int: return len(self._strats) - @property - def can_pregen_ask(self): - return self.strat is not None and self.enable_pregen - + #### Methods to handle parameter configs #### def _tensor_to_config(self, next_x): stim_per_trial = self.strat.stimuli_per_trial dim = self.strat.dim @@ -280,8 +171,11 @@ def _config_to_tensor(self, config): return x - def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): + def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]) -> Dict[int, Any]: # Given a dictionary of fixed parameters, turn the parameters names into indices + if self.strat is None: + raise ValueError("No strategy is set, cannot convert fixed parameters.") + dummy = np.zeros(len(self.parnames)).astype("O") for key, value in fixed.items(): idx = self.parnames.index(key) @@ -297,14 +191,211 @@ def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): return fixed_features - def __getstate__(self): - # nuke the socket since it's not pickleble - state = self.__dict__.copy() - del state["socket"] - del state["db"] - return state + #### Methods to handle replay #### + def replay(self, uuid_to_replay: int, skip_computations: bool = False) -> None: + """Replay an experiment with a specific unique ID. This will leave the + server state at the end of the replay. + + Args: + uuid_to_replay (int): Unique ID of the experiment to replay. This is + the primary key of the experiment's master table. + skip_computations (bool): If True, skip computations during the replay. + Defaults to False. + """ + return replay(self, uuid_to_replay, skip_computations) + + def get_strats_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> List[Strategy]: + """Replay an experiment then return the strategies from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + List[Union[SequentialStrategy, Strategy]]: List of strategies from + the replay. + """ + return get_strats_from_replay(self, uuid_of_replay, force_replay) + + def get_strat_from_replay( + self, uuid_of_replay: Optional[int] = None, strat_id: int = -1 + ) -> Strategy: + """Replay an experiment then return a strategy from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + strat_id (int): ID of the strategy to return. Defaults to -1, which + returns the last strategy. + + Returns: + Strategy: The strategy from the replay. + """ + return get_strat_from_replay(self, uuid_of_replay, strat_id) + + def get_dataframe_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> pd.DataFrame: + """Replay an experiment then return the dataframe from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + pd.DataFrame: Dataframe from the replay. + """ + return get_dataframe_from_replay(self, uuid_of_replay, force_replay) + + def _unpack_strat_buffer(self, strat_buffer): + # Unpacks a strategy buffer from the database. + if isinstance(strat_buffer, io.BytesIO): + strat = torch.load(strat_buffer, pickle_module=dill) + strat_buffer.seek(0) + elif isinstance(strat_buffer, bytes): + warnings.warn( + "Strat buffer is not in bytes format!" + + " This is a deprecated format, loading using dill.loads.", + DeprecationWarning, + ) + strat = dill.loads(strat_buffer) + else: + raise RuntimeError("Trying to load strat in unknown format!") + return strat - def write_strats(self, termination_type): + #### Method to handle async server #### + def start_blocking(self) -> None: + """Starts the server in a blocking state in the main thread. Used by the + command line interface to start the server for a client in another + process or machine.""" + asyncio.run(self.serve()) + + def start_background(self): + """Starts the server in a background thread. Used for scripts where the + client and server are in the same process.""" + raise NotImplementedError + + async def serve(self) -> None: + """Serves the server on the set IP and port. This creates a coroutine + for asyncio to handle requests asyncronously. + """ + self.server = await asyncio.start_server( + self.handle_client, self.host, self.port + ) + self.loop = asyncio.get_running_loop() + pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) + self.loop.set_default_executor(pool) + + async with self.server: + logging.info(f"Serving on {self.host}:{self.port}") + try: + await self.server.serve_forever() + except asyncio.CancelledError: + raise + except KeyboardInterrupt: + exception_type = "CTRL+C" + dump_type = "dump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + except RuntimeError as e: + exception_type = "RuntimeError" + dump_type = "crashdump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + raise RuntimeError(e) + + async def handle_client(self, reader, writer): + """Coroutine for handling a client connection. This will read messages + from the connected client and dispatch a task to handle the request on + another thread such that its blocking state does not block the server. + This coroutine will end if the client closes the connection. + + Args: + reader: asyncio.StreamReader: The stream reader for the client. + writer: asyncio.StreamWriter: The stream writer for the client. + """ + addr = writer.get_extra_info("peername") + logger.info(f"Connected to {addr}") + + try: + while True: + if self.exit_server_loop: + self.server.close() + break + rcv = await reader.read(1024 * 512) + try: + message = json.loads(rcv) + except UnicodeDecodeError as e: + logger.error(f"Malformed message: {rcv}") + logger.error(traceback.format_exc()) + result = {"error": str(e)} + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + continue + + future = self.loop.run_in_executor(None, self.handle_request, message) + try: + result = await future + except Exception as e: + logger.error(f"Error handling message: {message}") + logger.error(traceback.format_exc()) + # Some exceptions turned into string are meaningless, so we use repr + result = {"error": e.__repr__()} + if isinstance(result, dict): + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + else: + writer.write(str(result).encode()) + + await writer.drain() + except asyncio.CancelledError: + pass + finally: + logger.info(f"Connection closed for {addr}") + writer.close() + await writer.wait_closed() + + def handle_request(self, message: Dict[str, Any]) -> Union[Dict[str, Any], str]: + """Given a message, dispatch the correct handler and return the result. + + Args: + message (Dict[str, Any]): The message to handle. + + Returns: + Union[Dict[str, Any], str]: The result of handling the message. + """ + type_ = message["type"] + result = MESSAGE_MAP[type_](self, message) + return result + + def _simplify_arrays(self, message): + # Simplify arrays for encoding and sending a message to the client + return { + k: ( + v.tolist() + if type(v) == np.ndarray + else self._simplify_arrays(v) + if type(v) is dict + else v + ) + for k, v in message.items() + } + + #### Methods to handle exiting #### + def write_strats(self, termination_type: str) -> None: + """Pickle the stats and records them into the database. + + Args: + termination_type (str): The type of termination. This only affects + the log message. + """ if self._db_master_record is not None and self.strat is not None: logger.info(f"Dumping strats to DB due to {termination_type}.") for strat in self._strats: @@ -313,77 +404,28 @@ def write_strats(self, termination_type): buffer.seek(0) self.db.record_strat(master_table=self._db_master_record, strat=buffer) - def generate_debug_info(self, exception_type, dumptype): + def generate_debug_info(self, exception_type: str, dumptype: str) -> None: + """Generate a debug info file for the server. This will pickle the server + and save it to a file. + + Args: + exception_type (str): The type of exception that caused the server + to terminate. This only affects the log message. + dump_type (str): The type of dump. This only affects the log file. + """ fname = get_next_filename(".", dumptype, "pkl") logger.exception(f"Got {exception_type}, exiting! Server dump in {fname}") dill.dump(self, open(fname, "wb")) - def handle_request(self, request): - if "type" not in request.keys(): - raise RuntimeError(f"Request {request} contains no request type!") - else: - type = request["type"] - if type in MESSAGE_MAP.keys(): - logger.info(f"Received msg [{type}]") - ret_val = MESSAGE_MAP[type](self, request) - return ret_val - - else: - exception_message = ( - f"unknown type: {type}. Allowed types [{MESSAGE_MAP.keys()}]" - ) - - raise RuntimeError(exception_message) - - def replay(self, uuid_to_replay, skip_computations=False): - return replay(self, uuid_to_replay, skip_computations) - - def get_strats_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_strats_from_replay(self, uuid_of_replay, force_replay) - - def get_strat_from_replay(self, uuid_of_replay=None, strat_id=-1): - return get_strat_from_replay(self, uuid_of_replay, strat_id) - - def get_dataframe_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_dataframe_from_replay(self, uuid_of_replay, force_replay) - - -#! THIS IS WHAT START THE SERVER -def startServerAndRun( - server_class, socket=None, database_path=None, config_path=None, id_of_replay=None -): - server = server_class(socket=socket, database_path=database_path) - try: - if config_path is not None: - with open(config_path) as f: - config_str = f.read() - configure(server, config_str=config_str) - - if socket is not None: - if id_of_replay is not None: - server.replay(id_of_replay, skip_computations=True) - server.serve() - else: - if config_path is not None: - logger.info( - "You have passed in a config path but this is a replay. If there's a config in the database it will be used instead of the passed in config path." - ) - server.replay(id_of_replay) - except KeyboardInterrupt: - exception_type = "CTRL+C" - dump_type = "dump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - except RuntimeError as e: - exception_type = "RuntimeError" - dump_type = "crashdump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - raise RuntimeError(e) + def __getstate__(self): + # Called when the server is pickled, we can't pickle the DB. + state = self.__dict__.copy() + del state["db"] + return state def parse_argument(): - parser = argparse.ArgumentParser(description="AEPsych Server!") + parser = argparse.ArgumentParser(description="AEPsych Server") parser.add_argument( "--port", metavar="N", type=int, default=5555, help="port to serve on" ) @@ -415,72 +457,54 @@ def parse_argument(): "--db", type=str, help="The database to use if not the default (./databases/default.db).", - default=None, - ) - - parser.add_argument( - "-r", "--replay", type=str, help="Unique id of the experiment to replay." + default="./databases/default.db", ) parser.add_argument( - "-m", "--resume", action="store_true", help="Resume server after replay." + "-r", + "--resume", + type=str, + help="Unique id of the experiment to replay and resume the server from.", ) args = parser.parse_args() return args -def start_server(server_class, args): - logger.info("Starting the AEPsychServer") +def main(): + logger = utils_logging.getLogger() + logger.info("Starting AEPsychServer") logger.info(f"AEPsych Version: {version.__version__}") - try: - if "db" in args and args.db is not None: - database_path = args.db - if "replay" in args and args.replay is not None: - logger.info(f"Attempting to replay {args.replay}") - if args.resume is True: - sock = PySocket(port=args.port) - logger.info(f"Will resume {args.replay}") - else: - sock = None - startServerAndRun( - server_class, - socket=sock, - database_path=database_path, - uuid_of_replay=args.replay, - config_path=args.stratconfig, - ) - else: - logger.info(f"Setting the database path {database_path}") - sock = PySocket(port=args.port) - startServerAndRun( - server_class, - database_path=database_path, - socket=sock, - config_path=args.stratconfig, - ) - else: - sock = PySocket(port=args.port) - startServerAndRun(server_class, socket=sock, config_path=args.stratconfig) - except (KeyboardInterrupt, SystemExit): - logger.exception("Got Ctrl+C, exiting!") - sys.exit() - except RuntimeError as e: - fname = get_next_filename(".", "dump", "pkl") - logger.exception(f"CRASHING!! dump in {fname}") - raise RuntimeError(e) - - -def main(server_class=AEPsychServer): args = parse_argument() if args.logs: # overide logger path log_path = args.logs - logger = utils_logging.getLogger(logging.DEBUG, log_path) - logger.info(f"Saving logs to path: {log_path}") - start_server(server_class, args) + logger = utils_logging.getLogger(log_path) + logger.info(f"Saving logs to path: {log_path}") + + server = AEPsychServer( + host=args.ip, + port=args.port, + database_path=args.db, + ) + + if args.stratconfig is not None and args.resume is not None: + raise ValueError( + "Cannot configure the server with a config file and a resume from a replay at the same time." + ) + + elif args.stratconfig is not None: + configure(server, config_str=args.stratconfig) + + elif args.resume is not None: + if args.db is None: + raise ValueError("Cannot resume from a replay if no database is given.") + server.replay(args.resume, skip_computations=True) + + # Starts the server in a blocking state + server.start_blocking() if __name__ == "__main__": - main(AEPsychServer) + main() diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 494e5d72d..96425b481 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -34,8 +34,6 @@ def setUp(self): ) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py index 9b2b6c138..359c040f3 100644 --- a/tests/models/test_pairwise_probit.py +++ b/tests/models/test_pairwise_probit.py @@ -501,8 +501,6 @@ def setUp(self): self.s = server.AEPsychServer(database_path=database_path) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/server/message_handlers/test_ask_handlers.py b/tests/server/message_handlers/test_ask_handlers.py index 9d3fa5c6b..30d773f90 100644 --- a/tests/server/message_handlers/test_ask_handlers.py +++ b/tests/server/message_handlers/test_ask_handlers.py @@ -8,7 +8,7 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase dummy_config = """ [common] @@ -69,7 +69,7 @@ """ -class AskHandlerTestCase(BaseServerTestCase): +class AskHandlerTestCase(AsyncServerTestBase): def test_handle_ask(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_can_model.py b/tests/server/message_handlers/test_can_model.py index 01b3c4b8a..6f5b090e0 100644 --- a/tests/server/message_handlers/test_can_model.py +++ b/tests/server/message_handlers/test_can_model.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class StratCanModelTestCase(BaseServerTestCase): +class StratCanModelTestCase(AsyncServerTestBase): def test_strat_can_model(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_exit.py b/tests/server/message_handlers/test_handle_exit.py index 5b548bc25..bd8ff89b7 100644 --- a/tests/server/message_handlers/test_handle_exit.py +++ b/tests/server/message_handlers/test_handle_exit.py @@ -5,24 +5,30 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import unittest -from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): - def test_handle_exit(self): +class HandleExitTestCase(AsyncServerTestBase): + async def test_handle_exit(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + + await self.mock_client(setup_request) + request = {} request["type"] = "exit" - self.s.socket.accept_client = MagicMock() - self.s.socket.receive = MagicMock(return_value=request) - self.s.dump = MagicMock() + await self.mock_client(request) - with self.assertRaises(SystemExit) as cm: - self.s.serve() + with self.assertRaises(ConnectionRefusedError): + await asyncio.open_connection(self.s.host, self.s.port) - self.assertEqual(cm.exception.code, 0) + self.assertTrue(self.s.exit_server_loop) if __name__ == "__main__": diff --git a/tests/server/message_handlers/test_handle_finish_strategy.py b/tests/server/message_handlers/test_handle_finish_strategy.py index 9efffdb20..729421bdb 100644 --- a/tests/server/message_handlers/test_handle_finish_strategy.py +++ b/tests/server/message_handlers/test_handle_finish_strategy.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class ResumeTestCase(BaseServerTestCase): +class ResumeTestCase(AsyncServerTestBase): def test_handle_finish_strategy(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_get_config.py b/tests/server/message_handlers/test_handle_get_config.py index d79c0697f..b173d22e5 100644 --- a/tests/server/message_handlers/test_handle_get_config.py +++ b/tests/server/message_handlers/test_handle_get_config.py @@ -9,10 +9,10 @@ from aepsych.config import Config -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): +class HandleExitTestCase(AsyncServerTestBase): def test_get_config(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_query_handlers.py b/tests/server/message_handlers/test_query_handlers.py index 2f8eaff2a..3ff9f618e 100644 --- a/tests/server/message_handlers/test_query_handlers.py +++ b/tests/server/message_handlers/test_query_handlers.py @@ -7,12 +7,12 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase # Smoke test to make sure nothing breaks. This should really be combined with # the individual query tests -class QueryHandlerTestCase(BaseServerTestCase): +class QueryHandlerTestCase(AsyncServerTestBase): def test_strat_query(self): # Annoying and complex model and output shapes config_str = """ diff --git a/tests/server/message_handlers/test_tell_handlers.py b/tests/server/message_handlers/test_tell_handlers.py index 4128b4ed6..7f68e84f5 100644 --- a/tests/server/message_handlers/test_tell_handlers.py +++ b/tests/server/message_handlers/test_tell_handlers.py @@ -9,10 +9,10 @@ import unittest from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class MessageHandlerTellTests(BaseServerTestCase): +class MessageHandlerTellTests(AsyncServerTestBase): def test_tell(self): setup_request = { "type": "setup", diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 57d59da50..6c5617896 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -5,14 +5,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import json import logging -import select import time import unittest import uuid from pathlib import Path -from unittest.mock import MagicMock +from typing import Any, Dict import aepsych.server as server import aepsych.utils_logging as utils_logging @@ -77,28 +77,52 @@ """ -class BaseServerTestCase(unittest.TestCase): - # so that this can be overridden for tests that require specific databases. +class AsyncServerTestBase(unittest.IsolatedAsyncioTestCase): @property def database_path(self): return "./{}_test_server.db".format(str(uuid.uuid4().hex)) - def setUp(self): + async def asyncSetUp(self): + self.ip = "127.0.0.1" + self.port = 5555 + # setup logger server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random port - socket = server.sockets.PySocket(port=0) + # random datebase path name without dashes database_path = self.database_path - self.s = server.AEPsychServer(socket=socket, database_path=database_path) + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) self.db_name = database_path.split("/")[1] self.db_path = database_path - def tearDown(self): - self.s.cleanup() + try: + self.server_task = asyncio.create_task(self.s.serve()) + except OSError: + # Try 0.0.0.0 after waiting + time.sleep(5) + self.ip = "0.0.0.0" + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) + self.server_task = asyncio.create_task(self.s.serve()) + await asyncio.sleep(0.1) + + self.reader, self.writer = await asyncio.open_connection(self.ip, self.port) + + async def asyncTearDown(self): + # Stops the client + self.writer.close() - # sleep to ensure db is closed - time.sleep(0.2) + # Stops the server + self.server_task.cancel() + try: + await self.server_task + except asyncio.CancelledError: + pass + + await asyncio.sleep(0.2) # cleanup the db if self.s.db is not None: @@ -107,46 +131,18 @@ def tearDown(self): except PermissionError as e: print("Failed to deleted database: ", e) - def dummy_create_setup(self, server, request=None): - request = request or {"test": "test request"} - server._db_master_record = server.db.record_setup( - description="default description", name="default name", request=request - ) + async def mock_client(self, request: Dict[str, Any]) -> Any: + self.writer.write(json.dumps(request).encode()) + await self.writer.drain() + response = await self.reader.read(1024 * 512) + return response.decode() -class ServerTestCase(BaseServerTestCase): - def test_final_strat_serialization(self): - setup_request = { - "type": "setup", - "version": "0.01", - "message": {"config_str": dummy_config}, - } - ask_request = {"type": "ask", "message": ""} - tell_request = { - "type": "tell", - "message": {"config": {"x": [0.5]}, "outcome": 1}, - } - self.s.handle_request(setup_request) - while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) - unique_id = self.s.db.get_master_records()[-1].unique_id - stored_strat = self.s.get_strat_from_replay(unique_id) - # just some spot checks that the strat's the same - # same data. We do this twice to make sure buffers are - # in a good state and we can load twice without crashing - for _ in range(2): - stored_strat = self.s.get_strat_from_replay(unique_id) - self.assertTrue((stored_strat.x == self.s.strat.x).all()) - self.assertTrue((stored_strat.y == self.s.strat.y).all()) - # same lengthscale and outputscale - self.assertEqual( - stored_strat.model.covar_module.lengthscale, - self.s.strat.model.covar_module.lengthscale, - ) +class AsyncServerTestCase(AsyncServerTestBase): + """Server functions are all async""" - def test_pandadf_dump_single(self): + async def test_pandadf_dump_single(self): setup_request = { "type": "setup", "version": "0.01", @@ -158,20 +154,22 @@ def test_pandadf_dump_single(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + + await self.mock_client(setup_request) + expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -183,7 +181,38 @@ def test_pandadf_dump_single(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_multistrat(self): + async def test_final_strat_serialization(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + } + await self.mock_client(setup_request) + while not self.s.strat.finished: + await self.mock_client(ask_request) + await self.mock_client(tell_request) + + unique_id = self.s.db.get_master_records()[-1].unique_id + stored_strat = self.s.get_strat_from_replay(unique_id) + # just some spot checks that the strat's the same + # same data. We do this twice to make sure buffers are + # in a good state and we can load twice without crashing + for _ in range(2): + stored_strat = self.s.get_strat_from_replay(unique_id) + self.assertTrue((stored_strat.x == self.s.strat.x).all()) + self.assertTrue((stored_strat.y == self.s.strat.y).all()) + # same lengthscale and outputscale + self.assertEqual( + stored_strat.model.covar_module.lengthscale, + self.s.strat.model.covar_module.lengthscale, + ) + + async def test_pandadf_dump_multistrat(self): setup_request = { "type": "setup", "version": "0.01", @@ -199,16 +228,16 @@ def test_pandadf_dump_multistrat(self): expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -221,7 +250,7 @@ def test_pandadf_dump_multistrat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_flat(self): + async def test_pandadf_dump_flat(self): """ This test handles the case where the config values are flat scalars and not lists @@ -237,20 +266,20 @@ def test_pandadf_dump_flat(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = expected_x[i] tell_request["message"]["config"]["z"] = expected_z[i] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -262,52 +291,7 @@ def test_pandadf_dump_flat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_receive(self): - """test_receive - verifies the receive is working when server receives unexpected messages""" - - message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message - message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message - message3 = {"message": {"target": "test request"}} # valid message - message_list = [message1, message2, json.dumps(message3)] - - self.s.socket.conn = MagicMock() - - for i, message in enumerate(message_list): - select.select = MagicMock(return_value=[[self.s.socket.conn], [], []]) - self.s.socket.conn.recv = MagicMock(return_value=message) - if i != 2: - self.assertEqual(self.s.socket.receive(False), BAD_REQUEST) - else: - self.assertEqual(self.s.socket.receive(False), message3) - - def test_error_handling(self): - # double brace escapes, single brace to substitute, so we end up with 3 braces - request = f"{{{BAD_REQUEST}}}" - - expected_error = f"server_error, Request '{request}' raised error ''str' object has no attribute 'keys''!" - - self.s.socket.accept_client = MagicMock() - - self.s.socket.receive = MagicMock(return_value=request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - self.s.socket.send.assert_called_once_with(expected_error) - - def test_queue(self): - """Test to see that the queue is being handled correctly""" - - self.s.socket.accept_client = MagicMock() - ask_request = {"type": "ask", "message": ""} - self.s.socket.receive = MagicMock(return_value=ask_request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - assert len(self.s.queue) == 0 - - def test_replay(self): + async def test_replay(self): exp_config = """ [common] lb = [0] @@ -341,15 +325,14 @@ def test_replay(self): } exit_request = {"message": "", "type": "exit"} - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) - self.s.handle_request(exit_request) + await self.mock_client(exit_request) - socket = server.sockets.PySocket(port=0) - serv = server.AEPsychServer(socket=socket, database_path=self.db_path) + serv = server.AEPsychServer(database_path=self.db_path) exp_ids = [rec.unique_id for rec in serv.db.get_master_records()] serv.replay(exp_ids[-1], skip_computations=True) @@ -359,7 +342,7 @@ def test_replay(self): self.assertTrue(strat.finished) self.assertTrue(strat.x.shape[0] == 4) - def test_string_parameter(self): + async def test_string_parameter(self): string_config = """ [common] parnames = [x, y, z] @@ -405,16 +388,17 @@ def test_string_parameter(self): "type": "tell", "message": {"config": {"x": [0.5], "y": ["blue"], "z": [50]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - response = self.s.handle_request(ask_request) + response = await self.mock_client(ask_request) + response = json.loads(response) self.assertTrue(response["config"]["y"][0] == "blue") - self.s.handle_request(tell_request) + await self.mock_client(tell_request) self.assertTrue(len(self.s.strat.lb) == 2) self.assertTrue(len(self.s.strat.ub) == 2) - def test_metadata(self): + async def test_metadata(self): setup_request = { "type": "setup", "version": "0.01", @@ -425,10 +409,10 @@ def test_metadata(self): "type": "tell", "message": {"config": {"x": [0.5]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) master_record = self.s.db.get_master_records()[-1] extra_metadata = json.loads(master_record.extra_metadata) @@ -443,7 +427,7 @@ def test_metadata(self): self.assertTrue(extra_metadata["extra"] == "data that is arbitrary") self.assertTrue("experiment_id" not in extra_metadata) - def test_extension_server(self): + async def test_extension_server(self): extension_path = Path(__file__).parent.parent.parent extension_path = extension_path / "extensions_example" / "new_objects.py" @@ -470,8 +454,8 @@ def test_extension_server(self): "message": {"config_str": config_str}, } - with self.assertLogs(level=logging.INFO) as logs: - self.s.handle_request(setup_request) + with self.assertLogs() as logs: + await self.mock_client(setup_request) outputs = ";".join(logs.output) self.assertTrue(str(extension_path) in outputs) @@ -481,6 +465,30 @@ def test_extension_server(self): self.assertTrue(one == 1) self.assertTrue(strat.generator._base_obj.__class__.__name__ == "OnesGenerator") + async def test_receive(self): + """test_receive - verifies the receive is working when server receives unexpected messages""" + + message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message + message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message + message3 = {"message": {"target": "test request"}} # valid message + message_list = [message1, message2, message3] + + for i, message in enumerate(message_list): + if isinstance(message, dict): + send = json.dumps(message).encode() + else: + send = message + self.writer.write(send) + await self.writer.drain() + + response = await self.reader.read(1024 * 512) + response = response.decode() + response = json.loads(response) + if i != 2: + self.assertTrue("error" in response) # Very generic error for malformed + else: + self.assertTrue("KeyError" in response["error"]) # Specific error + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datafetcher.py b/tests/test_datafetcher.py index ffe3f1980..62761ffa6 100644 --- a/tests/test_datafetcher.py +++ b/tests/test_datafetcher.py @@ -98,9 +98,6 @@ def setUp(self): # setup logger server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random port - socket = server.sockets.PySocket(port=0) - database_path = Path(__file__).parent / "test_databases" / "1000_outcome.db" dst_db_path = Path("./{}.db".format(str(uuid.uuid4().hex))) @@ -109,7 +106,7 @@ def setUp(self): time.sleep(0.1) self.assertTrue(dst_db_path.is_file()) - self.s = server.AEPsychServer(socket=socket, database_path=dst_db_path) + self.s = server.AEPsychServer(database_path=dst_db_path) setup_message = { "type": "setup", @@ -125,8 +122,6 @@ def setUp(self): def tearDown(self): time.sleep(0.1) - - self.s.cleanup() self.s.db.delete_db() def test_create_from_config(self): diff --git a/tests/test_db.py b/tests/test_db.py index ec44dca58..0dbb98d1e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -31,11 +31,6 @@ def tearDown(self): time.sleep(0.1) self._database.delete_db() - def test_db_create(self): - engine = self._database.get_engine() - self.assertIsNotNone(engine) - self.assertIsNotNone(self._database._engine) - def test_record_setup_basic(self): master_table = self._database.record_setup( description="test description", From da8251813f8497392a4076fd318a361543c50a1d Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 2/8] Add support for async db operations Summary: Use scoped sessions that is shared between threads to allow async db operations. Test Plan: Tests should pass --- aepsych/database/db.py | 223 +++++++++++++++++-------------------- aepsych/database/tables.py | 87 ++++++++++----- tests/test_db.py | 53 ++++----- 3 files changed, 179 insertions(+), 184 deletions(-) diff --git a/aepsych/database/db.py b/aepsych/database/db.py index 421b859b2..f92fe3246 100644 --- a/aepsych/database/db.py +++ b/aepsych/database/db.py @@ -20,7 +20,7 @@ from aepsych.config import Config from aepsych.strategy import Strategy from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm.session import close_all_sessions logger = logging.getLogger() @@ -46,33 +46,22 @@ def __init__(self, db_path: Optional[str] = None, update: bool = True) -> None: else: logger.info(f"No DB found at {db_path}, creating a new DB!") - self._engine = self.get_engine() + self._full_db_path = Path(self._db_dir) + self._full_db_path.mkdir(parents=True, exist_ok=True) + self._full_db_path = self._full_db_path.joinpath(self._db_name) - if update and self.is_update_required(): - self.perform_updates() - - def get_engine(self) -> sessionmaker: - """Get the engine for the database. - - Returns: - sessionmaker: The sessionmaker object for the database. - """ - if not hasattr(self, "_engine") or self._engine is None: - self._full_db_path = Path(self._db_dir) - self._full_db_path.mkdir(parents=True, exist_ok=True) - self._full_db_path = self._full_db_path.joinpath(self._db_name) - - self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}") + self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}") - # create the table metadata and tables - tables.Base.metadata.create_all(self._engine) + # create the table metadata and tables + tables.Base.metadata.create_all(self._engine) - # create an ongoing session to be used. Provides a conduit - # to the db so the instantiated objects work properly. - Session = sessionmaker(bind=self.get_engine()) - self._session = Session() + # Create a session to be start and closed on each use + self.session = scoped_session( + sessionmaker(bind=self._engine, expire_on_commit=False) + ) - return self._engine + if update and self.is_update_required(): + self.perform_updates() def delete_db(self) -> None: """Delete the database.""" @@ -107,21 +96,6 @@ def perform_updates(self) -> None: tables.DbParamTable.update(self._engine) tables.DbOutcomeTable.update(self._engine) - @contextmanager - def session_scope(self): - """Provide a transactional scope around a series of operations.""" - Session = sessionmaker(bind=self.get_engine()) - session = Session() - try: - yield session - session.commit() - except Exception as err: - logger.error(f"db session use failed: {err}") - session.rollback() - raise - finally: - session.close() - # @retry(stop_max_attempt_number=8, wait_exponential_multiplier=1.8) def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]: """Execute an arbitrary query written in sql. @@ -133,7 +107,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]: Returns: List[Any]: The results of the query. """ - with self.session_scope() as session: + with self.session() as session: return session.execute(query, vals).all() def get_master_records(self) -> List[tables.DBMasterTable]: @@ -142,7 +116,8 @@ def get_master_records(self) -> List[tables.DBMasterTable]: Returns: List[tables.DBMasterTable]: The list of master records. """ - records = self._session.query(tables.DBMasterTable).all() + with self.session() as session: + records = session.query(tables.DBMasterTable).all() return records def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]: @@ -154,11 +129,12 @@ def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]: Returns: tables.DBMasterTable or None: The master record or None if it doesn't exist. """ - records = ( - self._session.query(tables.DBMasterTable) - .filter(tables.DBMasterTable.unique_id == master_id) - .all() - ) + with self.session() as session: + records = ( + session.query(tables.DBMasterTable) + .filter(tables.DBMasterTable.unique_id == master_id) + .all() + ) if 0 < len(records): return records[0] @@ -260,11 +236,7 @@ def get_params_for(self, master_id: int) -> List[List[tables.DbParamTable]]: raw_record = self.get_raw_for(master_id) if raw_record is not None: - return [ - rec.children_param - for rec in self.get_raw_for(master_id) - if rec is not None - ] + return [raw.children_param for raw in raw_record] return [] @@ -283,14 +255,19 @@ def get_outcomes_for(self, master_id: int) -> List[List[tables.DbParamTable]]: raw_record = self.get_raw_for(master_id) if raw_record is not None: - return [ - rec.children_outcome - for rec in self.get_raw_for(master_id) - if rec is not None - ] + return [raw.children_outcome for raw in raw_record] return [] + @staticmethod + def _add_commit(session, obj): + # Helps guarantee duplicated objects across session can still be written + merged = session.merge(obj) + session.add(merged) + session.commit() + session.refresh(merged) + return merged + def record_setup( self, description: str = None, @@ -313,34 +290,36 @@ def record_setup( Returns: str: The experiment id. """ - self.get_engine() - - master_table = tables.DBMasterTable() - master_table.experiment_description = description - master_table.experiment_name = name - master_table.experiment_id = exp_id if exp_id is not None else str(uuid.uuid4()) - master_table.participant_id = ( - par_id if par_id is not None else str(uuid.uuid4()) - ) - master_table.extra_metadata = extra_metadata - self._session.add(master_table) + with self.session() as session: + master_table = tables.DBMasterTable() + master_table.experiment_description = description + master_table.experiment_name = name + master_table.experiment_id = ( + exp_id if exp_id is not None else str(uuid.uuid4()) + ) + master_table.participant_id = ( + par_id if par_id is not None else str(uuid.uuid4()) + ) + master_table.extra_metadata = extra_metadata - logger.debug(f"record_setup = [{master_table}]") + master_table = self._add_commit(session, master_table) - record = tables.DbReplayTable() - record.message_type = "setup" - record.message_contents = request + logger.debug(f"record_setup = [{master_table}]") - if request is not None and "extra_info" in request: - record.extra_info = request["extra_info"] + record = tables.DbReplayTable() + record.message_type = "setup" + record.message_contents = request - record.timestamp = datetime.datetime.now() - record.parent = master_table - logger.debug(f"record_setup = [{record}]") + if request is not None and "extra_info" in request: + record.extra_info = request["extra_info"] - self._session.add(record) - self._session.commit() + record.timestamp = datetime.datetime.now() + record.parent = master_table + logger.debug(f"record_setup = [{record}]") + self._add_commit(session, record) + + master_table # return the master table if it has a link to the list of child rows # tis needs to be passed into all future calls to link properly return master_table @@ -355,19 +334,19 @@ def record_message( type (str): The type of the message. request (Dict[str, Any]): The request. """ - # create a linked setup table - record = tables.DbReplayTable() - record.message_type = type - record.message_contents = request + with self.session() as session: + # create a linked setup table + record = tables.DbReplayTable() + record.message_type = type + record.message_contents = request - if "extra_info" in request: - record.extra_info = request["extra_info"] + if "extra_info" in request: + record.extra_info = request["extra_info"] - record.timestamp = datetime.datetime.now() - record.parent = master_table + record.timestamp = datetime.datetime.now() + record.parent = master_table - self._session.add(record) - self._session.commit() + self._add_commit(session, record) def record_raw( self, @@ -387,19 +366,19 @@ def record_raw( Returns: tables.DbRawTable: The raw entry. """ - raw_entry = tables.DbRawTable() - raw_entry.model_data = model_data + with self.session() as session: + raw_entry = tables.DbRawTable() + raw_entry.model_data = model_data - if timestamp is None: - raw_entry.timestamp = datetime.datetime.now() - else: - raw_entry.timestamp = timestamp - raw_entry.parent = master_table + if timestamp is None: + raw_entry.timestamp = datetime.datetime.now() + else: + raw_entry.timestamp = timestamp + raw_entry.parent = master_table - raw_entry.extra_data = json.dumps(extra_data) + raw_entry.extra_data = json.dumps(extra_data) - self._session.add(raw_entry) - self._session.commit() + raw_entry = self._add_commit(session, raw_entry) return raw_entry @@ -413,14 +392,14 @@ def record_param( param_name (str): The parameter name. param_value (str): The parameter value. """ - param_entry = tables.DbParamTable() - param_entry.param_name = param_name - param_entry.param_value = param_value + with self.session() as session: + param_entry = tables.DbParamTable() + param_entry.param_name = param_name + param_entry.param_value = param_value - param_entry.parent = raw_table + param_entry.parent = raw_table - self._session.add(param_entry) - self._session.commit() + self._add_commit(session, param_entry) def record_outcome( self, raw_table: tables.DbRawTable, outcome_name: str, outcome_value: float @@ -432,14 +411,14 @@ def record_outcome( outcome_name (str): The outcome name. outcome_value (float): The outcome value. """ - outcome_entry = tables.DbOutcomeTable() - outcome_entry.outcome_name = outcome_name - outcome_entry.outcome_value = outcome_value + with self.session() as session: + outcome_entry = tables.DbOutcomeTable() + outcome_entry.outcome_name = outcome_name + outcome_entry.outcome_value = outcome_value - outcome_entry.parent = raw_table + outcome_entry.parent = raw_table - self._session.add(outcome_entry) - self._session.commit() + self._add_commit(session, outcome_entry) def record_strat( self, master_table: tables.DBMasterTable, strat: io.BytesIO @@ -450,13 +429,13 @@ def record_strat( master_table (tables.DBMasterTable): The master table. strat (BytesIO): The strategy in buffer form. """ - strat_entry = tables.DbStratTable() - strat_entry.strat = strat - strat_entry.timestamp = datetime.datetime.now() - strat_entry.parent = master_table + with self.session() as session: + strat_entry = tables.DbStratTable() + strat_entry.strat = strat + strat_entry.timestamp = datetime.datetime.now() + strat_entry.parent = master_table - self._session.add(strat_entry) - self._session.commit() + self._add_commit(session, strat_entry) def record_config(self, master_table: tables.DBMasterTable, config: Config) -> None: """Record a config in the database. @@ -465,13 +444,13 @@ def record_config(self, master_table: tables.DBMasterTable, config: Config) -> N master_table (tables.DBMasterTable): The master table. config (Config): The config. """ - config_entry = tables.DbConfigTable() - config_entry.config = config - config_entry.timestamp = datetime.datetime.now() - config_entry.parent = master_table + with self.session() as session: + config_entry = tables.DbConfigTable() + config_entry.config = config + config_entry.timestamp = datetime.datetime.now() + config_entry.parent = master_table - self._session.add(config_entry) - self._session.commit() + self._add_commit(session, config_entry) def summarize_experiments(self) -> pd.DataFrame: """Provides a summary of the experiments contained in the database as a pandas dataframe. diff --git a/aepsych/database/tables.py b/aepsych/database/tables.py index ca0087516..d38038b77 100644 --- a/aepsych/database/tables.py +++ b/aepsych/database/tables.py @@ -49,10 +49,18 @@ class DBMasterTable(Base): extra_metadata = Column(String(4096)) # JSON-formatted metadata - children_replay = relationship("DbReplayTable", back_populates="parent") - children_strat = relationship("DbStratTable", back_populates="parent") - children_config = relationship("DbConfigTable", back_populates="parent") - children_raw = relationship("DbRawTable", back_populates="parent") + children_replay = relationship( + "DbReplayTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_strat = relationship( + "DbStratTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_config = relationship( + "DbConfigTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_raw = relationship( + "DbRawTable", lazy="selectin", join_depth=1, back_populates="parent" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DBMasterTable": @@ -185,7 +193,9 @@ class DbReplayTable(Base): extra_info = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_replay") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_replay" + ) __mapper_args__ = {} @@ -297,7 +307,9 @@ class DbStratTable(Base): strat = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_strat") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_strat" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbStratTable": @@ -356,7 +368,9 @@ class DbConfigTable(Base): config = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_config") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_config" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbConfigTable": @@ -420,9 +434,15 @@ class DbRawTable(Base): extra_data = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_raw") - children_param = relationship("DbParamTable", back_populates="parent") - children_outcome = relationship("DbOutcomeTable", back_populates="parent") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_raw" + ) + children_param = relationship( + "DbParamTable", lazy="joined", join_depth=1, back_populates="parent" + ) + children_outcome = relationship( + "DbOutcomeTable", lazy="joined", join_depth=1, back_populates="parent" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbRawTable": @@ -527,6 +547,9 @@ def update(db: Any, engine: Engine) -> None: param_value=float(param_value), ) + # Refresh the raw + db_raw_record = db.get_raw_for(master_table.unique_id)[-1] + if isinstance(outcomes, Iterable) and type(outcomes) != str: for j, outcome_value in enumerate(outcomes): if ( @@ -551,23 +574,25 @@ def update(db: Any, engine: Engine) -> None: outcome_value=float(outcomes), ) else: # Raws are already in, so we just need to update it - for master_table in db.get_master_records(): - unique_id = master_table.unique_id - raws = db.get_raw_for(unique_id) - tells = [ - message - for message in db.get_replay_for(unique_id) - if message.message_type == "tell" - ] - - if len(raws) == len(tells): - for raw, tell in zip(raws, tells): - if tell.extra_info is not None and len(tell.extra_info) > 0: - raw.extra_data = tell.extra_info - else: - logger.warning( - f"Tried to update raw table for experiment unique ID {unique_id}, but the number of tells and raws were not the same." - ) + with db.session() as session: + for master_table in db.get_master_records(): + unique_id = master_table.unique_id + raws = db.get_raw_for(unique_id) + tells = [ + message + for message in db.get_replay_for(unique_id) + if message.message_type == "tell" + ] + + if len(raws) == len(tells): + for raw, tell in zip(raws, tells): + if tell.extra_info is not None and len(tell.extra_info) > 0: + raw.extra_data = tell.extra_info + db._add_commit(session, raw) + else: + logger.warning( + f"Tried to update raw table for experiment unique ID {unique_id}, but the number of tells and raws were not the same." + ) @staticmethod def requires_update(engine: Engine) -> bool: @@ -654,7 +679,9 @@ class DbParamTable(Base): param_value = Column(String(50)) iteration_id = Column(Integer, ForeignKey("raw_data.unique_id")) - parent = relationship("DbRawTable", back_populates="children_param") + parent = relationship( + "DbRawTable", lazy="immediate", join_depth=1, back_populates="children_param" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbParamTable": @@ -720,7 +747,9 @@ class DbOutcomeTable(Base): outcome_value = Column(Float) iteration_id = Column(Integer, ForeignKey("raw_data.unique_id")) - parent = relationship("DbRawTable", back_populates="children_outcome") + parent = relationship( + "DbRawTable", lazy="immediate", join_depth=1, back_populates="children_outcome" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbOutcomeTable": diff --git a/tests/test_db.py b/tests/test_db.py index 0dbb98d1e..e5f2cf2d3 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -110,7 +110,8 @@ def test_update_db(self): name="test name", request={"test": "this is a test request"}, ) - test_database._session.rollback() + with test_database.session() as session: + session.rollback() test_database.perform_updates() # retry adding rows @@ -169,38 +170,32 @@ def test_update_db_with_raw_data_tables(self): outcome_dict_expected[i]["outcome_1"] = outcomes[i - 1][1] # Check that the number of entries in each table is correct - n_iterations = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM raw_data") - .fetchone()[0] - ) + n_iterations = test_database.session.execute( + "SELECT COUNT(*) FROM raw_data" + ).fetchone()[0] self.assertEqual(n_iterations, 7) - n_params = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM param_data") - .fetchone()[0] - ) + n_params = test_database.session.execute( + "SELECT COUNT(*) FROM param_data" + ).fetchone()[0] self.assertEqual(n_params, 28) - n_outcomes = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM outcome_data") - .fetchone()[0] - ) + n_outcomes = test_database.session.execute( + "SELECT COUNT(*) FROM outcome_data" + ).fetchone()[0] self.assertEqual(n_outcomes, 14) # Check that the data is correct - param_data = ( - test_database.get_engine().execute("SELECT * FROM param_data").fetchall() - ) + param_data = test_database.session.execute( + "SELECT * FROM param_data" + ).fetchall() param_dict = {x: {} for x in range(1, 8)} for param in param_data: param_dict[param.iteration_id][param.param_name] = float(param.param_value) self.assertEqual(param_dict, param_dict_expected) - outcome_data = ( - test_database.get_engine().execute("SELECT * FROM outcome_data").fetchall() - ) + outcome_data = test_database.session.execute( + "SELECT * FROM outcome_data" + ).fetchall() outcome_dict = {x: {} for x in range(1, 8)} for outcome in outcome_data: outcome_dict[outcome.iteration_id][outcome.outcome_name] = ( @@ -210,13 +205,9 @@ def test_update_db_with_raw_data_tables(self): self.assertEqual(outcome_dict, outcome_dict_expected) # Check if we have the extra_data column - pragma = ( - test_database.get_engine() - .execute( - "SELECT * FROM pragma_table_info('raw_data') WHERE name='extra_data'" - ) - .fetchall() - ) + pragma = test_database.session.execute( + "SELECT * FROM pragma_table_info('raw_data') WHERE name='extra_data'" + ).fetchall() self.assertTrue(len(pragma) == 1) # Make sure that update is no longer required @@ -240,10 +231,6 @@ def test_update_db_with_raw_extra_data(self): # open the new db test_database = db.Database(db_path=dst_db_path.as_posix(), update=False) - replay_tells = [ - row for row in test_database.get_replay_for(1) if row.message_type == "tell" - ] - # Make sure that update is required self.assertTrue(test_database.is_update_required()) From 25787b42d08b29a14ee2bd13767c027ce34405e1 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 3/8] Remove ability to modify logging levels Summary: Logging levels were used inconsistently and would cause problems with multithreaded servers. We remove this ability. Test Plan: Tests updated. --- aepsych/benchmark/pathos_benchmark.py | 2 +- aepsych/server/message_handlers/handle_ask.py | 2 +- aepsych/server/message_handlers/handle_can_model.py | 2 +- aepsych/server/message_handlers/handle_exit.py | 2 +- aepsych/server/message_handlers/handle_get_config.py | 2 +- aepsych/server/message_handlers/handle_info.py | 2 +- aepsych/server/message_handlers/handle_params.py | 2 +- aepsych/server/message_handlers/handle_query.py | 2 +- aepsych/server/message_handlers/handle_resume.py | 2 +- aepsych/server/message_handlers/handle_setup.py | 2 +- aepsych/server/message_handlers/handle_tell.py | 2 +- aepsych/server/replay.py | 2 +- aepsych/server/server.py | 2 +- aepsych/server/sockets.py | 2 +- aepsych/server/utils.py | 2 +- aepsych/utils_logging.py | 10 +++++++--- tests/models/test_pairwise_probit.py | 2 +- tests/server/test_server.py | 2 +- tests/test_datafetcher.py | 2 +- 19 files changed, 25 insertions(+), 21 deletions(-) diff --git a/aepsych/benchmark/pathos_benchmark.py b/aepsych/benchmark/pathos_benchmark.py index f5e4a6643..ab8ae90ae 100644 --- a/aepsych/benchmark/pathos_benchmark.py +++ b/aepsych/benchmark/pathos_benchmark.py @@ -25,7 +25,7 @@ ctx._force_start_method("spawn") # fixes problems with CUDA and fork -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() class PathosBenchmark(Benchmark): diff --git a/aepsych/server/message_handlers/handle_ask.py b/aepsych/server/message_handlers/handle_ask.py index 7f1a21389..0255a072e 100644 --- a/aepsych/server/message_handlers/handle_ask.py +++ b/aepsych/server/message_handlers/handle_ask.py @@ -10,7 +10,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_ask(server, request): diff --git a/aepsych/server/message_handlers/handle_can_model.py b/aepsych/server/message_handlers/handle_can_model.py index 32fa2fb18..096c30f75 100644 --- a/aepsych/server/message_handlers/handle_can_model.py +++ b/aepsych/server/message_handlers/handle_can_model.py @@ -9,7 +9,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_can_model(server, request): diff --git a/aepsych/server/message_handlers/handle_exit.py b/aepsych/server/message_handlers/handle_exit.py index b654558b3..c73f6ad13 100644 --- a/aepsych/server/message_handlers/handle_exit.py +++ b/aepsych/server/message_handlers/handle_exit.py @@ -9,7 +9,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_exit(server, request): diff --git a/aepsych/server/message_handlers/handle_get_config.py b/aepsych/server/message_handlers/handle_get_config.py index 1a347dbae..0f186860a 100644 --- a/aepsych/server/message_handlers/handle_get_config.py +++ b/aepsych/server/message_handlers/handle_get_config.py @@ -8,7 +8,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_get_config(server, request): diff --git a/aepsych/server/message_handlers/handle_info.py b/aepsych/server/message_handlers/handle_info.py index 910aac720..f4e07ceee 100644 --- a/aepsych/server/message_handlers/handle_info.py +++ b/aepsych/server/message_handlers/handle_info.py @@ -10,7 +10,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_info(server, request: Dict[str, Any]) -> Dict[str, Any]: diff --git a/aepsych/server/message_handlers/handle_params.py b/aepsych/server/message_handlers/handle_params.py index 525ae2daf..1c64b7f3e 100644 --- a/aepsych/server/message_handlers/handle_params.py +++ b/aepsych/server/message_handlers/handle_params.py @@ -9,7 +9,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_params(server, request): diff --git a/aepsych/server/message_handlers/handle_query.py b/aepsych/server/message_handlers/handle_query.py index 2ba9f4d83..194779ad0 100644 --- a/aepsych/server/message_handlers/handle_query.py +++ b/aepsych/server/message_handlers/handle_query.py @@ -11,7 +11,7 @@ import numpy as np import torch -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_query(server, request): diff --git a/aepsych/server/message_handlers/handle_resume.py b/aepsych/server/message_handlers/handle_resume.py index 4da5a14ca..315c2b41d 100644 --- a/aepsych/server/message_handlers/handle_resume.py +++ b/aepsych/server/message_handlers/handle_resume.py @@ -9,7 +9,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_resume(server, request): diff --git a/aepsych/server/message_handlers/handle_setup.py b/aepsych/server/message_handlers/handle_setup.py index c7a7df2d7..1dc9bd557 100644 --- a/aepsych/server/message_handlers/handle_setup.py +++ b/aepsych/server/message_handlers/handle_setup.py @@ -14,7 +14,7 @@ from aepsych.strategy import SequentialStrategy from aepsych.version import __version__ -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def _configure(server, config): diff --git a/aepsych/server/message_handlers/handle_tell.py b/aepsych/server/message_handlers/handle_tell.py index 3b22e33fe..972a19484 100644 --- a/aepsych/server/message_handlers/handle_tell.py +++ b/aepsych/server/message_handlers/handle_tell.py @@ -15,7 +15,7 @@ import pandas as pd import torch -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() DEFAULT_DESC = "default description" DEFAULT_NAME = "default name" diff --git a/aepsych/server/replay.py b/aepsych/server/replay.py index d338900cc..9fda596fe 100644 --- a/aepsych/server/replay.py +++ b/aepsych/server/replay.py @@ -12,7 +12,7 @@ import pandas as pd from aepsych.server.message_handlers.handle_tell import flatten_tell_record -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def replay(server, uuid_to_replay, skip_computations=False): diff --git a/aepsych/server/server.py b/aepsych/server/server.py index 30ca22335..a1c90cc8c 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -33,7 +33,7 @@ ) from aepsych.strategy import SequentialStrategy, Strategy -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def get_next_filename(folder, fname, ext): diff --git a/aepsych/server/sockets.py b/aepsych/server/sockets.py index 12b3e640d..8aaf79257 100644 --- a/aepsych/server/sockets.py +++ b/aepsych/server/sockets.py @@ -14,7 +14,7 @@ import aepsych.utils_logging as utils_logging import numpy as np -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() BAD_REQUEST = "bad request" diff --git a/aepsych/server/utils.py b/aepsych/server/utils.py index 27e6db063..027ddd30f 100644 --- a/aepsych/server/utils.py +++ b/aepsych/server/utils.py @@ -13,7 +13,7 @@ import aepsych.database.db as db import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def get_next_filename(folder, fname, ext): diff --git a/aepsych/utils_logging.py b/aepsych/utils_logging.py index 9a5aef693..9c0eb37b5 100644 --- a/aepsych/utils_logging.py +++ b/aepsych/utils_logging.py @@ -35,7 +35,7 @@ def format(self, record): return formatter.format(record) -def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: +def getLogger(log_path: str = "logs") -> logging.Logger: """Get a logger with the specified level and log path. Args: @@ -53,7 +53,7 @@ def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: "formatters": {"standard": {"()": ColorFormatter}}, "handlers": { "default": { - "level": level, + "level": logging.INFO, "class": "logging.StreamHandler", "formatter": "standard", }, @@ -65,7 +65,11 @@ def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: }, }, "loggers": { - "": {"handlers": ["default", "file"], "level": level, "propagate": False}, + "": { + "handlers": ["default", "file"], + "level": logging.DEBUG, + "propagate": False, + }, }, } diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py index 359c040f3..60e22adef 100644 --- a/tests/models/test_pairwise_probit.py +++ b/tests/models/test_pairwise_probit.py @@ -495,7 +495,7 @@ def test_hyperparam_consistency(self): class PairwiseProbitModelServerTest(unittest.TestCase): def setUp(self): # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + server.logger = utils_logging.getLogger("logs") # random datebase path name without dashes database_path = "./{}.db".format(str(uuid.uuid4().hex)) self.s = server.AEPsychServer(database_path=database_path) diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 6c5617896..e5bb62bb0 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -87,7 +87,7 @@ async def asyncSetUp(self): self.port = 5555 # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + server.logger = utils_logging.getLogger("unittests") # random datebase path name without dashes database_path = self.database_path diff --git a/tests/test_datafetcher.py b/tests/test_datafetcher.py index 62761ffa6..af655ef36 100644 --- a/tests/test_datafetcher.py +++ b/tests/test_datafetcher.py @@ -96,7 +96,7 @@ def pre_seed_config( def setUp(self): # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + server.logger = utils_logging.getLogger("logs") database_path = Path(__file__).parent / "test_databases" / "1000_outcome.db" From de0c46024180d723046129dd8871a1f36b0d0362 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 4/8] Move message handlers to info logging level Summary: All message handlers will emit a info logging message to support better live tracking of server functions. Test Plan: Tests should pass --- aepsych/database/db.py | 2 +- aepsych/server/message_handlers/handle_ask.py | 2 +- aepsych/server/message_handlers/handle_can_model.py | 2 +- aepsych/server/message_handlers/handle_info.py | 2 +- aepsych/server/message_handlers/handle_params.py | 2 +- aepsych/server/message_handlers/handle_query.py | 2 +- aepsych/server/message_handlers/handle_resume.py | 2 +- aepsych/server/message_handlers/handle_setup.py | 2 +- aepsych/server/message_handlers/handle_tell.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/aepsych/database/db.py b/aepsych/database/db.py index f92fe3246..9bf16bd05 100644 --- a/aepsych/database/db.py +++ b/aepsych/database/db.py @@ -315,7 +315,7 @@ def record_setup( record.timestamp = datetime.datetime.now() record.parent = master_table - logger.debug(f"record_setup = [{record}]") + logger.debug(f"replay_record_setup = [{record}]") self._add_commit(session, record) diff --git a/aepsych/server/message_handlers/handle_ask.py b/aepsych/server/message_handlers/handle_ask.py index 0255a072e..ae2ec4e80 100644 --- a/aepsych/server/message_handlers/handle_ask.py +++ b/aepsych/server/message_handlers/handle_ask.py @@ -18,7 +18,7 @@ def handle_ask(server, request): "config" -- dictionary with config (keys are strings, values are floats) "is_finished" -- bool, true if the strat is finished """ - logger.debug("got ask message!") + logger.info("got ask message!") if server._pregen_asks: params = server._pregen_asks.pop() else: diff --git a/aepsych/server/message_handlers/handle_can_model.py b/aepsych/server/message_handlers/handle_can_model.py index 096c30f75..94e9de6d4 100644 --- a/aepsych/server/message_handlers/handle_can_model.py +++ b/aepsych/server/message_handlers/handle_can_model.py @@ -15,7 +15,7 @@ def handle_can_model(server, request): # Check if the strategy has finished initialization; i.e., # if it has a model and data to fit (strat.can_fit) - logger.debug("got can_model message!") + logger.info("got can_model message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="can_model", request=request diff --git a/aepsych/server/message_handlers/handle_info.py b/aepsych/server/message_handlers/handle_info.py index f4e07ceee..99251ac82 100644 --- a/aepsych/server/message_handlers/handle_info.py +++ b/aepsych/server/message_handlers/handle_info.py @@ -22,7 +22,7 @@ def handle_info(server, request: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict[str, Any]: Returns dictionary containing the current state of the experiment """ - logger.debug("got info message!") + logger.info("got info message!") ret_val = info(server) diff --git a/aepsych/server/message_handlers/handle_params.py b/aepsych/server/message_handlers/handle_params.py index 1c64b7f3e..ad4b7181d 100644 --- a/aepsych/server/message_handlers/handle_params.py +++ b/aepsych/server/message_handlers/handle_params.py @@ -13,7 +13,7 @@ def handle_params(server, request): - logger.debug("got parameters message!") + logger.info("got parameters message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="parameters", request=request diff --git a/aepsych/server/message_handlers/handle_query.py b/aepsych/server/message_handlers/handle_query.py index 194779ad0..65263e0b4 100644 --- a/aepsych/server/message_handlers/handle_query.py +++ b/aepsych/server/message_handlers/handle_query.py @@ -15,7 +15,7 @@ def handle_query(server, request): - logger.debug("got query message!") + logger.info("got query message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="query", request=request diff --git a/aepsych/server/message_handlers/handle_resume.py b/aepsych/server/message_handlers/handle_resume.py index 315c2b41d..e7196648d 100644 --- a/aepsych/server/message_handlers/handle_resume.py +++ b/aepsych/server/message_handlers/handle_resume.py @@ -13,7 +13,7 @@ def handle_resume(server, request): - logger.debug("got resume message!") + logger.info("got resume message!") strat_id = int(request["message"]["strat_id"]) server.strat_id = strat_id if not server.is_performing_replay: diff --git a/aepsych/server/message_handlers/handle_setup.py b/aepsych/server/message_handlers/handle_setup.py index 1dc9bd557..e8bc8c991 100644 --- a/aepsych/server/message_handlers/handle_setup.py +++ b/aepsych/server/message_handlers/handle_setup.py @@ -68,7 +68,7 @@ def configure(server, config=None, **config_args): def handle_setup(server, request): - logger.debug("got setup message!") + logger.info("got setup message!") ### make a temporary config object to derive parameters because server handles config after table if ( "config_str" in request["message"].keys() diff --git a/aepsych/server/message_handlers/handle_tell.py b/aepsych/server/message_handlers/handle_tell.py index 972a19484..1903e41e6 100644 --- a/aepsych/server/message_handlers/handle_tell.py +++ b/aepsych/server/message_handlers/handle_tell.py @@ -21,7 +21,7 @@ def handle_tell(server, request): - logger.debug("got tell message!") + logger.info("got tell message!") if not server.is_performing_replay: server.db.record_message( From 9619f2d93285ecc22f904031f90d6060aae01f5e Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 5/8] Strategy methods act on copies of models if needed Summary: To support multi client server, strategy methods will not act on copies of models to avoid changing tensor gradients between two threads. Test Plan: New test --- aepsych/server/server.py | 3 +++ aepsych/strategy/strategy.py | 43 ++++++++++++++++++++++++------------ tests/server/test_server.py | 43 ++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/aepsych/server/server.py b/aepsych/server/server.py index a1c90cc8c..aa9c40245 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -53,6 +53,7 @@ def __init__( self.host = host self.port = port self.max_workers = max_workers + self.clients_connected = 0 self.db: db.Database = db.Database(database_path) self.is_performing_replay = False self.exit_server_loop = False @@ -323,6 +324,7 @@ async def handle_client(self, reader, writer): """ addr = writer.get_extra_info("peername") logger.info(f"Connected to {addr}") + self.clients_connected += 1 try: while True: @@ -361,6 +363,7 @@ async def handle_client(self, reader, writer): logger.info(f"Connection closed for {addr}") writer.close() await writer.wait_closed() + self.clients_connected -= 1 def handle_request(self, message: Dict[str, Any]) -> Union[Dict[str, Any], str]: """Given a message, dispatch the correct handler and return the result. diff --git a/aepsych/strategy/strategy.py b/aepsych/strategy/strategy.py index 097a2f44c..182b55e04 100644 --- a/aepsych/strategy/strategy.py +++ b/aepsych/strategy/strategy.py @@ -8,6 +8,7 @@ from __future__ import annotations import warnings +from copy import deepcopy from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy as np @@ -56,6 +57,7 @@ def __init__( name: str = "", run_indefinitely: bool = False, transforms: ChainedInputTransform = ChainedInputTransform(**{}), + copy_model: bool = False, ) -> None: """Initialize the strategy object. @@ -90,6 +92,9 @@ def __init__( should be defined in raw parameter space for initialization. However, if the lb/ub attribute are access from an initialized Strategy object, it will be returned in transformed space. + copy_model (bool): Whether to do any model-related methods on a + copy or the original. Used for multi-client strategies. Defaults + to False. """ self.is_finished = False @@ -160,6 +165,7 @@ def __init__( self.min_total_outcome_occurrences = min_total_outcome_occurrences self.max_asks = max_asks or generator.max_asks self.keep_most_recent = keep_most_recent + self.copy_model = copy_model self.transforms = transforms if self.transforms is not None: @@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor: self.model.to(self.generator_device) # type: ignore self._count = self._count + num_points - points = self.generator.gen(num_points, self.model, **kwargs) + model = deepcopy(self.model) if self.copy_model else self.model + points = self.generator.gen(num_points, model, **kwargs) if original_device is not None: self.model.to(original_device) # type: ignore @@ -295,9 +302,9 @@ def get_max( self.model is not None ), "model is None! Cannot get the max without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_max( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -324,9 +331,9 @@ def get_min( self.model is not None ), "model is None! Cannot get the min without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_min( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -358,9 +365,9 @@ def inv_query( self.model is not None ), "model is None! Cannot get the inv_query without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = inv_query( - model=self.model, + model=model, y=y, bounds=self.bounds, locked_dims=constraints, @@ -385,7 +392,8 @@ def predict( """ assert self.model is not None, "model is None! Cannot predict without a model!" self.model.to(self.model_device) - return self.model.predict(x=x, probability_space=probability_space) + model = deepcopy(self.model) if self.copy_model else self.model + return model.predict(x=x, probability_space=probability_space) @ensure_model_is_fresh def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor: @@ -400,7 +408,8 @@ def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor: """ assert self.model is not None, "model is None! Cannot sample without a model!" self.model.to(self.model_device) - return self.model.sample(x, num_samples=num_samples) + model = deepcopy(self.model) if self.copy_model else self.model + return model.sample(x, num_samples=num_samples) def finish(self) -> None: """Finish the strategy.""" @@ -442,7 +451,8 @@ def finished(self) -> bool: assert ( self.model is not None ), "model is None! Cannot predict without a model!" - fmean, _ = self.model.predict(self.eval_grid, probability_space=True) + model = deepcopy(self.model) if self.copy_model else self.model + fmean, _ = model.predict(self.eval_grid, probability_space=True) meets_post_range = bool( ((fmean.max() - fmean.min()) >= self.min_post_range).item() ) @@ -504,9 +514,10 @@ def fit(self) -> None: """Fit the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.fit( # type: ignore + model.fit( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -516,11 +527,12 @@ def fit(self) -> None: ) else: try: - self.model.fit(self.x, self.y) # type: ignore + model.fit(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) @@ -528,9 +540,10 @@ def update(self) -> None: """Update the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.update( # type: ignore + model.update( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -540,11 +553,13 @@ def update(self) -> None: ) else: try: - self.model.update(self.x, self.y) # type: ignore + model.update(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) diff --git a/tests/server/test_server.py b/tests/server/test_server.py index e5bb62bb0..1b7d6bf13 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -46,6 +46,7 @@ generator = OptimizeAcqfGenerator model = GPClassificationModel min_total_outcome_occurrences = 0 +copy_model = True [OptimizeAcqfGenerator] acqf = MCPosteriorVariance @@ -489,6 +490,48 @@ async def test_receive(self): else: self.assertTrue("KeyError" in response["error"]) # Specific error + async def test_multi_client(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + "extra_info": {}, + } + + await self.mock_client(setup_request) + + # Create second client + reader2, writer2 = await asyncio.open_connection(self.ip, self.port) + + async def _mock_client2(request: Dict[str, Any]) -> Any: + writer2.write(json.dumps(request).encode()) + await writer2.drain() + + response = await reader2.read(1024 * 512) + return response.decode() + + for _ in range(2): # 2 loops should do it as we have 2 clients + tasks = [ + asyncio.create_task(self.mock_client(ask_request)), + asyncio.create_task(_mock_client2(ask_request)), + ] + await asyncio.gather(*tasks) + + tasks = [ + asyncio.create_task(self.mock_client(tell_request)), + asyncio.create_task(_mock_client2(tell_request)), + ] + await asyncio.gather(*tasks) + + self.assertTrue(self.s.strat.finished) + self.assertTrue(self.s.strat.x.numel() == 4) + self.assertTrue(self.s.clients_connected == 2) + if __name__ == "__main__": unittest.main() From c4446fca7fd6aa6a19313e90a994b0c336e26162 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 6/8] Add utility class to support background server Summary: A subclass of the aepsych server with methods specifically to run the server in a background process. This will be used to ensured that even within the same main script, the server will run like an actual server and does not do anything sneaky like bypassing the async queue. Test Plan: New test --- aepsych/server/__init__.py | 4 +- aepsych/server/server.py | 68 ++++++++++++++++++--- tests/server/test_server.py | 115 +++++++++++++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 12 deletions(-) diff --git a/aepsych/server/__init__.py b/aepsych/server/__init__.py index ae3552278..7783362c9 100644 --- a/aepsych/server/__init__.py +++ b/aepsych/server/__init__.py @@ -5,6 +5,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .server import AEPsychServer +from .server import AEPsychBackgroundServer, AEPsychServer -__all__ = ["AEPsychServer"] +__all__ = ["AEPsychServer", "AEPsychBackgroundServer"] diff --git a/aepsych/server/server.py b/aepsych/server/server.py index aa9c40245..8ff711a48 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -32,6 +32,7 @@ replay, ) from aepsych.strategy import SequentialStrategy, Strategy +from multiprocess import Process logger = utils_logging.getLogger() @@ -48,11 +49,9 @@ def __init__( host: str = "0.0.0.0", port: int = 5555, database_path: str = "./databases/default.db", - max_workers: Optional[int] = None, ): self.host = host self.port = port - self.max_workers = max_workers self.clients_connected = 0 self.db: db.Database = db.Database(database_path) self.is_performing_replay = False @@ -278,11 +277,6 @@ def start_blocking(self) -> None: process or machine.""" asyncio.run(self.serve()) - def start_background(self): - """Starts the server in a background thread. Used for scripts where the - client and server are in the same process.""" - raise NotImplementedError - async def serve(self) -> None: """Serves the server on the set IP and port. This creates a coroutine for asyncio to handle requests asyncronously. @@ -291,7 +285,7 @@ async def serve(self) -> None: self.handle_client, self.host, self.port ) self.loop = asyncio.get_running_loop() - pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) + pool = concurrent.futures.ThreadPoolExecutor() self.loop.set_default_executor(pool) async with self.server: @@ -427,6 +421,64 @@ def __getstate__(self): return state +class AEPsychBackgroundServer(AEPsychServer): + """A class to handle the server in a background thread. Unlike the normal + AEPsychServer, this does not create the db right away until the server is + started. When starting this server, it'll be sent to another process, a db + will be initialized, then the server will be served. This server should then + be interacted with by the main thread via a client.""" + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 5555, + database_path: str = "./databases/default.db", + ): + self.host = host + self.port = port + self.database_path = database_path + self.clients_connected = 0 + self.is_performing_replay = False + self.exit_server_loop = False + self._db_raw_record = None + self.skip_computations = False + self.background_process = None + self.strat_names = None + self.extensions = None + self._strats = [] + self._parnames = [] + self._configs = [] + self._master_records = [] + self.strat_id = -1 + self.outcome_names = [] + + def _start_server(self) -> None: + self.db: db.Database = db.Database(self.database_path) + if self.db.is_update_required(): + self.db.perform_updates() + + super().start_blocking() + + def start(self): + """Starts the server in a background thread. Used by the client to start + the server for a client in another process or machine.""" + self.background_process = Process(target=self._start_server, daemon=True) + self.background_process.start() + + def stop(self): + """Stops the server and closes the background process.""" + self.exit_server_loop = True + self.background_process.terminate() + self.background_process.join() + self.background_process.close() + self.background_process = None + + def __getstate__(self): + # Override parent's __getstate__ to not worry about the db + state = self.__dict__.copy() + return state + + def parse_argument(): parser = argparse.ArgumentParser(description="AEPsych Server") parser.add_argument( diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 1b7d6bf13..4345f4ef7 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -16,7 +16,6 @@ import aepsych.server as server import aepsych.utils_logging as utils_logging -from aepsych.server.sockets import BAD_REQUEST dummy_config = """ [common] @@ -88,7 +87,7 @@ async def asyncSetUp(self): self.port = 5555 # setup logger - server.logger = utils_logging.getLogger("unittests") + self.logger = utils_logging.getLogger("unittests") # random datebase path name without dashes database_path = self.database_path @@ -533,5 +532,117 @@ async def _mock_client2(request: Dict[str, Any]) -> Any: self.assertTrue(self.s.clients_connected == 2) +class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase): + @property + def database_path(self): + return "./{}_test_server.db".format(str(uuid.uuid4().hex)) + + async def asyncSetUp(self): + self.ip = "127.0.0.1" + self.port = 5555 + + # setup logger + self.logger = utils_logging.getLogger("unittests") + + # random datebase path name without dashes + database_path = self.database_path + self.s = server.AEPsychBackgroundServer( + database_path=database_path, host=self.ip, port=self.port + ) + self.db_name = database_path.split("/")[1] + self.db_path = database_path + + # Writer will be made in tests + self.writer = None + + async def asyncTearDown(self): + # Stops the client + if self.writer is not None: + self.writer.close() + + time.sleep(0.1) + + # cleanup the db + db_path = Path(self.db_path) + try: + print(db_path) + db_path.unlink() + except PermissionError as e: + print("Failed to deleted database: ", e) + + async def test_background_server(self): + self.assertIsNone(self.s.background_process) + self.s.start() + self.assertTrue(self.s.background_process.is_alive()) + + # Make a client + try_again = True + attempts = 0 + while try_again: + try_again = False + attempts += 1 + try: + reader, self.writer = await asyncio.open_connection(self.ip, self.port) + except ConnectionRefusedError: + if attempts > 10: + raise ConnectionRefusedError + try_again = True + time.sleep(1) + + async def _mock_client(request: Dict[str, Any]) -> Any: + self.writer.write(json.dumps(request).encode()) + await self.writer.drain() + + response = await reader.read(1024 * 512) + return response.decode() + + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + "extra_info": {}, + } + + await _mock_client(setup_request) + + expected_x = [0, 1, 2, 3] + expected_z = list(reversed(expected_x)) + expected_y = [x % 2 for x in expected_x] + i = 0 + while True: + response = await _mock_client(ask_request) + response = json.loads(response) + tell_request["message"]["config"]["x"] = [expected_x[i]] + tell_request["message"]["config"]["z"] = [expected_z[i]] + tell_request["message"]["outcome"] = expected_y[i] + tell_request["extra_info"]["e1"] = 1 + tell_request["extra_info"]["e2"] = 2 + i = i + 1 + await _mock_client(tell_request) + + if response["is_finished"]: + break + + self.s.stop() + self.assertIsNone(self.s.background_process) + + # Create a synchronous server to check db contents + s = server.AEPsychServer(database_path=self.db_path) + unique_id = s.db.get_master_records()[-1].unique_id + out_df = s.get_dataframe_from_replay(unique_id) + self.assertTrue((out_df.x == expected_x).all()) + self.assertTrue((out_df.z == expected_z).all()) + self.assertTrue((out_df.response == expected_y).all()) + self.assertTrue((out_df.e1 == [1] * 4).all()) + self.assertTrue((out_df.e2 == [2] * 4).all()) + self.assertTrue("post_mean" in out_df.columns) + self.assertTrue("post_var" in out_df.columns) + + if __name__ == "__main__": unittest.main() From c9935a1dd8fc29b18b41f9f88ef9d67494174a2a Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 7/8] Add reduce special method to transform wrappers Summary: Add the ___reduce__ method to transform wrappers (model/generator). This allows transform wrapped objects to be pickled (with original pickle module, not just dill) Test Plan: tests should pass --- aepsych/transforms/parameters.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index ce1b54a12..eabe73ca7 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -477,6 +477,10 @@ def eval(self): f"{self._base_obj.__class__.__name__} has no attribute 'eval'" ) + def __reduce__(self): + # Helps pickle work (not dill) + return (ParameterTransformedGenerator, (self._base_obj, self.transforms)) + @classmethod def get_config_options( cls, @@ -725,6 +729,10 @@ def eval(self): f"{self._base_obj.__class__.__name__} has no attribute 'eval'" ) + def __reduce__(self): + # Helps pickle work (not dill) + return (ParameterTransformedModel, (self._base_obj, self.transforms)) + @classmethod def get_config_options( cls, From 7deb7570d1fdbfc9cba02063d2df17629a88b3b4 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 10 Feb 2025 01:58:53 -0600 Subject: [PATCH 8/8] Implement async generator Summary: A generator that allows asynchronous generation of points. Uses a different process to generate points. This doesn't fix the situation where modeling fitting takes a long time. Test Plan: New test --- aepsych/generators/__init__.py | 2 + aepsych/generators/async_generator.py | 246 +++++++++++++++++++++++ tests/generators/test_async_generator.py | 168 ++++++++++++++++ 3 files changed, 416 insertions(+) create mode 100644 aepsych/generators/async_generator.py create mode 100644 tests/generators/test_async_generator.py diff --git a/aepsych/generators/__init__.py b/aepsych/generators/__init__.py index c8d4089a4..e90b93256 100644 --- a/aepsych/generators/__init__.py +++ b/aepsych/generators/__init__.py @@ -10,6 +10,7 @@ from ..config import Config from .acqf_grid_search_generator import AcqfGridSearchGenerator from .acqf_thompson_sampler_generator import AcqfThompsonSamplerGenerator +from .async_generator import AsyncGenerator from .epsilon_greedy_generator import EpsilonGreedyGenerator from .manual_generator import ManualGenerator, SampleAroundPointsGenerator from .optimize_acqf_generator import OptimizeAcqfGenerator @@ -27,6 +28,7 @@ "IntensityAwareSemiPGenerator", "AcqfThompsonSamplerGenerator", "AcqfGridSearchGenerator", + "AsyncGenerator", ] Config.register_module(sys.modules[__name__]) diff --git a/aepsych/generators/async_generator.py b/aepsych/generators/async_generator.py new file mode 100644 index 000000000..fb4489fab --- /dev/null +++ b/aepsych/generators/async_generator.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import os +import time +from concurrent import futures +from multiprocessing import get_context +from typing import Dict, List, Optional + +import numpy as np + +import torch +from aepsych.generators.base import AEPsychGenerator +from aepsych.models.model_protocol import ModelProtocol +from aepsych.utils_logging import getLogger + +logger = getLogger() + + +@dataclasses.dataclass +class AsyncPoint: + """Dataclass to keep track of asynchronously generated points.""" + + point: torch.Tensor + generator_name: str + gen_time: float + fixed_features: Optional[Dict[int, float]] = None + model: Optional[ModelProtocol] = None + data: Optional[torch.Tensor] = dataclasses.field(init=False, default=None) + + def __post_init__(self): + if self.model is not None: + self.data = self.model.train_inputs[0] + self.model = None + else: + self.data = None + + @property + def data_len(self) -> int: + """Return the length of the data tensor.""" + return self.data.shape[0] if self.data is not None else 0 + + +class AsyncGenerator(AEPsychGenerator): + """Generator that holds two generators. The primary generator will always + be sent to a different process to handle and if it cannot return within a + timeout, the backup generator will be used instead. In the case of timeout, + the other process will continue to run until the generator is called again. + """ + + def __init__( + self, + generator: AEPsychGenerator, + backup_generator: AEPsychGenerator, + timeout: float = 2.0, + data_diff_limit: Optional[int] = None, + n_pregen: int = 1, + ) -> None: + """Initialize an asynchronous generator. This holds two generators. The + primary generator will always be sent to a different process to handle + and if it cannot return within a timeout, the backup generator will be + used instead. In the case of timeout, the other process will continue to + run until the generator is called again. + + WARNING: Whenever the gen() is called, a new processes will be + forked from the main one. This means that the generators will have the + exact same state (including internal RNG seeds). While we do reseed the + new process, any seeds within an object (like the seed inside the + SobolGenerator) will not be modified and thus can potentially generate + exactly the same points. This should be fine for OptimizeAcqfGenerators. + + Args: + generator (AEPsychGenerator): The primary generator to use. + backup_generator (AEPsychGenerator): The backup generator to use if + the primary times out. + timeout (float): The timeout for the primary generator. Defaults to + 2.0. + data_diff_limit (int, optional): The maximum difference in data + length between the model and the point to accept. If not set, + there would not be any limit. + n_pregen (int, optional): The number of points to pre-generate. + Defaults to 1. + """ + self.generator = generator + self.backup_generator = backup_generator + self.timeout = timeout + self.data_diff_limit = data_diff_limit or np.inf + self.n_pregen = n_pregen + self.executor: Optional[futures.ProcessPoolExecutor] = None + self.futures: List[futures.Future] = [] + + # Populate generator class attributes based on main generator + self._requires_model = self.generator._requires_model + self.stimuli_per_trial = self.generator.stimuli_per_trial + self.max_asks = self.generator.max_asks + self.dim = self.generator.dim + + def gen( + self, + num_points: int, + model: Optional[ModelProtocol] = None, + fixed_features: Optional[Dict[int, float]] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> torch.Tensor: + """Get a point from the generator. When called, it will check if there + are any points being generated by the primary generator and if so, wait + for it to finish. If the timeout is reached, the backup generator will + be used instead. Whenever there is a timeout, the primary generator will + continue to work and the next time gen() is called, it will be checked + again. + + Args: + num_points (int): The number of points to generate. + model (ModelProtocol, optional): The model to use for generating + points. Defaults to None. + fixed_features (Dict[int, float], optional): The fixed features to + use for generating points. Defaults to None. + timeout (float, optional): The timeout for the primary generator. + If not set, defaults to the class timeout. + **kwargs: Additional keyword arguments to pass to the generator. + + Returns: + torch.Tensor: The generated point. + """ + if self.executor is None: # Initialize the executor + self.executor = futures.ProcessPoolExecutor( + max_workers=self.n_pregen, + mp_context=get_context("spawn"), + initializer=self._set_process_seed, + ) + + # We keep adding futures until we have enough + while len(self.futures) < self.n_pregen: + self.futures.append( + self.executor.submit( + self._gen, + num_points, + model, + fixed_features=fixed_features, + **kwargs, + ) + ) + + try: + # We return the first future that finished + timeout = timeout or self.timeout + for future in futures.as_completed(self.futures, timeout=timeout): + try: + result = future.result() + + # Check if fixed features match + if result.fixed_features != fixed_features: + # Throw it out and wait for next + # Heuristic to never allow a bunch of fixed to hold us back + logger.info( + "AsyncGenerator found mismatched fixed features, skipping." + ) + self.futures.remove(future) + continue + + if model is not None: + # Check if the data used to generate is close enough + if ( + result.data_len - model.train_inputs[0].shape[0] + <= self.data_diff_limit + ): + self.futures.remove(future) + return result.point + else: + logger.info( + "AsyncGenerator found a point that was generated with data that is too different, skipping." + ) + self.futures.remove(future) + else: + self.futures.remove(future) + return result.point + + except (futures.CancelledError, futures.process.BrokenProcessPool) as e: + logger.error("Generator job failed") + logger.error(e) + self.futures.remove(future) + continue + + # All futures resolved but we still have no point, so we use backup + return self.backup_generator.gen( + num_points=num_points, + model=model, + fixed_features=fixed_features, + **kwargs, + ) + + except futures.TimeoutError: # Timeout backup + logger.info("Main generator timed out, using backup generator.") + return self.backup_generator.gen( + num_points=num_points, + model=model, + fixed_features=fixed_features, + **kwargs, + ) + + @staticmethod + def _set_process_seed(): + # Set the random seed of numpy and pytorch based on pid and time + seed = os.getpid() + int(time.time()) + torch.manual_seed(seed) + np.random.seed(seed) + + def _gen( + self, + num_points: int, + model: Optional[ModelProtocol] = None, + fixed_features: Optional[Dict[int, float]] = None, + **kwargs, + ) -> AsyncPoint: + # Wrapper to pass the generator to the executor and return a async + # point, must be static as we don't want to pickle self. + start = time.time() + point = self.generator.gen(num_points, model, fixed_features, **kwargs) + end = time.time() + async_point = AsyncPoint( + point=point, + gen_time=end - start, + generator_name=self.generator.__class__.__name__, + model=model, + fixed_features=fixed_features, + ) + + return async_point + + def __del__(self): + # To shutdown executor on deletion + if self.executor is not None: + self.executor.shutdown(wait=True, cancel_futures=True) + + def __getstate__(self): + # Need to blank exectutor/futures to be able to pickle + state = self.__dict__.copy() + state["executor"] = None + state["futures"] = [] + return state diff --git a/tests/generators/test_async_generator.py b/tests/generators/test_async_generator.py new file mode 100644 index 000000000..8658b36cc --- /dev/null +++ b/tests/generators/test_async_generator.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import time +import unittest +import uuid + +import numpy as np +import torch +from aepsych.acquisition import MCLevelSetEstimation +from aepsych.config import Config +from aepsych.generators import AsyncGenerator, OptimizeAcqfGenerator, SobolGenerator +from aepsych.models import GPClassificationModel +from aepsych.server import AEPsychBackgroundServer, AEPsychServer +from aepsych.strategy import SequentialStrategy +from aepsych_client import AEPsychClient +from sklearn.datasets import make_classification + + +class TestAsyncGenerator(unittest.TestCase): + def test_timeout_fallback(self): + seed = 1 + torch.manual_seed(seed) + np.random.seed(seed) + timeout = 0.1 + + X, y = make_classification( + n_samples=100, + n_features=8, + n_redundant=3, + n_informative=5, + random_state=1, + n_clusters_per_class=4, + ) + X, y = torch.Tensor(X), torch.Tensor(y) + lb = -3 * torch.ones(8) + ub = 3 * torch.ones(8) + inducing_size = 10 + + model = GPClassificationModel( + dim=8, + inducing_size=inducing_size, + ) + + model.fit(X, y) + primary_gen = OptimizeAcqfGenerator( + acqf=MCLevelSetEstimation, + lb=lb, + ub=ub, + ) + # Bounds are higher than real bounds to test fallback + fallback_gen = SobolGenerator(lb=10 * torch.ones(8), ub=20 * torch.ones(8)) + generator = AsyncGenerator(primary_gen, fallback_gen, timeout=timeout) + + # A dummy call to make sure the generator is initialized + points = generator.gen(1, model) + + start = time.time() + points = generator.gen(1, model) + end = time.time() + fallback_time = end - start + + # Really loose test due to CI slowness + self.assertTrue(fallback_time < 0.3, f"fallback_time: {fallback_time}") + self.assertTrue(torch.all(points > 5)) + + # Wait a bit so that the primary generator is ready + time.sleep(10) + + start = time.time() + points = generator.gen(1, model) + end = time.time() + gen_time = end - start + self.assertTrue(gen_time < 0.01) + self.assertTrue(torch.all(points < 5)) + + # Delete it in case it is still running + del generator + + def test_async_pickle(self): + db_path = "./{}_test_server.db".format(str(uuid.uuid4().hex)) + # Create a server + server = AEPsychBackgroundServer(host="127.0.0.1", database_path=db_path) + server.start() + time.sleep(1) + + # Make a client + try_again = True + attempts = 0 + while try_again: + try_again = False + attempts += 1 + try: + client = AEPsychClient(ip=server.host, port=server.port) + except ConnectionRefusedError: + if attempts > 10: + raise ConnectionRefusedError + try_again = True + time.sleep(1) + + n_init = 10 + n_opt = 5 + lower_bound = 1 + upper_bound = 100 + target = 0.75 + + config_str = f""" + [common] + parnames = [signal1] + stimuli_per_trial = 1 + outcome_types = [binary] + target = {target} + strategy_names = [init_strat, opt_strat] + + [signal1] + par_type = continuous + lower_bound = {lower_bound} + upper_bound = {upper_bound} + + [init_strat] + generator = SobolGenerator + min_asks = {n_init} + + [SobolGenerator] + seed = 1 + + [opt_strat] + generator = AsyncGenerator + model = GPClassificationModel + min_asks = {n_opt} + + [AsyncGenerator] + generator = OptimizeAcqfGenerator + backup_generator = SobolGenerator + timeout = 1 + + [OptimizeAcqfGenerator] + acqf = MCLevelSetEstimation + """ + client.configure(config_str=config_str) + + finished = False + signals = [] + responses = [] + while not finished: + ask = client.ask() + outcome = int(np.random.rand() < (ask["config"]["signal1"][0] / 100)) + client.tell(ask["config"], outcome) + + finished = ask["is_finished"] + signals.append(ask["config"]["signal1"][0]) + responses.append(outcome) + + server.stop() + + server = AEPsychServer(database_path=db_path) + unique_id = server.db.get_master_records()[-1].unique_id + out_df = server.get_dataframe_from_replay(unique_id) + + self.assertTrue((out_df["response"] == responses).all()) + self.assertTrue((out_df["signal1"] == signals).all()) + + server.db.delete_db() + time.sleep(0.1)