diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5415a82398c7..f73338582192 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,13 @@ repos: entry: ./activated.py --poetry poetry check language: system pass_filenames: false + - repo: local + hooks: + - id: shielding + name: Check for proper async cancellation shielding + entry: ./activated.py chia dev check shielding + language: system + pass_filenames: false - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 hooks: diff --git a/chia/cmds/check.py b/chia/cmds/check.py new file mode 100644 index 000000000000..9fa4e749d57c --- /dev/null +++ b/chia/cmds/check.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import click + +from chia.cmds.check_func import check_shielding + + +@click.group(name="check", help="Project checks such as might be run in CI") +def check_group() -> None: + pass + + +@check_group.command(name="shielding") +@click.option("--use-file-ignore/--no-file-ignore", default=True) +def shielding_command(use_file_ignore: bool) -> None: + count = check_shielding(use_file_ignore=use_file_ignore) + + message = f"{count} concerns found" + if count > 0: + raise click.ClickException(message) + else: + print(message) diff --git a/chia/cmds/check_func.py b/chia/cmds/check_func.py new file mode 100644 index 000000000000..bf72269665bb --- /dev/null +++ b/chia/cmds/check_func.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import pathlib +import re + + +def check_shielding(use_file_ignore: bool) -> int: + exclude = {"mozilla-ca"} + roots = [path.parent for path in sorted(pathlib.Path(".").glob("*/__init__.py")) if path.parent.name not in exclude] + + total_count = 0 + paths_with_concerns: set[pathlib.Path] = set() + for root in roots: + for path in sorted(root.glob("**/*.py")): + lines = path.read_text().splitlines() + + for line_index, line in enumerate(lines): + line_number = line_index + 1 + + this_match = re.search(r"^ *(async def [^(]*(close|stop)|(except|finally)\b)[^:]*:", line) + if this_match is not None: + next_line_index = line_index + 1 + if next_line_index < len(lines): + next_line = lines[line_index + 1] + + ignore_match = re.search(r"^ *# shielding not required: .{10,}", next_line) + if ignore_match is not None: + continue + + next_match = re.search(r"^ *with anyio.CancelScope\(shield=True\):", next_line) + else: + next_match = None + if next_match is None: + for def_line in reversed(lines[:line_index]): + def_match = re.search(r"^ *def", def_line) + if def_match is not None: + # not async, doesn't need to be shielded + break + + async_def_match = re.search(r"^ *async def", def_line) + if async_def_match is not None: + paths_with_concerns.add(path) + if not (use_file_ignore and path in hardcoded_file_ignore_list): + total_count += 1 + print(f"{path.as_posix()}:{line_number}: {line}") + break + + if use_file_ignore: + for path in hardcoded_file_ignore_list - paths_with_concerns: + total_count += 1 + source = pathlib.Path(__file__).relative_to(pathlib.Path(".").absolute()) + line_number = next( + line_number + for line_number, line in enumerate(source.read_text().splitlines(), start=1) + if line.find(f'"{path.as_posix()}"') >= 0 + ) + print(f"{source.as_posix()}:{line_number}: {path.as_posix()} unnecessarily in file ignore list") + + return total_count + + +# chia dev check shielding --no-file-ignore &| sed -nE 's/^([^:]*):[0-9]*:.*/ "\1",/p' | sort | uniq +hardcoded_file_ignore_list_strings = { + "benchmarks/utils.py", + "chia/cmds/check_wallet_db.py", + "chia/cmds/cmds_util.py", + "chia/cmds/coin_funcs.py", + "chia/cmds/dao_funcs.py", + "chia/cmds/farm_funcs.py", + "chia/cmds/passphrase_funcs.py", + "chia/cmds/peer_funcs.py", + "chia/cmds/plotnft_funcs.py", + "chia/cmds/rpc.py", + "chia/cmds/sim_funcs.py", + "chia/cmds/start_funcs.py", + "chia/cmds/wallet_funcs.py", + "chia/consensus/multiprocess_validation.py", + "chia/daemon/client.py", + "chia/daemon/keychain_proxy.py", + "chia/daemon/keychain_server.py", + "chia/daemon/server.py", + "chia/data_layer/data_layer.py", + "chia/data_layer/data_layer_wallet.py", + "chia/data_layer/data_store.py", + "chia/data_layer/download_data.py", + "chia/data_layer/s3_plugin_service.py", + "chia/farmer/farmer_api.py", + "chia/farmer/farmer.py", + "chia/full_node/block_height_map.py", + "chia/full_node/block_store.py", + "chia/full_node/full_node_api.py", + "chia/full_node/full_node.py", + "chia/full_node/mempool_manager.py", + "chia/full_node/mempool.py", + "chia/harvester/harvester_api.py", + "chia/harvester/harvester.py", + "chia/introducer/introducer.py", + "chia/plot_sync/receiver.py", + "chia/plot_sync/sender.py", + "chia/plotting/create_plots.py", + "chia/pools/pool_wallet.py", + "chia/rpc/crawler_rpc_api.py", + "chia/rpc/data_layer_rpc_api.py", + "chia/rpc/farmer_rpc_client.py", + "chia/rpc/full_node_rpc_api.py", + "chia/rpc/full_node_rpc_client.py", + "chia/rpc/rpc_client.py", + "chia/rpc/rpc_server.py", + "chia/rpc/util.py", + "chia/rpc/wallet_rpc_api.py", + "chia/rpc/wallet_rpc_client.py", + "chia/seeder/crawler.py", + "chia/seeder/crawl_store.py", + "chia/seeder/dns_server.py", + "chia/server/address_manager_store.py", + "chia/server/chia_policy.py", + "chia/server/node_discovery.py", + "chia/server/server.py", + "chia/server/signal_handlers.py", + "chia/server/start_service.py", + "chia/server/ws_connection.py", + "chia/simulator/block_tools.py", + "chia/simulator/setup_services.py", + "chia/_tests/blockchain/blockchain_test_utils.py", + "chia/_tests/clvm/test_singletons.py", + "chia/_tests/conftest.py", + "chia/_tests/core/data_layer/test_data_rpc.py", + "chia/_tests/core/data_layer/test_data_store.py", + "chia/_tests/core/full_node/full_sync/test_full_sync.py", + "chia/_tests/core/full_node/ram_db.py", + "chia/_tests/core/full_node/stores/test_coin_store.py", + "chia/_tests/core/mempool/test_mempool_manager.py", + "chia/_tests/core/server/flood.py", + "chia/_tests/core/server/serve.py", + "chia/_tests/core/server/test_event_loop.py", + "chia/_tests/core/server/test_loop.py", + "chia/_tests/core/services/test_services.py", + "chia/_tests/core/test_farmer_harvester_rpc.py", + "chia/_tests/core/test_full_node_rpc.py", + "chia/_tests/db/test_db_wrapper.py", + "chia/_tests/environments/wallet.py", + "chia/_tests/pools/test_pool_rpc.py", + "chia/_tests/pools/test_wallet_pool_store.py", + "chia/_tests/rpc/test_rpc_client.py", + "chia/_tests/rpc/test_rpc_server.py", + "chia/_tests/simulation/test_start_simulator.py", + "chia/_tests/util/blockchain.py", + "chia/_tests/util/misc.py", + "chia/_tests/util/spend_sim.py", + "chia/_tests/util/split_managers.py", + "chia/_tests/util/test_async_pool.py", + "chia/_tests/util/test_priority_mutex.py", + "chia/_tests/util/time_out_assert.py", + "chia/_tests/wallet/clawback/test_clawback_metadata.py", + "chia/_tests/wallet/dao_wallet/test_dao_wallets.py", + "chia/_tests/wallet/nft_wallet/test_nft_bulk_mint.py", + "chia/_tests/wallet/rpc/test_dl_wallet_rpc.py", + "chia/_tests/wallet/rpc/test_wallet_rpc.py", + "chia/_tests/wallet/sync/test_wallet_sync.py", + "chia/timelord/timelord_api.py", + "chia/timelord/timelord_launcher.py", + "chia/timelord/timelord.py", + "chia/types/eligible_coin_spends.py", + "chia/util/action_scope.py", + "chia/util/async_pool.py", + "chia/util/beta_metrics.py", + "chia/util/db_version.py", + "chia/util/db_wrapper.py", + "chia/util/files.py", + "chia/util/limited_semaphore.py", + "chia/util/network.py", + "chia/util/priority_mutex.py", + "chia/util/profiler.py", + "chia/wallet/cat_wallet/cat_wallet.py", + "chia/wallet/cat_wallet/dao_cat_wallet.py", + "chia/wallet/dao_wallet/dao_wallet.py", + "chia/wallet/did_wallet/did_wallet.py", + "chia/wallet/nft_wallet/nft_wallet.py", + "chia/wallet/notification_store.py", + "chia/wallet/trade_manager.py", + "chia/wallet/trading/trade_store.py", + "chia/wallet/vc_wallet/cr_cat_wallet.py", + "chia/wallet/vc_wallet/vc_wallet.py", + "chia/wallet/wallet_coin_store.py", + "chia/wallet/wallet_nft_store.py", + "chia/wallet/wallet_node_api.py", + "chia/wallet/wallet_node.py", + "chia/wallet/wallet_puzzle_store.py", + "chia/wallet/wallet_state_manager.py", + "chia/wallet/wallet_transaction_store.py", +} + +hardcoded_file_ignore_list = {pathlib.Path(path_string) for path_string in hardcoded_file_ignore_list_strings} diff --git a/chia/cmds/dev.py b/chia/cmds/dev.py index c23d05570d32..379c9f7910ab 100644 --- a/chia/cmds/dev.py +++ b/chia/cmds/dev.py @@ -2,6 +2,7 @@ import click +from chia.cmds.check import check_group from chia.cmds.installers import installers_group from chia.cmds.sim import sim_cmd @@ -14,3 +15,4 @@ def dev_cmd(ctx: click.Context) -> None: dev_cmd.add_command(sim_cmd) dev_cmd.add_command(installers_group) +dev_cmd.add_command(check_group) diff --git a/chia/consensus/blockchain.py b/chia/consensus/blockchain.py index 621927a1e136..6731f0a51368 100644 --- a/chia/consensus/blockchain.py +++ b/chia/consensus/blockchain.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import enum import logging @@ -10,6 +11,7 @@ from pathlib import Path from typing import TYPE_CHECKING, ClassVar, Optional, cast +import anyio from chia_rs import additions_and_removals, get_flags_for_height_and_constants from chia.consensus.block_body_validation import ForkInfo, validate_block_body @@ -419,20 +421,19 @@ async def add_block( self._peak_height = block_record.height except BaseException as e: - # depending on exactly when the failure of adding the block - # happened, we may not have added it to the block record cache - try: - self.remove_block_record(header_hash) - except KeyError: - pass - fork_info.rollback(header_hash, -1 if previous_peak_height is None else previous_peak_height) - self.block_store.rollback_cache_block(header_hash) - self._peak_height = previous_peak_height - log.error( - f"Error while adding block {header_hash} height {block.height}," - f" rolling back: {traceback.format_exc()} {e}" - ) - raise + with anyio.CancelScope(shield=True): + # depending on exactly when the failure of adding the block + # happened, we may not have added it to the block record cache + with contextlib.suppress(KeyError): + self.remove_block_record(header_hash) + fork_info.rollback(header_hash, -1 if previous_peak_height is None else previous_peak_height) + self.block_store.rollback_cache_block(header_hash) + self._peak_height = previous_peak_height + log.error( + f"Error while adding block {header_hash} height {block.height}," + f" rolling back: {traceback.format_exc()} {e}" + ) + raise # This is done outside the try-except in case it fails, since we do not want to revert anything if it does await self.__height_map.maybe_flush()