Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ enum class PayloadTypeEncoding {
OTHER,
VP8,
VP9,
AV1,
H264,
RED,
RTX,
Expand Down Expand Up @@ -129,6 +130,12 @@ class Vp9PayloadType(
rtcpFeedbackSet: RtcpFeedbackSet = emptySet()
) : VideoPayloadType(pt, PayloadTypeEncoding.VP9, parameters = parameters, rtcpFeedbackSet = rtcpFeedbackSet)

class Av1PayloadType(
pt: Byte,
parameters: PayloadTypeParams = ConcurrentHashMap(),
rtcpFeedbackSet: RtcpFeedbackSet = emptySet()
) : VideoPayloadType(pt, PayloadTypeEncoding.AV1, parameters = parameters, rtcpFeedbackSet = rtcpFeedbackSet)

class H264PayloadType(
pt: Byte,
parameters: PayloadTypeParams = ConcurrentHashMap(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public class AdaptiveSourceProjection
*/
private int targetIndex = RtpLayerDesc.SUSPENDED_INDEX;

/**
* The map for persistent states.
* Currently, we only expect it for AV1.
*/
private final HashMap<Class<? extends AdaptiveSourceProjectionContext>, Object> persistentStates = new HashMap<>();

/**
* Ctor.
*
Expand Down Expand Up @@ -260,7 +266,9 @@ else if (rtpPacket instanceof Av1DDPacket)
(context == null ? "creating new" : "changing to") +
" AV1 DD context for source packet ssrc " + rtpPacket.getSsrc());
context = new Av1DDAdaptiveSourceProjectionContext(
diagnosticContext, rtpState, logger);
diagnosticContext, rtpState,
persistentStates.get(Av1DDAdaptiveSourceProjectionContext.class),
logger);
}

return context;
Expand Down Expand Up @@ -296,10 +304,33 @@ else if (rtpPacket instanceof Av1DDPacket)
}
else
{
savePersistentState();
return context.getRtpState();
}
}

private void savePersistentState()
{
if (context == null)
{
return;
}
Object state = context.getPersistentState();
if (state != null)
{
if (context.getClass() != Av1DDAdaptiveSourceProjectionContext.class)
{
logger.warn("Got unexpected context persistent state from class " +
context.getClass().getSimpleName());
}
persistentStates.put(context.getClass(), state);
}
else if (persistentStates.get(context.getClass()) != null)
{
persistentStates.remove(context.getClass());
}
}

/**
* Rewrites an RTP packet for projection.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.jitsi.videobridge.cc;

import edu.umd.cs.findbugs.annotations.Nullable;
import org.jitsi.nlj.*;
import org.jitsi.rtp.rtcp.*;
import org.json.simple.*;
Expand Down Expand Up @@ -79,6 +80,17 @@ void rewriteRtp(PacketInfo packetInfo)
*/
RtpState getRtpState();

/**
* @return Persistent state, if any, that should be associated with this
* particular projection context type, so that when it comes back it can
* resume data, if necessary.
*/
@Nullable
default Object getPersistentState()
{
return null;
}

/**
* Gets a JSON representation of the parts of this object's state that
* are deemed useful for debugging.
Expand Down
63 changes: 48 additions & 15 deletions jvb/src/main/kotlin/org/jitsi/videobridge/SsrcCache.kt
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ class RtpState {
var lastSequenceNumber = 0
var lastTimestamp = 0L
var codecState: CodecState? = null
var av1PersistentState: CodecState? = null // TODO? Generalize if needed in the future
var valid = false

fun update(packet: RtpPacket) {
lastSequenceNumber = packet.sequenceNumber
lastTimestamp = packet.timestamp
codecState = packet.getCodecState()
if (packet is Av1DDPacket) {
av1PersistentState = codecState
}
valid = true
}

Expand Down Expand Up @@ -136,7 +140,9 @@ class SendSsrc(val ssrc: Long) {
RtpUtils.getSequenceNumberDelta(state.lastSequenceNumber, recv.state.lastSequenceNumber)
timestampDelta =
RtpUtils.getTimestampDiff(state.lastTimestamp, recv.state.lastTimestamp)
codecDeltas = state.codecState?.getDeltas(recv.state.codecState)
codecDeltas = state.codecState?.getDeltas(
recv.state.codecState
)
} else {
val prevSequenceNumber =
RtpUtils.applySequenceNumberDelta(packet.sequenceNumber, -1)
Expand All @@ -152,6 +158,10 @@ class SendSsrc(val ssrc: Long) {
recv.hasDeltas = true
}

if (packet is Av1DDPacket && state.codecState !is Av1DDCodecState && state.av1PersistentState != null) {
codecDeltas = state.av1PersistentState?.getDeltas(packet)
}

recv.state.update(packet)

packet.ssrc = ssrc
Expand Down Expand Up @@ -431,7 +441,7 @@ abstract class SsrcCache(val size: Int, val ep: SsrcRewriter, val parentLogger:
notifyMappings(remappings)
}
} catch (e: Exception) {
logger.error("Error rewriting SSRC", e)
logger.error("Error rewriting SSRC, packet $packet", e)
send = false
}

Expand Down Expand Up @@ -604,7 +614,10 @@ private class Vp8CodecState(val lastTl0Index: Int) : CodecState {
if (packet !is Vp8Packet) {
return null
}
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(lastTl0Index, (packet.TL0PICIDX - 1))
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(
lastTl0Index,
VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, -1)
)
return Vp8CodecDeltas(tl0IndexDelta)
}

Expand All @@ -613,7 +626,9 @@ private class Vp8CodecState(val lastTl0Index: Int) : CodecState {

private class Vp8CodecDeltas(val tl0IndexDelta: Int) : CodecDeltas {
override fun rewritePacket(packet: RtpPacket) {
require(packet is Vp8Packet)
if (packet !is Vp8Packet) {
return
}
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
}

Expand All @@ -635,7 +650,14 @@ private class Vp9CodecState(val lastTl0Index: Int) : CodecState {
if (packet !is Vp9Packet) {
return null
}
val tl0IndexDelta = VpxUtils.getTl0PicIdxDelta(lastTl0Index, (packet.TL0PICIDX - 1))
val tl0IndexDelta = if (packet.hasTL0PICIDX) {
VpxUtils.getTl0PicIdxDelta(
lastTl0Index,
VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, -1)
)
} else {
0
}
return Vp9CodecDeltas(tl0IndexDelta)
}

Expand All @@ -644,34 +666,38 @@ private class Vp9CodecState(val lastTl0Index: Int) : CodecState {

private class Vp9CodecDeltas(val tl0IndexDelta: Int) : CodecDeltas {
override fun rewritePacket(packet: RtpPacket) {
require(packet is Vp9Packet)
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
if (packet !is Vp9Packet) {
return
}
if (packet.hasTL0PICIDX) {
packet.TL0PICIDX = VpxUtils.applyTl0PicIdxDelta(packet.TL0PICIDX, tl0IndexDelta)
}
}

override fun toString() = "[VP9 TL0Idx]$tl0IndexDelta"
}

private class Av1DDCodecState : CodecState {
val lastFrameNum: Int
val lastTemplateIdx: Int
val nextTemplateIdx: Int
constructor(lastFrameNum: Int, lastTemplateIdx: Int) {
this.lastFrameNum = lastFrameNum
this.lastTemplateIdx = lastTemplateIdx
this.nextTemplateIdx = lastTemplateIdx
}

constructor(packet: Av1DDPacket) {
val descriptor = packet.descriptor
requireNotNull(descriptor) { "AV1 Packet being routed must have non-null descriptor" }
this.lastFrameNum = packet.frameNumber
this.lastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
this.nextTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
}

override fun getDeltas(otherState: CodecState?): CodecDeltas? {
if (otherState !is Av1DDCodecState) {
return null
}
val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, otherState.lastFrameNum)
val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, otherState.lastTemplateIdx)
val templateIdDelta = getTemplateIdDelta(nextTemplateIdx, otherState.nextTemplateIdx)
return Av1DDCodecDeltas(frameNumDelta, templateIdDelta)
}

Expand All @@ -680,16 +706,23 @@ private class Av1DDCodecState : CodecState {
return null
}
val descriptor = packet.descriptor ?: return null
val frameNumDelta = RtpUtils.getSequenceNumberDelta(lastFrameNum, packet.frameNumber - 1)
val packetLastTemplateIdx = descriptor.structure.templateIdOffset + descriptor.structure.templateCount
val templateIdDelta = getTemplateIdDelta(lastTemplateIdx, packetLastTemplateIdx - 1)
val frameNumDelta = RtpUtils.getSequenceNumberDelta(
lastFrameNum,
RtpUtils.applySequenceNumberDelta(packet.frameNumber, -1)
)
val templateIdDelta = getTemplateIdDelta(nextTemplateIdx, descriptor.structure.templateIdOffset)
return Av1DDCodecDeltas(frameNumDelta, templateIdDelta)
}

override fun toString() = "[Av1DD FrameNum]$lastFrameNum [Av1DD TemplateIdx]$nextTemplateIdx"
}

private class Av1DDCodecDeltas(val frameNumDelta: Int, val templateIdDelta: Int) : CodecDeltas {
override fun rewritePacket(packet: RtpPacket) {
require(packet is Av1DDPacket)
if (packet !is Av1DDPacket) {
return
}

val descriptor = packet.descriptor
requireNotNull(descriptor)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import java.time.Instant
class Av1DDAdaptiveSourceProjectionContext(
private val diagnosticContext: DiagnosticContext,
rtpState: RtpState,
persistentState: Any?,
parentLogger: Logger
) : AdaptiveSourceProjectionContext {
private val logger: Logger = createChildLogger(parentLogger)
Expand All @@ -56,11 +57,17 @@ class Av1DDAdaptiveSourceProjectionContext(
*/
private val av1QualityFilter = Av1DDQualityFilter(av1FrameMaps, logger)

init {
require(persistentState is Av1PersistentState?)
}

private var lastAv1FrameProjection = Av1DDFrameProjection(
diagnosticContext,
rtpState.ssrc,
rtpState.maxSequenceNumber,
rtpState.maxTimestamp
rtpState.maxTimestamp,
(persistentState as? Av1PersistentState)?.frameNumber,
(persistentState as? Av1PersistentState)?.templateId
)

/**
Expand Down Expand Up @@ -334,19 +341,15 @@ class Av1DDAdaptiveSourceProjectionContext(

val frameNumber: Int
val templateIdDelta: Int
if (lastAv1FrameProjection.av1Frame != null) {
val nextTemplateId = lastAv1FrameProjection.getNextTemplateId()
if (nextTemplateId != null) {
frameNumber = RtpUtils.applySequenceNumberDelta(
lastAv1FrameProjection.frameNumber,
1
)
val nextTemplateId = lastAv1FrameProjection.getNextTemplateId()
templateIdDelta = if (nextTemplateId != null) {
val structure = frame.structure
check(structure != null)
getTemplateIdDelta(nextTemplateId, structure.templateIdOffset)
} else {
0
}
val structure = frame.structure
check(structure != null)
templateIdDelta = getTemplateIdDelta(nextTemplateId, structure.templateIdOffset)
} else {
frameNumber = frame.frameNumber
templateIdDelta = 0
Expand Down Expand Up @@ -651,6 +654,11 @@ class Av1DDAdaptiveSourceProjectionContext(
lastAv1FrameProjection.timestamp
)

override fun getPersistentState(): Any = Av1PersistentState(
lastAv1FrameProjection.frameNumber,
lastAv1FrameProjection.getNextTemplateId() ?: 0
)

override fun getDebugState(): JSONObject {
val debugState = JSONObject()
debugState["class"] = Av1DDAdaptiveSourceProjectionContext::class.java.simpleName
Expand All @@ -676,3 +684,8 @@ class Av1DDAdaptiveSourceProjectionContext(
TimeSeriesLogger.getTimeSeriesLogger(Av1DDAdaptiveSourceProjectionContext::class.java)
}
}

data class Av1PersistentState(
val frameNumber: Int,
val templateId: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,17 @@ class Av1DDFrameProjection internal constructor(
diagnosticContext: DiagnosticContext,
ssrc: Long,
sequenceNumberDelta: Int,
timestamp: Long
timestamp: Long,
frameNumber: Int?,
templateId: Int?
) : this(
diagnosticContext = diagnosticContext,
av1Frame = null,
ssrc = ssrc,
timestamp = timestamp,
sequenceNumberDelta = sequenceNumberDelta,
frameNumber = 0,
templateIdDelta = 0,
frameNumber = frameNumber ?: 0,
templateIdDelta = templateId ?: -1,
dti = null,
mark = false,
created = null
Expand Down Expand Up @@ -229,6 +231,9 @@ class Av1DDFrameProjection internal constructor(
* Get the next template ID that would come after the template IDs in this projection's structure
*/
fun getNextTemplateId(): Int? {
if (av1Frame == null && templateIdDelta != -1) {
return templateIdDelta
}
return av1Frame?.structure?.let { rewriteTemplateId(it.templateIdOffset + it.templateCount) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
package org.jitsi.videobridge.util

import org.jitsi.nlj.format.AudioRedPayloadType
import org.jitsi.nlj.format.Av1PayloadType
import org.jitsi.nlj.format.H264PayloadType
import org.jitsi.nlj.format.OpusPayloadType
import org.jitsi.nlj.format.OtherAudioPayloadType
import org.jitsi.nlj.format.OtherVideoPayloadType
import org.jitsi.nlj.format.PayloadType
import org.jitsi.nlj.format.PayloadTypeEncoding.AV1
import org.jitsi.nlj.format.PayloadTypeEncoding.Companion.createFrom
import org.jitsi.nlj.format.PayloadTypeEncoding.H264
import org.jitsi.nlj.format.PayloadTypeEncoding.OPUS
Expand Down Expand Up @@ -94,6 +96,7 @@ class PayloadTypeUtil {
return when (encoding) {
VP8 -> Vp8PayloadType(id, parameters, rtcpFeedbackSet)
VP9 -> Vp9PayloadType(id, parameters, rtcpFeedbackSet)
AV1 -> Av1PayloadType(id, parameters, rtcpFeedbackSet)
H264 -> H264PayloadType(id, parameters, rtcpFeedbackSet)
RTX -> RtxPayloadType(id, parameters)
OPUS -> OpusPayloadType(id, parameters)
Expand Down
Loading