Skip to content

Commit 61f22b0

Browse files
committed
Strategy methods act on copies of models if needed
Summary: To support multi client server, strategy methods will not act on copies of models to avoid changing tensor gradients between two threads. Test Plan: New test
1 parent 18f8464 commit 61f22b0

File tree

3 files changed

+81
-18
lines changed

3 files changed

+81
-18
lines changed

aepsych/server/server.py

+3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self.host = host
5454
self.port = port
5555
self.max_workers = max_workers
56+
self.clients_connected = 0
5657
self.db: db.Database = db.Database(database_path)
5758
self.is_performing_replay = False
5859
self.exit_server_loop = False
@@ -323,6 +324,7 @@ async def handle_client(self, reader, writer):
323324
"""
324325
addr = writer.get_extra_info("peername")
325326
logger.info(f"Connected to {addr}")
327+
self.clients_connected += 1
326328

327329
try:
328330
while True:
@@ -361,6 +363,7 @@ async def handle_client(self, reader, writer):
361363
logger.info(f"Connection closed for {addr}")
362364
writer.close()
363365
await writer.wait_closed()
366+
self.clients_connected -= 1
364367

365368
def handle_request(self, message: Dict[str, Any]) -> Union[Dict[str, Any], str]:
366369
"""Given a message, dispatch the correct handler and return the result.

aepsych/strategy/strategy.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import warnings
11+
from copy import deepcopy
1112
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
1213

1314
import numpy as np
@@ -56,6 +57,7 @@ def __init__(
5657
name: str = "",
5758
run_indefinitely: bool = False,
5859
transforms: ChainedInputTransform = ChainedInputTransform(**{}),
60+
copy_model: bool = False,
5961
) -> None:
6062
"""Initialize the strategy object.
6163
@@ -90,6 +92,9 @@ def __init__(
9092
should be defined in raw parameter space for initialization. However,
9193
if the lb/ub attribute are access from an initialized Strategy object,
9294
it will be returned in transformed space.
95+
copy_model (bool): Whether to do any model-related methods on a
96+
copy or the original. Used for multi-client strategies. Defaults
97+
to False.
9398
"""
9499
self.is_finished = False
95100

