Skip to content

Commit 953ea27

Browse files
committed
Restructure CLI.instance.solutions to cancel of all coroutines
1 parent c649cc8 commit 953ea27

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

src/minizinc/CLI/instance.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, cast
2020

2121
import minizinc
22-
from minizinc.error import parse_error
22+
from minizinc.error import MiniZincError, parse_error
2323
from minizinc.instance import Instance
2424
from minizinc.json import (
2525
MZNJSONDecoder,
@@ -345,20 +345,21 @@ async def solutions(
345345
with self.files() as files, self._solver.configuration() as solver:
346346
assert self.output_type is not None
347347
cmd.extend(files)
348-
# Run the MiniZinc process
349-
proc = await self._driver.create_process(cmd, solver=solver)
350-
assert isinstance(proc.stderr, asyncio.StreamReader)
351-
assert isinstance(proc.stdout, asyncio.StreamReader)
352-
353-
# Python 3.7+: replace with asyncio.create_task
354-
read_stderr = asyncio.ensure_future(_read_all(proc.stderr))
355348

356349
status = Status.UNKNOWN
350+
last_status = Status.UNKNOWN
357351
code = 0
358-
remainder: bytes = b""
359352
statistics: Dict[str, Any] = {}
360353

361354
try:
355+
# Run the MiniZinc process
356+
proc = await self._driver.create_process(cmd, solver=solver)
357+
assert isinstance(proc.stderr, asyncio.StreamReader)
358+
assert isinstance(proc.stdout, asyncio.StreamReader)
359+
360+
# Python 3.7+: replace with asyncio.create_task
361+
read_stderr = asyncio.ensure_future(_read_all(proc.stderr))
362+
362363
if self._driver.parsed_version >= (2, 6, 0):
363364
async for obj in decode_async_json_stream(
364365
proc.stdout, cls=MZNJSONDecoder, enum_map=self._enum_map
@@ -372,6 +373,7 @@ async def solutions(
372373
if status == Status.UNKNOWN:
373374
status = Status.SATISFIED
374375
yield Result(status, solution, statistics)
376+
last_status = status
375377
solution = None
376378
statistics = {}
377379
else:
@@ -391,16 +393,8 @@ async def solutions(
391393
# Read remaining text in buffer
392394
code = await proc.wait()
393395
remainder = err.partial
394-
except asyncio.CancelledError as e:
395-
# Process was cancelled by the user.
396-
# Terminate process and read remaining output
397-
proc.terminate()
398-
remainder = await _read_all(proc.stdout)
399396

400-
if isinstance(e, asyncio.CancelledError):
401-
raise
402-
finally:
403-
# parse the remaining statistics
397+
# Parse and output the remaining statistics and status messages
404398
if self._driver.parsed_version >= (2, 6, 0):
405399
for obj in decode_json_stream(
406400
remainder, cls=MZNJSONDecoder, enum_map=self._enum_map
@@ -416,11 +410,6 @@ async def solutions(
416410
yield Result(status, solution, statistics)
417411
solution = None
418412
statistics = {}
419-
if (
420-
status not in [Status.UNKNOWN, Status.SATISFIED]
421-
or statistics != {}
422-
):
423-
yield Result(status, None, statistics)
424413
else:
425414
for res in filter(None, remainder.split(SEPARATOR)):
426415
new_status = Status.from_output(res, method)
@@ -433,17 +422,26 @@ async def solutions(
433422
self._field_renames,
434423
)
435424
yield Result(status, solution, statistics)
436-
437-
# Raise error if required
438-
stderr = None
439-
if code != 0 or status == Status.ERROR:
440-
stderr = await read_stderr
441-
raise parse_error(stderr)
442-
443-
if debug_output is not None:
444-
if stderr is None:
445-
stderr = await read_stderr
446-
debug_output.write_bytes(stderr)
425+
except (asyncio.CancelledError, MiniZincError, Exception):
426+
# Process was cancelled by the user, a MiniZincError occurred, or
427+
# an unexpected Python exception occurred
428+
# First, terminate the process
429+
proc.terminate()
430+
_ = await proc.wait()
431+
# Then, reraise the error that occurred
432+
raise
433+
if self._driver.parsed_version >= (2, 6, 0) and (
434+
status != last_status or statistics != {}
435+
):
436+
yield Result(status, None, statistics)
437+
438+
# Raise error if required
439+
stderr = await read_stderr
440+
if code != 0 or status == Status.ERROR:
441+
raise parse_error(stderr)
442+
443+
if debug_output is not None:
444+
debug_output.write_bytes(stderr)
447445

448446
@contextlib.contextmanager
449447
def flat(

0 commit comments

Comments
 (0)