Skip to content

Commit f3c7074

Browse files
committed
Implement async generator
Summary: A generator that allows asynchronous generation of points. Uses a different process to generate points. This doesn't fix the situation where modeling fitting takes a long time. Test Plan: New test
1 parent 972c8e1 commit f3c7074

File tree

3 files changed

+416
-0
lines changed

3 files changed

+416
-0
lines changed

aepsych/generators/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..config import Config
1111
from .acqf_grid_search_generator import AcqfGridSearchGenerator
1212
from .acqf_thompson_sampler_generator import AcqfThompsonSamplerGenerator
13+
from .async_generator import AsyncGenerator
1314
from .epsilon_greedy_generator import EpsilonGreedyGenerator
1415
from .manual_generator import ManualGenerator, SampleAroundPointsGenerator
1516
from .optimize_acqf_generator import OptimizeAcqfGenerator
@@ -27,6 +28,7 @@
2728
"IntensityAwareSemiPGenerator",
2829
"AcqfThompsonSamplerGenerator",
2930
"AcqfGridSearchGenerator",
31+
"AsyncGenerator",
3032
]
3133

3234
Config.register_module(sys.modules[__name__])

aepsych/generators/async_generator.py

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import dataclasses
9+
import os
10+
import time
11+
from concurrent import futures
12+
from multiprocessing import get_context
13+
from typing import Dict, List, Optional
14+
15+
import numpy as np
16+
17+
import torch
18+
from aepsych.generators.base import AEPsychGenerator
19+
from aepsych.models.model_protocol import ModelProtocol
20+
from aepsych.utils_logging import getLogger
21+
22+
logger = getLogger()
23+
24+
25+
@dataclasses.dataclass
26+
class AsyncPoint:
27+
"""Dataclass to keep track of asynchronously generated points."""
28+
29+
point: torch.Tensor
30+
generator_name: str
31+
gen_time: float
32+
fixed_features: Optional[Dict[int, float]] = None
33+
model: Optional[ModelProtocol] = None
34+
data: Optional[torch.Tensor] = dataclasses.field(init=False, default=None)
35+
36+
def __post_init__(self):
37+
if self.model is not None:
38+
self.data = self.model.train_inputs[0]
39+
self.model = None
40+
else:
41+
self.data = None
42+
43+
@property
44+
def data_len(self) -> int:
45+
"""Return the length of the data tensor."""
46+
return self.data.shape[0] if self.data is not None else 0
47+
48+
49+
class AsyncGenerator(AEPsychGenerator):
50+
"""Generator that holds two generators. The primary generator will always
51+
be sent to a different process to handle and if it cannot return within a
52+
timeout, the backup generator will be used instead. In the case of timeout,
53+
the other process will continue to run until the generator is called again.
54+
"""
55+
56+
def __init__(
57+
self,
58+
generator: AEPsychGenerator,
59+
backup_generator: AEPsychGenerator,
60+
timeout: float = 2.0,
61+
data_diff_limit: Optional[int] = None,
62+
n_pregen: int = 1,
63+
) -> None:
64+
"""Initialize an asynchronous generator. This holds two generators. The
65+
primary generator will always be sent to a different process to handle
66+
and if it cannot return within a timeout, the backup generator will be
67+
used instead. In the case of timeout, the other process will continue to
68+
run until the generator is called again.
69+
70+
WARNING: Whenever the gen() is called, a new processes will be
71+
forked from the main one. This means that the generators will have the
72+
exact same state (including internal RNG seeds). While we do reseed the
73+
new process, any seeds within an object (like the seed inside the
74+
SobolGenerator) will not be modified and thus can potentially generate
75+
exactly the same points. This should be fine for OptimizeAcqfGenerators.
76+
77+
Args:
78+
generator (AEPsychGenerator): The primary generator to use.
79+
backup_generator (AEPsychGenerator): The backup generator to use if
80+
the primary times out.
81+
timeout (float): The timeout for the primary generator. Defaults to
82+
2.0.
83+
data_diff_limit (int, optional): The maximum difference in data
84+
length between the model and the point to accept. If not set,
85+
there would not be any limit.
86+
n_pregen (int, optional): The number of points to pre-generate.
87+
Defaults to 1.
88+
"""
89+
self.generator = generator
90+
self.backup_generator = backup_generator
91+
self.timeout = timeout
92+
self.data_diff_limit = data_diff_limit or np.inf
93+
self.n_pregen = n_pregen
94+
self.executor: Optional[futures.ProcessPoolExecutor] = None
95+
self.futures: List[futures.Future] = []
96+
97+
# Populate generator class attributes based on main generator
98+
self._requires_model = self.generator._requires_model
99+
self.stimuli_per_trial = self.generator.stimuli_per_trial
100+
self.max_asks = self.generator.max_asks
101+
self.dim = self.generator.dim
102+
103+
def gen(
104+
self,
105+
num_points: int,
106+
model: Optional[ModelProtocol] = None,
107+
fixed_features: Optional[Dict[int, float]] = None,
108+
timeout: Optional[float] = None,
109+
**kwargs,
110+
) -> torch.Tensor:
111+
"""Get a point from the generator. When called, it will check if there
112+
are any points being generated by the primary generator and if so, wait
113+
for it to finish. If the timeout is reached, the backup generator will
114+
be used instead. Whenever there is a timeout, the primary generator will
115+
continue to work and the next time gen() is called, it will be checked
116+
again.
117+
118+
Args:
119+
num_points (int): The number of points to generate.
120+
model (ModelProtocol, optional): The model to use for generating
121+
points. Defaults to None.
122+
fixed_features (Dict[int, float], optional): The fixed features to
123+
use for generating points. Defaults to None.
124+
timeout (float, optional): The timeout for the primary generator.
125+
If not set, defaults to the class timeout.
126+
**kwargs: Additional keyword arguments to pass to the generator.
127+
128+
Returns:
129+
torch.Tensor: The generated point.
130+
"""
131+
if self.executor is None: # Initialize the executor
132+
self.executor = futures.ProcessPoolExecutor(
133+
max_workers=self.n_pregen,
134+
mp_context=get_context("spawn"),
135+
initializer=self._set_process_seed,
136+
)
137+
138+
# We keep adding futures until we have enough
139+
while len(self.futures) < self.n_pregen:
140+
self.futures.append(
141+
self.executor.submit(
142+
self._gen,
143+
num_points,
144+
model,
145+
fixed_features=fixed_features,
146+
**kwargs,
147+
)
148+
)
149+
150+
try:
151+
# We return the first future that finished
152+
timeout = timeout or self.timeout
153+
for future in futures.as_completed(self.futures, timeout=timeout):
154+
try:
155+
result = future.result()
156+
157+
# Check if fixed features match
158+
if result.fixed_features != fixed_features:
159+
# Throw it out and wait for next
160+
# Heuristic to never allow a bunch of fixed to hold us back
161+
logger.info(
162+
"AsyncGenerator found mismatched fixed features, skipping."
163+
)
164+
self.futures.remove(future)
165+
continue
166+
167+
if model is not None:
168+
# Check if the data used to generate is close enough
169+
if (
170+
result.data_len - model.train_inputs[0].shape[0]
171+
<= self.data_diff_limit
172+
):
173+
self.futures.remove(future)
174+
return result.point
175+
else:
176+
logger.info(
177+
"AsyncGenerator found a point that was generated with data that is too different, skipping."
178+
)
179+
self.futures.remove(future)
180+
else:
181+
self.futures.remove(future)
182+
return result.point
183+
184+
except (futures.CancelledError, futures.process.BrokenProcessPool) as e:
185+
logger.error("Generator job failed")
186+
logger.error(e)
187+
self.futures.remove(future)
188+
continue
189+
190+
# All futures resolved but we still have no point, so we use backup
191+
return self.backup_generator.gen(
192+
num_points=num_points,
193+
model=model,
194+
fixed_features=fixed_features,
195+
**kwargs,
196+
)
197+
198+
except futures.TimeoutError: # Timeout backup
199+
logger.info("Main generator timed out, using backup generator.")
200+
return self.backup_generator.gen(
201+
num_points=num_points,
202+
model=model,
203+
fixed_features=fixed_features,
204+
**kwargs,
205+
)
206+
207+
@staticmethod
208+
def _set_process_seed():
209+
# Set the random seed of numpy and pytorch based on pid and time
210+
seed = os.getpid() + int(time.time())
211+
torch.manual_seed(seed)
212+
np.random.seed(seed)
213+
214+
def _gen(
215+
self,
216+
num_points: int,
217+
model: Optional[ModelProtocol] = None,
218+
fixed_features: Optional[Dict[int, float]] = None,
219+
**kwargs,
220+
) -> AsyncPoint:
221+
# Wrapper to pass the generator to the executor and return a async
222+
# point, must be static as we don't want to pickle self.
223+
start = time.time()
224+
point = self.generator.gen(num_points, model, fixed_features, **kwargs)
225+
end = time.time()
226+
async_point = AsyncPoint(
227+
point=point,
228+
gen_time=end - start,
229+
generator_name=self.generator.__class__.__name__,
230+
model=model,
231+
fixed_features=fixed_features,
232+
)
233+
234+
return async_point
235+
236+
def __del__(self):
237+
# To shutdown executor on deletion
238+
if self.executor is not None:
239+
self.executor.shutdown(wait=True, cancel_futures=True)
240+
241+
def __getstate__(self):
242+
# Need to blank exectutor/futures to be able to pickle
243+
state = self.__dict__.copy()
244+
state["executor"] = None
245+
state["futures"] = []
246+
return state

0 commit comments

Comments
 (0)