@@ -160,6 +165,7 @@ def __init__(
160165
self.min_total_outcome_occurrences = min_total_outcome_occurrences
161166
self.max_asks = max_asks or generator.max_asks
162167
self.keep_most_recent = keep_most_recent
168+
self.copy_model = copy_model
163169

164170
self.transforms = transforms
165171
if self.transforms is not None:
@@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
267273
self.model.to(self.generator_device) # type: ignore
268274

269275
self._count = self._count + num_points
270-
points = self.generator.gen(num_points, self.model, **kwargs)
276+
model = deepcopy(self.model) if self.copy_model else self.model
277+
points = self.generator.gen(num_points, model, **kwargs)
271278

272279
if original_device is not None:
273280
self.model.to(original_device) # type: ignore
@@ -295,9 +302,9 @@ def get_max(
295302
self.model is not None
296303
), "model is None! Cannot get the max without a model!"
297304
self.model.to(self.model_device)
298-
305+
model = deepcopy(self.model) if self.copy_model else self.model
299306
val, arg = get_max(
300-
self.model,
307+
model,
301308
self.bounds,
302309
locked_dims=constraints,
303310
probability_space=probability_space,
@@ -324,9 +331,9 @@ def get_min(
324331
self.model is not None
325332
), "model is None! Cannot get the min without a model!"
326333
self.model.to(self.model_device)
327-
334+
model = deepcopy(self.model) if self.copy_model else self.model
328335
val, arg = get_min(
329-
self.model,
336+
model,
330337
self.bounds,
331338
locked_dims=constraints,
332339
probability_space=probability_space,
@@ -358,9 +365,9 @@ def inv_query(
358365
self.model is not None
359366
), "model is None! Cannot get the inv_query without a model!"
360367
self.model.to(self.model_device)
361-
368+
model = deepcopy(self.model) if self.copy_model else self.model
362369
val, arg = inv_query(
363-
model=self.model,
370+
model=model,
364371
y=y,
365372
bounds=self.bounds,
366373
locked_dims=constraints,
@@ -385,7 +392,8 @@ def predict(
385392
"""
386393
assert self.model is not None, "model is None! Cannot predict without a model!"
387394
self.model.to(self.model_device)
388-
return self.model.predict(x=x, probability_space=probability_space)
395+
model = deepcopy(self.model) if self.copy_model else self.model
396+
return model.predict(x=x, probability_space=probability_space)
389397

390398
@ensure_model_is_fresh
391399
def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor:
@@ -400,7 +408,8 @@ def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor:
400408
"""
401409
assert self.model is not None, "model is None! Cannot sample without a model!"
402410
self.model.to(self.model_device)
403-
return self.model.sample(x, num_samples=num_samples)
411+
model = deepcopy(self.model) if self.copy_model else self.model
412+
return model.sample(x, num_samples=num_samples)
404413

405414
def finish(self) -> None:
406415
"""Finish the strategy."""
@@ -442,7 +451,8 @@ def finished(self) -> bool:
442451
assert (
443452
self.model is not None
444453
), "model is None! Cannot predict without a model!"
445-
fmean, _ = self.model.predict(self.eval_grid, probability_space=True)
454+
model = deepcopy(self.model) if self.copy_model else self.model
455+
fmean, _ = model.predict(self.eval_grid, probability_space=True)
446456
meets_post_range = bool(
447457
((fmean.max() - fmean.min()) >= self.min_post_range).item()
448458
)
@@ -504,9 +514,10 @@ def fit(self) -> None:
504514
"""Fit the model."""
505515
if self.can_fit:
506516
self.model.to(self.model_device) # type: ignore
517+
model = deepcopy(self.model) if self.copy_model else self.model
507518
if self.keep_most_recent is not None:
508519
try:
509-
self.model.fit( # type: ignore
520+
model.fit( # type: ignore
510521
self.x[-self.keep_most_recent :], # type: ignore
511522
self.y[-self.keep_most_recent :], # type: ignore
512523
)
@@ -516,21 +527,23 @@ def fit(self) -> None:
516527
)
517528
else:
518529
try:
519-
self.model.fit(self.x, self.y) # type: ignore
530+
model.fit(self.x, self.y) # type: ignore
520531
except ModelFittingError:
521532
logger.warning(
522533
"Failed to fit model! Predictions may not be accurate!"
523534
)
535+
self.model = model
524536
else:
525537
warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning)
526538

527539
def update(self) -> None:
528540
"""Update the model."""
529541
if self.can_fit:
530542
self.model.to(self.model_device) # type: ignore
543+
model = deepcopy(self.model) if self.copy_model else self.model
531544
if self.keep_most_recent is not None:
532545
try:
533-
self.model.update( # type: ignore
546+
model.update( # type: ignore
534547
self.x[-self.keep_most_recent :], # type: ignore
535548
self.y[-self.keep_most_recent :], # type: ignore
536549
)
@@ -540,11 +553,13 @@ def update(self) -> None:
540553
)
541554
else:
542555
try:
543-
self.model.update(self.x, self.y) # type: ignore
556+
model.update(self.x, self.y) # type: ignore
544557
except ModelFittingError:
545558
logger.warning(
546559
"Failed to fit model! Predictions may not be accurate!"
547560
)
561+
562+
self.model = model
548563
else:
549564
warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning)
550565

tests/server/test_server.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
generator = OptimizeAcqfGenerator
4646
model = GPClassificationModel
4747
min_total_outcome_occurrences = 0
48+
copy_model = True
4849
4950
[OptimizeAcqfGenerator]
5051
acqf = MCPosteriorVariance
@@ -82,15 +83,17 @@ def database_path(self):
8283
return "./{}_test_server.db".format(str(uuid.uuid4().hex))
8384

8485
async def asyncSetUp(self):
85-
ip = "127.0.0.1"
86-
port = 5555
86+
self.ip = "127.0.0.1"
87+
self.port = 5555
8788

8889
# setup logger
8990
server.logger = utils_logging.getLogger("unittests")
9091

9192
# random datebase path name without dashes
9293
database_path = self.database_path
93-
self.s = server.AEPsychServer(database_path=database_path, host=ip, port=port)
94+
self.s = server.AEPsychServer(
95+
database_path=database_path, host=self.ip, port=self.port
96+
)
9497
self.db_name = database_path.split("/")[1]
9598
self.db_path = database_path
9699

@@ -106,7 +109,7 @@ async def asyncSetUp(self):
106109
self.server_task = asyncio.create_task(self.s.serve())
107110
await asyncio.sleep(0.1)
108111

109-
self.reader, self.writer = await asyncio.open_connection(ip, port)
112+
self.reader, self.writer = await asyncio.open_connection(self.ip, self.port)
110113

111114
async def asyncTearDown(self):
112115
# Stops the client
@@ -486,6 +489,48 @@ async def test_receive(self):
486489
else:
487490
self.assertTrue("KeyError" in response["error"]) # Specific error
488491

492+
async def test_multi_client(self):
493+
setup_request = {
494+
"type": "setup",
495+
"version": "0.01",
496+
"message": {"config_str": dummy_config},
497+
}
498+
ask_request = {"type": "ask", "message": ""}
499+
tell_request = {
500+
"type": "tell",
501+
"message": {"config": {"x": [0.5]}, "outcome": 1},
502+
"extra_info": {},
503+
}
504+
505+
await self.mock_client(setup_request)
506+
507+
# Create second client
508+
reader2, writer2 = await asyncio.open_connection(self.ip, self.port)
509+
510+
async def _mock_client2(request: Dict[str, Any]) -> Any:
511+
writer2.write(json.dumps(request).encode())
512+
await writer2.drain()
513+
514+
response = await reader2.read(1024 * 512)
515+
return response.decode()
516+
517+
for _ in range(2): # 2 loops should do it as we have 2 clients
518+
tasks = [
519+
asyncio.create_task(self.mock_client(ask_request)),
520+
asyncio.create_task(_mock_client2(ask_request)),
521+
]
522+
await asyncio.gather(*tasks)
523+
524+
tasks = [
525+
asyncio.create_task(self.mock_client(tell_request)),
526+
asyncio.create_task(_mock_client2(tell_request)),
527+
]
528+
await asyncio.gather(*tasks)
529+
530+
self.assertTrue(self.s.strat.finished)
531+
self.assertTrue(self.s.strat.x.numel() == 4)
532+
self.assertTrue(self.s.clients_connected == 2)
533+
489534

490535
if __name__ == "__main__":
491536
unittest.main()

0 commit comments

Comments
 (0)