Skip to content

patch capabilities for old clients #14105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
11 changes: 4 additions & 7 deletions chia/cmds/init_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,23 +257,22 @@ def init(


def chia_version_number() -> Tuple[str, str, str, str]:
scm_full_version = __version__
left_full_version = scm_full_version.split("+")
return get_version_numbers(__version__)

version = left_full_version[0].split(".")

def get_version_numbers(scm_full_version: str) -> Tuple[str, str, str, str]:
left_full_version = scm_full_version.split("+")
version = left_full_version[0].split(".")
scm_major_version = version[0]
scm_minor_version = version[1]
if len(version) > 2:
smc_patch_version = version[2]
patch_release_number = smc_patch_version
else:
smc_patch_version = ""

major_release_number = scm_major_version
minor_release_number = scm_minor_version
dev_release_number = ""

# If this is a beta dev release - get which beta it is
if "0b" in scm_minor_version:
original_minor_ver_list = scm_minor_version.split("0b")
Expand All @@ -294,13 +293,11 @@ def chia_version_number() -> Tuple[str, str, str, str]:
minor_release_number = scm_minor_version
patch_release_number = smc_patch_version
dev_release_number = ""

install_release_number = major_release_number + "." + minor_release_number
if len(patch_release_number) > 0:
install_release_number += "." + patch_release_number
if len(dev_release_number) > 0:
install_release_number += dev_release_number

return major_release_number, minor_release_number, patch_release_number, dev_release_number


Expand Down
7 changes: 7 additions & 0 deletions chia/protocols/shared_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ class Handshake(Streamable):
(uint16(Capability.RATE_LIMITS_V2.value), "1"),
(uint16(Capability.NONE_RESPONSE.value), "1"),
]

# capabilities to send to nodes prior to 1.6.2
limitedcapabilties = [
(uint16(Capability.BASE.value), "1"),
(uint16(Capability.BLOCK_HEADERS.value), "1"),
(uint16(Capability.RATE_LIMITS_V2.value), "1"),
]
60 changes: 46 additions & 14 deletions chia/server/ws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import contextlib
import logging
import re
import time
import traceback
from dataclasses import dataclass, field
Expand All @@ -14,11 +15,11 @@
from aiohttp.web import WebSocketResponse
from typing_extensions import Protocol, final

from chia.cmds.init_funcs import chia_full_version_str
from chia.cmds.init_funcs import chia_full_version_str, get_version_numbers
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_requires_reply, message_response_ok
from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.protocols.shared_protocol import Capability, Handshake, limitedcapabilties
from chia.server.capabilities import known_active_capabilities
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
Expand Down Expand Up @@ -172,23 +173,23 @@ async def perform_handshake(
server_port: int,
local_type: NodeType,
) -> None:
outbound_handshake = make_msg(
ProtocolMessageTypes.handshake,
Handshake(
network_id,
protocol_version,
chia_full_version_str(),
uint16(server_port),
uint8(local_type.value),
self.local_capabilities_for_handshake,
),
)
if self.is_outbound:
outbound_handshake = make_msg(
ProtocolMessageTypes.handshake,
Handshake(
network_id,
protocol_version,
chia_full_version_str(),
uint16(server_port),
uint8(local_type.value),
self.local_capabilities_for_handshake,
),
)
await self._send_message(outbound_handshake)
inbound_handshake_msg = await self._read_one_message()
if inbound_handshake_msg is None:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
inbound_handshake: Handshake = Handshake.from_bytes(inbound_handshake_msg.data)

# Handle case of invalid ProtocolMessageType
try:
Expand Down Expand Up @@ -229,7 +230,22 @@ async def perform_handshake(
inbound_handshake = Handshake.from_bytes(message.data)
if inbound_handshake.network_id != network_id:
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

outbound_handshake = make_msg(
ProtocolMessageTypes.handshake,
Handshake(
network_id,
protocol_version,
chia_full_version_str(),
uint16(server_port),
uint8(local_type.value),
self.get_capabilties_for_version(inbound_handshake.software_version),
),
)

await self._send_message(outbound_handshake)
self.version = inbound_handshake.software_version
self.protocol_version = inbound_handshake.protocol_version
self.peer_server_port = inbound_handshake.server_port
self.connection_type = NodeType(inbound_handshake.node_type)
# "1" means capability is enabled
Expand Down Expand Up @@ -675,3 +691,19 @@ def get_peer_logging(self) -> PeerInfo:

def has_capability(self, capability: Capability) -> bool:
return capability in self.peer_capabilities

# only send limitedcapabilties to peers before 1.6.2
# see https://github.com/Chia-Network/chia-blockchain/commit/618f93b4c42b176659cc74c02a4dd711adc62052
# for why that version was selected
def get_capabilties_for_version(self, software_version: str) -> list[tuple[uint16, str]]:
major, minor, patch, _ = get_version_numbers(software_version)
patch_number = 0
try:
path_split = re.findall("[0-9]+", patch) # extract number from patch string
if len(path_split) > 1:
patch_number = path_split[0]
if (int(major), int(minor), int(patch_number)) < (1, 6, 2):
return limitedcapabilties
except Exception:
self.log.error(f"could not parse incoming version {software_version}, returning limited capabilities")
return self.local_capabilities_for_handshake
41 changes: 41 additions & 0 deletions tests/core/server/test_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

import logging
from typing import Tuple

import pytest

from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.shared_protocol import Capability
from chia.server import ws_connection
from chia.server.server import ChiaServer
from chia.simulator.block_tools import BlockTools
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16

log = logging.getLogger(__name__)


@pytest.mark.asyncio
async def test_duplicate_client_connection(
Expand All @@ -18,3 +23,39 @@ async def test_duplicate_client_connection(
_, _, server_1, server_2, _ = two_nodes
assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
assert not await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)


@pytest.mark.asyncio
@pytest.mark.parametrize("version", ["1.6.1b657", "1.6.abcb657", "1.6.0abcb657", "1.5.0abcb657"])
async def test_capabilities_back_comp_before_162( # type: ignore
monkeypatch,
two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools],
self_hostname: str,
version: str,
) -> None:
def getverold() -> str:
return version

monkeypatch.setattr(ws_connection, "chia_full_version_str", getverold)
_, _, server_1, server_2, _ = two_nodes
assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
con = server_2.get_connections()[0]
assert con.has_capability(Capability.NONE_RESPONSE) is False


@pytest.mark.asyncio
@pytest.mark.parametrize("version", ["1.6.2b657", "1.6.3b657"])
async def test_capabilities_back_comp_fix_162_or_above( # type: ignore
monkeypatch,
two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools],
self_hostname: str,
version: str,
) -> None:
def getver162() -> str:
return version

monkeypatch.setattr(ws_connection, "chia_full_version_str", getver162)
_, _, server_1, server_2, _ = two_nodes
assert await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
con = server_2.get_connections()[0]
assert con.has_capability(Capability.NONE_RESPONSE) is True