Skip to content

Commit 7c041e4

Browse files
Save persistent AV1 state (#2303)
* Add Av1PayloadType. * Don't try to rewrite TL0PICIDX on VP9 packets that don't have it. This happens for flexible-mode packets. * Fix ssrc rewrite when codecs change. * Save persistent AV1 state across source projection context changes. * Save persistent AV1 state for SSRC rewriting. * Fix some ssrc rewriting wraparounds.
1 parent 80ae790 commit 7c041e4

File tree

8 files changed

+168
-31
lines changed

8 files changed

+168
-31
lines changed

jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/format/PayloadType.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ enum class PayloadTypeEncoding {
7474
OTHER,
7575
VP8,
7676
VP9,
77+
AV1,
7778
H264,
7879
RED,
7980
RTX,
@@ -129,6 +130,12 @@ class Vp9PayloadType(
129130
rtcpFeedbackSet: RtcpFeedbackSet = emptySet()
130131
) : VideoPayloadType(pt, PayloadTypeEncoding.VP9, parameters = parameters, rtcpFeedbackSet = rtcpFeedbackSet)
131132

133+
class Av1PayloadType(
134+
pt: Byte,
135+
parameters: PayloadTypeParams = ConcurrentHashMap(),
136+
rtcpFeedbackSet: RtcpFeedbackSet = emptySet()
137+
) : VideoPayloadType(pt, PayloadTypeEncoding.AV1, parameters = parameters, rtcpFeedbackSet = rtcpFeedbackSet)
138+
132139
class H264PayloadType(
133140
pt: Byte,
134141
parameters: PayloadTypeParams = ConcurrentHashMap(),

jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjection.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ public class AdaptiveSourceProjection
8888
*/
8989
private int targetIndex = RtpLayerDesc.SUSPENDED_INDEX;
9090

91+
/**
92+
* The map for persistent states.
93+
* Currently, we only expect it for AV1.
94+
*/
95+
private final HashMap<Class<? extends AdaptiveSourceProjectionContext>, Object> persistentStates = new HashMap<>();
96+
9197
/**
9298
* Ctor.
9399
*
@@ -260,7 +266,9 @@ else if (rtpPacket instanceof Av1DDPacket)
260266
(context == null ? "creating new" : "changing to") +
261267
" AV1 DD context for source packet ssrc " + rtpPacket.getSsrc());
262268
context = new Av1DDAdaptiveSourceProjectionContext(
263-
diagnosticContext, rtpState, logger);
269+
diagnosticContext, rtpState,
270+
persistentStates.get(Av1DDAdaptiveSourceProjectionContext.class),
271+
logger);
264272
}
265273

266274
return context;
@@ -296,10 +304,33 @@ else if (rtpPacket instanceof Av1DDPacket)
296304
}
297305
else
298306
{
307+
savePersistentState();
299308
return context.getRtpState();
300309
}
301310
}
302311

312+
private void savePersistentState()
313+
{
314+
if (context == null)
315+
{
316+
return;
317+
}
318+
Object state = context.getPersistentState();
319+
if (state != null)
320+
{
321+
if (context.getClass() != Av1DDAdaptiveSourceProjectionContext.class)
322+
{
323+
logger.warn("Got unexpected context persistent state from class " +
324+
context.getClass().getSimpleName());
325+
}
326+
persistentStates.put(context.getClass(), state);
327+
}
328+
else if (persistentStates.get(context.getClass()) != null)
329+
{
330+
persistentStates.remove(context.getClass());
331+
}
332+
}
333+
303334
/**
304335
* Rewrites an RTP packet for projection.
305336
*

jvb/src/main/java/org/jitsi/videobridge/cc/AdaptiveSourceProjectionContext.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.jitsi.videobridge.cc;
1717

18+
import edu.umd.cs.findbugs.annotations.Nullable;
1819
import org.jitsi.nlj.*;
1920
import org.jitsi.rtp.rtcp.*;
2021
import org.json.simple.*;
@@ -79,6 +80,17 @@ void rewriteRtp(PacketInfo packetInfo)
7980
*/
8081
RtpState getRtpState();
8182

83+
/**
84+
* @return Persistent state, if any, that should be associated with this
85+
* particular projection context type, so that when it comes back it can
86+
* resume data, if necessary.
87+
*/
88+
@Nullable
89+
default Object getPersistentState()
90+
{
91+
return null;
92+
}
93+
8294
/**
8395
* Gets a JSON representation of the parts of this object's state that
8496
* are deemed useful for debugging.

jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,16 @@ class RtpState {
8181
var lastSequenceNumber = 0
8282
var lastTimestamp = 0L
8383
var codecState: CodecState? = null
84+
var av1PersistentState: CodecState? = null // TODO? Generalize if needed in the future
8485
var valid = false
8586

8687
fun update(packet: RtpPacket) {
8788
lastSequenceNumber = packet.sequenceNumber
8889
lastTimestamp = packet.timestamp
8990
codecState = packet.getCodecState()
91+
if (packet is Av1DDPacket) {
92+
av1PersistentState = codecState
93+
}
9094
valid = true
9195
}
9296

@@ -136,7 +140,9 @@ class SendSsrc(val ssrc: Long) {
136140
RtpUtils.getSequenceNumberDelta(state.lastSequenceNumber, recv.state.lastSequenceNumber)
137141
timestampDelta =
138142
RtpUtils.getTimestampDiff(state.lastTimestamp, recv.state.lastTimestamp)
139-
codecDeltas = state.codecState?.getDeltas(recv.state.codecState)
143+
codecDeltas = state.codecState?.getDeltas(
144+
recv.state.codecState
145+
)
140146
} else {
141147
val prevSequenceNumber =
142148
RtpUtils.applySequenceNumberDelta(packet.sequenceNumber, -1)
@@ -152,6 +158,10 @@ class SendSsrc(val ssrc: Long) {
152158
recv.hasDeltas = true
153159
}
154160

161+
if (packet is Av1DDPacket && state.codecState !is Av1DDCodecState && state.av1PersistentState != null) {
162+
codecDeltas = state.av1PersistentState?.getDeltas(packet)
163+
}
164+
155165
recv.state.update(packet)
156166

157167
packet.ssrc = ssrc
@@ -431,7 +441,7 @@ abstract class SsrcCache(val size: Int, val ep: SsrcRewriter, val parentLogger:
431441
notifyMappings(remappings)
432442
}
433443
} catch (e: Exception) {
434-
logger.error("Error rewriting SSRC", e)
444+
logger.error("Error rewriting SSRC, packet $packet", e)
435445
send = false
436446
}
437447

@@ -604,7 +614,10 @@ private class Vp8CodecState(val lastTl0Index: Int) : CodecState {
604614
if (packet !is Vp8Packet) {
605615
return null
606616
}
607-
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(lastTl0Index, (packet.TL0PICIDX - 1))
617+
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(
618+
lastTl0Index,
619+
VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, -1)
620+
)
608621
return Vp8CodecDeltas(tl0IndexDelta)
609622
}
610623

@@ -613,7 +626,9 @@ private class Vp8CodecState(val lastTl0Index: Int) : CodecState {
613626

614627
private class Vp8CodecDeltas(val tl0IndexDelta: Int) : CodecDeltas {
615628
override fun rewritePacket(packet: RtpPacket) {
616-
require(packet is Vp8Packet)
629+
if (packet !is Vp8Packet) {
630+
return
631+
}
617632
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
618633
}
619634

@@ -635,7 +650,14 @@ private class Vp9CodecState(val lastTl0Index: Int) : CodecState {
635650
if (packet !is Vp9Packet) {
636651
return null
637652
}
638-
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(lastTl0Index, (packet.TL0PICIDX - 1))
653+
val tl0IndexDelta = if (packet.hasTL0PICIDX) {
654+
VpxUtils.getTl0PicIdxDelta(
655+
lastTl0Index,
656+
VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, -1)
657+
)
658+
} else {
659+
0
660+
}
639661
return Vp9CodecDeltas(tl0IndexDelta)
640662
}
641663

@@ -644,34 +666,38 @@ private class Vp9CodecState(val lastTl0Index: Int) : CodecState {
644666

645667
private class Vp9CodecDeltas(val tl0IndexDelta: Int) : CodecDeltas {
646668
override fun rewritePacket(packet: RtpPacket) {
647-
require(packet is Vp9Packet)
648-
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
669+
if (packet !is Vp9Packet) {
670+
return
671+
}
672+
if (packet.hasTL0PICIDX) {
673+
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
674+
}
649675
}
650676

651677
override fun toString() = "[VP9 TL0Idx]$tl0IndexDelta"
652678
}
653679

654680
private class Av1DDCodecState : CodecState {
655681
val lastFrameNum: Int
656-
val lastTemplateIdx: Int
682+
val nextTemplateIdx: Int
657683
constructor(lastFrameNum: Int, lastTemplateIdx: Int) {
658684
this.lastFrameNum = lastFrameNum
659-
this.lastTemplateIdx = lastTemplateIdx
685+
this.nextTemplateIdx = lastTemplateIdx
660686
}
661687

662688
constructor(packet: Av1DDPacket) {
663689
val descriptor = packet.descriptor
664690
requireNotNull(descriptor) { "AV1 Packet being routed must have non-null descriptor" }
665691
this.lastFrameNum = packet.frameNumber
666-
this.lastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
692+
this.nextTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
667693
}
668694

669695
override fun getDeltas(otherState: CodecState?): CodecDeltas? {
670696
if (otherState !is Av1DDCodecState) {
671697
return null
672698
}
673699
val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, otherState.lastFrameNum)
674-
val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, otherState.lastTemplateIdx)
700+
val templateIdDelta = getTemplateIdDelta(nextTemplateIdx, otherState.nextTemplateIdx)
675701
return Av1DDCodecDeltas(frameNumDelta, templateIdDelta)
676702
}
677703

@@ -680,16 +706,23 @@ private class Av1DDCodecState : CodecState {
680706
return null
681707
}
682708
val descriptor = packet.descriptor ?: return null
683-
val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, packet.frameNumber - 1)
684-
val packetLastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
685-
val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, packetLastTemplateIdx - 1)
709+
val frameNumDelta = RtpUtils.getSequenceNumberDelta(
710+
lastFrameNum,
711+
RtpUtils.applySequenceNumberDelta(packet.frameNumber, -1)
712+
)
713+
val templateIdDelta = getTemplateIdDelta(nextTemplateIdx, descriptor.structure.templateIdOffset)
686714
return Av1DDCodecDeltas(frameNumDelta, templateIdDelta)
687715
}
716+
717+
override fun toString() = "[Av1DD FrameNum]$lastFrameNum [Av1DD TemplateIdx]$nextTemplateIdx"
688718
}
689719

690720
private class Av1DDCodecDeltas(val frameNumDelta: Int, val templateIdDelta: Int) : CodecDeltas {
691721
override fun rewritePacket(packet: RtpPacket) {
692-
require(packet is Av1DDPacket)
722+
if (packet !is Av1DDPacket) {
723+
return
724+
}
725+
693726
val descriptor = packet.descriptor
694727
requireNotNull(descriptor)
695728

jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDAdaptiveSourceProjectionContext.kt

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import java.time.Instant
4141
class Av1DDAdaptiveSourceProjectionContext(
4242
private val diagnosticContext: DiagnosticContext,
4343
rtpState: RtpState,
44+
persistentState: Any?,
4445
parentLogger: Logger
4546
) : AdaptiveSourceProjectionContext {
4647
private val logger: Logger = createChildLogger(parentLogger)
@@ -56,11 +57,17 @@ class Av1DDAdaptiveSourceProjectionContext(
5657
*/
5758
private val av1QualityFilter = Av1DDQualityFilter(av1FrameMaps, logger)
5859

60+
init {
61+
require(persistentState is Av1PersistentState?)
62+
}
63+
5964
private var lastAv1FrameProjection = Av1DDFrameProjection(
6065
diagnosticContext,
6166
rtpState.ssrc,
6267
rtpState.maxSequenceNumber,
63-
rtpState.maxTimestamp
68+
rtpState.maxTimestamp,
69+
(persistentState as? Av1PersistentState)?.frameNumber,
70+
(persistentState as? Av1PersistentState)?.templateId
6471
)
6572

6673
/**
@@ -334,19 +341,15 @@ class Av1DDAdaptiveSourceProjectionContext(
334341

335342
val frameNumber: Int
336343
val templateIdDelta: Int
337-
if (lastAv1FrameProjection.av1Frame != null) {
344+
val nextTemplateId = lastAv1FrameProjection.getNextTemplateId()
345+
if (nextTemplateId != null) {
338346
frameNumber = RtpUtils.applySequenceNumberDelta(
339347
lastAv1FrameProjection.frameNumber,
340348
1
341349
)
342-
val nextTemplateId = lastAv1FrameProjection.getNextTemplateId()
343-
templateIdDelta = if (nextTemplateId != null) {
344-
val structure = frame.structure
345-
check(structure != null)
346-
getTemplateIdDelta(nextTemplateId, structure.templateIdOffset)
347-
} else {
348-
0
349-
}
350+
val structure = frame.structure
351+
check(structure != null)
352+
templateIdDelta = getTemplateIdDelta(nextTemplateId, structure.templateIdOffset)
350353
} else {
351354
frameNumber = frame.frameNumber
352355
templateIdDelta = 0
@@ -651,6 +654,11 @@ class Av1DDAdaptiveSourceProjectionContext(
651654
lastAv1FrameProjection.timestamp
652655
)
653656

657+
override fun getPersistentState(): Any = Av1PersistentState(
658+
lastAv1FrameProjection.frameNumber,
659+
lastAv1FrameProjection.getNextTemplateId() ?: 0
660+
)
661+
654662
override fun getDebugState(): JSONObject {
655663
val debugState = JSONObject()
656664
debugState["class"] = Av1DDAdaptiveSourceProjectionContext::class.java.simpleName
@@ -676,3 +684,8 @@ class Av1DDAdaptiveSourceProjectionContext(
676684
TimeSeriesLogger.getTimeSeriesLogger(Av1DDAdaptiveSourceProjectionContext::class.java)
677685
}
678686
}
687+
688+
data class Av1PersistentState(
689+
val frameNumber: Int,
690+
val templateId: Int
691+
)

jvb/src/main/kotlin/org/jitsi/videobridge/cc/av1/Av1DDFrameProjection.kt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,17 @@ class Av1DDFrameProjection internal constructor(
9696
diagnosticContext: DiagnosticContext,
9797
ssrc: Long,
9898
sequenceNumberDelta: Int,
99-
timestamp: Long
99+
timestamp: Long,
100+
frameNumber: Int?,
101+
templateId: Int?
100102
) : this(
101103
diagnosticContext = diagnosticContext,
102104
av1Frame = null,
103105
ssrc = ssrc,
104106
timestamp = timestamp,
105107
sequenceNumberDelta = sequenceNumberDelta,
106-
frameNumber = 0,
107-
templateIdDelta = 0,
108+
frameNumber = frameNumber ?: 0,
109+
templateIdDelta = templateId ?: -1,
108110
dti = null,
109111
mark = false,
110112
created = null
@@ -229,6 +231,9 @@ class Av1DDFrameProjection internal constructor(
229231
* Get the next template ID that would come after the template IDs in this projection's structure
230232
*/
231233
fun getNextTemplateId(): Int? {
234+
if (av1Frame == null && templateIdDelta != -1) {
235+
return templateIdDelta
236+
}
232237
return av1Frame?.structure?.let { rewriteTemplateId(it.templateIdOffset + it.templateCount) }
233238
}
234239

jvb/src/main/kotlin/org/jitsi/videobridge/util/PayloadTypeUtil.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
package org.jitsi.videobridge.util
1717

1818
import org.jitsi.nlj.format.AudioRedPayloadType
19+
import org.jitsi.nlj.format.Av1PayloadType
1920
import org.jitsi.nlj.format.H264PayloadType
2021
import org.jitsi.nlj.format.OpusPayloadType
2122
import org.jitsi.nlj.format.OtherAudioPayloadType
2223
import org.jitsi.nlj.format.OtherVideoPayloadType
2324
import org.jitsi.nlj.format.PayloadType
25+
import org.jitsi.nlj.format.PayloadTypeEncoding.AV1
2426
import org.jitsi.nlj.format.PayloadTypeEncoding.Companion.createFrom
2527
import org.jitsi.nlj.format.PayloadTypeEncoding.H264
2628
import org.jitsi.nlj.format.PayloadTypeEncoding.OPUS
@@ -94,6 +96,7 @@ class PayloadTypeUtil {
9496
return when (encoding) {
9597
VP8 -> Vp8PayloadType(id, parameters, rtcpFeedbackSet)
9698
VP9 -> Vp9PayloadType(id, parameters, rtcpFeedbackSet)
99+
AV1 -> Av1PayloadType(id, parameters, rtcpFeedbackSet)
97100
H264 -> H264PayloadType(id, parameters, rtcpFeedbackSet)
98101
RTX -> RtxPayloadType(id, parameters)
99102
OPUS -> OpusPayloadType(id, parameters)

0 commit comments

Comments
 (0)