diff --git a/beacon_node/lighthouse_network/src/rpc/handler.rs b/beacon_node/lighthouse_network/src/rpc/handler.rs index 33c5521c3b..61eee78ea7 100644 --- a/beacon_node/lighthouse_network/src/rpc/handler.rs +++ b/beacon_node/lighthouse_network/src/rpc/handler.rs @@ -892,6 +892,20 @@ where ConnectionEvent::DialUpgradeError(DialUpgradeError { info, error }) => { self.on_dial_upgrade_error(info, error) } + ConnectionEvent::ListenUpgradeError(e) => { + if matches!(e.error.1, RPCError::InvalidData(_)) { + // Peer is not complying with the protocol. Notify the application and disconnect. + let inbound_substream_id = self.current_inbound_substream_id; + self.current_inbound_substream_id.0 += 1; + + self.events_out.push(HandlerEvent::Err(HandlerErr::Inbound { + id: inbound_substream_id, + proto: e.error.0, + error: e.error.1, + })); + self.shutdown(None); + } + } _ => { // NOTE: ConnectionEvent is a non exhaustive enum so updates should be based on // release notes more than compiler feedback diff --git a/beacon_node/lighthouse_network/src/rpc/protocol.rs b/beacon_node/lighthouse_network/src/rpc/protocol.rs index 820f50ac93..6e8d07d39a 100644 --- a/beacon_node/lighthouse_network/src/rpc/protocol.rs +++ b/beacon_node/lighthouse_network/src/rpc/protocol.rs @@ -652,7 +652,7 @@ where E: EthSpec, { type Output = InboundOutput; - type Error = RPCError; + type Error = (Protocol, RPCError); type Future = BoxFuture<'static, Result>; fn upgrade_inbound(self, socket: TSocket, protocol: ProtocolId) -> Self::Future { @@ -697,10 +697,12 @@ where ) .await { - Err(e) => Err(RPCError::from(e)), + Err(e) => Err((versioned_protocol.protocol(), RPCError::from(e))), Ok((Some(Ok(request)), stream)) => Ok((request, stream)), - Ok((Some(Err(e)), _)) => Err(e), - Ok((None, _)) => Err(RPCError::IncompleteStream), + Ok((Some(Err(e)), _)) => Err((versioned_protocol.protocol(), e)), + Ok((None, _)) => { + Err((versioned_protocol.protocol(), RPCError::IncompleteStream)) + } } } } diff --git a/beacon_node/lighthouse_network/tests/rpc_tests.rs b/beacon_node/lighthouse_network/tests/rpc_tests.rs index 72d7aa0074..fbacef0841 100644 --- a/beacon_node/lighthouse_network/tests/rpc_tests.rs +++ b/beacon_node/lighthouse_network/tests/rpc_tests.rs @@ -4,7 +4,10 @@ mod common; use common::{build_tracing_subscriber, Protocol}; use lighthouse_network::rpc::{methods::*, RequestType}; -use lighthouse_network::service::api_types::AppRequestId; +use lighthouse_network::service::api_types::{ + AppRequestId, BlobsByRangeRequestId, BlocksByRangeRequestId, ComponentsByRangeRequestId, + DataColumnsByRangeRequestId, RangeRequestId, SyncRequestId, +}; use lighthouse_network::{NetworkEvent, ReportSource, Response}; use ssz::Encode; use ssz_types::VariableList; @@ -1413,3 +1416,175 @@ fn test_active_requests() { } }) } + +#[test] +fn test_request_too_large_blocks_by_range() { + let spec = Arc::new(E::default_spec()); + + test_request_too_large( + AppRequestId::Sync(SyncRequestId::BlocksByRange(BlocksByRangeRequestId { + id: 1, + parent_request_id: ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + })), + RequestType::BlocksByRange(OldBlocksByRangeRequest::new( + 0, + spec.max_request_blocks(ForkName::Base) as u64 + 1, // exceeds the max request defined in the spec. + 1, + )), + // Due to the invalid request, the receiver does not respond and closes the stream. + // On the sender's side, the handler sends an end-of-stream to the application because the + // stream has been closed. Therefore, we expect `BlocksByRange(None)` in this test. + Some(Response::BlocksByRange(None)), + ); +} + +#[test] +fn test_request_too_large_blobs_by_range() { + let spec = Arc::new(E::default_spec()); + + let max_request_blobs_count = spec.max_request_blob_sidecars(ForkName::Base) as u64 + / spec.max_blobs_per_block_by_fork(ForkName::Base); + test_request_too_large( + AppRequestId::Sync(SyncRequestId::BlobsByRange(BlobsByRangeRequestId { + id: 1, + parent_request_id: ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + })), + RequestType::BlobsByRange(BlobsByRangeRequest { + start_slot: 0, + count: max_request_blobs_count + 1, // exceeds the max request defined in the spec. + }), + // Due to the invalid request, the receiver does not respond and closes the stream. + // On the sender's side, the handler sends an end-of-stream to the application because the + // stream has been closed. Therefore, we expect `BlobsByRange(None)` in this test. + Some(Response::BlobsByRange(None)), + ); +} + +#[test] +fn test_request_too_large_data_columns_by_range() { + let spec = Arc::new(E::default_spec()); + + test_request_too_large( + AppRequestId::Sync(SyncRequestId::DataColumnsByRange( + DataColumnsByRangeRequestId { + id: 1, + parent_request_id: ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + }, + )), + RequestType::DataColumnsByRange(DataColumnsByRangeRequest { + start_slot: 0, + count: 0, + // exceeds the max request defined in the spec. + columns: vec![0; spec.number_of_columns as usize + 1], + }), + None, + ); +} + +fn test_request_too_large( + app_request_id: AppRequestId, + request: RequestType, + expected_response: Option>, +) { + // Set up the logging. + let log_level = "debug"; + let enable_logging = false; + build_tracing_subscriber(log_level, enable_logging); + let rt = Arc::new(Runtime::new().unwrap()); + let spec = Arc::new(E::default_spec()); + + rt.block_on(async { + let (mut sender, mut receiver) = common::build_node_pair( + Arc::downgrade(&rt), + ForkName::Base, + spec, + Protocol::Tcp, + false, + None, + ) + .await; + + // Build the sender future + let sender_future = async { + let mut is_response_received = false; + let mut is_disconnected = false; + loop { + if (expected_response.is_none() + || (expected_response.is_some() && is_response_received)) + && is_disconnected + { + // End the test. + return; + } + + match sender.next_event().await { + NetworkEvent::PeerConnectedOutgoing(peer_id) => { + debug!(?request, %peer_id, "Sending RPC request"); + sender + .send_request(peer_id, app_request_id, request.clone()) + .unwrap(); + } + NetworkEvent::ResponseReceived { + app_request_id, + response, + .. + } => { + debug!(?app_request_id, ?response, "Received response"); + if let Some(r) = &expected_response { + assert_eq!(&response, r); + is_response_received = true; + } else { + unreachable!(); + } + } + NetworkEvent::RPCFailed { .. } => { + // This variant should be unreachable, as the receiver doesn't respond with an error when a request exceeds the limit. + unreachable!(); + } + NetworkEvent::PeerDisconnected(peer_id) => { + // The receiver should disconnect as a result of the invalid request. + debug!(%peer_id, "Peer disconnected"); + is_disconnected = true; + } + _ => {} + } + } + }; + + // Build the receiver future + let receiver_future = async { + loop { + if let NetworkEvent::RequestReceived { .. } = receiver.next_event().await { + // This event should be unreachable, as the handler drops the invalid request. + unreachable!(); + } + } + }; + + tokio::select! { + _ = sender_future => {} + _ = receiver_future => {} + _ = sleep(Duration::from_secs(30)) => { + panic!("Future timed out"); + } + } + }); +}