diff --git a/rtc-rtp/src/codec/h264/h264_test.rs b/rtc-rtp/src/codec/h264/h264_test.rs index bc8aa464..07bf59da 100644 --- a/rtc-rtp/src/codec/h264/h264_test.rs +++ b/rtc-rtp/src/codec/h264/h264_test.rs @@ -261,3 +261,124 @@ fn test_h264_payloader_payload_sps_and_pps_handling() -> Result<()> { Ok(()) } + +/// When the combined STAP-A of SPS + PPS exceeds the MTU, both should still +/// be emitted as individual (possibly FU-A fragmented) packets instead of +/// being silently dropped. +#[test] +fn test_h264_stap_a_exceeds_mtu_emits_individually() -> Result<()> { + let mut pck = H264Payloader::default(); + + // SPS: 3 bytes (NALU type 7) + let sps = Bytes::from_static(&[0x07, 0xAA, 0xBB]); + // PPS: 3 bytes (NALU type 8) + let pps = Bytes::from_static(&[0x08, 0xCC, 0xDD]); + + let res = pck.payload(1500, &sps)?; + assert!(res.is_empty(), "SPS alone should be stashed, not emitted"); + + let res = pck.payload(1500, &pps)?; + assert!(res.is_empty(), "PPS alone should be stashed, not emitted"); + + // Use a tiny MTU so the STAP-A (1 + 2+3 + 2+3 = 11 bytes) exceeds it. + // SPS and PPS individually are 3 bytes each, which fits in MTU=5. + let result = pck.payload(5, &Bytes::from_static(&[0x05, 0x01, 0x02]))?; + + // Expect: SPS (3 bytes, fits), PPS (3 bytes, fits), then the IDR NALU (3 bytes, fits) + assert_eq!(result.len(), 3, "expected SPS + PPS + IDR = 3 packets"); + assert_eq!(result[0], sps, "first packet should be the SPS NALU"); + assert_eq!(result[1], pps, "second packet should be the PPS NALU"); + assert_eq!( + result[2], + Bytes::from_static(&[0x05, 0x01, 0x02]), + "third packet should be the IDR NALU" + ); + + Ok(()) +} + +/// When SPS or PPS are too large for a u16 length field (>65535 bytes), they +/// should be emitted via FU-A fragmentation using emit_single_or_fragment +/// rather than being packed into a STAP-A. +#[test] +fn test_h264_oversized_sps_uses_fua_fragmentation() -> Result<()> { + let mut pck = H264Payloader::default(); + + // Build a large SPS that genuinely exceeds u16::MAX (65535 bytes). + // This ensures we actually hit the `sps_nalu.len() > u16::MAX` branch. + let mut sps_data = vec![0x67]; // NALU type 7, with ref_idc bits set + sps_data.extend(vec![0xAA; 70_000]); + let sps = Bytes::from(sps_data); // 70001 bytes, well above u16::MAX threshold + + let pps = Bytes::from_static(&[0x68, 0xCC, 0xDD]); // NALU type 8, with ref_idc bits + + let res = pck.payload(1500, &sps)?; + assert!(res.is_empty(), "SPS alone should be stashed"); + + let res = pck.payload(1500, &pps)?; + assert!(res.is_empty(), "PPS alone should be stashed"); + + // Trigger emission with a small non-SPS/PPS NALU + let result = pck.payload(1500, &Bytes::from_static(&[0x65, 0x01, 0x02]))?; + + // SPS (70001 bytes) exceeds u16::MAX, so it should be FU-A fragmented. + // PPS (3 bytes) fits in a single packet. + // IDR (3 bytes) fits in a single packet. + assert!( + result.len() >= 3, + "expected at least 3 packets (fragmented SPS + PPS + IDR), got {}", + result.len() + ); + + // Verify the first packet is a FU-A start fragment of the SPS + assert_eq!( + result[0][0] & NALU_TYPE_BITMASK, + FUA_NALU_TYPE, + "first packet should be a FU-A fragment" + ); + assert_ne!( + result[0][1] & FU_START_BITMASK, + 0, + "first FU-A fragment should have start bit set" + ); + + Ok(()) +} + +/// The emit_single_or_fragment helper should pass through small NALUs directly +/// and fragment large ones via FU-A. +#[test] +fn test_h264_emit_single_or_fragment_small_nalu() { + let nalu = Bytes::from_static(&[0x65, 0x01, 0x02, 0x03]); + let mut payloads = vec![]; + H264Payloader::emit_single_or_fragment(&nalu, 10, &mut payloads); + assert_eq!(payloads.len(), 1, "small NALU should emit as single packet"); + assert_eq!(payloads[0], nalu); +} + +#[test] +fn test_h264_emit_single_or_fragment_large_nalu() { + let mut data = vec![0x65]; // IDR NALU type + data.extend(vec![0xBB; 20]); + let nalu = Bytes::from(data); + let mut payloads = vec![]; + H264Payloader::emit_single_or_fragment(&nalu, 10, &mut payloads); + assert!( + payloads.len() > 1, + "large NALU should be FU-A fragmented into multiple packets" + ); + // First fragment should have FU-A type and start bit + assert_eq!(payloads[0][0] & NALU_TYPE_BITMASK, FUA_NALU_TYPE); + assert_ne!(payloads[0][1] & FU_START_BITMASK, 0); + // Last fragment should have end bit + let last = payloads.last().unwrap(); + assert_ne!(last[1] & FU_END_BITMASK, 0); +} + +#[test] +fn test_h264_emit_single_or_fragment_empty_nalu() { + let nalu = Bytes::new(); + let mut payloads = vec![]; + H264Payloader::emit_single_or_fragment(&nalu, 10, &mut payloads); + assert!(payloads.is_empty(), "empty NALU should produce no packets"); +} diff --git a/rtc-rtp/src/codec/h264/mod.rs b/rtc-rtp/src/codec/h264/mod.rs index a18249ac..188243a7 100644 --- a/rtc-rtp/src/codec/h264/mod.rs +++ b/rtc-rtp/src/codec/h264/mod.rs @@ -50,6 +50,52 @@ impl H264Payloader { (-1, -1) } + /// Emit a NALU as a single packet or fragment it via FU-A, without + /// special-casing SPS/PPS (which `emit()` would re-stash). + fn emit_single_or_fragment(nalu: &Bytes, mtu: usize, payloads: &mut Vec) { + if nalu.is_empty() { + return; + } + + let nalu_ref_idc = nalu[0] & NALU_REF_IDC_BITMASK; + let nalu_type = nalu[0] & NALU_TYPE_BITMASK; + + // Single NALU + if nalu.len() <= mtu { + payloads.push(nalu.clone()); + return; + } + + // FU-A fragmentation + let max_fragment_size = mtu as isize - FUA_HEADER_SIZE as isize; + let mut nalu_data_index: isize = 1; + let nalu_data_length = nalu.len() as isize - nalu_data_index; + let mut nalu_data_remaining = nalu_data_length; + + if std::cmp::min(max_fragment_size, nalu_data_remaining) <= 0 { + return; + } + + while nalu_data_remaining > 0 { + let current_fragment_size = std::cmp::min(max_fragment_size, nalu_data_remaining); + let mut out = BytesMut::with_capacity(FUA_HEADER_SIZE + current_fragment_size as usize); + out.put_u8(FUA_NALU_TYPE | nalu_ref_idc); + let mut b1 = nalu_type; + if nalu_data_remaining == nalu_data_length { + b1 |= 1 << 7; // start bit + } else if nalu_data_remaining - current_fragment_size == 0 { + b1 |= 1 << 6; // end bit + } + out.put_u8(b1); + out.put( + &nalu[nalu_data_index as usize..(nalu_data_index + current_fragment_size) as usize], + ); + payloads.push(out.freeze()); + nalu_data_remaining -= current_fragment_size; + nalu_data_index += current_fragment_size; + } + } + fn emit(&mut self, nalu: &Bytes, mtu: usize, payloads: &mut Vec) { if nalu.is_empty() { return; @@ -66,21 +112,42 @@ impl H264Payloader { } else if nalu_type == PPS_NALU_TYPE { self.pps_nalu = Some(nalu.clone()); return; - } else if let (Some(sps_nalu), Some(pps_nalu)) = (&self.sps_nalu, &self.pps_nalu) { - // Pack current NALU with SPS and PPS as STAP-A - let sps_len = (sps_nalu.len() as u16).to_be_bytes(); - let pps_len = (pps_nalu.len() as u16).to_be_bytes(); - - let mut stap_a_nalu = Vec::with_capacity(1 + 2 + sps_nalu.len() + 2 + pps_nalu.len()); - stap_a_nalu.push(OUTPUT_STAP_AHEADER); - stap_a_nalu.extend(sps_len); - stap_a_nalu.extend_from_slice(sps_nalu); - stap_a_nalu.extend(pps_len); - stap_a_nalu.extend_from_slice(pps_nalu); - if stap_a_nalu.len() <= mtu { - payloads.push(Bytes::from(stap_a_nalu)); + } else if self.sps_nalu.is_some() && self.pps_nalu.is_some() { + // Clone to release the borrow on self so we can call self.emit() below if needed. + let sps_nalu = self.sps_nalu.clone().unwrap(); + let pps_nalu = self.pps_nalu.clone().unwrap(); + + // Pack the cached SPS and PPS into a STAP-A. + // STAP-A length fields are u16; only pack if both NALUs fit within 65535 bytes. + if sps_nalu.len() <= u16::MAX as usize && pps_nalu.len() <= u16::MAX as usize { + let sps_len = (sps_nalu.len() as u16).to_be_bytes(); + let pps_len = (pps_nalu.len() as u16).to_be_bytes(); + + let mut stap_a_nalu = + Vec::with_capacity(1 + 2 + sps_nalu.len() + 2 + pps_nalu.len()); + stap_a_nalu.push(OUTPUT_STAP_AHEADER); + stap_a_nalu.extend(sps_len); + stap_a_nalu.extend_from_slice(&sps_nalu); + stap_a_nalu.extend(pps_len); + stap_a_nalu.extend_from_slice(&pps_nalu); + if stap_a_nalu.len() <= mtu { + // STAP-A fits within MTU, emit as aggregate + payloads.push(Bytes::from(stap_a_nalu)); + } else { + // STAP-A does not fit within the MTU; fall back to emitting + // SPS and PPS separately so they are not silently dropped. + Self::emit_single_or_fragment(&sps_nalu, mtu, payloads); + Self::emit_single_or_fragment(&pps_nalu, mtu, payloads); + } + } else { + // SPS or PPS exceeds u16::MAX; fall back to emitting them as + // separate NALUs (which will be fragmented via FU-A if needed). + // We cannot use self.emit() here because it would re-stash SPS/PPS + // NALUs instead of actually outputting them. Push single or fragment directly. + Self::emit_single_or_fragment(&sps_nalu, mtu, payloads); + Self::emit_single_or_fragment(&pps_nalu, mtu, payloads); } - } + } // end SPS+PPS STAP-A / fallback emission block if self.sps_nalu.is_some() && self.pps_nalu.is_some() { self.sps_nalu = None; diff --git a/rtc-rtp/src/codec/h265/h265_test.rs b/rtc-rtp/src/codec/h265/h265_test.rs index 2ed1b100..1d6fb520 100644 --- a/rtc-rtp/src/codec/h265/h265_test.rs +++ b/rtc-rtp/src/codec/h265/h265_test.rs @@ -982,3 +982,130 @@ fn test_h265_packet_real() -> Result<()> { Ok(()) } + +/// When an aggregation buffer contains a NALU that exceeds u16::MAX, the +/// oversized NALU must be emitted individually (via FU fragmentation) while +/// normal-sized NALUs are still packed into the aggregation packet. +#[test] +fn test_h265_oversized_nalu_in_aggregation_buffer() -> Result<()> { + let mut pck = HevcPayloader::default(); + + // Build a payload with two NALUs: one normal-sized and one genuinely + // oversized (> u16::MAX = 65535 bytes) to hit the oversized NALU path. + let mut payload = vec![]; + // NALU 1 with 3-byte start code: 5 bytes, type=PFR + payload.extend_from_slice(&[0x00, 0x00, 0x01]); + payload.extend_from_slice(&[0x02, 0x01, 0xAA, 0xBB, 0xCC]); // 5 bytes + // NALU 2 with 3-byte start code: >65535 bytes, type=PFR + payload.extend_from_slice(&[0x00, 0x00, 0x01]); + let mut nalu2 = vec![0x02, 0x01]; // header: type=PFR + nalu2.extend(vec![0xDD; 70_000]); // 70002 bytes total, exceeds u16::MAX + payload.extend_from_slice(&nalu2); + + let result = pck.payload(1500, &Bytes::from(payload))?; // packetize mixed-size NALUs + + // NALU 1 (5 bytes) fits in MTU, emitted as a single RTP packet. + // NALU 2 (70002 bytes) exceeds u16::MAX and cannot fit in an AP size field, + // so it must be FU-fragmented into multiple RTP packets. + assert!( + result.len() >= 2, + "expected at least 2 packets (single + fragments), got {}", + result.len() + ); + + // First packet should be the small NALU (single packet, emitted directly) + assert_eq!(result[0].len(), 5, "first packet should be the 5-byte NALU"); + + // Remaining packets should be FU fragments of the oversized NALU + for i in 1..result.len() { + let header = H265NALUHeader::new(result[i][0], result[i][1]); + assert!( + header.is_fragmentation_unit(), + "packet {} should be a fragmentation unit", + i + ); + } + + Ok(()) +} + +/// The aggregation_payload_header should only be computed from normal-sized +/// NALUs, not oversized ones that get split out. This ensures the header +/// fields (F, layer_id, tid) reflect only the NALUs actually in the packet. +#[test] +fn test_h265_aggregation_header_excludes_oversized_nalus() -> Result<()> { + let mut pck = HevcPayloader::default(); + + // Three NALUs: two small ones and one oversized (> u16::MAX). + // The oversized NALU should be excluded from the AP and emitted separately. + let mut payload = vec![]; + // NALU 1: type=PFR, tid=1, 3 bytes (normal) + payload.extend_from_slice(&[0x00, 0x00, 0x01]); + payload.extend_from_slice(&[0x02, 0x01, 0xAA]); + // NALU 2: type=PFR, tid=1, 3 bytes (normal) + payload.extend_from_slice(&[0x00, 0x00, 0x01]); + payload.extend_from_slice(&[0x02, 0x01, 0xBB]); + // NALU 3: type=PFR, tid=1, >65535 bytes (oversized) + payload.extend_from_slice(&[0x00, 0x00, 0x01]); + let mut nalu3 = vec![0x02, 0x01]; // header + nalu3.extend(vec![0xCC; 70_000]); // 70002 bytes total + payload.extend_from_slice(&nalu3); + + let result = pck.payload(1500, &Bytes::from(payload))?; + + // The two small NALUs should be aggregated into a single AP packet, + // and the oversized NALU should be FU-fragmented separately. + assert!( + result.len() >= 2, + "expected at least 2 packets (AP + FU fragments), got {}", + result.len() + ); + let header = H265NALUHeader::new(result[0][0], result[0][1]); + assert!( + header.is_aggregation_packet(), + "first packet should be an AP containing only the two normal-sized NALUs" + ); + + // Remaining packets should be FU fragments of the oversized NALU + for i in 1..result.len() { + let fu_header = H265NALUHeader::new(result[i][0], result[i][1]); + assert!( + fu_header.is_fragmentation_unit(), + "packet {} should be a fragmentation unit", + i + ); + } + + Ok(()) +} + +/// A single oversized NALU (> MTU) passed through flush_aggregation_buffer +/// should be FU-fragmented into multiple packets via emit(). +#[test] +fn test_h265_flush_single_oversized_nalu_fu_fragmentation() -> Result<()> { + let mut pck = HevcPayloader::default(); + + // Single NALU larger than MTU to trigger FU fragmentation. This tests + // that flush_aggregation_buffer with a single oversized NALU correctly + // emits it (which then gets fragmented via emit()). + let mut nalu_data = vec![0x02, 0x01]; // header: type=PFR + nalu_data.extend(vec![0xAA; 70_000]); // 70002 bytes total, exceeds u16::MAX + let mut payload = vec![0x00, 0x00, 0x01]; // 3-byte start code + payload.extend_from_slice(&nalu_data); + + let result = pck.payload(1500, &Bytes::from(payload))?; + + // Single oversized NALU should be FU-fragmented into multiple packets + assert!( + result.len() > 1, + "oversized NALU should be FU-fragmented, got {} packets", + result.len() + ); + let header = H265NALUHeader::new(result[0][0], result[0][1]); + assert!( + header.is_fragmentation_unit(), + "first packet should be a fragmentation unit" + ); + + Ok(()) +} diff --git a/rtc-rtp/src/codec/h265/mod.rs b/rtc-rtp/src/codec/h265/mod.rs index 1fd4317a..1c92ccd4 100644 --- a/rtc-rtp/src/codec/h265/mod.rs +++ b/rtc-rtp/src/codec/h265/mod.rs @@ -115,10 +115,48 @@ impl HevcPayloader { } fn flush_aggregation_buffer(nalus: &mut Vec, mtu: usize, payloads: &mut Vec) { + if nalus.is_empty() { + return; + } + if nalus.len() == 1 { + payloads.push(nalus.pop().expect("single buffered NAL exists")); + return; + } + + // Process NALUs in original order: accumulate consecutive normal-sized + // NALUs into an AP, but when an oversized NALU (> u16::MAX) is + // encountered, flush the current AP first, then emit the oversized + // NALU individually (it will be FU-fragmented). + // + // Defense-in-depth: In practice, payload() sends any nalu.len() > mtu + // directly to emit() before it reaches this buffer, so the u16::MAX + // branch below is currently unreachable (MTU << u16::MAX). The check + // is retained to guard against future callers or MTU changes. + let mut pending_normal: Vec = Vec::new(); + + for nalu in nalus.drain(..) { + if nalu.len() > u16::MAX as usize { + // Flush any accumulated normal NALUs as an AP (or single) + Self::flush_normal_nalus(&mut pending_normal, mtu, payloads); + // Emit the oversized NALU individually via FU fragmentation; + // it cannot be placed in an AP because the 16-bit size field would overflow. + Self::emit(&nalu, mtu, payloads); + } else { + pending_normal.push(nalu); + } + } + + // Flush any remaining normal NALUs + Self::flush_normal_nalus(&mut pending_normal, mtu, payloads); + } + + /// Flush accumulated normal-sized NALUs: emit a single NALU directly, + /// or pack multiple into an aggregation packet. + fn flush_normal_nalus(nalus: &mut Vec, mtu: usize, payloads: &mut Vec) { match nalus.len() { 0 => {} 1 => { - payloads.push(nalus.pop().expect("single buffered NAL exists")); + payloads.push(nalus.pop().expect("single normal NAL exists")); } _ => { let header = Self::aggregation_payload_header(nalus); @@ -126,10 +164,11 @@ impl HevcPayloader { NAL_HEADER_SIZE + nalus.iter().map(|nalu| 2 + nalu.len()).sum::(), ); aggr_nalu.extend_from_slice(&header); - for nalu in nalus.drain(..) { + for nalu in nalus.iter() { aggr_nalu.extend_from_slice(&(nalu.len() as u16).to_be_bytes()); - aggr_nalu.extend_from_slice(&nalu); + aggr_nalu.extend_from_slice(nalu); } + nalus.clear(); if aggr_nalu.len() <= mtu { payloads.push(aggr_nalu.freeze()); } diff --git a/rtc-rtp/src/header/header_test.rs b/rtc-rtp/src/header/header_test.rs new file mode 100644 index 00000000..27a44e6f --- /dev/null +++ b/rtc-rtp/src/header/header_test.rs @@ -0,0 +1,291 @@ +use super::*; +use shared::error::Error; +use shared::marshal::Marshal; + +/// Helper: create a minimal valid header and marshal it, returning the result. +fn marshal_header(header: &Header) -> shared::error::Result { + let mut buf = vec![0u8; header.marshal_size()]; + header.marshal_to(&mut &mut buf[..]) +} + +// -- CSRC validation -- + +#[test] +fn test_too_many_csrcs() { + let header = Header { + csrc: vec![0u32; 16], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::TooManyCSRCs(16)), + "expected TooManyCSRCs(16), got {err:?}" + ); +} + +#[test] +fn test_max_csrcs_valid() { + let header = Header { + csrc: vec![0u32; 15], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +// -- One-byte extension payload size validation -- + +#[test] +fn test_one_byte_extension_payload_zero_length() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::new(), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::OneByteHeaderExtensionPayloadOutOfRange(0)), + "expected OneByteHeaderExtensionPayloadOutOfRange(0), got {err:?}" + ); +} + +#[test] +fn test_one_byte_extension_payload_too_large() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0u8; 17]), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::OneByteHeaderExtensionPayloadOutOfRange(17)), + "expected OneByteHeaderExtensionPayloadOutOfRange(17), got {err:?}" + ); +} + +#[test] +fn test_valid_one_byte_extension_boundaries() { + // 1 byte payload -- minimum valid + let header_min = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0u8; 1]), + }], + ..Default::default() + }; + assert!(marshal_header(&header_min).is_ok()); + + // 16 byte payload -- maximum valid + let header_max = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0u8; 16]), + }], + ..Default::default() + }; + assert!(marshal_header(&header_max).is_ok()); +} + +// -- Two-byte extension payload size validation -- + +#[test] +fn test_two_byte_extension_payload_too_large() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0u8; 256]), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::TwoByteHeaderExtensionPayloadTooLarge(256)), + "expected TwoByteHeaderExtensionPayloadTooLarge(256), got {err:?}" + ); +} + +#[test] +fn test_valid_two_byte_extension_boundary() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0u8; 255]), + }], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +// -- One-byte extension ID validation (RFC 8285 section 4.2) -- + +#[test] +fn test_one_byte_extension_id_zero_rejected() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 0, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::ErrRfc8285oneByteHeaderIdrange), + "expected ErrRfc8285oneByteHeaderIdrange, got {err:?}" + ); +} + +#[test] +fn test_one_byte_extension_id_fifteen_rejected() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 15, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::ErrRfc8285oneByteHeaderIdrange), + "expected ErrRfc8285oneByteHeaderIdrange, got {err:?}" + ); +} + +#[test] +fn test_one_byte_extension_id_fourteen_valid() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + extensions: vec![Extension { + id: 14, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +// -- Two-byte extension ID validation (RFC 8285 section 4.3) -- + +#[test] +fn test_two_byte_extension_id_zero_rejected() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 0, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::ErrRfc8285twoByteHeaderIdrange), + "expected ErrRfc8285twoByteHeaderIdrange, got {err:?}" + ); +} + +#[test] +fn test_two_byte_extension_id_one_valid() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +#[test] +fn test_two_byte_extension_id_255_valid() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 255, + payload: Bytes::from(vec![0xAB]), + }], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +// -- Two-byte zero-length payload (valid per RFC 8285) -- + +#[test] +fn test_two_byte_extension_zero_length_payload_valid() { + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions: vec![Extension { + id: 1, + payload: Bytes::new(), + }], + ..Default::default() + }; + assert!(marshal_header(&header).is_ok()); +} + +// -- set_extension API consistency (Issue #1) -- + +#[test] +fn test_set_extension_rejects_zero_byte_payload_one_byte_profile() { + let mut header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_ONE_BYTE, + ..Default::default() + }; + let err = header.set_extension(1, Bytes::new()).unwrap_err(); + assert!( + matches!(err, Error::ErrRfc8285oneByteHeaderSize), + "expected ErrRfc8285oneByteHeaderSize, got {err:?}" + ); +} + +// -- Total extension payload overflow (Issue #5) -- + +#[test] +fn test_extension_payload_total_overflow() { + // Build a header with enough extensions to exceed u16::MAX total payload bytes. + // Two-byte profile allows up to 255 bytes per extension; we need ~258 extensions + // of 255 bytes each to exceed 65535. + let count = 258; + let extensions: Vec = (1..=count) + .map(|i| Extension { + id: (i % 255 + 1) as u8, + payload: Bytes::from(vec![0xAA; 255]), + }) + .collect(); + let header = Header { + extension: true, + extension_profile: EXTENSION_PROFILE_TWO_BYTE, + extensions, + ..Default::default() + }; + let err = marshal_header(&header).unwrap_err(); + assert!( + matches!(err, Error::ExtensionPayloadTotalOverflow(_)), + "expected ExtensionPayloadTotalOverflow, got {err:?}" + ); +} diff --git a/rtc-rtp/src/header.rs b/rtc-rtp/src/header/mod.rs similarity index 87% rename from rtc-rtp/src/header.rs rename to rtc-rtp/src/header/mod.rs index 3bf56959..ca2a7f08 100644 --- a/rtc-rtp/src/header.rs +++ b/rtc-rtp/src/header/mod.rs @@ -251,7 +251,11 @@ impl Marshal for Header { return Err(Error::ErrBufferTooSmall); } - // The first byte contains the version, padding bit, extension bit, and csrc size + // The first byte contains the version, padding bit, extension bit, and csrc size. + // RFC 3550 §5.1: CC is a 4-bit field, so at most 15 contributing sources are allowed. + if self.csrc.len() > 15 { + return Err(Error::TooManyCSRCs(self.csrc.len())); // reject early to avoid truncating CC + } let mut b0 = (self.version << VERSION_SHIFT) | self.csrc.len() as u8; if self.padding { b0 |= 1 << PADDING_SHIFT; @@ -282,6 +286,9 @@ impl Marshal for Header { // calculate extensions size and round to 4 bytes boundaries let extension_payload_len = self.get_extension_payload_len(); + if extension_payload_len > u16::MAX as usize { + return Err(Error::ExtensionPayloadTotalOverflow(extension_payload_len)); + } if self.extension_profile != EXTENSION_PROFILE_ONE_BYTE && self.extension_profile != EXTENSION_PROFILE_TWO_BYTE && !extension_payload_len.is_multiple_of(4) @@ -296,15 +303,42 @@ impl Marshal for Header { // RFC 8285 RTP One Byte Header Extension EXTENSION_PROFILE_ONE_BYTE => { for extension in &self.extensions { - buf.put_u8((extension.id << 4) | (extension.payload.len() as u8 - 1)); - buf.put(&*extension.payload); + // validate each extension before writing + // RFC 8285 §4.2: IDs 0 and 15 are reserved. + if !(1..=14).contains(&extension.id) { + return Err(Error::ErrRfc8285oneByteHeaderIdrange); + } + // RFC 8285 §4.2: payload must be 1-16 bytes; the 4-bit length field + // encodes (len-1), so 0 is invalid and >16 cannot be represented. + let ext_payload_len = extension.payload.len(); + if ext_payload_len == 0 || ext_payload_len > 16 { + return Err(Error::OneByteHeaderExtensionPayloadOutOfRange( + ext_payload_len, + )); + } + // Upper nibble = extension ID, lower nibble = (length - 1) + buf.put_u8((extension.id << 4) | (ext_payload_len as u8 - 1)); + buf.put(&*extension.payload); // write validated extension payload } } // RFC 8285 RTP Two Byte Header Extension EXTENSION_PROFILE_TWO_BYTE => { for extension in &self.extensions { - buf.put_u8(extension.id); - buf.put_u8(extension.payload.len() as u8); + // RFC 8285 §4.3: ID 0 is padding in two-byte format; + // marshaling it as an extension would confuse receivers. + if extension.id == 0 { + return Err(Error::ErrRfc8285twoByteHeaderIdrange); + } + // RFC 8285 §4.3: the length field is a single byte, capping + // the maximum extension payload at 255 bytes. + let ext_payload_len = extension.payload.len(); + if ext_payload_len > 255 { + return Err(Error::TwoByteHeaderExtensionPayloadTooLarge( + ext_payload_len, + )); + } + buf.put_u8(extension.id); // two-byte format: ID occupies a full byte + buf.put_u8(ext_payload_len as u8); buf.put(&*extension.payload); } } @@ -362,7 +396,8 @@ impl Header { if !(1..=14).contains(&id) { return Err(Error::ErrRfc8285oneByteHeaderIdrange); } - if payload_len > 16 { + // RFC 8285 §4.2: payload must be 1–16 bytes. + if payload_len == 0 || payload_len > 16 { return Err(Error::ErrRfc8285oneByteHeaderSize); } 1 @@ -489,3 +524,6 @@ impl Header { } } } + +#[cfg(test)] +mod header_test; diff --git a/rtc-shared/src/error.rs b/rtc-shared/src/error.rs index d319f59b..afbcd9ff 100644 --- a/rtc-shared/src/error.rs +++ b/rtc-shared/src/error.rs @@ -276,6 +276,20 @@ pub enum Error { #[error("extension_payload must be in 32-bit words")] HeaderExtensionPayloadNot32BitWords, + #[error("too many CSRCs: {0} exceeds the 4-bit CC field maximum of 15")] + TooManyCSRCs(usize), + #[error( + "one-byte header extension payload length {0} is outside the RFC 8285 valid range of 1-16 bytes" + )] + OneByteHeaderExtensionPayloadOutOfRange(usize), + #[error( + "two-byte header extension payload length {0} exceeds the RFC 8285 maximum of 255 bytes" + )] + TwoByteHeaderExtensionPayloadTooLarge(usize), + #[error("total extension payload length {0} exceeds u16::MAX (65535 bytes)")] + ExtensionPayloadTotalOverflow(usize), + #[error("NALU length {0} exceeds u16::MAX (65535 bytes)")] + NaluTooLarge(usize), #[error("audio level overflow")] AudioLevelOverflow, #[error("playout delay overflow")]