Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions pkg/sfu/downtrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
return codec, nil
}

// Bind is called under RTPSender.mu lock, call the RTPSender.GetParameters in goroutine to avoid deadlock
// Bind is called under RTPSender.mu lock,
// call the RTPSender.GetParameters (which setRTPHeaderExtensions invokes)
// in goroutine to avoid deadlock
go d.setRTPHeaderExtensions()

doBind := func() {
Expand Down Expand Up @@ -777,25 +779,24 @@ func (d *DownTrack) SetReceiver(r TrackReceiver) {

// Sets RTP header extensions for this track
func (d *DownTrack) setRTPHeaderExtensions() {
d.bindLock.Lock()
defer d.bindLock.Unlock()

sal := d.getStreamAllocatorListener()
if sal == nil {
return
}
isBWEEnabled := sal.IsBWEEnabled(d)
bweType := sal.BWEType()

tr := d.transceiver.Load()
if tr == nil {
return
}
var extensions []webrtc.RTPHeaderExtensionParameter
if tr := d.transceiver.Load(); tr != nil {
if sender := tr.Sender(); sender != nil {
extensions = sender.GetParameters().HeaderExtensions
d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions)
}
if sender := tr.Sender(); sender != nil {
extensions = sender.GetParameters().HeaderExtensions
d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions)
}

isBWEEnabled := sal.IsBWEEnabled(d)
bweType := sal.BWEType()

d.bindLock.Lock()
for _, ext := range extensions {
switch ext.URI {
case sdp.ABSSendTimeURI:
Expand All @@ -818,6 +819,7 @@ func (d *DownTrack) setRTPHeaderExtensions() {
d.absCaptureTimeExtID = ext.ID
}
}
d.bindLock.Unlock()
}

// Kind controls if this TrackLocal is audio or video
Expand All @@ -840,6 +842,7 @@ func (d *DownTrack) SSRCRTX() uint32 {

func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) {
d.transceiver.Store(transceiver)
d.setRTPHeaderExtensions()
}

func (d *DownTrack) GetTransceiver() *webrtc.RTPTransceiver {
Expand Down Expand Up @@ -957,7 +960,11 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
copy(payload, tp.codecBytes)
n := copy(payload[len(tp.codecBytes):], extPkt.Packet.Payload[tp.incomingHeaderSize:])
if n != len(extPkt.Packet.Payload[tp.incomingHeaderSize:]) {
d.params.Logger.Errorw("payload overflow", nil, "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), "have", n)
d.params.Logger.Errorw(
"payload overflow", nil,
"want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]),
"have", n,
)
PacketFactory.Put(poolEntity)
return ErrPayloadOverflow
}
Expand Down
67 changes: 42 additions & 25 deletions pkg/sfu/forwardstats.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
package sfu

import (
"fmt"
"sync"
"time"

"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils"
"github.com/livekit/protocol/utils/mono"
)

const (
highForwardingLatency = 500 * time.Millisecond
highForwardingLatency = 20 * time.Millisecond
skewFactor = 10
)

type ForwardStats struct {
lock sync.Mutex
latency *utils.LatencyAggregate
closeCh chan struct{}
lock sync.Mutex
latency *utils.LatencyAggregate
lowest int64
highest int64
lastUpdateAt int64
closeCh chan struct{}
}

func NewForwardStats(latencyUpdateInterval, reportInterval, latencyWindowLength time.Duration) *ForwardStats {
s := &ForwardStats{
latency: utils.NewLatencyAggregate(latencyUpdateInterval, latencyWindowLength),
lowest: time.Second.Nanoseconds(),
closeCh: make(chan struct{}),
}

Expand All @@ -40,34 +44,48 @@ func (s *ForwardStats) Update(arrival, left int64) (int64, bool) {

s.lock.Lock()
s.latency.Update(time.Duration(arrival), float64(transit))
s.lowest = min(transit, s.lowest)
s.highest = max(transit, s.highest)
s.lastUpdateAt = arrival
s.lock.Unlock()

return transit, isHighForwardingLatency
}

func (s *ForwardStats) GetStats() (time.Duration, time.Duration) {
func (s *ForwardStats) GetStats(shortDuration time.Duration) (time.Duration, time.Duration, time.Duration, time.Duration) {
s.lock.Lock()
w := s.latency.Summarize()
// a dummy sample to flush the pipe to current time
now := mono.UnixNano()
if (now - s.lastUpdateAt) > shortDuration.Nanoseconds() {
s.latency.Update(time.Duration(now), 0)
}

wLong := s.latency.Summarize()
wShort := s.latency.SummarizeLast(shortDuration)

lowest := s.lowest
s.lowest = time.Second.Nanoseconds()

highest := s.highest
s.highest = 0
s.lock.Unlock()

latency, jitter := time.Duration(w.Mean()), time.Duration(w.StdDev())
if jitter > latency*skewFactor {
latencyLong, jitterLong := time.Duration(wLong.Mean()), time.Duration(wLong.StdDev())
latencyShort, jitterShort := time.Duration(wShort.Mean()), time.Duration(wShort.StdDev())
if jitterLong > latencyLong*skewFactor {
logger.Infow(
"high jitter in forwarding path",
"latency", latency,
"jitter", jitter,
"stats", fmt.Sprintf("count %.2f, mean %.2f, stdDev %.2f", w.Count(), w.Mean(), w.StdDev()),
"lowest", time.Duration(lowest),
"highest", time.Duration(highest),
"countLong", wLong.Count(),
"latencyLong", latencyLong,
"jitterLong", jitterLong,
"countShort", wShort.Count(),
"latencyShort", latencyShort,
"jitterShort", jitterShort,
)
}
return latency, jitter
}

func (s *ForwardStats) GetLastStats(duration time.Duration) (time.Duration, time.Duration) {
s.lock.Lock()
w := s.latency.SummarizeLast(duration)
s.lock.Unlock()

return time.Duration(w.Mean()), time.Duration(w.StdDev())
return latencyLong, jitterLong, latencyShort, jitterShort
}

func (s *ForwardStats) Stop() {
Expand All @@ -84,10 +102,9 @@ func (s *ForwardStats) report(reportInterval time.Duration) {
return

case <-ticker.C:
latency, jitter := s.GetLastStats(reportInterval)
latencySlow, jitterSlow := s.GetStats()
prometheus.RecordForwardJitter(uint32(jitter.Microseconds()), uint32(jitterSlow.Microseconds()))
prometheus.RecordForwardLatency(uint32(latency.Microseconds()), uint32(latencySlow.Microseconds()))
latencyLong, jitterLong, latencyShort, jitterShort := s.GetStats(reportInterval)
prometheus.RecordForwardJitter(uint32(jitterShort.Microseconds()), uint32(jitterLong.Microseconds()))
prometheus.RecordForwardLatency(uint32(latencyShort.Microseconds()), uint32(latencyLong.Microseconds()))
}
}
}
Loading