diff --git a/crates/ethportal-api/src/lib.rs b/crates/ethportal-api/src/lib.rs index d94097480..1d7ecc69a 100644 --- a/crates/ethportal-api/src/lib.rs +++ b/crates/ethportal-api/src/lib.rs @@ -47,5 +47,7 @@ pub use types::{ execution::{block_body::*, receipts::*}, node_id::*, portal::{RawContentKey, RawContentValue}, + protocol_info::*, + protocol_versions::*, }; pub use web3::{Web3ApiClient, Web3ApiServer}; diff --git a/crates/ethportal-api/src/types/mod.rs b/crates/ethportal-api/src/types/mod.rs index bf65de72d..e8c5e5111 100644 --- a/crates/ethportal-api/src/types/mod.rs +++ b/crates/ethportal-api/src/types/mod.rs @@ -15,6 +15,7 @@ pub mod node_id; pub mod ping_extensions; pub mod portal; pub mod portal_wire; +pub mod protocol_info; pub mod protocol_versions; pub mod query_trace; pub mod state_trie; diff --git a/crates/ethportal-api/src/types/protocol_info.rs b/crates/ethportal-api/src/types/protocol_info.rs new file mode 100644 index 000000000..c031bc9ff --- /dev/null +++ b/crates/ethportal-api/src/types/protocol_info.rs @@ -0,0 +1,190 @@ +use std::io::BufRead; + +use alloy_rlp::{Decodable, Encodable}; +use anyhow::ensure; + +use crate::ProtocolVersion; + +/// ENR key for Portal protocol info. +pub const ENR_PORTAL_KEY: &str = "p"; + +/// The information about active Portal Protocol. +/// +/// Current implementation follows the protocol version 2, specified in +/// [Portal Wire Protocol spec](https://github.com/ethereum/portal-network-specs/blob/dd7b7cbae96a1c54546263d8484f1aa01c5035b9/portal-wire-protocol.md#enr-record). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProtocolInfo { + min_protocol_version: ProtocolVersion, + max_protocol_version: ProtocolVersion, + chain_id: u64, +} + +impl ProtocolInfo { + pub fn new( + min_protocol_version: ProtocolVersion, + max_protocol_version: ProtocolVersion, + chain_id: u64, + ) -> anyhow::Result { + ensure!( + min_protocol_version <= max_protocol_version, + "Min version ({}) must be lower than Max version ({})", + *min_protocol_version, + *max_protocol_version, + ); + Ok(Self { + min_protocol_version, + max_protocol_version, + chain_id, + }) + } + + pub fn min_protocol_version(&self) -> ProtocolVersion { + self.min_protocol_version + } + + pub fn max_protocol_version(&self) -> ProtocolVersion { + self.max_protocol_version + } + + pub fn chain_id(&self) -> u64 { + self.chain_id + } + + pub fn supports(&self, protocol_version: ProtocolVersion) -> bool { + (self.min_protocol_version..=self.max_protocol_version).contains(&protocol_version) + } + + /// Returns the highest common protocol version, or `None` otherwise. + /// + /// It also returns `None` if other is not part of the same chain. + pub fn highest_common_protocol_version(&self, other: &ProtocolInfo) -> Option { + if self.chain_id != other.chain_id { + return None; + } + + let min = *self.min_protocol_version; + let max = *self.max_protocol_version; + (min..=max) + .rev() + .map(ProtocolVersion::from) + .find(|protocol_version| other.supports(*protocol_version)) + } + + fn rlp_payload_length(&self) -> usize { + Encodable::length(&*self.min_protocol_version) + + Encodable::length(&*self.max_protocol_version) + + Encodable::length(&self.chain_id) + } +} + +impl Encodable for ProtocolInfo { + fn length(&self) -> usize { + let payload_length = self.rlp_payload_length(); + payload_length + alloy_rlp::length_of_length(payload_length) + } + + fn encode(&self, out: &mut dyn bytes::BufMut) { + alloy_rlp::Header { + list: true, + payload_length: self.rlp_payload_length(), + } + .encode(out); + self.min_protocol_version.encode(out); + self.max_protocol_version.encode(out); + self.chain_id.encode(out); + } +} + +impl Decodable for ProtocolInfo { + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let alloy_rlp::Header { + list, + payload_length, + } = alloy_rlp::Header::decode(buf)?; + + if !list { + return Err(alloy_rlp::Error::UnexpectedString); + } + let started_len = buf.len(); + if started_len < payload_length { + return Err(alloy_rlp::Error::InputTooShort); + } + let min_protocol_version: u8 = Decodable::decode(buf)?; + let max_protocol_version: u8 = Decodable::decode(buf)?; + let chain_id = Decodable::decode(buf)?; + + let consumed = started_len - buf.len(); + + if consumed > payload_length { + // We shouldn't have consumed more than 'payload_length' + return Err(alloy_rlp::Error::ListLengthMismatch { + expected: payload_length, + got: consumed, + }); + } + if payload_length > consumed { + // Payload can be longer then consumed when peer upgraded the protocol version and + // added more fields but we didn't. + // In that case, we just read and ignore the rest of the payload. + buf.consume(payload_length - consumed); + } + + Self::new( + min_protocol_version.into(), + max_protocol_version.into(), + chain_id, + ) + .map_err(|_| alloy_rlp::Error::Custom("Decoded ProtocolInfo is invalid")) + } +} + +#[cfg(test)] +mod tests { + use alloy::{ + hex::FromHex, + primitives::{bytes, Bytes}, + }; + use rstest::rstest; + + use super::*; + + #[rstest] + #[case::only_v2("0xc3020201", ProtocolVersion::V2, ProtocolVersion::V2, 1)] + #[case::v0_to_v2("0xc3800201", ProtocolVersion::V0, ProtocolVersion::V2, 1)] + #[case::hoodi("0xc6020283088bb0", ProtocolVersion::V2, ProtocolVersion::V2, /* hoodi testnet */ 560048)] + fn encode_decode( + #[case] bytes: String, + #[case] min_protocol_version: ProtocolVersion, + #[case] max_protocol_version: ProtocolVersion, + #[case] chain_id: u64, + ) { + let bytes = Bytes::from_hex(bytes).unwrap(); + let protocol_info = + ProtocolInfo::new(min_protocol_version, max_protocol_version, chain_id).unwrap(); + + assert_eq!(alloy_rlp::encode(protocol_info), bytes.to_vec()); + + assert_eq!( + alloy_rlp::decode_exact::(bytes), + Ok(protocol_info), + ); + } + + /// Tests that decoding rlp bytes that includes unknown version (e.g. 3) and extra bytes works. + #[test] + fn unsupported_protocol() { + let bytes = bytes!("0xc602030182abcd"); + + let expected_protocol_info = ProtocolInfo::new( + ProtocolVersion::V2, + ProtocolVersion::UnspecifiedVersion(3), + 1, + ) + .unwrap(); + + assert_eq!( + alloy_rlp::decode_exact::(&bytes), + Ok(expected_protocol_info) + ); + } +} diff --git a/crates/ethportal-api/src/types/protocol_versions.rs b/crates/ethportal-api/src/types/protocol_versions.rs index f68fe3ef9..473c10400 100644 --- a/crates/ethportal-api/src/types/protocol_versions.rs +++ b/crates/ethportal-api/src/types/protocol_versions.rs @@ -17,10 +17,25 @@ pub enum ProtocolVersion { V0, /// Adds `accept codes` and varint size encoding for find content messages. V1, + /// Uses 'p' ENR key to indicate protocol version and chain id. + V2, /// Unspecified version is a version that we don't know about, but the other side does. UnspecifiedVersion(u8), } +impl Deref for ProtocolVersion { + type Target = u8; + + fn deref(&self) -> &Self::Target { + match self { + ProtocolVersion::V0 => &0, + ProtocolVersion::V1 => &1, + ProtocolVersion::V2 => &2, + ProtocolVersion::UnspecifiedVersion(version) => version, + } + } +} + impl ProtocolVersion { pub fn is_v1_enabled(&self) -> bool { self >= &ProtocolVersion::V1 @@ -29,11 +44,7 @@ impl ProtocolVersion { impl From for u8 { fn from(version: ProtocolVersion) -> u8 { - match version { - ProtocolVersion::V0 => 0, - ProtocolVersion::V1 => 1, - ProtocolVersion::UnspecifiedVersion(version) => version, - } + *version } } @@ -42,6 +53,7 @@ impl From for ProtocolVersion { match version { 0 => ProtocolVersion::V0, 1 => ProtocolVersion::V1, + 2 => ProtocolVersion::V2, version => ProtocolVersion::UnspecifiedVersion(version), } }