Skip to content

Commit 9b39d25

Browse files
committed
Start fixing tests
1 parent 9e50301 commit 9b39d25

File tree

5 files changed

+137
-155
lines changed

5 files changed

+137
-155
lines changed

src/ert/ensemble_evaluator/evaluator.py

+42-9
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@ class EventSentinel:
1111
pass
1212

1313

14+
class UserCancelled(Exception):
15+
pass
16+
17+
1418
import zmq.asyncio
1519

1620
from _ert.events import (
1721
EEEvent,
1822
EESnapshot,
1923
EESnapshotUpdate,
20-
EETerminated,
2124
EEUserCancel,
2225
EEUserDone,
2326
EnsembleCancelled,
@@ -57,6 +60,7 @@ def __init__(
5760
self,
5861
ensemble: Ensemble,
5962
config: EvaluatorServerConfig,
63+
send_to_brm: Callable[[Event], None],
6064
) -> None:
6165
self._config: EvaluatorServerConfig = config
6266
self._ensemble: Ensemble = ensemble
@@ -66,7 +70,7 @@ def __init__(
6670

6771
self._ee_tasks: list[asyncio.Task[None]] = []
6872
self._server_done: asyncio.Event = asyncio.Event()
69-
73+
self._running_loop = asyncio.get_running_loop()
7074
# batching section
7175
self._batch_processing_queue: asyncio.Queue[
7276
list[tuple[EVENT_HANDLER, Event]]
@@ -78,19 +82,37 @@ def __init__(
7882
self._dispatchers_connected: set[bytes] = set()
7983
self._dispatchers_empty: asyncio.Event = asyncio.Event()
8084
self._dispatchers_empty.set()
81-
self._monitor_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue()
8285
current_snapshot_dict = self._ensemble.snapshot.to_dict()
86+
self._send_to_brm = send_to_brm
8387
event: Event = EESnapshot(
8488
snapshot=current_snapshot_dict,
8589
ensemble=self.ensemble.id_,
8690
)
87-
self._monitor_queue.put_nowait(event)
91+
self._send_to_brm(event)
92+
self._monitoring_result: asyncio.Future[bool] = asyncio.Future()
8893

8994
async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None:
9095
event = EESnapshotUpdate(
9196
snapshot=snapshot_update_event.to_dict(), ensemble=self._ensemble.id_
9297
)
93-
await self._monitor_queue.put(event)
98+
self._send_to_brm(event)
99+
if event.snapshot.get(ids.STATUS) in {
100+
ENSEMBLE_STATE_STOPPED,
101+
ENSEMBLE_STATE_FAILED,
102+
}:
103+
print("Ensemble was stopped")
104+
logger.debug("observed evaluation stopped event, signal done")
105+
logger.debug("monitor informing server monitor is done...")
106+
107+
done_event = EEUserDone()
108+
await self.handle_client_event(done_event)
109+
logger.debug("monitor informed server monitor is done")
110+
111+
if event.snapshot.get(ids.STATUS) == ENSEMBLE_STATE_CANCELLED:
112+
logger.debug("observed evaluation cancelled event, exit drainer")
113+
self._monitoring_result.set_exception(
114+
UserCancelled("Experiment cancelled by user during evaluation")
115+
)
94116

95117
async def _process_event_buffer(self) -> None:
96118
while True:
@@ -193,16 +215,24 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
193215
def ensemble(self) -> Ensemble:
194216
return self._ensemble
195217

218+
def cancel_gracefully(self) -> None:
219+
cancel_event = EEUserCancel()
220+
self._ee_tasks.append(
221+
self._running_loop.create_task(self.handle_client_event(cancel_event))
222+
)
223+
196224
async def handle_client_event(self, event: EEEvent) -> None:
197225
if type(event) is EEUserCancel:
198226
print("EE GOT CANCEL EVENT")
199227
logger.debug("Client asked to cancel.")
200228
await self._signal_cancel()
229+
self._monitoring_result.set_result(False)
201230
# self._clients_empty.set()
202231
elif type(event) is EEUserDone:
203232
print("EE GOT USER DONE EVENT")
204233
logger.debug("Client signalled done.")
205234
self.stop()
235+
self._monitoring_result.set_result(True)
206236
# self._clients_empty.set()
207237

208238
async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None:
@@ -249,7 +279,7 @@ async def listen_for_messages(self) -> None:
249279

250280
async def forward_checksum(self, event: Event) -> None:
251281
# clients still need to receive events via ws
252-
await self._monitor_queue.put(event)
282+
self._send_to_brm(event)
253283
await self._manifest_queue.put(event)
254284

255285
async def _server(self) -> None:
@@ -288,9 +318,12 @@ async def _server(self) -> None:
288318
await self._events.join()
289319
await self._complete_batch.wait()
290320
await self._batch_processing_queue.join()
291-
event = EETerminated(ensemble=self._ensemble.id_)
292-
await self._monitor_queue.put(event)
293-
print("PUT EETerminated")
321+
try:
322+
await asyncio.wait_for(self._monitoring_result, timeout=5)
323+
except TimeoutError:
324+
logger.warning(
325+
"Not all clients were disconnected when closing zmq server!"
326+
)
294327
logger.debug("Async server exiting.")
295328
finally:
296329
try:

src/ert/run_models/base_run_model.py

+30-102
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
from contextlib import contextmanager
1717
from pathlib import Path
1818
from queue import SimpleQueue
19-
from typing import TYPE_CHECKING, Any, Protocol, cast
19+
from typing import TYPE_CHECKING, Any, Protocol
2020

2121
import numpy as np
2222

2323
from _ert.events import (
2424
EESnapshot,
2525
EESnapshotUpdate,
26-
EETerminated,
27-
EEUserCancel,
28-
EEUserDone,
2926
Event,
3027
)
3128
from ert.analysis import ErtAnalysisError, smoother_update
@@ -51,13 +48,9 @@
5148
EvaluatorServerConfig,
5249
Realization,
5350
)
54-
from ert.ensemble_evaluator.evaluator import EventSentinel
55-
from ert.ensemble_evaluator.identifiers import STATUS
51+
from ert.ensemble_evaluator.evaluator import UserCancelled
5652
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot
5753
from ert.ensemble_evaluator.state import (
58-
ENSEMBLE_STATE_CANCELLED,
59-
ENSEMBLE_STATE_FAILED,
60-
ENSEMBLE_STATE_STOPPED,
6154
REALIZATION_STATE_FAILED,
6255
REALIZATION_STATE_FINISHED,
6356
)
@@ -105,10 +98,6 @@ def delete_runpath(run_path: str) -> None:
10598
shutil.rmtree(run_path)
10699

107100

108-
class UserCancelled(Exception):
109-
pass
110-
111-
112101
class _LogAggregration(logging.Handler):
113102
def __init__(self, messages: MutableSequence[str]) -> None:
114103
self.messages = messages
@@ -214,12 +203,13 @@ def __init__(
214203
)
215204
self._iter_snapshot: dict[int, EnsembleSnapshot] = {}
216205
self._status_queue = status_queue
217-
self._end_queue: SimpleQueue[str] = SimpleQueue()
206+
218207
# This holds state about the run model
219208
self.minimum_required_realizations = minimum_required_realizations
220209
self.active_realizations = copy.copy(active_realizations)
221210
self.start_iteration = start_iteration
222211
self.restart = False
212+
self._cancelled = False
223213

224214
@property
225215
def api(self) -> BaseRunModelAPI:
@@ -238,7 +228,7 @@ def reports_dir(self, experiment_name: str) -> str:
238228

239229
def log_at_startup(self) -> None:
240230
keys_to_drop = [
241-
"_end_queue",
231+
"_cancelled",
242232
"_queue_config",
243233
"_status_queue",
244234
"_storage",
@@ -324,7 +314,8 @@ def ensemble_size(self) -> int:
324314
return len(self._initial_realizations_mask)
325315

326316
def cancel(self) -> None:
327-
self._end_queue.put("END")
317+
self._cancelled = True
318+
self._ensemble_evaluator.cancel_gracefully()
328319

329320
def has_failed_realizations(self) -> bool:
330321
return any(self._create_mask_from_failed_realizations())
@@ -552,113 +543,43 @@ def send_snapshot_event(self, event: Event, iteration: int) -> None:
552543
)
553544
)
554545

555-
async def run_monitor(
556-
self,
557-
iteration: int,
558-
evaluator: EnsembleEvaluator,
559-
) -> bool:
560-
try:
561-
heartbeat_interval_: float | None = 0.1
562-
receiver_timeout: float = 60.0
563-
closetracker_received: bool = False
564-
565-
while True:
566-
try:
567-
event = await asyncio.wait_for(
568-
evaluator._monitor_queue.get(), timeout=heartbeat_interval_
569-
)
570-
evaluator._monitor_queue.task_done()
571-
except TimeoutError:
572-
if closetracker_received:
573-
logger.error("Evaluator did not send the TERMINATED event!")
574-
break
575-
event = None
576-
if isinstance(event, EventSentinel):
577-
closetracker_received = True
578-
heartbeat_interval_ = receiver_timeout
579-
print("JONAK - received sentinel")
580-
continue
581-
if type(event) in {
582-
EESnapshot,
583-
EESnapshotUpdate,
584-
}:
585-
event = cast(EESnapshot | EESnapshotUpdate, event)
586-
587-
self.send_snapshot_event(event, iteration)
588-
589-
if event.snapshot.get(STATUS) in {
590-
ENSEMBLE_STATE_STOPPED,
591-
ENSEMBLE_STATE_FAILED,
592-
}:
593-
print("Ensemble was stopped")
594-
logger.debug("observed evaluation stopped event, signal done")
595-
logger.debug("monitor informing server monitor is done...")
596-
597-
done_event = EEUserDone()
598-
await evaluator.handle_client_event(done_event)
599-
logger.debug("monitor informed server monitor is done")
600-
await evaluator._monitor_queue.put(EventSentinel())
601-
602-
if event.snapshot.get(STATUS) == ENSEMBLE_STATE_CANCELLED:
603-
logger.debug(
604-
"observed evaluation cancelled event, exit drainer"
605-
)
606-
raise UserCancelled(
607-
"Experiment cancelled by user during evaluation"
608-
)
609-
elif type(event) is EETerminated:
610-
logger.debug("got terminated event")
611-
break
612-
613-
if not self._end_queue.empty():
614-
print("RUN MODEL WAS CANCELLED")
615-
logger.debug("Run model canceled - during evaluation")
616-
self._end_queue.get()
617-
logger.debug("monitor asking server to cancel...")
618-
cancel_event = EEUserCancel()
619-
await evaluator.handle_client_event(cancel_event)
620-
await evaluator._monitor_queue.put(EventSentinel())
621-
logger.debug("monitor asked server to cancel")
622-
logger.debug("Run model canceled - during evaluation - cancel sent")
623-
except UserCancelled:
624-
raise
625-
except Exception as e:
626-
logger.exception(f"unexpected error: {e}")
627-
# We really don't know what happened... shut down
628-
# the thread and get out of here. The monitor has
629-
# been stopped by the ctx-mgr
630-
return False
631-
632-
return True
633-
634546
async def run_ensemble_evaluator_async(
635547
self,
636548
run_args: list[RunArg],
637549
ensemble: Ensemble,
638550
ee_config: EvaluatorServerConfig,
639551
) -> list[int]:
640-
if not self._end_queue.empty():
552+
if self._cancelled:
641553
logger.debug("Run model canceled - pre evaluation")
642-
self._end_queue.get()
554+
self._cancelled = False
643555
raise UserCancelled("Experiment cancelled by user in pre evaluation")
644556

645557
ee_ensemble = self._build_ensemble(run_args, ensemble.experiment_id)
646-
evaluator = EnsembleEvaluator(ee_ensemble, ee_config)
558+
evaluator = EnsembleEvaluator(
559+
ee_ensemble,
560+
ee_config,
561+
send_to_brm=functools.partial(
562+
self.send_snapshot_event, iteration=ensemble.iteration
563+
),
564+
)
565+
self._ensemble_evaluator = evaluator
647566
evaluator_task = asyncio.create_task(
648567
evaluator.run_and_get_successful_realizations()
649568
)
650569
await evaluator._server_started
651-
if not (await self.run_monitor(ensemble.iteration, evaluator)):
570+
571+
if not (await self._wait_for_evaluator_result(evaluator._monitoring_result)):
652572
await evaluator_task
653573
return []
654574

655575
logger.debug("observed that model was finished, waiting tasks completion...")
656576
# The model has finished, we indicate this by sending a DONE
657577
logger.debug("tasks complete")
658578

659-
if not self._end_queue.empty():
579+
if self._cancelled:
580+
print("YEGA!")
660581
logger.debug("Run model canceled - post evaluation")
661-
self._end_queue.get()
582+
self._cancelled = False
662583
try:
663584
await evaluator_task
664585
except Exception as e:
@@ -668,7 +589,8 @@ async def run_ensemble_evaluator_async(
668589
) from e
669590
print("RAISING USER_CANCELLED")
670591
raise UserCancelled("Experiment cancelled by user in post evaluation")
671-
592+
else:
593+
print("HELL NAH!")
672594
await evaluator_task
673595
ensemble.refresh_ensemble_state()
674596

@@ -849,6 +771,12 @@ def _evaluate_and_postprocess(
849771

850772
return num_successful_realizations
851773

774+
async def _wait_for_evaluator_result(
775+
self, monitoring_future: asyncio.Future[bool]
776+
) -> bool:
777+
"""This helper function is here for the sake of mocking in tests."""
778+
return await monitoring_future
779+
852780

853781
class UpdateRunModel(BaseRunModel):
854782
def __init__(

0 commit comments

Comments
 (0)