16
16
from contextlib import contextmanager
17
17
from pathlib import Path
18
18
from queue import SimpleQueue
19
- from typing import TYPE_CHECKING , Any , Protocol , cast
19
+ from typing import TYPE_CHECKING , Any , Protocol
20
20
21
21
import numpy as np
22
22
23
23
from _ert .events import (
24
24
EESnapshot ,
25
25
EESnapshotUpdate ,
26
- EETerminated ,
27
- EEUserCancel ,
28
- EEUserDone ,
29
26
Event ,
30
27
)
31
28
from ert .analysis import ErtAnalysisError , smoother_update
51
48
EvaluatorServerConfig ,
52
49
Realization ,
53
50
)
54
- from ert .ensemble_evaluator .evaluator import EventSentinel
55
- from ert .ensemble_evaluator .identifiers import STATUS
51
+ from ert .ensemble_evaluator .evaluator import UserCancelled
56
52
from ert .ensemble_evaluator .snapshot import EnsembleSnapshot
57
53
from ert .ensemble_evaluator .state import (
58
- ENSEMBLE_STATE_CANCELLED ,
59
- ENSEMBLE_STATE_FAILED ,
60
- ENSEMBLE_STATE_STOPPED ,
61
54
REALIZATION_STATE_FAILED ,
62
55
REALIZATION_STATE_FINISHED ,
63
56
)
@@ -105,10 +98,6 @@ def delete_runpath(run_path: str) -> None:
105
98
shutil .rmtree (run_path )
106
99
107
100
108
- class UserCancelled (Exception ):
109
- pass
110
-
111
-
112
101
class _LogAggregration (logging .Handler ):
113
102
def __init__ (self , messages : MutableSequence [str ]) -> None :
114
103
self .messages = messages
@@ -214,12 +203,13 @@ def __init__(
214
203
)
215
204
self ._iter_snapshot : dict [int , EnsembleSnapshot ] = {}
216
205
self ._status_queue = status_queue
217
- self . _end_queue : SimpleQueue [ str ] = SimpleQueue ()
206
+
218
207
# This holds state about the run model
219
208
self .minimum_required_realizations = minimum_required_realizations
220
209
self .active_realizations = copy .copy (active_realizations )
221
210
self .start_iteration = start_iteration
222
211
self .restart = False
212
+ self ._cancelled = False
223
213
224
214
@property
225
215
def api (self ) -> BaseRunModelAPI :
@@ -238,7 +228,7 @@ def reports_dir(self, experiment_name: str) -> str:
238
228
239
229
def log_at_startup (self ) -> None :
240
230
keys_to_drop = [
241
- "_end_queue " ,
231
+ "_cancelled " ,
242
232
"_queue_config" ,
243
233
"_status_queue" ,
244
234
"_storage" ,
@@ -324,7 +314,8 @@ def ensemble_size(self) -> int:
324
314
return len (self ._initial_realizations_mask )
325
315
326
316
def cancel (self ) -> None :
327
- self ._end_queue .put ("END" )
317
+ self ._cancelled = True
318
+ self ._ensemble_evaluator .cancel_gracefully ()
328
319
329
320
def has_failed_realizations (self ) -> bool :
330
321
return any (self ._create_mask_from_failed_realizations ())
@@ -552,113 +543,43 @@ def send_snapshot_event(self, event: Event, iteration: int) -> None:
552
543
)
553
544
)
554
545
555
- async def run_monitor (
556
- self ,
557
- iteration : int ,
558
- evaluator : EnsembleEvaluator ,
559
- ) -> bool :
560
- try :
561
- heartbeat_interval_ : float | None = 0.1
562
- receiver_timeout : float = 60.0
563
- closetracker_received : bool = False
564
-
565
- while True :
566
- try :
567
- event = await asyncio .wait_for (
568
- evaluator ._monitor_queue .get (), timeout = heartbeat_interval_
569
- )
570
- evaluator ._monitor_queue .task_done ()
571
- except TimeoutError :
572
- if closetracker_received :
573
- logger .error ("Evaluator did not send the TERMINATED event!" )
574
- break
575
- event = None
576
- if isinstance (event , EventSentinel ):
577
- closetracker_received = True
578
- heartbeat_interval_ = receiver_timeout
579
- print ("JONAK - received sentinel" )
580
- continue
581
- if type (event ) in {
582
- EESnapshot ,
583
- EESnapshotUpdate ,
584
- }:
585
- event = cast (EESnapshot | EESnapshotUpdate , event )
586
-
587
- self .send_snapshot_event (event , iteration )
588
-
589
- if event .snapshot .get (STATUS ) in {
590
- ENSEMBLE_STATE_STOPPED ,
591
- ENSEMBLE_STATE_FAILED ,
592
- }:
593
- print ("Ensemble was stopped" )
594
- logger .debug ("observed evaluation stopped event, signal done" )
595
- logger .debug ("monitor informing server monitor is done..." )
596
-
597
- done_event = EEUserDone ()
598
- await evaluator .handle_client_event (done_event )
599
- logger .debug ("monitor informed server monitor is done" )
600
- await evaluator ._monitor_queue .put (EventSentinel ())
601
-
602
- if event .snapshot .get (STATUS ) == ENSEMBLE_STATE_CANCELLED :
603
- logger .debug (
604
- "observed evaluation cancelled event, exit drainer"
605
- )
606
- raise UserCancelled (
607
- "Experiment cancelled by user during evaluation"
608
- )
609
- elif type (event ) is EETerminated :
610
- logger .debug ("got terminated event" )
611
- break
612
-
613
- if not self ._end_queue .empty ():
614
- print ("RUN MODEL WAS CANCELLED" )
615
- logger .debug ("Run model canceled - during evaluation" )
616
- self ._end_queue .get ()
617
- logger .debug ("monitor asking server to cancel..." )
618
- cancel_event = EEUserCancel ()
619
- await evaluator .handle_client_event (cancel_event )
620
- await evaluator ._monitor_queue .put (EventSentinel ())
621
- logger .debug ("monitor asked server to cancel" )
622
- logger .debug ("Run model canceled - during evaluation - cancel sent" )
623
- except UserCancelled :
624
- raise
625
- except Exception as e :
626
- logger .exception (f"unexpected error: { e } " )
627
- # We really don't know what happened... shut down
628
- # the thread and get out of here. The monitor has
629
- # been stopped by the ctx-mgr
630
- return False
631
-
632
- return True
633
-
634
546
async def run_ensemble_evaluator_async (
635
547
self ,
636
548
run_args : list [RunArg ],
637
549
ensemble : Ensemble ,
638
550
ee_config : EvaluatorServerConfig ,
639
551
) -> list [int ]:
640
- if not self ._end_queue . empty () :
552
+ if self ._cancelled :
641
553
logger .debug ("Run model canceled - pre evaluation" )
642
- self ._end_queue . get ()
554
+ self ._cancelled = False
643
555
raise UserCancelled ("Experiment cancelled by user in pre evaluation" )
644
556
645
557
ee_ensemble = self ._build_ensemble (run_args , ensemble .experiment_id )
646
- evaluator = EnsembleEvaluator (ee_ensemble , ee_config )
558
+ evaluator = EnsembleEvaluator (
559
+ ee_ensemble ,
560
+ ee_config ,
561
+ send_to_brm = functools .partial (
562
+ self .send_snapshot_event , iteration = ensemble .iteration
563
+ ),
564
+ )
565
+ self ._ensemble_evaluator = evaluator
647
566
evaluator_task = asyncio .create_task (
648
567
evaluator .run_and_get_successful_realizations ()
649
568
)
650
569
await evaluator ._server_started
651
- if not (await self .run_monitor (ensemble .iteration , evaluator )):
570
+
571
+ if not (await self ._wait_for_evaluator_result (evaluator ._monitoring_result )):
652
572
await evaluator_task
653
573
return []
654
574
655
575
logger .debug ("observed that model was finished, waiting tasks completion..." )
656
576
# The model has finished, we indicate this by sending a DONE
657
577
logger .debug ("tasks complete" )
658
578
659
- if not self ._end_queue .empty ():
579
+ if self ._cancelled :
580
+ print ("YEGA!" )
660
581
logger .debug ("Run model canceled - post evaluation" )
661
- self ._end_queue . get ()
582
+ self ._cancelled = False
662
583
try :
663
584
await evaluator_task
664
585
except Exception as e :
@@ -668,7 +589,8 @@ async def run_ensemble_evaluator_async(
668
589
) from e
669
590
print ("RAISING USER_CANCELLED" )
670
591
raise UserCancelled ("Experiment cancelled by user in post evaluation" )
671
-
592
+ else :
593
+ print ("HELL NAH!" )
672
594
await evaluator_task
673
595
ensemble .refresh_ensemble_state ()
674
596
@@ -849,6 +771,12 @@ def _evaluate_and_postprocess(
849
771
850
772
return num_successful_realizations
851
773
774
+ async def _wait_for_evaluator_result (
775
+ self , monitoring_future : asyncio .Future [bool ]
776
+ ) -> bool :
777
+ """This helper function is here for the sake of mocking in tests."""
778
+ return await monitoring_future
779
+
852
780
853
781
class UpdateRunModel (BaseRunModel ):
854
782
def __init__ (
0 commit comments