|
7 | 7 |
|
8 | 8 | import asyncio
|
9 | 9 | import json
|
10 |
| -import logging |
| 10 | +import time |
11 | 11 | import unittest
|
12 | 12 | import uuid
|
13 | 13 | from pathlib import Path
|
14 | 14 | from typing import Any, Dict
|
15 | 15 |
|
16 | 16 | import aepsych.server as server
|
17 | 17 | import aepsych.utils_logging as utils_logging
|
18 |
| -from aepsych.server.sockets import BAD_REQUEST |
19 | 18 |
|
20 | 19 | dummy_config = """
|
21 | 20 | [common]
|
@@ -87,7 +86,7 @@ async def asyncSetUp(self):
|
87 | 86 | self.port = 5555
|
88 | 87 |
|
89 | 88 | # setup logger
|
90 |
| - server.logger = utils_logging.getLogger("unittests") |
| 89 | + self.logger = utils_logging.getLogger("unittests") |
91 | 90 |
|
92 | 91 | # random datebase path name without dashes
|
93 | 92 | database_path = self.database_path
|
@@ -523,5 +522,109 @@ async def _mock_client2(request: Dict[str, Any]) -> Any:
|
523 | 522 | self.assertTrue(self.s.clients_connected == 2)
|
524 | 523 |
|
525 | 524 |
|
| 525 | +class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase): |
| 526 | + @property |
| 527 | + def database_path(self): |
| 528 | + return "./{}_test_server.db".format(str(uuid.uuid4().hex)) |
| 529 | + |
| 530 | + async def asyncSetUp(self): |
| 531 | + self.ip = "127.0.0.1" |
| 532 | + self.port = 5555 |
| 533 | + |
| 534 | + # setup logger |
| 535 | + self.logger = utils_logging.getLogger("unittests") |
| 536 | + |
| 537 | + # random datebase path name without dashes |
| 538 | + database_path = self.database_path |
| 539 | + self.s = server.AEPsychBackgroundServer( |
| 540 | + database_path=database_path, host=self.ip, port=self.port |
| 541 | + ) |
| 542 | + self.db_name = database_path.split("/")[1] |
| 543 | + self.db_path = database_path |
| 544 | + |
| 545 | + # Writer will be made in tests |
| 546 | + self.writer = None |
| 547 | + |
| 548 | + async def asyncTearDown(self): |
| 549 | + # Stops the client |
| 550 | + if self.writer is not None: |
| 551 | + self.writer.close() |
| 552 | + |
| 553 | + time.sleep(0.1) |
| 554 | + |
| 555 | + # cleanup the db |
| 556 | + db_path = Path(self.db_path) |
| 557 | + try: |
| 558 | + print(db_path) |
| 559 | + db_path.unlink() |
| 560 | + except PermissionError as e: |
| 561 | + print("Failed to deleted database: ", e) |
| 562 | + |
| 563 | + async def test_background_server(self): |
| 564 | + self.assertIsNone(self.s.background_process) |
| 565 | + self.s.start() |
| 566 | + self.assertTrue(self.s.background_process.is_alive()) |
| 567 | + |
| 568 | + # Give time for the server to start |
| 569 | + time.sleep(5) |
| 570 | + |
| 571 | + # Create a client |
| 572 | + reader, self.writer = await asyncio.open_connection(self.ip, self.port) |
| 573 | + |
| 574 | + async def _mock_client(request: Dict[str, Any]) -> Any: |
| 575 | + self.writer.write(json.dumps(request).encode()) |
| 576 | + await self.writer.drain() |
| 577 | + |
| 578 | + response = await reader.read(1024 * 512) |
| 579 | + return response.decode() |
| 580 | + |
| 581 | + setup_request = { |
| 582 | + "type": "setup", |
| 583 | + "version": "0.01", |
| 584 | + "message": {"config_str": dummy_config}, |
| 585 | + } |
| 586 | + ask_request = {"type": "ask", "message": ""} |
| 587 | + tell_request = { |
| 588 | + "type": "tell", |
| 589 | + "message": {"config": {"x": [0.5]}, "outcome": 1}, |
| 590 | + "extra_info": {}, |
| 591 | + } |
| 592 | + |
| 593 | + await _mock_client(setup_request) |
| 594 | + |
| 595 | + expected_x = [0, 1, 2, 3] |
| 596 | + expected_z = list(reversed(expected_x)) |
| 597 | + expected_y = [x % 2 for x in expected_x] |
| 598 | + i = 0 |
| 599 | + while True: |
| 600 | + response = await _mock_client(ask_request) |
| 601 | + response = json.loads(response) |
| 602 | + tell_request["message"]["config"]["x"] = [expected_x[i]] |
| 603 | + tell_request["message"]["config"]["z"] = [expected_z[i]] |
| 604 | + tell_request["message"]["outcome"] = expected_y[i] |
| 605 | + tell_request["extra_info"]["e1"] = 1 |
| 606 | + tell_request["extra_info"]["e2"] = 2 |
| 607 | + i = i + 1 |
| 608 | + await _mock_client(tell_request) |
| 609 | + |
| 610 | + if response["is_finished"]: |
| 611 | + break |
| 612 | + |
| 613 | + self.s.stop() |
| 614 | + self.assertIsNone(self.s.background_process) |
| 615 | + |
| 616 | + # Create a synchronous server to check db contents |
| 617 | + s = server.AEPsychServer(database_path=self.db_path) |
| 618 | + unique_id = s.db.get_master_records()[-1].unique_id |
| 619 | + out_df = s.get_dataframe_from_replay(unique_id) |
| 620 | + self.assertTrue((out_df.x == expected_x).all()) |
| 621 | + self.assertTrue((out_df.z == expected_z).all()) |
| 622 | + self.assertTrue((out_df.response == expected_y).all()) |
| 623 | + self.assertTrue((out_df.e1 == [1] * 4).all()) |
| 624 | + self.assertTrue((out_df.e2 == [2] * 4).all()) |
| 625 | + self.assertTrue("post_mean" in out_df.columns) |
| 626 | + self.assertTrue("post_var" in out_df.columns) |
| 627 | + |
| 628 | + |
526 | 629 | if __name__ == "__main__":
|
527 | 630 | unittest.main()
|
0 commit comments