@@ -82,9 +82,32 @@ type impl struct {
8282 receivers []Receiver
8383}
8484
85+ // interfaceWrapper is concrete type that wraps an interface. Necessary because
86+ // atomic.Value needs the same type and can not Store(nil). This indirection
87+ // allows us to store nil.
88+ type interfaceWrapper [T any ] struct {
89+ t T
90+ }
91+ type atomicInterface [T any ] struct {
92+ iface atomic.Value
93+ }
94+
95+ func (a * atomicInterface [T ]) Load () T {
96+ var v T
97+ x := a .iface .Load ()
98+ if x != nil {
99+ return x .(interfaceWrapper [T ]).t
100+ }
101+ return v
102+ }
103+
104+ func (a * atomicInterface [T ]) Store (v T ) {
105+ a .iface .Store (interfaceWrapper [T ]{v })
106+ }
107+
85108type streamMessageSender struct {
86109 to peer.ID
87- stream network.Stream
110+ stream atomicInterface [ network.Stream ]
88111 bsnet * impl
89112 opts * MessageSenderOpts
90113}
@@ -95,7 +118,7 @@ type HasContext interface {
95118
96119// Open a stream to the remote peer
97120func (s * streamMessageSender ) Connect (ctx context.Context ) (network.Stream , error ) {
98- stream := s .stream
121+ stream := s .stream . Load ()
99122 if stream != nil {
100123 return stream , nil
101124 }
@@ -111,36 +134,41 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro
111134 if err != nil {
112135 return nil , err
113136 }
137+ if withCtx , ok := stream .Conn ().(HasContext ); ok {
138+ context .AfterFunc (withCtx .Context (), func () {
139+ s .stream .Store (nil )
140+ })
141+ }
114142
115- s .stream = stream
143+ s .stream . Store ( stream )
116144 return stream , nil
117145}
118146
119147// Reset the stream
120148func (s * streamMessageSender ) Reset () error {
121- stream := s .stream
149+ stream := s .stream . Load ()
122150 if stream != nil {
123151 err := stream .Reset ()
124- s .stream = nil
152+ s .stream . Store ( nil )
125153 return err
126154 }
127155 return nil
128156}
129157
130158// Close the stream
131159func (s * streamMessageSender ) Close () error {
132- stream := s .stream
160+ stream := s .stream . Load ()
133161 if stream != nil {
134162 err := stream .Close ()
135- s .stream = nil
163+ s .stream . Store ( nil )
136164 return err
137165 }
138166 return nil
139167}
140168
141169// Indicates whether the peer supports HAVE / DONT_HAVE messages
142170func (s * streamMessageSender ) SupportsHave () bool {
143- stream := s .stream
171+ stream := s .stream . Load ()
144172 if stream == nil {
145173 return false
146174 }
0 commit comments