|
1 | 1 | # Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | | - |
4 | | -from typing import TYPE_CHECKING, Callable, TypeVar |
| 3 | +import asyncio |
| 4 | +import concurrent.futures |
| 5 | +from typing import TYPE_CHECKING, Callable, Optional, TypeVar |
5 | 6 |
|
6 | 7 | if TYPE_CHECKING: |
7 | 8 | from . import Connection |
|
15 | 16 | "ConnectionClosedError", |
16 | 17 | "_rewrite_exceptions", |
17 | 18 | "_translate_exceptions", |
| 19 | + "_allow_cancel", |
18 | 20 | ] |
19 | 21 |
|
20 | 22 |
|
@@ -95,9 +97,46 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): |
95 | 97 | raise ConnectionClosedError() from exc_val |
96 | 98 |
|
97 | 99 |
|
| 100 | +class AllowCancellation: |
| 101 | + def __init__(self, allow: Callable[[], bool]): |
| 102 | + self.allow = allow |
| 103 | + |
| 104 | + def __enter__(self): |
| 105 | + return self |
| 106 | + |
| 107 | + def __exit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: |
| 108 | + if exc_val is not None: |
| 109 | + # These two exceptions are actually the same type under the hood--however, nothing |
| 110 | + # that I can find in the Python docs suggest that this HAS to be the case. Either |
| 111 | + # way, suppress cancellation errors that happen when we are closed, because . |
| 112 | + return self.allow() and ( |
| 113 | + exc_type is asyncio.CancelledError or exc_type is concurrent.futures.CancelledError |
| 114 | + ) |
| 115 | + |
| 116 | + return None |
| 117 | + |
| 118 | + async def __aenter__(self): |
| 119 | + return self |
| 120 | + |
| 121 | + async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: |
| 122 | + if exc_val is not None: |
| 123 | + # These two exceptions are actually the same type under the hood--however, nothing |
| 124 | + # that I can find in the Python docs suggest that this HAS to be the case. Either |
| 125 | + # way, suppress cancellation errors that happen when we are closed, because . |
| 126 | + return self.allow() and ( |
| 127 | + exc_type is asyncio.CancelledError or exc_type is concurrent.futures.CancelledError |
| 128 | + ) |
| 129 | + |
| 130 | + return None |
| 131 | + |
| 132 | + |
98 | 133 | def _translate_exceptions(conn: "Connection") -> ExceptionTranslator: |
99 | 134 | """ |
100 | 135 | Return an (async) context manager that translates exceptions thrown from low-level gRPC calls |
101 | 136 | to high-level dazl exceptions. |
102 | 137 | """ |
103 | 138 | return ExceptionTranslator(conn) |
| 139 | + |
| 140 | + |
| 141 | +def _allow_cancel(allow: Callable[[], bool]) -> AllowCancellation: |
| 142 | + return AllowCancellation(allow) |
0 commit comments