diff --git a/async/notifier.go b/async/notifier.go index 747ec5ea..a24ad162 100644 --- a/async/notifier.go +++ b/async/notifier.go @@ -24,8 +24,11 @@ package async import ( "context" "sync" + "time" ) +const notifiedKeyTTL = 30 * time.Second + type rootWaiter struct { key string ch chan struct{} @@ -56,6 +59,9 @@ type Notifier struct { waiters map[string]*rootWaiter // +checklocks:Mutex waiterMap map[string]map[*Waiter]bool + // +checklocks:Mutex + // Sticky: remembers recent notifications so late waiters don't miss them. + notifiedAt map[string]time.Time } type ReleaseFunc func() @@ -64,6 +70,15 @@ func (n *Notifier) NewWaiter(key string) (*Waiter, ReleaseFunc) { n.Lock() defer n.Unlock() + // If this key was notified recently, return an already-closed channel + // so the caller retries immediately without waiting. + if t, ok := n.notifiedAt[key]; ok && time.Since(t) < notifiedKeyTTL { + ch := make(chan struct{}) + close(ch) + w := &Waiter{key: key, ch: ch} + return w, func() {} + } + waiter, found := n.waiters[key] if !found { waiter = &rootWaiter{ @@ -103,6 +118,7 @@ func (n *Notifier) Reset() { } n.waiters = nil n.waiterMap = nil + n.notifiedAt = nil } func (n *Notifier) release(w *Waiter) { @@ -126,6 +142,11 @@ func (n *Notifier) Notify(key string) { n.Lock() defer n.Unlock() + if n.notifiedAt == nil { + n.notifiedAt = make(map[string]time.Time) + } + n.notifiedAt[key] = time.Now() + if w, found := n.waiters[key]; found { w.notify() delete(n.waiters, w.key) diff --git a/sfu/janus/subscriber.go b/sfu/janus/subscriber.go index 0d3bd736..b6d6fcbc 100644 --- a/sfu/janus/subscriber.go +++ b/sfu/janus/subscriber.go @@ -169,8 +169,6 @@ func (p *janusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, return } - waiter, stop := p.mcu.newPublisherConnectedWaiter(p.publisher, p.streamType) - defer stop() loggedNotPublishingYet := false retry: @@ -253,11 +251,14 @@ retry: sfuinternal.StatsWaitingForPublisherTotal.WithLabelValues(string(p.streamType)).Inc() } + waiter, stop := p.mcu.newPublisherConnectedWaiter(p.publisher, p.streamType) if err := waiter.Wait(ctx); err != nil { + stop() p.Close(context.Background()) callback(err, nil) return } + stop() p.logger.Printf("Retry subscribing %s from %s", p.streamType, p.publisher) goto retry default: