diff --git a/rtc-interceptor/Cargo.toml b/rtc-interceptor/Cargo.toml index f7c77409..3497c12b 100644 --- a/rtc-interceptor/Cargo.toml +++ b/rtc-interceptor/Cargo.toml @@ -20,5 +20,6 @@ rtp.workspace = true rtcp.workspace = true rand.workspace = true log.workspace = true +bytes.workspace = true [dev-dependencies] \ No newline at end of file diff --git a/rtc-interceptor/src/gcc/mod.rs b/rtc-interceptor/src/gcc/mod.rs new file mode 100644 index 00000000..2c90b932 --- /dev/null +++ b/rtc-interceptor/src/gcc/mod.rs @@ -0,0 +1,673 @@ +//! GCC (Google Congestion Control) sender-side bandwidth estimator. +//! +//! Reads TWCC feedback packets generated by the remote [`TwccReceiverInterceptor`] and +//! estimates available send bandwidth using a trendline delay filter and AIMD rate +//! controller, following RFC 8698 / draft-ietf-rmcat-gcc-02. +//! +//! # Chain position +//! +//! GCC must be **inner** to [`TwccSenderInterceptor`] so that `handle_write` sees +//! packets after the TWCC sequence number has already been stamped. +//! Add GCC first (innermost), then TwccSender: +//! +//! ```text +//! Registry::new() +//! .with(gcc_builder.build()) // inner +//! .with(TwccSenderBuilder::new()...) // outer — stamps TWCC seq, then calls GCC +//! ``` +//! +//! `handle_write` flows outer → inner: TwccSender stamps first, then GCC reads it. +//! `handle_read` (TWCC feedback) also flows outer → inner: TwccSender passes through, +//! then GCC processes the feedback. +//! +//! # Usage +//! +//! ```ignore +//! use rtc_interceptor::{Registry, TwccSenderBuilder, GccInterceptorBuilder}; +//! +//! let (gcc_builder, gcc_handle) = GccInterceptorBuilder::new(); +//! let chain = Registry::new() +//! .with(gcc_builder.build()) // GCC inner +//! .with(TwccSenderBuilder::new().build()) // TwccSender outer +//! .build(); +//! +//! // In the application polling loop: +//! if let Some(bps) = gcc_handle.target_bitrate_bps() { +//! encoder.set_bitrate(bps); +//! } +//! ``` +//! +//! [`TwccSenderInterceptor`]: crate::TwccSenderInterceptor +//! [`TwccReceiverInterceptor`]: crate::TwccReceiverInterceptor + +mod rate_controller; +mod trendline; + +use crate::stream_info::StreamInfo; +use crate::{Interceptor, Packet, TaggedPacket, interceptor}; +use log::trace; +use rate_controller::AimdRateController; +use rtcp::transport_feedbacks::transport_layer_cc::{ + PacketStatusChunk, SymbolTypeTcc, TransportLayerCc, +}; +use rtp::extension::transport_cc_extension::TransportCcExtension; +use shared::error::Error; +use shared::marshal::Unmarshal; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::{Arc, Mutex}; +use std::time::Instant; +use trendline::TrendlineFilter; + +/// Shared state written by [`GccInterceptor`] and read by [`GccHandle`]. +pub struct GccShared { + /// Latest estimated available send bandwidth, in bits per second. + /// `None` until the first TWCC feedback has been processed. + pub target_bitrate_bps: Option, +} + +/// Read-only handle to the latest GCC bandwidth estimate. +/// +/// Clone-able and thread-safe — share freely across application threads. +#[derive(Clone)] +pub struct GccHandle { + inner: Arc>, +} + +impl GccHandle { + /// Returns the latest estimated available send bandwidth in bps, or `None` + /// if no TWCC feedback has been processed yet. + pub fn target_bitrate_bps(&self) -> Option { + self.inner + .lock() + .unwrap_or_else(|e| { + log::warn!("gcc: shared mutex poisoned, recovering last estimate"); + e.into_inner() + }) + .target_bitrate_bps + } +} + +/// Builder for [`GccInterceptor`]. +/// +/// Create with [`GccInterceptorBuilder::new`], which returns `(builder, handle)`. +pub struct GccInterceptorBuilder

{ + min_bitrate_bps: u32, + max_bitrate_bps: u32, + shared: Arc>, + _phantom: PhantomData

, +} + +impl

GccInterceptorBuilder

{ + /// Create a new builder and the associated [`GccHandle`]. + /// + /// The handle can be shared with the application so it can read the + /// latest bitrate estimate each polling tick. + pub fn new() -> (Self, GccHandle) { + let shared = Arc::new(Mutex::new(GccShared { + target_bitrate_bps: None, + })); + let handle = GccHandle { + inner: shared.clone(), + }; + let builder = Self { + min_bitrate_bps: 30_000, + max_bitrate_bps: 2_500_000, + shared, + _phantom: PhantomData, + }; + (builder, handle) + } + + /// Minimum send rate floor (default 30 kbps). + pub fn with_min_bitrate(mut self, bps: u32) -> Self { + self.min_bitrate_bps = bps; + self + } + + /// Maximum send rate ceiling (default 2.5 Mbps). + pub fn with_max_bitrate(mut self, bps: u32) -> Self { + self.max_bitrate_bps = bps; + self + } + + /// Build the interceptor factory closure. + pub fn build(self) -> impl FnOnce(P) -> GccInterceptor

{ + move |inner| GccInterceptor { + inner, + local_streams: HashMap::new(), + send_times: HashMap::new(), + start_time: Instant::now(), + trendline: TrendlineFilter::new(), + rate_controller: AimdRateController::new( + self.min_bitrate_bps as f64, + self.max_bitrate_bps as f64, + ), + shared: self.shared, + } + } +} + +/// Sender-side GCC bandwidth estimator interceptor. +/// +/// Intercepts outgoing RTP to record send-times (keyed by TWCC transport-wide +/// sequence number), then processes incoming TWCC feedback RTCP packets to +/// update the trendline delay filter and AIMD rate controller. The resulting +/// bitrate estimate is written to the [`GccHandle`] shared state. +#[derive(Interceptor)] +pub struct GccInterceptor

