8
8
from __future__ import annotations
9
9
10
10
import warnings
11
+ from copy import deepcopy
11
12
from typing import Any , Dict , List , Mapping , Optional , Tuple , Union
12
13
13
14
import numpy as np
@@ -56,6 +57,7 @@ def __init__(
56
57
name : str = "" ,
57
58
run_indefinitely : bool = False ,
58
59
transforms : ChainedInputTransform = ChainedInputTransform (** {}),
60
+ copy_model : bool = False ,
59
61
) -> None :
60
62
"""Initialize the strategy object.
61
63
@@ -90,6 +92,9 @@ def __init__(
90
92
should be defined in raw parameter space for initialization. However,
91
93
if the lb/ub attribute are access from an initialized Strategy object,
92
94
it will be returned in transformed space.
95
+ copy_model (bool): Whether to do any model-related methods on a
96
+ copy or the original. Used for multi-client strategies. Defaults
97
+ to False.
93
98
"""
94
99
self .is_finished = False
95
100
@@ -160,6 +165,7 @@ def __init__(
160
165
self .min_total_outcome_occurrences = min_total_outcome_occurrences
161
166
self .max_asks = max_asks or generator .max_asks
162
167
self .keep_most_recent = keep_most_recent
168
+ self .copy_model = copy_model
163
169
164
170
self .transforms = transforms
165
171
if self .transforms is not None :
@@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
267
273
self .model .to (self .generator_device ) # type: ignore
268
274
269
275
self ._count = self ._count + num_points
270
- points = self .generator .gen (num_points , self .model , ** kwargs )
276
+ model = deepcopy (self .model ) if self .copy_model else self .model
277
+ points = self .generator .gen (num_points , model , ** kwargs )
271
278
272
279
if original_device is not None :
273
280
self .model .to (original_device ) # type: ignore
@@ -295,9 +302,9 @@ def get_max(
295
302
self .model is not None
296
303
), "model is None! Cannot get the max without a model!"
297
304
self .model .to (self .model_device )
298
-
305
+ model = deepcopy ( self . model ) if self . copy_model else self . model
299
306
val , arg = get_max (
300
- self . model ,
307
+ model ,
301
308
self .bounds ,
302
309
locked_dims = constraints ,
303
310
probability_space = probability_space ,
@@ -324,9 +331,9 @@ def get_min(
324
331
self .model is not None
325
332
), "model is None! Cannot get the min without a model!"
326
333
self .model .to (self .model_device )
327
-
334
+ model = deepcopy ( self . model ) if self . copy_model else self . model
328
335
val , arg = get_min (
329
- self . model ,
336
+ model ,
330
337
self .bounds ,
331
338
locked_dims = constraints ,
332
339
probability_space = probability_space ,
@@ -358,9 +365,9 @@ def inv_query(
358
365
self .model is not None
359
366
), "model is None! Cannot get the inv_query without a model!"
360
367
self .model .to (self .model_device )
361
-
368
+ model = deepcopy ( self . model ) if self . copy_model else self . model
362
369
val , arg = inv_query (
363
- model = self . model ,
370
+ model = model ,
364
371
y = y ,
365
372
bounds = self .bounds ,
366
373
locked_dims = constraints ,
@@ -385,7 +392,8 @@ def predict(
385
392
"""
386
393
assert self .model is not None , "model is None! Cannot predict without a model!"
387
394
self .model .to (self .model_device )
388
- return self .model .predict (x = x , probability_space = probability_space )
395
+ model = deepcopy (self .model ) if self .copy_model else self .model
396
+ return model .predict (x = x , probability_space = probability_space )
389
397
390
398
@ensure_model_is_fresh
391
399
def sample (self , x : torch .Tensor , num_samples : int = 1000 ) -> torch .Tensor :
@@ -400,7 +408,8 @@ def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor:
400
408
"""
401
409
assert self .model is not None , "model is None! Cannot sample without a model!"
402
410
self .model .to (self .model_device )
403
- return self .model .sample (x , num_samples = num_samples )
411
+ model = deepcopy (self .model ) if self .copy_model else self .model
412
+ return model .sample (x , num_samples = num_samples )
404
413
405
414
def finish (self ) -> None :
406
415
"""Finish the strategy."""
@@ -442,7 +451,8 @@ def finished(self) -> bool:
442
451
assert (
443
452
self .model is not None
444
453
), "model is None! Cannot predict without a model!"
445
- fmean , _ = self .model .predict (self .eval_grid , probability_space = True )
454
+ model = deepcopy (self .model ) if self .copy_model else self .model
455
+ fmean , _ = model .predict (self .eval_grid , probability_space = True )
446
456
meets_post_range = bool (
447
457
((fmean .max () - fmean .min ()) >= self .min_post_range ).item ()
448
458
)
@@ -504,9 +514,10 @@ def fit(self) -> None:
504
514
"""Fit the model."""
505
515
if self .can_fit :
506
516
self .model .to (self .model_device ) # type: ignore
517
+ model = deepcopy (self .model ) if self .copy_model else self .model
507
518
if self .keep_most_recent is not None :
508
519
try :
509
- self . model .fit ( # type: ignore
520
+ model .fit ( # type: ignore
510
521
self .x [- self .keep_most_recent :], # type: ignore
511
522
self .y [- self .keep_most_recent :], # type: ignore
512
523
)
@@ -516,21 +527,23 @@ def fit(self) -> None:
516
527
)
517
528
else :
518
529
try :
519
- self . model .fit (self .x , self .y ) # type: ignore
530
+ model .fit (self .x , self .y ) # type: ignore
520
531
except ModelFittingError :
521
532
logger .warning (
522
533
"Failed to fit model! Predictions may not be accurate!"
523
534
)
535
+ self .model = model
524
536
else :
525
537
warnings .warn ("Cannot fit: no model has been initialized!" , RuntimeWarning )
526
538
527
539
def update (self ) -> None :
528
540
"""Update the model."""
529
541
if self .can_fit :
530
542
self .model .to (self .model_device ) # type: ignore
543
+ model = deepcopy (self .model ) if self .copy_model else self .model
531
544
if self .keep_most_recent is not None :
532
545
try :
533
- self . model .update ( # type: ignore
546
+ model .update ( # type: ignore
534
547
self .x [- self .keep_most_recent :], # type: ignore
535
548
self .y [- self .keep_most_recent :], # type: ignore
536
549
)
@@ -540,11 +553,13 @@ def update(self) -> None:
540
553
)
541
554
else :
542
555
try :
543
- self . model .update (self .x , self .y ) # type: ignore
556
+ model .update (self .x , self .y ) # type: ignore
544
557
except ModelFittingError :
545
558
logger .warning (
546
559
"Failed to fit model! Predictions may not be accurate!"
547
560
)
561
+
562
+ self .model = model
548
563
else :
549
564
warnings .warn ("Cannot fit: no model has been initialized!" , RuntimeWarning )
550
565
0 commit comments