diff --git a/chia/cmds/init_funcs.py b/chia/cmds/init_funcs.py index c79aaa040fa3..e6ef1ec2a9ec 100644 --- a/chia/cmds/init_funcs.py +++ b/chia/cmds/init_funcs.py @@ -257,11 +257,12 @@ def init( def chia_version_number() -> Tuple[str, str, str, str]: - scm_full_version = __version__ - left_full_version = scm_full_version.split("+") + return get_version_numbers(__version__) - version = left_full_version[0].split(".") +def get_version_numbers(scm_full_version: str) -> Tuple[str, str, str, str]: + left_full_version = scm_full_version.split("+") + version = left_full_version[0].split(".") scm_major_version = version[0] scm_minor_version = version[1] if len(version) > 2: @@ -269,11 +270,9 @@ def chia_version_number() -> Tuple[str, str, str, str]: patch_release_number = smc_patch_version else: smc_patch_version = "" - major_release_number = scm_major_version minor_release_number = scm_minor_version dev_release_number = "" - # If this is a beta dev release - get which beta it is if "0b" in scm_minor_version: original_minor_ver_list = scm_minor_version.split("0b") @@ -294,13 +293,11 @@ def chia_version_number() -> Tuple[str, str, str, str]: minor_release_number = scm_minor_version patch_release_number = smc_patch_version dev_release_number = "" - install_release_number = major_release_number + "." + minor_release_number if len(patch_release_number) > 0: install_release_number += "." + patch_release_number if len(dev_release_number) > 0: install_release_number += dev_release_number - return major_release_number, minor_release_number, patch_release_number, dev_release_number diff --git a/chia/protocols/shared_protocol.py b/chia/protocols/shared_protocol.py index 4eb6ea19b45c..255328a133cb 100644 --- a/chia/protocols/shared_protocol.py +++ b/chia/protocols/shared_protocol.py @@ -49,3 +49,10 @@ class Handshake(Streamable): (uint16(Capability.RATE_LIMITS_V2.value), "1"), (uint16(Capability.NONE_RESPONSE.value), "1"), ] + +# capabilities to send to nodes prior to 1.6.2 +limitedcapabilties = [ + (uint16(Capability.BASE.value), "1"), + (uint16(Capability.BLOCK_HEADERS.value), "1"), + (uint16(Capability.RATE_LIMITS_V2.value), "1"), +] diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index 0082907105ca..c05593013825 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -3,6 +3,7 @@ import asyncio import contextlib import logging +import re import time import traceback from dataclasses import dataclass, field @@ -14,11 +15,11 @@ from aiohttp.web import WebSocketResponse from typing_extensions import Protocol, final -from chia.cmds.init_funcs import chia_full_version_str +from chia.cmds.init_funcs import chia_full_version_str, get_version_numbers from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.protocol_state_machine import message_requires_reply, message_response_ok from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS -from chia.protocols.shared_protocol import Capability, Handshake +from chia.protocols.shared_protocol import Capability, Handshake, limitedcapabilties from chia.server.capabilities import known_active_capabilities from chia.server.outbound_message import Message, NodeType, make_msg from chia.server.rate_limits import RateLimiter @@ -172,23 +173,23 @@ async def perform_handshake( server_port: int, local_type: NodeType, ) -> None: - outbound_handshake = make_msg( - ProtocolMessageTypes.handshake, - Handshake( - network_id, - protocol_version, - chia_full_version_str(), - uint16(server_port), - uint8(local_type.value), - self.local_capabilities_for_handshake, - ), - ) if self.is_outbound: + outbound_handshake = make_msg( + ProtocolMessageTypes.handshake, + Handshake( + network_id, + protocol_version, + chia_full_version_str(), + uint16(server_port), + uint8(local_type.value), + self.local_capabilities_for_handshake, + ), + ) await self._send_message(outbound_handshake) inbound_handshake_msg = await self._read_one_message() if inbound_handshake_msg is None: raise ProtocolError(Err.INVALID_HANDSHAKE) - inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data) + inbound_handshake: Handshake = Handshake.from_bytes(inbound_handshake_msg.data) # Handle case of invalid ProtocolMessageType try: @@ -229,7 +230,22 @@ async def perform_handshake( inbound_handshake = Handshake.from_bytes(message.data) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) + + outbound_handshake = make_msg( + ProtocolMessageTypes.handshake, + Handshake( + network_id, + protocol_version, + chia_full_version_str(), + uint16(server_port), + uint8(local_type.value), + self.get_capabilties_for_version(inbound_handshake.software_version), + ), + ) + await self._send_message(outbound_handshake) + self.version = inbound_handshake.software_version + self.protocol_version = inbound_handshake.protocol_version self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) # "1" means capability is enabled @@ -675,3 +691,19 @@ def get_peer_logging(self) -> PeerInfo: def has_capability(self, capability: Capability) -> bool: return capability in self.peer_capabilities + + # only send limitedcapabilties to peers before 1.6.2 + # see https://github.com/Chia-Network/chia-blockchain/commit/618f93b4c42b176659cc74c02a4dd711adc62052 + # for why that version was selected + def get_capabilties_for_version(self, software_version: str) -> list[tuple[uint16, str]]: + major, minor, patch, _ = get_version_numbers(software_version) + patch_number = 0 + try: + path_split = re.findall("[0-9]+", patch) # extract number from patch string + if len(path_split) > 1: + patch_number = path_split[0] + if (int(major), int(minor), int(patch_number)) < (1, 6, 2): + return limitedcapabilties + except Exception: + self.log.error(f"could not parse incoming version {software_version}, returning limited capabilities") + return self.local_capabilities_for_handshake diff --git a/tests/core/server/test_server.py b/tests/core/server/test_server.py index fba4736e9dcc..4e8042b88409 100644 --- a/tests/core/server/test_server.py +++ b/tests/core/server/test_server.py @@ -1,15 +1,20 @@ from __future__ import annotations +import logging from typing import Tuple import pytest from chia.full_node.full_node_api import FullNodeAPI +from chia.protocols.shared_protocol import Capability +from chia.server import ws_connection from chia.server.server import ChiaServer from chia.simulator.block_tools import BlockTools from chia.types.peer_info import PeerInfo from chia.util.ints import uint16 +log = logging.getLogger(__name__) + @pytest.mark.asyncio async def test_duplicate_client_connection( @@ -18,3 +23,39 @@ async def test_duplicate_client_connection( _, _, server_1, server_2, _ = two_nodes assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) assert not await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("version", ["1.6.1b657", "1.6.abcb657", "1.6.0abcb657", "1.5.0abcb657"]) +async def test_capabilities_back_comp_before_162( # type: ignore + monkeypatch, + two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools], + self_hostname: str, + version: str, +) -> None: + def getverold() -> str: + return version + + monkeypatch.setattr(ws_connection, "chia_full_version_str", getverold) + _, _, server_1, server_2, _ = two_nodes + assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) + con = server_2.get_connections()[0] + assert con.has_capability(Capability.NONE_RESPONSE) is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("version", ["1.6.2b657", "1.6.3b657"]) +async def test_capabilities_back_comp_fix_162_or_above( # type: ignore + monkeypatch, + two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools], + self_hostname: str, + version: str, +) -> None: + def getver162() -> str: + return version + + monkeypatch.setattr(ws_connection, "chia_full_version_str", getver162) + _, _, server_1, server_2, _ = two_nodes + assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) + con = server_2.get_connections()[0] + assert con.has_capability(Capability.NONE_RESPONSE) is True