{ + #[next] + inner: P, + + /// SSRC → TWCC header extension ID, populated in `bind_local_stream`. + local_streams: HashMap, + + /// Send-time log: transport-wide seq → send_time_ms_from_start. + /// + /// Ring-buffer semantics: entries more than 512 sequence numbers behind the + /// current head are pruned to bound memory usage. + send_times: HashMap, + + /// Monotonic baseline for converting `Instant` values to milliseconds. + start_time: Instant, + + trendline: TrendlineFilter, + rate_controller: AimdRateController, + + /// Shared state — written here, read via [`GccHandle`]. + shared: Arc>, +} + +#[interceptor] +impl GccInterceptor

{ + /// Snoop on outgoing RTP: record send_time_ms keyed by TWCC seq. + /// + /// Must be inner to [`TwccSenderInterceptor`] so the sequence number has + /// already been stamped by the time this method is called. + #[overrides] + fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> { + if let Packet::Rtp(ref rtp) = msg.message + && let Some(&ext_id) = self.local_streams.get(&rtp.header.ssrc) + && let Some(ext_bytes) = rtp.header.get_extension(ext_id) + { + let mut buf = ext_bytes.clone(); + if let Ok(tcc_ext) = TransportCcExtension::unmarshal(&mut buf) { + let seq = tcc_ext.transport_sequence; + let send_ms = msg.now.duration_since(self.start_time).as_secs_f64() * 1000.0; + self.send_times.insert(seq, send_ms); + + // Prune entries > 512 sequence numbers behind the current head. + if self.send_times.len() > 512 { + self.send_times.retain(|&s, _| seq.wrapping_sub(s) <= 512); + } + } + } + self.inner.handle_write(msg) + } + + /// Intercept incoming TWCC feedback RTCP and update the bandwidth estimate. + #[overrides] + fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> { + if let Packet::Rtcp(ref rtcp_pkts) = msg.message { + let now = msg.now; + for fb in rtcp_pkts + .iter() + .filter_map(|pkt| pkt.as_any().downcast_ref::()) + { + self.process_feedback(fb, now); + } + } + self.inner.handle_read(msg) + } + + /// Register a local stream and record its TWCC extension ID. + #[overrides] + fn bind_local_stream(&mut self, info: &StreamInfo) { + use crate::twcc::stream_supports_twcc; + if let Some(ext_id) = stream_supports_twcc(info) + && ext_id != 0 + { + self.local_streams.insert(info.ssrc, ext_id); + } + self.inner.bind_local_stream(info); + } + + /// Remove a local stream. + #[overrides] + fn unbind_local_stream(&mut self, info: &StreamInfo) { + self.local_streams.remove(&info.ssrc); + self.inner.unbind_local_stream(info); + } +} + +impl GccInterceptor

{ + /// Process one TWCC feedback packet: reconstruct timing pairs, update + /// trendline + rate controller, publish new estimate to shared state. + fn process_feedback(&mut self, fb: &TransportLayerCc, now: Instant) { + let pairs = self.extract_send_recv_pairs(fb); + if pairs.is_empty() { + return; + } + + let loss = loss_fraction(fb); + + for (send_ms, recv_ms) in pairs { + self.trendline.update(send_ms, recv_ms); + } + + let signal = self.trendline.signal(); + let estimate = self.rate_controller.update(signal, loss, now); + + trace!( + "gcc: signal={:?} loss={:.1}% estimate={:.0}bps", + signal, + loss * 100.0, + estimate + ); + + self.shared + .lock() + .unwrap_or_else(|e| { + log::warn!("gcc: shared mutex poisoned, recovering to publish estimate"); + e.into_inner() + }) + .target_bitrate_bps = Some(estimate as u32); + } + + /// Reconstruct `(send_time_ms, recv_time_ms)` pairs from a TWCC feedback packet. + /// + /// Only packets that are in the local `send_times` log AND were marked as + /// received in the feedback are included. + fn extract_send_recv_pairs(&self, fb: &TransportLayerCc) -> Vec<(f64, f64)> { + // Base receive time: reference_time is 24-bit in 64 ms units → microseconds. + let base_recv_us = (fb.reference_time as i64) * 64_000; + let mut recv_us = base_recv_us; + let mut delta_idx = 0usize; + let mut seq = fb.base_sequence_number; + let mut result = Vec::new(); + + for chunk in &fb.packet_chunks { + // Iterate chunk contents inline to avoid allocating a Vec per chunk. + match chunk { + PacketStatusChunk::RunLengthChunk(rlc) => { + for _ in 0..rlc.run_length { + self.process_status( + rlc.packet_status_symbol, + fb, + &mut recv_us, + &mut delta_idx, + seq, + &mut result, + ); + seq = seq.wrapping_add(1); + } + } + PacketStatusChunk::StatusVectorChunk(svc) => { + for &status in &svc.symbol_list { + self.process_status( + status, + fb, + &mut recv_us, + &mut delta_idx, + seq, + &mut result, + ); + seq = seq.wrapping_add(1); + } + } + } + } + + result + } + + /// Helper: process one packet status symbol during send/recv pair extraction. + #[inline] + fn process_status( + &self, + status: SymbolTypeTcc, + fb: &TransportLayerCc, + recv_us: &mut i64, + delta_idx: &mut usize, + seq: u16, + result: &mut Vec<(f64, f64)>, + ) { + if matches!( + status, + SymbolTypeTcc::PacketReceivedSmallDelta | SymbolTypeTcc::PacketReceivedLargeDelta + ) && *delta_idx < fb.recv_deltas.len() + { + *recv_us += fb.recv_deltas[*delta_idx].delta; + *delta_idx += 1; + + if let Some(&send_ms) = self.send_times.get(&seq) { + let recv_ms = *recv_us as f64 / 1000.0; + result.push((send_ms, recv_ms)); + } + } + } +} + +/// Compute the fraction of packets not received in a TWCC feedback packet. +fn loss_fraction(fb: &TransportLayerCc) -> f64 { + let mut total = 0u32; + let mut received = 0u32; + + for chunk in &fb.packet_chunks { + match chunk { + PacketStatusChunk::RunLengthChunk(rlc) => { + total += rlc.run_length as u32; + if rlc.packet_status_symbol != SymbolTypeTcc::PacketNotReceived { + received += rlc.run_length as u32; + } + } + PacketStatusChunk::StatusVectorChunk(svc) => { + for sym in &svc.symbol_list { + total += 1; + if *sym != SymbolTypeTcc::PacketNotReceived { + received += 1; + } + } + } + } + } + + if total == 0 { + 0.0 + } else { + 1.0 - received as f64 / total as f64 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{RTPHeaderExtension, Registry, twcc::TRANSPORT_CC_URI}; + use rtcp::transport_feedbacks::transport_layer_cc::{ + RecvDelta, RunLengthChunk, StatusChunkTypeTcc, TYPE_TCC_DELTA_SCALE_FACTOR, + }; + use sansio::Protocol; + use shared::{TransportContext, TransportMessage}; + use std::time::Duration; + + fn stream_info_with_twcc(ssrc: u32) -> StreamInfo { + StreamInfo { + ssrc, + rtp_header_extensions: vec![RTPHeaderExtension { + uri: TRANSPORT_CC_URI.to_string(), + id: 5, + }], + ..Default::default() + } + } + + fn make_rtp(ssrc: u32, rtp_seq: u16, now: Instant) -> TaggedPacket { + TransportMessage { + now, + transport: TransportContext::default(), + message: Packet::Rtp(rtp::Packet { + header: rtp::header::Header { + ssrc, + sequence_number: rtp_seq, + ..Default::default() + }, + ..Default::default() + }), + } + } + + fn make_twcc_feedback( + base_seq: u16, + recv_times_us: &[i64], // per-packet receive times in microseconds (absolute) + now: Instant, + ) -> TaggedPacket { + let base_recv_us = recv_times_us.first().copied().unwrap_or(0); + // reference_time is 24-bit in 64ms units. + let reference_time = (base_recv_us / 64_000) as u32; + let base_ref_us = reference_time as i64 * 64_000; + + let mut deltas = Vec::new(); + let mut prev_us = base_ref_us; + for &t in recv_times_us { + let delta_us = t - prev_us; + deltas.push(RecvDelta { + type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, + delta: delta_us, + }); + prev_us = t; + } + + let chunk = PacketStatusChunk::RunLengthChunk(RunLengthChunk { + type_tcc: StatusChunkTypeTcc::RunLengthChunk, + packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, + run_length: recv_times_us.len() as u16, + }); + + let fb = TransportLayerCc { + sender_ssrc: 0, + media_ssrc: 0, + base_sequence_number: base_seq, + packet_status_count: recv_times_us.len() as u16, + reference_time, + fb_pkt_count: 0, + packet_chunks: vec![chunk], + recv_deltas: deltas, + }; + + TransportMessage { + now, + transport: TransportContext::default(), + message: Packet::Rtcp(vec![Box::new(fb)]), + } + } + + /// Build a chain with TwccSender (outer) → GccInterceptor (inner). + /// + /// GCC must be inner so `handle_write` sees TWCC-stamped packets. + fn make_chain() -> ( + impl Protocol + crate::Interceptor, + GccHandle, + ) { + use crate::TwccSenderBuilder; + let (gcc_builder, handle) = GccInterceptorBuilder::new(); + let chain = Registry::new() + .with(gcc_builder.build()) // GCC inner + .with(TwccSenderBuilder::new().build()) // TwccSender outer + .build(); + (chain, handle) + } + + #[test] + fn test_handle_returns_estimate_after_feedback() { + let (mut chain, handle) = make_chain(); + let ssrc = 1001u32; + chain.bind_local_stream(&stream_info_with_twcc(ssrc)); + + let base = Instant::now(); + // Send 5 packets; TwccSender will stamp seq 0..4. + for i in 0..5u16 { + chain + .handle_write(make_rtp( + ssrc, + i, + base + Duration::from_millis(i as u64 * 33), + )) + .unwrap(); + } + + // Feed back receive times: 33ms inter-packet spacing, 20ms constant one-way delay. + // Units: microseconds. 20ms = 20_000µs; spacing 33ms = 33_000µs. + let recv_us: Vec = (0..5).map(|i: i64| i * 33_000 + 20_000).collect(); + + chain + .handle_read(make_twcc_feedback( + 0, + &recv_us, + base + Duration::from_millis(200), + )) + .unwrap(); + + assert!( + handle.target_bitrate_bps().is_some(), + "should have an estimate after feedback" + ); + } + + #[test] + fn test_no_estimate_before_feedback() { + let (mut chain, handle) = make_chain(); + let ssrc = 2001u32; + chain.bind_local_stream(&stream_info_with_twcc(ssrc)); + chain + .handle_write(make_rtp(ssrc, 0, Instant::now())) + .unwrap(); + // No feedback yet. + assert!(handle.target_bitrate_bps().is_none()); + } + + #[test] + fn test_non_twcc_rtcp_passes_through() { + let (mut chain, handle) = make_chain(); + let rtcp_pkt = TransportMessage { + now: Instant::now(), + transport: TransportContext::default(), + message: Packet::Rtcp(vec![Box::new(rtcp::receiver_report::ReceiverReport { + ssrc: 42, + ..Default::default() + })]), + }; + // Should not panic; non-TWCC RTCP is forwarded. + chain.handle_read(rtcp_pkt).unwrap(); + assert!(handle.target_bitrate_bps().is_none()); + } + + #[test] + fn test_unbind_removes_stream() { + let (mut chain, _handle) = make_chain(); + let ssrc = 3001u32; + let info = stream_info_with_twcc(ssrc); + chain.bind_local_stream(&info); + chain.unbind_remote_stream(&info); // unbind (remote stream - no-op here) + chain.unbind_local_stream(&info); + // No panic — stream was removed cleanly. + } + + #[test] + fn test_send_time_not_recorded_without_twcc_ext() { + // RTP sent on an SSRC not bound with TWCC should not populate send_times. + let (mut chain, _handle) = make_chain(); + let ssrc = 4001u32; + // Bind WITHOUT TWCC ext. + let info = StreamInfo { + ssrc, + rtp_header_extensions: vec![], + ..Default::default() + }; + chain.bind_local_stream(&info); + // Just verify no panic and no estimate without TWCC feedback. + chain + .handle_write(make_rtp(ssrc, 1, Instant::now())) + .unwrap(); + } + + #[test] + fn test_loss_fraction_all_received() { + let chunk = PacketStatusChunk::RunLengthChunk(RunLengthChunk { + type_tcc: StatusChunkTypeTcc::RunLengthChunk, + packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, + run_length: 10, + }); + let fb = TransportLayerCc { + packet_chunks: vec![chunk], + recv_deltas: (0..10) + .map(|_| RecvDelta { + type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, + delta: 33_000 * TYPE_TCC_DELTA_SCALE_FACTOR, + }) + .collect(), + ..Default::default() + }; + assert_eq!(loss_fraction(&fb), 0.0); + } + + #[test] + fn test_loss_fraction_empty_chunks() { + let fb = TransportLayerCc { + packet_chunks: vec![], + ..Default::default() + }; + assert_eq!(loss_fraction(&fb), 0.0); + } + + #[test] + fn test_loss_fraction_status_vector_mixed() { + use rtcp::transport_feedbacks::transport_layer_cc::{StatusVectorChunk, SymbolSizeTypeTcc}; + // 3 received + 2 lost = 60% received, 40% loss. + let chunk = PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { + type_tcc: StatusChunkTypeTcc::StatusVectorChunk, + symbol_size: SymbolSizeTypeTcc::TwoBit, + symbol_list: vec![ + SymbolTypeTcc::PacketReceivedSmallDelta, + SymbolTypeTcc::PacketNotReceived, + SymbolTypeTcc::PacketReceivedLargeDelta, + SymbolTypeTcc::PacketNotReceived, + SymbolTypeTcc::PacketReceivedSmallDelta, + ], + }); + let fb = TransportLayerCc { + packet_chunks: vec![chunk], + recv_deltas: vec![ + RecvDelta { + type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, + delta: 10_000, + }, + RecvDelta { + type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, + delta: 20_000, + }, + RecvDelta { + type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, + delta: 10_000, + }, + ], + ..Default::default() + }; + let loss = loss_fraction(&fb); + assert!((loss - 0.4).abs() < 1e-9, "expected 40% loss, got {loss}"); + } + + #[test] + fn test_builder_custom_bitrate_bounds() { + let (builder, _handle) = GccInterceptorBuilder::<()>::new(); + let builder = builder + .with_min_bitrate(100_000) + .with_max_bitrate(5_000_000); + assert_eq!(builder.min_bitrate_bps, 100_000); + assert_eq!(builder.max_bitrate_bps, 5_000_000); + } + + #[test] + fn test_gcc_handle_clone_shares_state() { + let (_builder, handle) = GccInterceptorBuilder::<()>::new(); + let handle2 = handle.clone(); + assert!(handle.target_bitrate_bps().is_none()); + assert!(handle2.target_bitrate_bps().is_none()); + // Write through the shared mutex. + handle.inner.lock().unwrap().target_bitrate_bps = Some(500_000); + assert_eq!(handle2.target_bitrate_bps(), Some(500_000)); + } + + #[test] + fn test_loss_fraction_all_lost() { + let chunk = PacketStatusChunk::RunLengthChunk(RunLengthChunk { + type_tcc: StatusChunkTypeTcc::RunLengthChunk, + packet_status_symbol: SymbolTypeTcc::PacketNotReceived, + run_length: 5, + }); + let fb = TransportLayerCc { + packet_chunks: vec![chunk], + ..Default::default() + }; + assert_eq!(loss_fraction(&fb), 1.0); + } +} diff --git a/rtc-interceptor/src/gcc/rate_controller.rs b/rtc-interceptor/src/gcc/rate_controller.rs new file mode 100644 index 00000000..923f4417 --- /dev/null +++ b/rtc-interceptor/src/gcc/rate_controller.rs @@ -0,0 +1,214 @@ +use super::trendline::OveruseSignal; +use std::time::{Duration, Instant}; + +/// Internal state of the AIMD rate controller. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// Holding the current estimate after a recent decrease. + Hold, + /// Multiplicatively increasing toward available bandwidth. + Increase, + /// Multiplicatively decreasing in response to overuse. + Decrease, +} + +/// AIMD rate controller. +/// +/// Translates delay-based overuse signals and packet-loss fractions into a +/// target send bitrate, following RFC 8698 / draft-ietf-rmcat-gcc-02 §5.5. +/// +/// # Rate adaptation rules +/// +/// | Signal | Action | +/// |------------|-----------------------------------------------------| +/// | Overusing | Decrease: `estimate × 0.85` | +/// | Normal | Increase: `estimate × 1.08^dt_sec` | +/// | Underusing | Increase (same as Normal) | +/// | Loss > 10% | Additional decrease: `estimate × (1 - 0.5 × loss)` | +/// +/// After a decrease the controller enters `Hold` for [`HOLD_DURATION`] before +/// switching back to `Increase`, preventing rapid oscillation. +pub(crate) struct AimdRateController { + estimate_bps: f64, + min_bps: f64, + max_bps: f64, + state: State, + last_update: Option, + last_decrease: Option, +} + +/// How long to hold the estimate after a multiplicative decrease. +const HOLD_DURATION: Duration = Duration::from_millis(250); + +impl AimdRateController { + pub(crate) fn new(min_bps: f64, max_bps: f64) -> Self { + // Start at min or 300 kbps, whichever is larger, capped by max. + let initial = min_bps.max(300_000.0).min(max_bps); + Self { + estimate_bps: initial, + min_bps, + max_bps, + state: State::Hold, + last_update: None, + last_decrease: None, + } + } + + /// Update the rate estimate and return the new target in bps. + /// + /// # Parameters + /// - `signal`: delay-based overuse signal from the trendline filter + /// - `loss_fraction`: fraction of lost packets in the latest feedback window (`0.0`–`1.0`) + /// - `now`: current wall-clock time for computing the elapsed interval + pub(crate) fn update( + &mut self, + signal: OveruseSignal, + loss_fraction: f64, + now: Instant, + ) -> f64 { + let dt_s = self + .last_update + .map(|t| now.duration_since(t).as_secs_f64().min(1.0)) + .unwrap_or(0.1); + self.last_update = Some(now); + + // State machine transitions. + self.state = match signal { + OveruseSignal::Overusing => State::Decrease, + OveruseSignal::Normal | OveruseSignal::Underusing => { + match self.state { + State::Decrease | State::Hold => { + // After a decrease, hold briefly before increasing. + let held_long_enough = self + .last_decrease + .map(|t| now.duration_since(t) >= HOLD_DURATION) + .unwrap_or(true); + if held_long_enough { + State::Increase + } else { + State::Hold + } + } + State::Increase => State::Increase, + } + } + }; + + // Apply the rate adjustment. + match self.state { + State::Increase => { + // Multiplicative increase: 8 % per second. + let alpha = 1.08f64.powf(dt_s); + self.estimate_bps = (self.estimate_bps * alpha).min(self.max_bps); + } + State::Decrease => { + // Multiplicative decrease × 0.85; immediately enter Hold. + self.estimate_bps = (self.estimate_bps * 0.85).max(self.min_bps); + self.last_decrease = Some(now); + self.state = State::Hold; + } + State::Hold => {} + } + + // Loss-based adjustment (applied additively after delay-based control). + // > 10 % loss → reduce further; < 2 % → no-op (delay controller allows increase). + if loss_fraction > 0.1 { + let factor = 1.0 - 0.5 * loss_fraction; + self.estimate_bps = (self.estimate_bps * factor).max(self.min_bps); + } + + self.estimate_bps + } + + pub(crate) fn estimate_bps(&self) -> f64 { + self.estimate_bps + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_increase_on_normal() { + let mut ctrl = AimdRateController::new(30_000.0, 2_500_000.0); + // Seed with one update so last_update is set. + let t0 = Instant::now(); + ctrl.update(OveruseSignal::Normal, 0.0, t0); + let after = ctrl.update(OveruseSignal::Normal, 0.0, t0 + Duration::from_secs(1)); + // After 1 second of Normal, rate should have increased. + assert!(after > 300_000.0, "rate should increase on Normal: {after}"); + } + + #[test] + fn test_decrease_on_overuse() { + let mut ctrl = AimdRateController::new(30_000.0, 2_500_000.0); + let t0 = Instant::now(); + // Warm up to 1 Mbps. + ctrl.estimate_bps = 1_000_000.0; + let after = ctrl.update(OveruseSignal::Overusing, 0.0, t0); + assert!( + after < 1_000_000.0, + "rate should decrease on Overuse: {after}" + ); + assert!( + (after - 850_000.0).abs() < 5_000.0, + "decrease should be ~×0.85: {after}" + ); + } + + #[test] + fn test_hold_after_decrease() { + let mut ctrl = AimdRateController::new(30_000.0, 2_500_000.0); + let t0 = Instant::now(); + ctrl.estimate_bps = 1_000_000.0; + ctrl.update(OveruseSignal::Overusing, 0.0, t0); + let after_hold = ctrl.update(OveruseSignal::Normal, 0.0, t0 + Duration::from_millis(100)); + // Still in hold window — rate must not increase. + assert!( + (after_hold - ctrl.estimate_bps).abs() < 1.0, + "should be holding: {after_hold}" + ); + } + + #[test] + fn test_loss_reduces_rate() { + let mut ctrl = AimdRateController::new(30_000.0, 2_500_000.0); + let t0 = Instant::now(); + ctrl.estimate_bps = 1_000_000.0; + let after = ctrl.update(OveruseSignal::Normal, 0.15, t0); // 15 % loss + assert!(after < 1_000_000.0, "loss should reduce rate: {after}"); + } + + #[test] + fn test_clamped_at_min() { + let mut ctrl = AimdRateController::new(100_000.0, 2_500_000.0); + ctrl.estimate_bps = 101_000.0; + let t0 = Instant::now(); + // Multiple overuse decreases should not go below min. + for _ in 0..20 { + ctrl.update(OveruseSignal::Overusing, 0.5, t0); + } + assert!( + ctrl.estimate_bps >= 100_000.0, + "should not fall below min: {}", + ctrl.estimate_bps + ); + } + + #[test] + fn test_clamped_at_max() { + let mut ctrl = AimdRateController::new(30_000.0, 500_000.0); + let t0 = Instant::now(); + ctrl.estimate_bps = 490_000.0; + // Many seconds of Normal should not exceed max. + for i in 0..20u64 { + ctrl.update(OveruseSignal::Normal, 0.0, t0 + Duration::from_secs(i)); + } + assert!( + ctrl.estimate_bps <= 500_000.0, + "should not exceed max: {}", + ctrl.estimate_bps + ); + } +} diff --git a/rtc-interceptor/src/gcc/trendline.rs b/rtc-interceptor/src/gcc/trendline.rs new file mode 100644 index 00000000..435e6e42 --- /dev/null +++ b/rtc-interceptor/src/gcc/trendline.rs @@ -0,0 +1,225 @@ +use std::collections::VecDeque; + +/// Signal from the overuse detector. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum OveruseSignal { + /// Queuing delay is stable — safe to hold or increase send rate. + Normal, + /// Queuing delay is growing — reduce send rate. + Overusing, + /// Queuing delay is shrinking — network has headroom. + Underusing, +} + +/// Trendline delay filter (Chrome-style, RFC 8698 §5.4). +/// +/// Maintains a sliding window of `(send_time_ms, smoothed_accumulated_delay_ms)` +/// samples and fits a linear regression to detect whether one-way queuing delay +/// is trending upward (overuse), downward (underuse), or stable. +/// +/// # Algorithm summary +/// +/// For each consecutive received packet pair (i-1, i): +/// ```text +/// gradient = (recv_i - recv_{i-1}) - (send_i - send_{i-1}) +/// smoothed += (1 - smoothing) * (gradient - smoothed) +/// accumulated += smoothed +/// window.push_back((send_i, accumulated)) +/// ``` +/// A linear regression of `accumulated_delay ~ slope * send_time` over the window +/// gives the trend. `modified_trend = slope * window_len`. Compared against a +/// dynamic threshold `T_hat` (maintained via Kalman-like update) this yields the +/// three-state overuse signal. +pub(crate) struct TrendlineFilter { + /// Sliding window: (send_time_ms, accumulated_smoothed_delay_ms). + window: VecDeque<(f64, f64)>, + /// Maximum window size (default 20). + window_size: usize, + /// Previous send time (ms) for computing inter-packet gradients. + prev_send_ms: Option, + /// Previous receive time (ms) for computing inter-packet gradients. + prev_recv_ms: Option, + /// Running sum of smoothed delay gradients. + accumulated_delay: f64, + /// Exponential moving average of the delay gradient. + smoothed_delay: f64, + /// EMA smoothing coefficient (higher = smoother/slower). Default 0.9. + smoothing: f64, + /// Dynamic detection threshold (ms). Bounded to [6, 600]. + threshold: f64, +} + +impl Default for TrendlineFilter { + fn default() -> Self { + Self { + window: VecDeque::new(), + window_size: 20, + prev_send_ms: None, + prev_recv_ms: None, + accumulated_delay: 0.0, + smoothed_delay: 0.0, + smoothing: 0.9, + threshold: 12.5, + } + } +} + +impl TrendlineFilter { + pub(crate) fn new() -> Self { + Self::default() + } + + /// Feed a new (send_time_ms, recv_time_ms) pair into the filter. + /// + /// `send_time_ms` must be monotonically increasing within a session (the + /// baseline can be arbitrary — only differences matter for the regression). + /// `recv_time_ms` is the arrival time in the same ms units. + pub(crate) fn update(&mut self, send_ms: f64, recv_ms: f64) { + if let (Some(ps), Some(pr)) = (self.prev_send_ms, self.prev_recv_ms) { + let gradient = (recv_ms - pr) - (send_ms - ps); + self.smoothed_delay += (1.0 - self.smoothing) * (gradient - self.smoothed_delay); + self.accumulated_delay += self.smoothed_delay; + } + self.prev_send_ms = Some(send_ms); + self.prev_recv_ms = Some(recv_ms); + + self.window.push_back((send_ms, self.accumulated_delay)); + if self.window.len() > self.window_size { + self.window.pop_front(); + } + + // Update the dynamic threshold (simplified Kalman-like adaptation). + let modified_trend = self.slope() * self.window.len() as f64; + let gamma = modified_trend.abs(); + const K_UP: f64 = 0.0087; + const K_DOWN: f64 = 0.039; + if gamma > self.threshold { + self.threshold += K_UP * (gamma - self.threshold); + } else { + self.threshold -= K_DOWN * self.threshold; + } + self.threshold = self.threshold.clamp(6.0, 600.0); + } + + /// Current overuse signal based on the latest trendline slope. + pub(crate) fn signal(&self) -> OveruseSignal { + if self.window.len() < 2 { + return OveruseSignal::Normal; + } + let modified_trend = self.slope() * self.window.len() as f64; + if modified_trend > self.threshold { + OveruseSignal::Overusing + } else if modified_trend < -self.threshold { + OveruseSignal::Underusing + } else { + OveruseSignal::Normal + } + } + + /// Ordinary least-squares slope of `accumulated_delay ~ k * send_time`. + fn slope(&self) -> f64 { + let n = self.window.len() as f64; + if n < 2.0 { + return 0.0; + } + let mut sx = 0.0f64; + let mut sy = 0.0f64; + let mut sxx = 0.0f64; + let mut sxy = 0.0f64; + for &(x, y) in &self.window { + sx += x; + sy += y; + sxx += x * x; + sxy += x * y; + } + let denom = n * sxx - sx * sx; + if denom.abs() < 1e-10 { + return 0.0; + } + (n * sxy - sx * sy) / denom + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_signal_normal_constant_delay() { + let mut f = TrendlineFilter::new(); + // Packets arrive with exactly the same one-way delay — no trend. + let base_send = 0.0f64; + let base_recv = 20.0f64; // constant 20 ms delay + for i in 0..25u32 { + let send = base_send + i as f64 * 33.0; + let recv = base_recv + i as f64 * 33.0; + f.update(send, recv); + } + assert_eq!(f.signal(), OveruseSignal::Normal); + } + + #[test] + fn test_signal_overuse_increasing_delay() { + let mut f = TrendlineFilter::new(); + // Each packet adds 50 ms of extra queuing delay — unmistakable overuse. + // delay_gradient per step = 50ms >> threshold (12.5ms). + for i in 0..25u32 { + let send = i as f64 * 33.0; + let recv = i as f64 * 33.0 + i as f64 * 50.0; // cumulative +50ms/packet + f.update(send, recv); + } + assert_eq!(f.signal(), OveruseSignal::Overusing); + } + + #[test] + fn test_signal_underuse_decreasing_delay() { + let mut f = TrendlineFilter::new(); + // Delay decreases by 50ms per packet: gradient = -50ms >> -threshold. + // Start with 1200ms to stay non-negative across 24 packets. + for i in 0..25u32 { + let send = i as f64 * 33.0; + let recv = send + (1200.0 - i as f64 * 50.0).max(0.0); + f.update(send, recv); + } + assert_eq!(f.signal(), OveruseSignal::Underusing); + } + + #[test] + fn test_window_size_capped() { + let mut f = TrendlineFilter::new(); + for i in 0..50u32 { + f.update(i as f64 * 33.0, i as f64 * 33.0 + 20.0); + } + assert!(f.window.len() <= f.window_size); + } + + #[test] + fn test_threshold_adapts() { + let mut f = TrendlineFilter::new(); + // Feed packets with 50ms/step gradient — strong sustained overuse. + // The dynamic threshold adapts (via k_up / k_down update rules) and must + // remain strictly between the hard bounds [6, 600] ms. + for i in 0..25u32 { + let send = i as f64 * 33.0; + let recv = send + i as f64 * 50.0; + f.update(send, recv); + } + // Threshold must stay within its hard bounds — never hit the floor or ceiling. + assert!( + f.threshold > 6.0, + "threshold should be above floor: {}", + f.threshold + ); + assert!( + f.threshold < 600.0, + "threshold should be below ceiling: {}", + f.threshold + ); + // With sustained high gradient the filter must detect overuse. + assert_eq!( + f.signal(), + OveruseSignal::Overusing, + "should be overusing with 50ms/packet delay growth" + ); + } +} diff --git a/rtc-interceptor/src/jitter_buffer/mod.rs b/rtc-interceptor/src/jitter_buffer/mod.rs new file mode 100644 index 00000000..eccfd60d --- /dev/null +++ b/rtc-interceptor/src/jitter_buffer/mod.rs @@ -0,0 +1,446 @@ +//! Jitter Buffer Interceptor +//! +//! A receiver-side interceptor that buffers incoming RTP packets and releases +//! them in sequence order after an adaptive playout delay. +//! +//! # Algorithm +//! +//! The target playout delay adapts to observed interarrival jitter using the +//! RFC 3550 §A.8 formula: `target = clamp(jitter / clock_rate × 3, min, max)`. +//! The ×3 factor covers ~99.7% of the jitter spread under a Gaussian model. +//! +//! If the sender includes a `playout-delay` RTP header extension, its +//! `min_delay` and `max_delay` values (in 10 ms increments) are applied as +//! bounds on the adaptive target. +//! +//! # Placement in the interceptor chain +//! +//! The jitter buffer should be the **outermost** interceptor so that all inner +//! interceptors (NACK generator, receiver-report, TWCC) still observe every +//! packet in its eventually-correct order: +//! +//! ```text +//! JitterBuffer → NackGenerator → ReceiverReport → TwccReceiver → Noop +//! ``` +//! +//! # Usage +//! +//! ```ignore +//! use rtc_interceptor::{Registry, JitterBufferBuilder}; +//! use std::time::Duration; +//! +//! let chain = Registry::new() +//! .with(JitterBufferBuilder::new() +//! .with_min_delay(Duration::from_millis(20)) +//! .with_max_delay(Duration::from_millis(500)) +//! .with_initial_delay(Duration::from_millis(50)) +//! .build()) +//! .build(); +//! ``` + +use crate::stream_info::StreamInfo; +use crate::{Interceptor, Packet, TaggedPacket, interceptor}; +use log::error; +use shared::error::Error; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::time::{Duration, Instant}; + +mod stream; +use stream::{JitterBufferStream, PLAYOUT_DELAY_URI}; + +/// Builder for [`JitterBufferInterceptor`]. +pub struct JitterBufferBuilder

{ + min_delay: Duration, + max_delay: Duration, + initial_delay: Duration, + _phantom: PhantomData

, +} + +impl

Default for JitterBufferBuilder

{ + fn default() -> Self { + Self { + min_delay: Duration::from_millis(20), + max_delay: Duration::from_millis(500), + initial_delay: Duration::from_millis(50), + _phantom: PhantomData, + } + } +} + +impl

JitterBufferBuilder

{ + pub fn new() -> Self { + Self::default() + } + + /// Minimum playout delay floor (default 20 ms). + pub fn with_min_delay(mut self, d: Duration) -> Self { + self.min_delay = d; + self + } + + /// Maximum playout delay / force-release ceiling (default 500 ms). + pub fn with_max_delay(mut self, d: Duration) -> Self { + self.max_delay = d; + self + } + + /// Starting target delay before enough packets have been seen to estimate jitter + /// (default 50 ms). + pub fn with_initial_delay(mut self, d: Duration) -> Self { + self.initial_delay = d; + self + } + + /// Build the interceptor factory closure. + pub fn build(self) -> impl FnOnce(P) -> JitterBufferInterceptor

{ + move |inner| JitterBufferInterceptor { + inner, + min_delay: self.min_delay, + max_delay: self.max_delay, + initial_delay: self.initial_delay, + streams: HashMap::new(), + last_now: None, + } + } +} + +/// Receiver-side jitter buffer interceptor. +/// +/// Buffers incoming RTP packets per SSRC and releases them in sequence order +/// after an adaptive playout delay. RTCP packets and packets from unbound +/// SSRCs are forwarded immediately without buffering. +#[derive(Interceptor)] +pub struct JitterBufferInterceptor

{ + #[next] + inner: P, + + min_delay: Duration, + max_delay: Duration, + initial_delay: Duration, + + /// Per-SSRC jitter buffer state, created in `bind_remote_stream`. + streams: HashMap, + + /// Monotonic timestamp tracked from `handle_read` / `handle_timeout` calls, + /// used by `poll_read` instead of `Instant::now()` to avoid wall-clock dependency. + last_now: Option, +} + +#[interceptor] +impl JitterBufferInterceptor

{ + /// Buffer incoming RTP for tracked SSRCs; pass everything else through immediately. + #[overrides] + fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> { + // Track the latest timestamp for use by poll_read. + self.last_now = Some(msg.now); + if let Packet::Rtp(ref rtp) = msg.message + && let Some(stream) = self.streams.get_mut(&rtp.header.ssrc) + { + // insert() returns false for already-released sequences or duplicates; drop those. + stream.insert(msg.now, msg); + return Ok(()); + } + // RTCP, or RTP from an unbound SSRC → forward without delay. + self.inner.handle_read(msg) + } + + /// Flush ready buffered packets into the inner chain, then poll the inner chain. + /// + /// Uses the latest timestamp seen from `handle_read`/`handle_timeout` rather + /// than `Instant::now()`, so the interceptor stays deterministic and avoids + /// panics when buffered arrivals are in the future relative to wall-clock time. + #[overrides] + fn poll_read(&mut self) -> Option { + if let Some(now) = self.last_now { + self.drain_ready(now); + } + self.inner.poll_read() + } + + /// Drain ready packets on each timer tick so buffers don't stall between app polls. + #[overrides] + fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> { + self.last_now = Some(now); + self.drain_ready(now); + self.inner.handle_timeout(now) + } + + /// Return the earliest scheduled release time so the driver wakes at the right moment. + #[overrides] + fn poll_timeout(&mut self) -> Option { + let buf_eto = self + .streams + .values() + .filter_map(|s| s.next_wake_time()) + .min(); + let inner_eto = self.inner.poll_timeout(); + match (buf_eto, inner_eto) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, b) => b, + } + } + + /// Create a per-SSRC buffer when a remote stream is bound. + #[overrides] + fn bind_remote_stream(&mut self, info: &StreamInfo) { + let ext_id = info + .rtp_header_extensions + .iter() + .find(|e| e.uri == PLAYOUT_DELAY_URI) + .map(|e| e.id as u8); + + self.streams.insert( + info.ssrc, + JitterBufferStream::new( + info.clock_rate, + ext_id, + self.initial_delay, + self.min_delay, + self.max_delay, + ), + ); + self.inner.bind_remote_stream(info); + } + + /// Drop the per-SSRC buffer when a remote stream is unbound. + #[overrides] + fn unbind_remote_stream(&mut self, info: &StreamInfo) { + self.streams.remove(&info.ssrc); + self.inner.unbind_remote_stream(info); + } +} + +impl JitterBufferInterceptor

{ + /// Collect ready packets from all streams and inject them into the inner chain. + /// + /// We collect first to satisfy the borrow checker: `streams` and `inner` + /// are separate fields but both require `&mut self`. + fn drain_ready(&mut self, now: Instant) { + let mut ready = Vec::new(); + for stream in self.streams.values_mut() { + while let Some(pkt) = stream.pop_ready(now) { + ready.push(pkt); + } + } + for pkt in ready { + if let Err(e) = self.inner.handle_read(pkt) { + error!("jitter_buffer: inner.handle_read error: {}", e); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream_info::RTPHeaderExtension; + use crate::{Registry, stream_info::StreamInfo}; + use sansio::Protocol; + use shared::{TransportContext, TransportMessage}; + + fn make_stream_info(ssrc: u32, clock_rate: u32) -> StreamInfo { + StreamInfo { + ssrc, + clock_rate, + ..Default::default() + } + } + + fn make_rtp_at(ssrc: u32, seq: u16, ts: u32, now: Instant) -> TaggedPacket { + TransportMessage { + now, + transport: TransportContext::default(), + message: Packet::Rtp(rtp::Packet { + header: rtp::header::Header { + ssrc, + sequence_number: seq, + timestamp: ts, + ..Default::default() + }, + ..Default::default() + }), + } + } + + fn make_rtcp(ssrc: u32) -> TaggedPacket { + TransportMessage { + now: Instant::now(), + transport: TransportContext::default(), + message: Packet::Rtcp(vec![Box::new(rtcp::receiver_report::ReceiverReport { + ssrc, + ..Default::default() + })]), + } + } + + /// Build a chain with a short initial delay for testing. + fn make_chain( + initial_ms: u64, + max_ms: u64, + ) -> impl Protocol + crate::Interceptor { + Registry::new() + .with( + JitterBufferBuilder::new() + .with_min_delay(Duration::from_millis(initial_ms)) + .with_max_delay(Duration::from_millis(max_ms)) + .with_initial_delay(Duration::from_millis(initial_ms)) + .build(), + ) + .build() + } + + #[test] + fn test_in_order_packets_released_after_delay() { + let mut chain = make_chain(50, 500); + let ssrc = 1111; + chain.bind_remote_stream(&make_stream_info(ssrc, 90000)); + + let base = Instant::now(); + for i in 0..3u16 { + chain + .handle_read(make_rtp_at(ssrc, i + 1, i as u32 * 3000, base)) + .unwrap(); + } + + // Before delay has elapsed — nothing ready. + chain.handle_timeout(base).unwrap(); + assert!(chain.poll_read().is_none()); + + // After delay has elapsed — packets should be available. + chain + .handle_timeout(base + Duration::from_millis(100)) + .unwrap(); + let mut released = 0u16; + while chain.poll_read().is_some() { + released += 1; + } + assert_eq!(released, 3); + } + + #[test] + fn test_out_of_order_reordered() { + let mut chain = make_chain(50, 500); + let ssrc = 2222; + chain.bind_remote_stream(&make_stream_info(ssrc, 90000)); + + let base = Instant::now(); + // Arrive as seq 1, 3, 2. + chain.handle_read(make_rtp_at(ssrc, 1, 0, base)).unwrap(); + chain.handle_read(make_rtp_at(ssrc, 3, 6000, base)).unwrap(); + chain.handle_read(make_rtp_at(ssrc, 2, 3000, base)).unwrap(); + + // Release all after the delay. + chain + .handle_timeout(base + Duration::from_millis(100)) + .unwrap(); + + let mut seqs = Vec::new(); + while let Some(pkt) = chain.poll_read() { + if let Packet::Rtp(rtp) = pkt.message { + seqs.push(rtp.header.sequence_number); + } + } + // Must come out in sequence order. + assert_eq!(seqs, vec![1, 2, 3]); + } + + #[test] + fn test_force_release_at_max_delay() { + let initial_ms = 50u64; + let max_ms = 200u64; + let mut chain = make_chain(initial_ms, max_ms); + let ssrc = 3333; + chain.bind_remote_stream(&make_stream_info(ssrc, 90000)); + + let base = Instant::now(); + // Insert seq 1; seq 2 never arrives. + chain.handle_read(make_rtp_at(ssrc, 1, 0, base)).unwrap(); + + // At max_delay + 1ms: seq 1 must be force-released even without seq 2. + let force_time = base + Duration::from_millis(max_ms + 1); + chain.handle_timeout(force_time).unwrap(); + assert!( + chain.poll_read().is_some(), + "seq 1 should be force-released" + ); + } + + #[test] + fn test_rtcp_passes_through_immediately() { + let mut chain = make_chain(50, 500); + let ssrc = 4444; + chain.bind_remote_stream(&make_stream_info(ssrc, 90000)); + + chain.handle_read(make_rtcp(ssrc)).unwrap(); + // RTCP bypasses the buffer and should be visible to the inner chain. + // (The noop inner doesn't surface it, but the call must not hang or panic.) + // Verify by checking that poll_read doesn't return a buffered item. + chain.handle_timeout(Instant::now()).unwrap(); + assert!(chain.poll_read().is_none()); + } + + #[test] + fn test_unbind_clears_buffer() { + let initial_ms = 50u64; + let mut chain = make_chain(initial_ms, 500); + let ssrc = 5555; + let info = make_stream_info(ssrc, 90000); + chain.bind_remote_stream(&info); + + let base = Instant::now(); + chain.handle_read(make_rtp_at(ssrc, 1, 0, base)).unwrap(); + + // Unbind before the delay expires. + chain.unbind_remote_stream(&info); + + // After the delay, nothing is released (buffer was dropped). + chain + .handle_timeout(base + Duration::from_millis(100)) + .unwrap(); + assert!(chain.poll_read().is_none()); + } + + #[test] + fn test_unbound_ssrc_passes_through() { + let mut chain = make_chain(50, 500); + // Do NOT bind any stream. + let ssrc = 6666; + let base = Instant::now(); + + // Packet from an unbound SSRC must not be buffered — forwarded immediately. + chain.handle_read(make_rtp_at(ssrc, 1, 0, base)).unwrap(); + // handle_timeout at exactly base (no delay passed) should not hold the packet back. + chain.handle_timeout(base).unwrap(); + // The packet should be immediately readable from the inner chain rather than buffered. + assert!( + chain.poll_read().is_some(), + "unbound SSRC packets should pass through immediately" + ); + } + + #[test] + fn test_poll_timeout_returns_buffer_wake_time() { + let initial_ms = 50u64; + let mut chain = make_chain(initial_ms, 500); + let ssrc = 7777; + chain.bind_remote_stream(&make_stream_info(ssrc, 90000)); + + let base = Instant::now(); + chain.handle_read(make_rtp_at(ssrc, 1, 0, base)).unwrap(); + + let wake = chain.poll_timeout(); + assert!( + wake.is_some(), + "should have a wake time after buffering a packet" + ); + // Wake time should be approximately base + initial_delay. + let wake = wake.unwrap(); + assert!(wake > base, "wake time must be in the future"); + assert!( + wake <= base + Duration::from_millis(initial_ms + 10), + "wake time should be close to initial_delay" + ); + } +} diff --git a/rtc-interceptor/src/jitter_buffer/stream.rs b/rtc-interceptor/src/jitter_buffer/stream.rs new file mode 100644 index 00000000..3878ca13 --- /dev/null +++ b/rtc-interceptor/src/jitter_buffer/stream.rs @@ -0,0 +1,448 @@ +use crate::{Packet, TaggedPacket}; +use std::collections::VecDeque; +use std::time::{Duration, Instant}; + +/// RTP header extension URI for the playout-delay extension. +/// +pub(crate) const PLAYOUT_DELAY_URI: &str = + "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay"; + +/// Per-SSRC jitter buffer state. +/// +/// Buffers incoming RTP packets in sequence order and releases them after +/// an adaptive playout delay computed from the RFC 3550 §A.8 jitter formula. +pub(crate) struct JitterBufferStream { + /// RTP clock rate (e.g. 90 000 for video, 48 000 for Opus audio). + clock_rate: u32, + /// One-byte header-extension ID for playout-delay, if negotiated. + playout_delay_ext_id: Option, + /// Sorted packet buffer: (seq, arrival_time, scheduled_release, packet). + buffer: VecDeque<(u16, Instant, Instant, TaggedPacket)>, + /// Last sequence number released to the application (guards against late arrivals). + last_released: Option, + // --- RFC 3550 §A.8 adaptive delay state --- + last_rtp_ts: u32, + last_arrival: Instant, + jitter: f64, // running estimate in RTP clock units + started: bool, + // --- Configuration --- + pub(crate) target_delay: Duration, + min_delay: Duration, + max_delay: Duration, +} + +impl JitterBufferStream { + pub(crate) fn new( + clock_rate: u32, + playout_delay_ext_id: Option, + initial_delay: Duration, + min_delay: Duration, + max_delay: Duration, + ) -> Self { + Self { + clock_rate, + playout_delay_ext_id, + buffer: VecDeque::new(), + last_released: None, + last_rtp_ts: 0, + last_arrival: Instant::now(), + jitter: 0.0, + started: false, + target_delay: initial_delay, + min_delay, + max_delay, + } + } + + /// Returns `true` if sequence number `a` is strictly after `b` under u16 wrapping. + #[inline] + fn seq_is_after(a: u16, b: u16) -> bool { + a != b && a.wrapping_sub(b) < 0x8000 + } + + /// Update the jitter estimate from a new packet and compute its scheduled release time. + /// + /// Jitter is only updated for packets that advance the RTP timestamp (i.e. the RTP + /// timestamp difference is in the forward half of the u32 space, matching the same + /// wrapping arithmetic used for sequence numbers). Out-of-order or duplicate + /// timestamps are accepted into the buffer but do not corrupt the jitter estimate. + fn compute_release(&mut self, now: Instant, rtp_ts: u32) -> Instant { + if self.started { + let rtp_diff = rtp_ts.wrapping_sub(self.last_rtp_ts); + // Only update for forward-advancing RTP timestamps (rtp_diff in (0, 2^31)). + if rtp_diff > 0 + && rtp_diff < 0x8000_0000 + && self.clock_rate > 0 + && let Some(arrival_diff) = now.checked_duration_since(self.last_arrival) + { + let arrival_diff = arrival_diff.as_secs_f64(); + let d = (arrival_diff * self.clock_rate as f64 - rtp_diff as f64).abs(); + // RFC 3550 §A.8: J(i) = J(i-1) + (|D(i,i-1)| - J(i-1)) / 16 + self.jitter += (d - self.jitter) / 16.0; + + // target = clamp(jitter_seconds × 3, min, max) + let jitter_secs = self.jitter / self.clock_rate as f64 * 3.0; + self.target_delay = Duration::from_secs_f64(jitter_secs) + .max(self.min_delay) + .min(self.max_delay); + + self.last_rtp_ts = rtp_ts; + self.last_arrival = now; + } + } else { + self.started = true; + self.last_rtp_ts = rtp_ts; + self.last_arrival = now; + } + now + self.target_delay + } + + /// Parse a playout-delay RTP extension (3 bytes, 12-bit min + 12-bit max in 10 ms units). + fn parse_playout_delay(data: &[u8]) -> Option<(Duration, Duration)> { + if data.len() < 3 { + return None; + } + let min_raw = ((data[0] as u16) << 4) | ((data[1] as u16) >> 4); + let max_raw = (((data[1] as u16) & 0x0F) << 8) | (data[2] as u16); + Some(( + Duration::from_millis(min_raw as u64 * 10), + Duration::from_millis(max_raw as u64 * 10), + )) + } + + /// Insert a packet into the buffer in sequence order. + /// + /// Returns `false` if the packet is a late duplicate (already past `last_released`). + pub(crate) fn insert(&mut self, now: Instant, pkt: TaggedPacket) -> bool { + let (seq, rtp_ts) = match &pkt.message { + Packet::Rtp(rtp) => (rtp.header.sequence_number, rtp.header.timestamp), + _ => return false, + }; + + // Reject packets at or before the last released sequence. + if let Some(last) = self.last_released + && !Self::seq_is_after(seq, last) + { + return false; + } + + // Compute the release time (this also updates target_delay via jitter estimate). + let release = self.compute_release(now, rtp_ts); + + // Apply playout-delay extension hints from the sender for this packet only. + // We compute effective bounds without permanently mutating the configured + // min/max so that subsequent packets with different (or absent) hints are + // not permanently clamped. + let release = if let (Packet::Rtp(rtp), Some(ext_id)) = + (&pkt.message, self.playout_delay_ext_id) + && let Some(ext_bytes) = rtp.header.get_extension(ext_id) + && let Some((sender_min, sender_max)) = Self::parse_playout_delay(ext_bytes.as_ref()) + { + // Sender's min raises our floor; sender's max lowers our ceiling. + let effective_min = self.min_delay.max(sender_min); + let effective_max = self.max_delay.min(sender_max.max(effective_min)); + let clamped_delay = self.target_delay.max(effective_min).min(effective_max); + now + clamped_delay + } else { + release + }; + + // Reject duplicate sequence numbers already in the buffer. + if self.buffer.iter().any(|(s, _, _, _)| *s == seq) { + return false; + } + + // Insert at the correct sorted position (ascending sequence order). + let pos = self + .buffer + .iter() + .position(|(s, _, _, _)| Self::seq_is_after(*s, seq)) + .unwrap_or(self.buffer.len()); + self.buffer.insert(pos, (seq, now, release, pkt)); + true + } + + /// Return the head packet if it is ready for release, or `None` if not yet. + /// + /// A packet is ready when `now >= release_time` or it has been held for `>= max_delay`. + pub(crate) fn pop_ready(&mut self, now: Instant) -> Option { + if let Some(&(_, arrival, release, _)) = self.buffer.front() { + let ready = now >= release || now.duration_since(arrival) >= self.max_delay; + if ready { + let (seq, _, _, pkt) = self.buffer.pop_front().unwrap(); + self.last_released = Some(seq); + return Some(pkt); + } + } + None + } + + /// Earliest instant at which the driver should wake up to service this stream. + pub(crate) fn next_wake_time(&self) -> Option { + self.buffer.front().map(|(_, arrival, release, _)| { + let force_release = *arrival + self.max_delay; + (*release).min(force_release) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use shared::{TransportContext, TransportMessage}; + + fn make_rtp(ssrc: u32, seq: u16, ts: u32) -> TaggedPacket { + TransportMessage { + now: Instant::now(), + transport: TransportContext::default(), + message: Packet::Rtp(rtp::Packet { + header: rtp::header::Header { + ssrc, + sequence_number: seq, + timestamp: ts, + ..Default::default() + }, + ..Default::default() + }), + } + } + + #[test] + fn test_seq_is_after() { + assert!(JitterBufferStream::seq_is_after(1, 0)); + assert!(JitterBufferStream::seq_is_after(100, 99)); + assert!(!JitterBufferStream::seq_is_after(0, 1)); + // wraparound: 0 is after 0xffff + assert!(JitterBufferStream::seq_is_after(0, 0xffff)); + // equal is not "after" + assert!(!JitterBufferStream::seq_is_after(5, 5)); + } + + #[test] + fn test_insert_in_order() { + let delay = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, None, delay, delay, Duration::from_secs(1)); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + s.insert(now, make_rtp(1, 2, 900)); + s.insert(now, make_rtp(1, 3, 1800)); + assert_eq!(s.buffer.len(), 3); + assert_eq!(s.buffer[0].0, 1); + assert_eq!(s.buffer[1].0, 2); + assert_eq!(s.buffer[2].0, 3); + } + + #[test] + fn test_insert_out_of_order() { + let delay = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, None, delay, delay, Duration::from_secs(1)); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + s.insert(now, make_rtp(1, 3, 1800)); + s.insert(now, make_rtp(1, 2, 900)); // late, but within window + assert_eq!(s.buffer.len(), 3); + assert_eq!(s.buffer[0].0, 1); + assert_eq!(s.buffer[1].0, 2); // reordered into correct position + assert_eq!(s.buffer[2].0, 3); + } + + #[test] + fn test_pop_ready_not_yet() { + let delay = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, None, delay, delay, Duration::from_secs(1)); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + // Just after insertion — release time hasn't passed yet. + assert!(s.pop_ready(now).is_none()); + } + + #[test] + fn test_pop_ready_after_delay() { + let delay = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, None, delay, delay, Duration::from_secs(1)); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + let later = now + Duration::from_millis(100); + let pkt = s.pop_ready(later); + assert!(pkt.is_some()); + assert!(s.buffer.is_empty()); + } + + #[test] + fn test_force_release_at_max_delay() { + let delay = Duration::from_millis(50); + let max = Duration::from_millis(200); + let mut s = JitterBufferStream::new(90000, None, delay, delay, max); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + // Simulate a very late pop — past max_delay. + let very_late = now + max + Duration::from_millis(1); + assert!(s.pop_ready(very_late).is_some()); + } + + #[test] + fn test_late_arrival_rejected() { + let delay = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, None, delay, delay, Duration::from_secs(1)); + let now = Instant::now(); + s.insert(now, make_rtp(1, 5, 0)); + // Release seq 5. + s.pop_ready(now + Duration::from_millis(100)); + // seq 4 (before released seq 5) should be rejected. + let accepted = s.insert(now + Duration::from_millis(200), make_rtp(1, 4, 0)); + assert!(!accepted); + } + + #[test] + fn test_jitter_adapts_target_delay() { + let initial = Duration::from_millis(5); + let min = Duration::from_millis(5); + let mut s = JitterBufferStream::new(90000, None, initial, min, Duration::from_secs(2)); + let base = Instant::now(); + let mut elapsed_ms = 0u64; + // Feed packets with variable but strictly increasing arrival times to grow jitter. + // RTP timestamps advance at 90 kHz rate while packet spacing alternates between + // shorter and longer gaps, producing inter-arrival variation without time going + // backwards. + for i in 0u32..40 { + elapsed_ms += if i % 2 == 0 { 50 } else { 15 }; + let arrival = base + Duration::from_millis(elapsed_ms); + let ts = i * 3000; // 90kHz / 30fps = 3000 units per frame + s.insert(arrival, make_rtp(1, i as u16 + 1, ts)); + } + // After significant jitter, target_delay should be above initial 5ms. + assert!( + s.target_delay > initial, + "target_delay {:?} should have grown above {:?}", + s.target_delay, + initial + ); + } + + #[test] + fn test_initial_delay_clamped_to_bounds() { + // initial_delay > max_delay should be clamped down. + let s = JitterBufferStream::new( + 90000, + None, + Duration::from_secs(5), // initial_delay (too high) + Duration::from_millis(10), // min_delay + Duration::from_millis(200), // max_delay + ); + assert_eq!(s.target_delay, Duration::from_millis(200)); + + // initial_delay < min_delay should be clamped up. + let s2 = JitterBufferStream::new( + 90000, + None, + Duration::from_millis(1), // initial_delay (too low) + Duration::from_millis(50), // min_delay + Duration::from_millis(500), // max_delay + ); + assert_eq!(s2.target_delay, Duration::from_millis(50)); + + // initial_delay within bounds stays unchanged. + let s3 = JitterBufferStream::new( + 90000, + None, + Duration::from_millis(100), + Duration::from_millis(50), + Duration::from_millis(500), + ); + assert_eq!(s3.target_delay, Duration::from_millis(100)); + } + + #[test] + fn test_parse_playout_delay() { + // 12-bit min + 12-bit max, each in units of 10ms. + // min = 0x005 (5 * 10 = 50ms), max = 0x014 (20 * 10 = 200ms) + // byte 0 = min >> 4 = 0x00 + // byte 1 = (min & 0xF) << 4 | max >> 8 = 0x50 + // byte 2 = max & 0xFF = 0x14 + let data = [0x00, 0x50, 0x14]; + let (min, max) = JitterBufferStream::parse_playout_delay(&data).unwrap(); + assert_eq!(min, Duration::from_millis(50)); + assert_eq!(max, Duration::from_millis(200)); + + // Too-short data returns None. + assert!(JitterBufferStream::parse_playout_delay(&[0x00, 0x00]).is_none()); + } + + #[test] + fn test_playout_delay_applied_after_compute_release() { + // Verify that compute_release() runs with stream-wide min/max, and + // the per-packet playout-delay extension only narrows the final release + // for that specific packet. + let ext_id = 3u8; + let min = Duration::from_millis(10); + let max = Duration::from_secs(2); + let initial = Duration::from_millis(50); + let mut s = JitterBufferStream::new(90000, Some(ext_id), initial, min, max); + + let now = Instant::now(); + + // Build an RTP packet with a playout-delay extension. + // Sender requests min=100ms, max=150ms. + // min_raw=10 (10*10ms=100ms), max_raw=15 (15*10ms=150ms) + // byte 0 = 10 >> 4 = 0x00 + // byte 1 = (10 & 0xF) << 4 | 15 >> 8 = 0xA0 + // byte 2 = 15 & 0xFF = 0x0F + let playout_ext_data = bytes::Bytes::from_static(&[0x00, 0xA0, 0x0F]); + let mut rtp_pkt = rtp::Packet { + header: rtp::header::Header { + ssrc: 1, + sequence_number: 1, + timestamp: 0, + extension: true, + extension_profile: rtp::header::EXTENSION_PROFILE_ONE_BYTE, + ..Default::default() + }, + ..Default::default() + }; + rtp_pkt + .header + .set_extension(ext_id, playout_ext_data) + .unwrap(); + + let tagged = TransportMessage { + now, + transport: TransportContext::default(), + message: Packet::Rtp(rtp_pkt), + }; + s.insert(now, tagged); + + // The stream-wide min/max are 10ms and 2s, so compute_release uses + // target_delay = 50ms (the clamped initial). The per-packet extension + // requests min=100ms, max=150ms. effective_min = max(10, 100) = 100ms, + // effective_max = min(2000, max(150, 100)) = 150ms. + // So the release should be clamped to at least 100ms from `now`. + let (_, _, release, _) = &s.buffer[0]; + let delay_from_now = release.duration_since(now); + assert!( + delay_from_now >= Duration::from_millis(100), + "expected release >= 100ms, got {:?}", + delay_from_now + ); + assert!( + delay_from_now <= Duration::from_millis(150), + "expected release <= 150ms, got {:?}", + delay_from_now + ); + + // Verify stream-wide min/max are unchanged (not permanently tightened). + assert_eq!(s.min_delay, min); + assert_eq!(s.max_delay, max); + } + + #[test] + fn test_next_wake_time_is_min_of_release_and_force() { + let delay = Duration::from_millis(50); + let max = Duration::from_millis(200); + let mut s = JitterBufferStream::new(90000, None, delay, delay, max); + let now = Instant::now(); + s.insert(now, make_rtp(1, 1, 0)); + let wake = s.next_wake_time().expect("should have a wake time"); + // Wake time should be <= arrival + max_delay + assert!(wake <= now + max + Duration::from_millis(1)); + } +} diff --git a/rtc-interceptor/src/lib.rs b/rtc-interceptor/src/lib.rs index 44e74561..acbef3ff 100644 --- a/rtc-interceptor/src/lib.rs +++ b/rtc-interceptor/src/lib.rs @@ -180,11 +180,15 @@ use std::time::Instant; mod noop; mod registry; +pub(crate) mod gcc; +pub(crate) mod jitter_buffer; pub(crate) mod nack; pub(crate) mod report; pub(crate) mod stream_info; pub(crate) mod twcc; +pub use gcc::{GccHandle, GccInterceptorBuilder}; +pub use jitter_buffer::{JitterBufferBuilder, JitterBufferInterceptor}; pub use nack::{ generator::{NackGeneratorBuilder, NackGeneratorInterceptor}, responder::{NackResponderBuilder, NackResponderInterceptor},