Skip to content

Commit da6631e

Browse files
wjsifyrestone
andauthored
[BACKPORT]Fix duplicate exceptions in log (#2723) (#2736)
Co-authored-by: Liu Bao <[email protected]>
1 parent 2662b26 commit da6631e

File tree

2 files changed

+72
-29
lines changed

2 files changed

+72
-29
lines changed

mars/deploy/oscar/session.py

+35-29
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def __await__(self):
124124
self._ensure_future()
125125
return self._future_local.aio_future.__await__()
126126

127+
def get_future(self):
128+
self._ensure_future()
129+
return self._future_local.aio_future
130+
127131

128132
warning_msg = """
129133
No session found, local session \
@@ -1670,6 +1674,27 @@ def __exit__(self, *_):
16701674
self.close()
16711675

16721676

1677+
async def _execute_with_progress(
1678+
execution_info: ExecutionInfo,
1679+
progress_bar: ProgressBar,
1680+
progress_update_interval: Union[int, float],
1681+
cancelled: asyncio.Event,
1682+
):
1683+
with progress_bar:
1684+
while not cancelled.is_set():
1685+
done, _pending = await asyncio.wait(
1686+
[execution_info.get_future()], timeout=progress_update_interval
1687+
)
1688+
if not done:
1689+
if not cancelled.is_set() and execution_info.progress() is not None:
1690+
progress_bar.update(execution_info.progress() * 100)
1691+
else:
1692+
# done
1693+
if not cancelled.is_set():
1694+
progress_bar.update(100)
1695+
break
1696+
1697+
16731698
async def _execute(
16741699
*tileables: Tuple[TileableType],
16751700
session: _IsolatedSession = None,
@@ -1691,39 +1716,20 @@ def _attach_session(future: asyncio.Future):
16911716
if wait:
16921717
progress_bar = ProgressBar(show_progress)
16931718
if progress_bar.show_progress:
1694-
with progress_bar:
1695-
while not cancelled.is_set():
1696-
try:
1697-
await asyncio.wait_for(
1698-
asyncio.shield(execution_info), progress_update_interval
1699-
)
1700-
# done
1701-
if not cancelled.is_set():
1702-
progress_bar.update(100)
1703-
break
1704-
except asyncio.TimeoutError:
1705-
# timeout
1706-
if (
1707-
not cancelled.is_set()
1708-
and execution_info.progress() is not None
1709-
):
1710-
progress_bar.update(execution_info.progress() * 100)
1711-
if cancelled.is_set():
1712-
# cancel execution
1713-
execution_info.cancel()
1714-
execution_info.remove_done_callback(_attach_session)
1715-
await execution_info
1719+
await _execute_with_progress(
1720+
execution_info, progress_bar, progress_update_interval, cancelled
1721+
)
17161722
else:
17171723
await asyncio.wait(
17181724
[execution_info, cancelled.wait()], return_when=asyncio.FIRST_COMPLETED
17191725
)
1720-
if cancelled.is_set():
1721-
execution_info.remove_done_callback(_attach_session)
1722-
execution_info.cancel()
1723-
else:
1724-
# set cancelled to avoid wait task leak
1725-
cancelled.set()
1726-
await execution_info
1726+
if cancelled.is_set():
1727+
execution_info.remove_done_callback(_attach_session)
1728+
execution_info.cancel()
1729+
else:
1730+
# set cancelled to avoid wait task leak
1731+
cancelled.set()
1732+
await execution_info
17271733
else:
17281734
return execution_info
17291735

mars/deploy/oscar/tests/test_local.py

+37
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ....storage import StorageLevel
3838
from ....services.storage import StorageAPI
3939
from ....tensor.arithmetic.add import TensorAdd
40+
from ....tests.core import mock
4041
from ..local import new_cluster
4142
from ..service import load_config
4243
from ..session import (
@@ -48,7 +49,10 @@
4849
fetch_infos,
4950
stop_server,
5051
AsyncSession,
52+
ExecutionInfo,
53+
Progress,
5154
_IsolatedWebSession,
55+
_execute_with_progress,
5256
)
5357
from .modules.utils import ( # noqa: F401; pylint: disable=unused-variable
5458
cleanup_third_party_modules_output,
@@ -574,3 +578,36 @@ def test_load_third_party_modules(cleanup_third_party_modules_output): # noqa:
574578

575579
session.stop_server()
576580
assert get_default_session() is None
581+
582+
583+
@mock.patch("asyncio.base_events.logger")
584+
def test_show_progress_raise_exception(m_log):
585+
loop = asyncio.get_event_loop()
586+
event = asyncio.Event()
587+
588+
class ProgressBar:
589+
def __init__(self, *args, **kwargs):
590+
pass
591+
592+
def __enter__(self):
593+
pass
594+
595+
def __exit__(self, *_):
596+
pass
597+
598+
def update(self, progress: float):
599+
pass
600+
601+
async def _exec():
602+
progress = Progress()
603+
execution_info = ExecutionInfo(
604+
asyncio.create_task(event.wait()), progress, loop
605+
)
606+
progress_bar = ProgressBar(True)
607+
cancel_event = asyncio.Event()
608+
loop.call_later(2, cancel_event.set)
609+
await _execute_with_progress(execution_info, progress_bar, 0.01, cancel_event)
610+
execution_info.get_future().set_exception(Exception("Expect Exception!!!"))
611+
612+
loop.run_until_complete(_exec())
613+
assert len(m_log.mock_calls) < 3

0 commit comments

Comments
 (0)