@@ -82,18 +82,45 @@ type impl struct {
8282 receivers []Receiver
8383}
8484
85+ // interfaceWrapper is concrete type that wraps an interface. Necessary because
86+ // atomic.Value needs the same type and can not Store(nil). This indirection
87+ // allows us to store nil.
88+ type interfaceWrapper [T any ] struct {
89+ t T
90+ }
91+ type atomicInterface [T any ] struct {
92+ iface atomic.Value
93+ }
94+
95+ func (a * atomicInterface [T ]) Load () T {
96+ var v T
97+ x := a .iface .Load ()
98+ if x != nil {
99+ return x .(interfaceWrapper [T ]).t
100+ }
101+ return v
102+ }
103+
104+ func (a * atomicInterface [T ]) Store (v T ) {
105+ a .iface .Store (interfaceWrapper [T ]{v })
106+ }
107+
85108type streamMessageSender struct {
86- to peer.ID
87- stream network.Stream
88- connected bool
89- bsnet * impl
90- opts * MessageSenderOpts
109+ to peer.ID
110+ stream atomicInterface [network.Stream ]
111+ bsnet * impl
112+ opts * MessageSenderOpts
113+ }
114+
115+ type HasContext interface {
116+ Context () context.Context
91117}
92118
93119// Open a stream to the remote peer
94120func (s * streamMessageSender ) Connect (ctx context.Context ) (network.Stream , error ) {
95- if s .connected {
96- return s .stream , nil
121+ stream := s .stream .Load ()
122+ if stream != nil {
123+ return stream , nil
97124 }
98125
99126 tctx , cancel := context .WithTimeout (ctx , s .opts .SendTimeout )
@@ -107,30 +134,45 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro
107134 if err != nil {
108135 return nil , err
109136 }
137+ if withCtx , ok := stream .Conn ().(HasContext ); ok {
138+ context .AfterFunc (withCtx .Context (), func () {
139+ s .stream .Store (nil )
140+ })
141+ }
110142
111- s .stream = stream
112- s .connected = true
113- return s .stream , nil
143+ s .stream .Store (stream )
144+ return stream , nil
114145}
115146
116147// Reset the stream
117148func (s * streamMessageSender ) Reset () error {
118- if s .stream != nil {
119- err := s .stream .Reset ()
120- s .connected = false
149+ stream := s .stream .Load ()
150+ if stream != nil {
151+ err := stream .Reset ()
152+ s .stream .Store (nil )
121153 return err
122154 }
123155 return nil
124156}
125157
126158// Close the stream
127159func (s * streamMessageSender ) Close () error {
128- return s .stream .Close ()
160+ stream := s .stream .Load ()
161+ if stream != nil {
162+ err := stream .Close ()
163+ s .stream .Store (nil )
164+ return err
165+ }
166+ return nil
129167}
130168
131169// Indicates whether the peer supports HAVE / DONT_HAVE messages
132170func (s * streamMessageSender ) SupportsHave () bool {
133- return s .bsnet .SupportsHave (s .stream .Protocol ())
171+ stream := s .stream .Load ()
172+ if stream == nil {
173+ return false
174+ }
175+ return s .bsnet .SupportsHave (stream .Protocol ())
134176}
135177
136178// Send a message to the peer, attempting multiple times
0 commit comments