diff --git a/.gitignore b/.gitignore index ea8c4bf7f..19c012f2b 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ /target +.idea + diff --git a/Cargo.lock b/Cargo.lock index 451d6fa91..7d22258fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1026,15 +1026,6 @@ dependencies = [ "rustc_version", ] -[[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -2796,7 +2787,7 @@ dependencies = [ "pin-project-lite 0.2.16", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 1.1.0", "rustls 0.20.9", "thiserror 1.0.69", "tokio", @@ -2813,7 +2804,7 @@ dependencies = [ "bytes", "rand 0.8.5", "ring 0.16.20", - "rustc-hash", + "rustc-hash 1.1.0", "rustls 0.20.9", "slab", "thiserror 1.0.69", @@ -3049,6 +3040,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -3187,17 +3184,17 @@ dependencies = [ [[package]] name = "sctp-proto" -version = "0.3.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4dea4fe3384a24652f065296ac333c810dfd0c5b39b98a2214762c16aaadc3c" +checksum = "423139d8cca3021b9d800f084a711ba2d23b508ae71b33dba167f11ca33e54c7" dependencies = [ "bytes", "crc", - "fxhash", "log", - "rand 0.8.5", + "rand 0.9.2", + "rustc-hash 2.1.1", "slab", - "thiserror 1.0.69", + "thiserror 2.0.17", ] [[package]] @@ -3480,9 +3477,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "str0m" -version = "0.9.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7f9fdffeb677e1d5d2cf84993f865143680b8e5eece3396fa488d43b34c1245" +checksum = "26890ff5b60e33eb8bedcf44792fc459c8f348ecbf2658edb19477571e547ac2" dependencies = [ "combine", "crc", @@ -3495,7 +3492,6 @@ dependencies = [ "sctp-proto", "serde", "sha1", - "thiserror 1.0.69", "tracing", ] diff --git a/Cargo.toml b/Cargo.toml index a4eba5881..02587ca3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,7 +65,7 @@ rcgen = { version = "0.14.5", optional = true } # End of Quic related dependencies. # WebRTC related dependencies. WebRTC is an experimental feature flag. The dependencies must be updated. -str0m = { version = "0.9.0", optional = true } +str0m = { version = "0.11.1", optional = true } # End of WebRTC related dependencies. # Fuzzing related dependencies. diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 0a0970b4d..86c22647c 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -22,18 +22,19 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, - error::{self, Error, ParseError}, + error::{self, Error, ParseError, SubstreamError}, multistream_select::{ + drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, - ProtocolError, + ProtocolError, PROTO_MULTISTREAM_1_0, }, Negotiated, NegotiationError, Version, }, types::protocol::ProtocolName, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::prelude::*; use std::{ convert::TryFrom as _, @@ -357,24 +358,59 @@ impl WebRtcDialerState { &mut self, payload: Vec, ) -> Result { - let Message::Protocols(protocols) = - Message::decode(payload.into()).map_err(|_| ParseError::InvalidData)? - else { - return Err(crate::error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - )); + // All multistream-select messages are length-prefixed. Since this code path is not using + // multistream_select::protocol::MessageIO, we need to decode and remove the length here. + let remaining: &[u8] = &payload; + let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| { + tracing::debug!( + target: LOG_TARGET, + ?error, + message = ?payload, + "Failed to decode length-prefix in multistream message"); + error::NegotiationError::ParseError(ParseError::InvalidData) + })?; + + let len_size = remaining.len() - tail.len(); + let bytes = Bytes::from(payload); + let payload = bytes.slice(len_size..len_size + len); + let remaining = bytes.slice(len_size + len..); + let message = Message::decode(payload); + + tracing::trace!( + target: LOG_TARGET, + ?message, + "Decoded message while registering response", + ); + + let mut protocols = match message { + Ok(Message::Header(HeaderLine::V1)) => { + vec![PROTO_MULTISTREAM_1_0] + } + Ok(Message::Protocol(protocol)) => vec![protocol], + Ok(Message::Protocols(protocols)) => protocols, + Ok(Message::NotAvailable) => + return match &self.state { + HandshakeState::WaitingProtocol => Err( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + ), + _ => Err(error::NegotiationError::StateMismatch), + }, + Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch), + Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), }; + match drain_trailing_protocols(remaining) { + Ok(protos) => protocols.extend(protos), + Err(error) => return Err(error), + } + let mut protocol_iter = protocols.into_iter(); loop { match (&self.state, protocol_iter.next()) { (HandshakeState::WaitingResponse, None) => return Err(crate::error::NegotiationError::StateMismatch), (HandshakeState::WaitingResponse, Some(protocol)) => { - let header = Protocol::try_from(&b"/multistream/1.0.0"[..]) - .expect("valid multitstream-select header"); - - if protocol == header { + if protocol == PROTO_MULTISTREAM_1_0 { self.state = HandshakeState::WaitingProtocol; } else { return Err(crate::error::NegotiationError::MultistreamSelectError( @@ -383,6 +419,10 @@ impl WebRtcDialerState { } } (HandshakeState::WaitingProtocol, Some(protocol)) => { + if protocol == PROTO_MULTISTREAM_1_0 { + return Err(crate::error::NegotiationError::StateMismatch); + } + if self.protocol.as_bytes() == protocol.as_ref() { return Ok(HandshakeResult::Succeeded(self.protocol.clone())); } @@ -408,10 +448,9 @@ impl WebRtcDialerState { #[cfg(test)] mod tests { use super::*; - use crate::multistream_select::listener_select_proto; + use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0}; + use bytes::BufMut; use std::time::Duration; - use tokio::net::{TcpListener, TcpStream}; - #[tokio::test] async fn select_proto_basic() { async fn run(version: Version) { @@ -755,23 +794,18 @@ mod tests { fn propose() { let (mut dialer_state, message) = WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); - let message = bytes::BytesMut::from(&message[..]).freeze(); - let Message::Protocols(protocols) = Message::decode(message).unwrap() else { - panic!("invalid message type"); - }; + let mut bytes = BytesMut::with_capacity(32); + bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); - assert_eq!(protocols.len(), 2); - assert_eq!( - protocols[0], - Protocol::try_from(&b"/multistream/1.0.0"[..]) - .expect("valid multitstream-select header") - ); - assert_eq!( - protocols[1], - Protocol::try_from(&b"/13371338/proto/1"[..]) - .expect("valid multitstream-select header") - ); + let proto = Protocol::try_from(&b"/13371338/proto/1"[..]).expect("valid protocol name"); + bytes.put_u8((proto.as_ref().len() + 1) as u8); // + 1 for \n + let _ = Message::Protocol(proto).encode(&mut bytes).unwrap(); + + let expected_message = bytes.freeze().to_vec(); + + assert_eq!(message, expected_message); } #[test] @@ -781,33 +815,29 @@ mod tests { vec![ProtocolName::from("/sup/proto/1")], ) .unwrap(); - let message = bytes::BytesMut::from(&message[..]).freeze(); - let Message::Protocols(protocols) = Message::decode(message).unwrap() else { - panic!("invalid message type"); - }; + let mut bytes = BytesMut::with_capacity(32); + bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); - assert_eq!(protocols.len(), 3); - assert_eq!( - protocols[0], - Protocol::try_from(&b"/multistream/1.0.0"[..]) - .expect("valid multitstream-select header") - ); - assert_eq!( - protocols[1], - Protocol::try_from(&b"/13371338/proto/1"[..]) - .expect("valid multitstream-select header") - ); - assert_eq!( - protocols[2], - Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid multitstream-select header") - ); + let proto1 = Protocol::try_from(&b"/13371338/proto/1"[..]).expect("valid protocol name"); + bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n + let _ = Message::Protocol(proto1).encode(&mut bytes).unwrap(); + + let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); + bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n + let _ = Message::Protocol(proto2).encode(&mut bytes).unwrap(); + + let expected_message = bytes.freeze().to_vec(); + + assert_eq!(message, expected_message); } #[test] - fn register_response_invalid_message() { - // send only header line + fn register_response_header_only() { let mut bytes = BytesMut::with_capacity(32); + bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let message = Message::Header(HeaderLine::V1); message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); @@ -815,7 +845,8 @@ mod tests { WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); match dialer_state.register_response(bytes.freeze().to_vec()) { - Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} + Ok(HandshakeResult::NotReady) => {} + Err(err) => panic!("unexpected error: {:?}", err), event => panic!("invalid event: {event:?}"), } } @@ -823,17 +854,20 @@ mod tests { #[test] fn header_line_missing() { // header line missing - let mut bytes = BytesMut::with_capacity(256); - let message = Message::Protocols(vec![ - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ]); - message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + let proto = b"/13371338/proto/1"; + let mut bytes = BytesMut::with_capacity(proto.len() + 2); + bytes.put_u8((proto.len() + 1) as u8); + + let response = Message::Protocol(Protocol::try_from(&proto[..]).unwrap()) + .encode(&mut bytes) + .expect("valid message encodes"); + + let response = bytes.freeze().to_vec(); let (mut dialer_state, _message) = WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); - match dialer_state.register_response(bytes.freeze().to_vec()) { + match dialer_state.register_response(response) { Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} event => panic!("invalid event: {event:?}"), } diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index f574005fe..6d73efcb7 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -25,9 +25,10 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error}, multistream_select::{ + drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, - ProtocolError, + ProtocolError, PROTO_MULTISTREAM_1_0, }, Negotiated, NegotiationError, }, @@ -349,22 +350,14 @@ pub enum ListenerSelectResult { /// response and the negotiated protocol. If parsing fails or no match is found, return an error. pub fn webrtc_listener_negotiate<'a>( supported_protocols: &'a mut impl Iterator, - payload: Bytes, + mut payload: Bytes, ) -> crate::Result { - let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)? - else { - return Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - )); - }; + let protocols = drain_trailing_protocols(payload)?; + let mut protocol_iter = protocols.into_iter(); // skip the multistream-select header because it's not part of user protocols but verify it's // present - let mut protocol_iter = protocols.into_iter(); - let header = - Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header"); - - if protocol_iter.next() != Some(header) { + if protocol_iter.next() != Some(PROTO_MULTISTREAM_1_0) { return Err(Error::NegotiationError( error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), )); @@ -402,6 +395,8 @@ pub fn webrtc_listener_negotiate<'a>( #[cfg(test)] mod tests { use super::*; + use crate::error; + use bytes::BufMut; #[test] fn webrtc_listener_negotiate_works() { @@ -437,6 +432,21 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; + // The invalid message is really two multistream-select messages inside one `WebRtcMessage`: + // 1. the multistream-select header + // 2. an "ls response" message (that does not contain another header) + // + // This is invalid for two reasons: + // 1. It is malformed. Either the header is followed by one or more `Message::Protocol` + // instances or the header is part of the "ls response". + // 2. This sequence of messages is not spec compliant. A listener receives one of the + // following on an inbound substream: + // - a multistream-select header followed by a `Message::Protocol` instance + // - a multistream-select header followed by an "ls" message (<\n>) + // + // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be + // `InvalidData` because the message is malformed or `StateMismatch` because the message is + // not expected at this point in the protocol. let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![ Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), @@ -445,7 +455,13 @@ mod tests { .freeze(); match webrtc_listener_negotiate(&mut local_protocols.iter(), message) { - Err(error) => assert!(std::matches!(error, Error::InvalidData)), + Err(error) => assert!(std::matches!( + error, + // something has gone off the rails here... + Error::NegotiationError(error::NegotiationError::ParseError( + error::ParseError::InvalidData + )), + )), _ => panic!("invalid event"), } } @@ -468,9 +484,9 @@ mod tests { match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { Err(error) => assert!(std::matches!( error, - Error::NegotiationError(error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed - )) + Error::NegotiationError(error::NegotiationError::ParseError( + error::ParseError::InvalidData + )), )), event => panic!("invalid event: {event:?}"), } @@ -488,11 +504,15 @@ mod tests { // header line missing let mut bytes = BytesMut::with_capacity(256); - let message = Message::Protocols(vec![ - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ]); - message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + vec![&b"/13371338/proto/1"[..], &b"/sup/proto/1"[..]] + .into_iter() + .for_each(|proto| { + bytes.put_u8((proto.len() + 1) as u8); + + Message::Protocol(Protocol::try_from(proto).unwrap()) + .encode(&mut bytes) + .unwrap(); + }); match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { Err(error) => assert!(std::matches!( diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index b28093b3a..f195b1f3d 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -75,6 +75,7 @@ mod listener_select; mod negotiated; mod protocol; +use crate::error::{self, ParseError}; pub use crate::multistream_select::{ dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState}, listener_select::{ @@ -82,9 +83,13 @@ pub use crate::multistream_select::{ ListenerSelectResult, }, negotiated::{Negotiated, NegotiatedComplete, NegotiationError}, - protocol::{HeaderLine, Message, Protocol, ProtocolError}, + protocol::{HeaderLine, Message, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0}, }; +use bytes::Bytes; + +const LOG_TARGET: &str = "litep2p::multistream-select"; + /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { @@ -132,3 +137,63 @@ impl Default for Version { Version::V1 } } + +// This function is only used in the WebRTC transport. It expects one or more multistream-select +// messages in `remaining` and returns a list of protocols that were decoded from them. +fn drain_trailing_protocols( + mut remaining: Bytes, +) -> Result, error::NegotiationError> { + let mut protocols = vec![]; + + loop { + if remaining.is_empty() { + break; + } + + let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { + tracing::debug!( + target: LOG_TARGET, + ?error, + message = ?remaining, + "Failed to decode length-prefix in multistream message"); + error::NegotiationError::ParseError(ParseError::InvalidData) + })?; + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + message = ?tail, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); + } + + let len_size = remaining.len() - tail.len(); + let payload = remaining.slice(len_size..len_size + len); + let res = Message::decode(payload); + + match res { + Ok(Message::Header(HeaderLine::V1)) => protocols.push(PROTO_MULTISTREAM_1_0), + Ok(Message::Protocol(protocol)) => protocols.push(protocol), + Ok(Message::Protocols(_)) => + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + message = ?tail[..len], + "Failed to decode multistream message", + ); + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); + } + _ => return Err(error::NegotiationError::StateMismatch), + } + + remaining = remaining.slice(len_size + len..); + } + + Ok(protocols) +} diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index b24b31da2..71775df9a 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -54,6 +54,8 @@ pub const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n"; const MSG_PROTOCOL_NA: &[u8] = b"na\n"; /// The encoded form of a multistream-select 'ls' message. const MSG_LS: &[u8] = b"ls\n"; +/// A Protocol instance for the `/multistream/1.0.0` header line. +pub const PROTO_MULTISTREAM_1_0: Protocol = Protocol(Bytes::from_static(b"/multistream/1.0.0")); /// Logging target. const LOG_TARGET: &str = "litep2p::multistream-select"; @@ -230,7 +232,7 @@ impl Message { /// /// # Note /// -/// This is implementation is not compliant with the multistream-select protocol spec. +/// This implementation may not be compliant with the multistream-select protocol spec. /// The only purpose of this was to get the `multistream-select` protocol working with smoldot. pub fn webrtc_encode_multistream_message( messages: impl IntoIterator, @@ -249,9 +251,6 @@ pub fn webrtc_encode_multistream_message( header.append(&mut proto_bytes); } - // For the `Message::Protocols` to be interpreted correctly, it must be followed by a newline. - header.push(b'\n'); - Ok(BytesMut::from(&header[..])) } diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index 6c1e57462..c47e0024d 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -382,6 +382,7 @@ impl WebRtcConnection { target: LOG_TARGET, peer = ?self.peer, ?channel_id, + data_len = ?data.len(), "handle opening outbound substream", ); @@ -396,7 +397,7 @@ impl WebRtcConnection { target: LOG_TARGET, peer = ?self.peer, ?channel_id, - "multisteam-select handshake not ready", + "multistream-select handshake not ready", ); self.channels.insert( diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index 90fcfcc8c..ba6e2af7a 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -176,14 +176,14 @@ impl OpeningWebRtcConnection { .rtc .direct_api() .remote_dtls_fingerprint() - .clone() - .expect("fingerprint to exist"); + .expect("fingerprint to exist") + .clone(); Self::fingerprint_to_bytes(&fingerprint) } /// Get local fingerprint as bytes. fn local_fingerprint(&mut self) -> Vec { - Self::fingerprint_to_bytes(&self.rtc.direct_api().local_dtls_fingerprint()) + Self::fingerprint_to_bytes(self.rtc.direct_api().local_dtls_fingerprint()) } /// Convert `Fingerprint` to bytes. @@ -268,8 +268,8 @@ impl OpeningWebRtcConnection { .rtc .direct_api() .remote_dtls_fingerprint() - .clone() .expect("fingerprint to exist") + .clone() .bytes; const MULTIHASH_SHA256_CODE: u64 = 0x12;