@@ -2,6 +2,7 @@ package webtransport
22
33import (
44 "context"
5+ "slices"
56 "sync"
67 "time"
78
@@ -21,11 +22,14 @@ type sessionEntry struct {
2122 Session * Session
2223}
2324
25+ const maxRecentlyClosedSessions = 16
26+
2427type sessionManager struct {
2528 timeout time.Duration
2629
27- mx sync.Mutex
28- sessions map [sessionID ]sessionEntry
30+ mx sync.Mutex
31+ sessions map [sessionID ]sessionEntry
32+ recentlyClosedSessions []sessionID
2933}
3034
3135func newSessionManager (timeout time.Duration ) * sessionManager {
@@ -45,6 +49,13 @@ func (m *sessionManager) AddStream(str *quic.Stream, id sessionID) {
4549
4650 entry , ok := m .sessions [id ]
4751 if ! ok {
52+ // Receiving a stream for an unknown session is expected to be rare,
53+ // so the performance impact of searching through the slice is negligible.
54+ if slices .Contains (m .recentlyClosedSessions , id ) {
55+ str .CancelRead (WTBufferedStreamRejectedErrorCode )
56+ str .CancelWrite (WTBufferedStreamRejectedErrorCode )
57+ return
58+ }
4859 entry = sessionEntry {Unestablished : & unestablishedSession {}}
4960 m .sessions [id ] = entry
5061 }
@@ -67,6 +78,12 @@ func (m *sessionManager) AddUniStream(str *quic.ReceiveStream, id sessionID) {
6778
6879 entry , ok := m .sessions [id ]
6980 if ! ok {
81+ // Receiving a stream for an unknown session is expected to be rare,
82+ // so the performance impact of searching through the slice is negligible.
83+ if slices .Contains (m .recentlyClosedSessions , id ) {
84+ str .CancelRead (WTBufferedStreamRejectedErrorCode )
85+ return
86+ }
7087 entry = sessionEntry {Unestablished : & unestablishedSession {}}
7188 m .sessions [id ] = entry
7289 }
@@ -134,12 +151,21 @@ func (m *sessionManager) AddSession(id sessionID, s *Session) {
134151 m .sessions [id ] = sessionEntry {Session : s }
135152
136153 context .AfterFunc (s .Context (), func () {
137- m .mx .Lock ()
138- defer m .mx .Unlock ()
139- delete (m .sessions , id )
154+ m .deleteSession (id )
140155 })
141156}
142157
158+ func (m * sessionManager ) deleteSession (id sessionID ) {
159+ m .mx .Lock ()
160+ defer m .mx .Unlock ()
161+
162+ delete (m .sessions , id )
163+ m .recentlyClosedSessions = append (m .recentlyClosedSessions , id )
164+ if len (m .recentlyClosedSessions ) > maxRecentlyClosedSessions {
165+ m .recentlyClosedSessions = m .recentlyClosedSessions [1 :]
166+ }
167+ }
168+
143169func (m * sessionManager ) Close () {
144170 m .mx .Lock ()
145171 defer m .mx .Unlock ()
0 commit comments