diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 96cef2b8e80..ef5e22a9894 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -484,7 +484,9 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, return codec, nil } - // Bind is called under RTPSender.mu lock, call the RTPSender.GetParameters in goroutine to avoid deadlock + // Bind is called under RTPSender.mu lock, + // call the RTPSender.GetParameters (which setRTPHeaderExtensions invokes) + // in goroutine to avoid deadlock go d.setRTPHeaderExtensions() doBind := func() { @@ -777,25 +779,24 @@ func (d *DownTrack) SetReceiver(r TrackReceiver) { // Sets RTP header extensions for this track func (d *DownTrack) setRTPHeaderExtensions() { - d.bindLock.Lock() - defer d.bindLock.Unlock() - sal := d.getStreamAllocatorListener() if sal == nil { return } + isBWEEnabled := sal.IsBWEEnabled(d) + bweType := sal.BWEType() + tr := d.transceiver.Load() + if tr == nil { + return + } var extensions []webrtc.RTPHeaderExtensionParameter - if tr := d.transceiver.Load(); tr != nil { - if sender := tr.Sender(); sender != nil { - extensions = sender.GetParameters().HeaderExtensions - d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions) - } + if sender := tr.Sender(); sender != nil { + extensions = sender.GetParameters().HeaderExtensions + d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions) } - isBWEEnabled := sal.IsBWEEnabled(d) - bweType := sal.BWEType() - + d.bindLock.Lock() for _, ext := range extensions { switch ext.URI { case sdp.ABSSendTimeURI: @@ -818,6 +819,7 @@ func (d *DownTrack) setRTPHeaderExtensions() { d.absCaptureTimeExtID = ext.ID } } + d.bindLock.Unlock() } // Kind controls if this TrackLocal is audio or video @@ -840,6 +842,7 @@ func (d *DownTrack) SSRCRTX() uint32 { func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { d.transceiver.Store(transceiver) + d.setRTPHeaderExtensions() } func (d *DownTrack) GetTransceiver() *webrtc.RTPTransceiver { @@ -957,7 +960,11 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { copy(payload, tp.codecBytes) n := copy(payload[len(tp.codecBytes):], extPkt.Packet.Payload[tp.incomingHeaderSize:]) if n != len(extPkt.Packet.Payload[tp.incomingHeaderSize:]) { - d.params.Logger.Errorw("payload overflow", nil, "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), "have", n) + d.params.Logger.Errorw( + "payload overflow", nil, + "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), + "have", n, + ) PacketFactory.Put(poolEntity) return ErrPayloadOverflow } diff --git a/pkg/sfu/forwardstats.go b/pkg/sfu/forwardstats.go index 0601168b5ea..00b4e91ffdd 100644 --- a/pkg/sfu/forwardstats.go +++ b/pkg/sfu/forwardstats.go @@ -1,29 +1,33 @@ package sfu import ( - "fmt" "sync" "time" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" ) const ( - highForwardingLatency = 500 * time.Millisecond + highForwardingLatency = 20 * time.Millisecond skewFactor = 10 ) type ForwardStats struct { - lock sync.Mutex - latency *utils.LatencyAggregate - closeCh chan struct{} + lock sync.Mutex + latency *utils.LatencyAggregate + lowest int64 + highest int64 + lastUpdateAt int64 + closeCh chan struct{} } func NewForwardStats(latencyUpdateInterval, reportInterval, latencyWindowLength time.Duration) *ForwardStats { s := &ForwardStats{ latency: utils.NewLatencyAggregate(latencyUpdateInterval, latencyWindowLength), + lowest: time.Second.Nanoseconds(), closeCh: make(chan struct{}), } @@ -40,34 +44,48 @@ func (s *ForwardStats) Update(arrival, left int64) (int64, bool) { s.lock.Lock() s.latency.Update(time.Duration(arrival), float64(transit)) + s.lowest = min(transit, s.lowest) + s.highest = max(transit, s.highest) + s.lastUpdateAt = arrival s.lock.Unlock() return transit, isHighForwardingLatency } -func (s *ForwardStats) GetStats() (time.Duration, time.Duration) { +func (s *ForwardStats) GetStats(shortDuration time.Duration) (time.Duration, time.Duration, time.Duration, time.Duration) { s.lock.Lock() - w := s.latency.Summarize() + // a dummy sample to flush the pipe to current time + now := mono.UnixNano() + if (now - s.lastUpdateAt) > shortDuration.Nanoseconds() { + s.latency.Update(time.Duration(now), 0) + } + + wLong := s.latency.Summarize() + wShort := s.latency.SummarizeLast(shortDuration) + + lowest := s.lowest + s.lowest = time.Second.Nanoseconds() + + highest := s.highest + s.highest = 0 s.lock.Unlock() - latency, jitter := time.Duration(w.Mean()), time.Duration(w.StdDev()) - if jitter > latency*skewFactor { + latencyLong, jitterLong := time.Duration(wLong.Mean()), time.Duration(wLong.StdDev()) + latencyShort, jitterShort := time.Duration(wShort.Mean()), time.Duration(wShort.StdDev()) + if jitterLong > latencyLong*skewFactor { logger.Infow( "high jitter in forwarding path", - "latency", latency, - "jitter", jitter, - "stats", fmt.Sprintf("count %.2f, mean %.2f, stdDev %.2f", w.Count(), w.Mean(), w.StdDev()), + "lowest", time.Duration(lowest), + "highest", time.Duration(highest), + "countLong", wLong.Count(), + "latencyLong", latencyLong, + "jitterLong", jitterLong, + "countShort", wShort.Count(), + "latencyShort", latencyShort, + "jitterShort", jitterShort, ) } - return latency, jitter -} - -func (s *ForwardStats) GetLastStats(duration time.Duration) (time.Duration, time.Duration) { - s.lock.Lock() - w := s.latency.SummarizeLast(duration) - s.lock.Unlock() - - return time.Duration(w.Mean()), time.Duration(w.StdDev()) + return latencyLong, jitterLong, latencyShort, jitterShort } func (s *ForwardStats) Stop() { @@ -84,10 +102,9 @@ func (s *ForwardStats) report(reportInterval time.Duration) { return case <-ticker.C: - latency, jitter := s.GetLastStats(reportInterval) - latencySlow, jitterSlow := s.GetStats() - prometheus.RecordForwardJitter(uint32(jitter.Microseconds()), uint32(jitterSlow.Microseconds())) - prometheus.RecordForwardLatency(uint32(latency.Microseconds()), uint32(latencySlow.Microseconds())) + latencyLong, jitterLong, latencyShort, jitterShort := s.GetStats(reportInterval) + prometheus.RecordForwardJitter(uint32(jitterShort.Microseconds()), uint32(jitterLong.Microseconds())) + prometheus.RecordForwardLatency(uint32(latencyShort.Microseconds()), uint32(latencyLong.Microseconds())) } } }