@@ -81,7 +81,7 @@ func (m *Manager) dialPeer(ctx context.Context, p *peer.Peer, opts []ConnectPeer
81
81
}
82
82
83
83
if len (streams ) == 0 {
84
- return nil , fmt .Errorf ("no streams initiated with peer %s / %s" , address , p .ID ())
84
+ return nil , errors .Errorf ("no streams initiated with peer %s / %s" , address , p .ID ())
85
85
}
86
86
87
87
return streams , nil
@@ -102,7 +102,7 @@ func (m *Manager) acceptPeer(ctx context.Context, p *peer.Peer, opts []ConnectPe
102
102
ctx , cancel = context .WithTimeout (ctx , defaultConnectionTimeout )
103
103
defer cancel ()
104
104
}
105
- am , err := m .newAcceptMatcher (p , protocolID )
105
+ amCtx , am , err := m .newAcceptMatcher (ctx , p , protocolID )
106
106
if err != nil {
107
107
return nil , errors .WithStack (err )
108
108
}
@@ -118,11 +118,11 @@ func (m *Manager) acceptPeer(ctx context.Context, p *peer.Peer, opts []ConnectPe
118
118
select {
119
119
case ps := <- streamCh :
120
120
if ps .Protocol () != protocolID {
121
- return nil , fmt .Errorf ("accepted stream has wrong protocol: %s != %s" , ps .Protocol (), protocolID )
121
+ return nil , errors .Errorf ("accepted stream has wrong protocol: %s != %s" , ps .Protocol (), protocolID )
122
122
}
123
123
return ps , nil
124
- case <- ctx .Done ():
125
- err := ctx .Err ()
124
+ case <- amCtx .Done ():
125
+ err := amCtx .Err ()
126
126
if errors .Is (err , context .DeadlineExceeded ) {
127
127
m .log .Debugw ("accept timeout" , "id" , am .Peer .ID (), "proto" , protocolID )
128
128
return nil , errors .WithStack (ErrTimeout )
@@ -166,7 +166,7 @@ func (m *Manager) acceptPeer(ctx context.Context, p *peer.Peer, opts []ConnectPe
166
166
}
167
167
168
168
if len (streams ) == 0 {
169
- return nil , fmt .Errorf ("no streams accepted from peer %s" , p .ID ())
169
+ return nil , errors .Errorf ("no streams accepted from peer %s" , p .ID ())
170
170
}
171
171
172
172
return streams , nil
@@ -175,7 +175,7 @@ func (m *Manager) acceptPeer(ctx context.Context, p *peer.Peer, opts []ConnectPe
175
175
func (m * Manager ) initiateStream (ctx context.Context , libp2pID libp2ppeer.ID , protocolID protocol.ID ) (* PacketsStream , error ) {
176
176
protocolHandler , registered := m .registeredProtocols [protocolID ]
177
177
if ! registered {
178
- return nil , fmt .Errorf ("cannot initiate stream protocol %s is not registered" , protocolID )
178
+ return nil , errors .Errorf ("cannot initiate stream protocol %s is not registered" , protocolID )
179
179
}
180
180
stream , err := m .GetP2PHost ().NewStream (ctx , libp2pID , protocolID )
181
181
if err != nil {
@@ -210,11 +210,14 @@ func (m *Manager) handleStream(stream network.Stream) {
210
210
am := m .matchNewStream (stream )
211
211
if am != nil {
212
212
am .StreamChMutex .RLock ()
213
+ defer am .StreamChMutex .RUnlock ()
213
214
streamCh := am .StreamCh [protocolID ]
214
- am .StreamChMutex .RUnlock ()
215
215
216
- m .log .Debugw ("incoming stream matched" , "id" , am .Peer .ID (), "proto" , protocolID )
217
- streamCh <- ps
216
+ select {
217
+ case <- am .Ctx .Done ():
218
+ case streamCh <- ps :
219
+ m .log .Debugw ("incoming stream matched" , "id" , am .Peer .ID (), "proto" , protocolID )
220
+ }
218
221
} else {
219
222
// close the connection if not matched
220
223
m .log .Debugw ("unexpected connection" , "addr" , stream .Conn ().RemoteMultiaddr (),
@@ -230,54 +233,61 @@ type AcceptMatcher struct {
230
233
Libp2pID libp2ppeer.ID
231
234
StreamChMutex sync.RWMutex
232
235
StreamCh map [protocol.ID ]chan * PacketsStream
236
+ Ctx context.Context
237
+ CtxCancel context.CancelFunc
233
238
}
234
239
235
- func (m * Manager ) newAcceptMatcher (p * peer.Peer , protocolID protocol.ID ) (* AcceptMatcher , error ) {
240
+ func (m * Manager ) newAcceptMatcher (ctx context. Context , p * peer.Peer , protocolID protocol.ID ) (context. Context , * AcceptMatcher , error ) {
236
241
m .acceptMutex .Lock ()
237
242
defer m .acceptMutex .Unlock ()
238
243
239
244
libp2pID , err := libp2putil .ToLibp2pPeerID (p )
240
245
if err != nil {
241
- return nil , errors .WithStack (err )
246
+ return nil , nil , errors .WithStack (err )
242
247
}
243
248
244
249
acceptMatcher , acceptExists := m .acceptMap [libp2pID ]
245
250
if acceptExists {
246
251
acceptMatcher .StreamChMutex .Lock ()
247
252
defer acceptMatcher .StreamChMutex .Unlock ()
248
253
if _ , streamChanExists := acceptMatcher .StreamCh [protocolID ]; streamChanExists {
249
- return nil , nil
254
+ return nil , nil , nil
250
255
}
251
256
acceptMatcher .StreamCh [protocolID ] = make (chan * PacketsStream )
252
- return acceptMatcher , nil
257
+ return acceptMatcher . Ctx , acceptMatcher , nil
253
258
}
254
259
260
+ cancelCtx , cancelCtxFunc := context .WithCancel (ctx )
261
+
255
262
am := & AcceptMatcher {
256
- Peer : p ,
257
- Libp2pID : libp2pID ,
258
- StreamCh : make (map [protocol.ID ]chan * PacketsStream ),
263
+ Peer : p ,
264
+ Libp2pID : libp2pID ,
265
+ StreamCh : make (map [protocol.ID ]chan * PacketsStream ),
266
+ Ctx : cancelCtx ,
267
+ CtxCancel : cancelCtxFunc ,
259
268
}
260
269
261
270
am .StreamCh [protocolID ] = make (chan * PacketsStream )
262
271
263
272
m .acceptMap [libp2pID ] = am
264
273
265
- return am , nil
274
+ return cancelCtx , am , nil
266
275
}
267
276
268
277
func (m * Manager ) removeAcceptMatcher (am * AcceptMatcher , protocolID protocol.ID ) {
269
278
m .acceptMutex .Lock ()
270
279
defer m .acceptMutex .Unlock ()
271
280
272
281
existingAm := m .acceptMap [am .Libp2pID ]
282
+
273
283
existingAm .StreamChMutex .Lock ()
274
284
defer existingAm .StreamChMutex .Unlock ()
275
285
276
- close (existingAm .StreamCh [protocolID ])
277
286
delete (existingAm .StreamCh , protocolID )
278
287
279
288
if len (existingAm .StreamCh ) == 0 {
280
289
delete (m .acceptMap , am .Libp2pID )
290
+ existingAm .CtxCancel ()
281
291
}
282
292
}
283
293
0 commit comments