@@ -10,6 +10,7 @@ import (
1010 "net"
1111 "net/http"
1212 "sync"
13+ "sync/atomic"
1314 "time"
1415
1516 "github.com/gorilla/websocket"
@@ -42,7 +43,10 @@ type Protocol struct {
4243 connInfo netx.ConnInfo
4344 rnd * rand.Rand
4445 measurer Measurer
45- once * sync.Once
46+ once sync.Once
47+
48+ applicationBytesReceived atomic.Int64
49+ applicationBytesSent atomic.Int64
4650}
4751
4852// New returns a new Protocol with the specified connection and every other
@@ -54,7 +58,6 @@ func New(conn *websocket.Conn) *Protocol {
5458 // Seed randomness source with the current time.
5559 rnd : rand .New (rand .NewSource (time .Now ().UnixMilli ())),
5660 measurer : & DefaultMeasurer {},
57- once : & sync.Once {},
5861 }
5962}
6063
@@ -138,7 +141,8 @@ func (p *Protocol) senderReceiverLoop(ctx context.Context,
138141}
139142
140143// receiver reads from the connection until NextReader fails. It returns
141- // the measurements received over the provided channel.
144+ // the measurements received over the provided channel and updates the sent and
145+ // received byte counters as needed.
142146func (p * Protocol ) receiver (ctx context.Context ,
143147 results chan <- model.WireMeasurement , errCh chan <- error ) {
144148 for {
@@ -147,12 +151,22 @@ func (p *Protocol) receiver(ctx context.Context,
147151 errCh <- err
148152 return
149153 }
154+ if kind == websocket .BinaryMessage {
155+ // Binary messages are discarded after reading their size.
156+ size , err := io .Copy (io .Discard , reader )
157+ if err != nil {
158+ errCh <- err
159+ return
160+ }
161+ p .applicationBytesReceived .Add (size )
162+ }
150163 if kind == websocket .TextMessage {
151164 data , err := io .ReadAll (reader )
152165 if err != nil {
153166 errCh <- err
154167 return
155168 }
169+ p .applicationBytesReceived .Add (int64 (len (data )))
156170 var m model.WireMeasurement
157171 if err := json .Unmarshal (data , & m ); err != nil {
158172 errCh <- err
@@ -177,12 +191,27 @@ func (p *Protocol) sendCounterflow(ctx context.Context,
177191 wm = p .createWireMeasurement (ctx )
178192 })
179193 wm .Measurement = m
180- err := p .conn .WriteJSON (wm )
194+ wm .Application = model.ByteCounters {
195+ BytesSent : p .applicationBytesSent .Load (),
196+ BytesReceived : p .applicationBytesReceived .Load (),
197+ }
198+ // Encode as JSON separately so we can read the message size before
199+ // sending.
200+ jsonwm , err := json .Marshal (wm )
201+ if err != nil {
202+ log .Printf ("failed to encode measurement (ctx: %p, err: %v)" ,
203+ ctx , err )
204+ errCh <- err
205+ return
206+ }
207+ err = p .conn .WriteMessage (websocket .TextMessage , jsonwm )
181208 if err != nil {
182209 log .Printf ("failed to write measurement JSON (ctx: %p, err: %v)" , ctx , err )
183210 errCh <- err
184211 return
185212 }
213+ p .applicationBytesSent .Add (int64 (len (jsonwm )))
214+
186215 // This send is non-blocking in case there is no one to read the
187216 // Measurement message and the channel's buffer is full.
188217 select {
@@ -195,7 +224,6 @@ func (p *Protocol) sendCounterflow(ctx context.Context,
195224
196225func (p * Protocol ) sender (ctx context.Context , measurerCh <- chan model.Measurement ,
197226 results chan <- model.WireMeasurement , errCh chan <- error ) {
198- ci := netx .ToConnInfo (p .conn .UnderlyingConn ())
199227 size := spec .MinMessageSize
200228 message , err := p .makePreparedMessage (size )
201229 if err != nil {
@@ -219,12 +247,27 @@ func (p *Protocol) sender(ctx context.Context, measurerCh <-chan model.Measureme
219247 wm = p .createWireMeasurement (ctx )
220248 })
221249 wm .Measurement = m
222- err = p .conn .WriteJSON (wm )
250+ wm .Application = model.ByteCounters {
251+ BytesReceived : p .applicationBytesReceived .Load (),
252+ BytesSent : p .applicationBytesSent .Load (),
253+ }
254+ // Encode as JSON separately so we can read the message size before
255+ // sending.
256+ jsonwm , err := json .Marshal (wm )
257+ if err != nil {
258+ log .Printf ("failed to encode measurement (ctx: %p, err: %v)" ,
259+ ctx , err )
260+ errCh <- err
261+ return
262+ }
263+ err = p .conn .WriteMessage (websocket .TextMessage , jsonwm )
223264 if err != nil {
224265 log .Printf ("failed to write measurement JSON (ctx: %p, err: %v)" , ctx , err )
225266 errCh <- err
226267 return
227268 }
269+ p .applicationBytesSent .Add (int64 (len (jsonwm )))
270+
228271 // This send is non-blocking in case there is no one to read the
229272 // Measurement message and the channel's buffer is full.
230273 select {
@@ -238,14 +281,14 @@ func (p *Protocol) sender(ctx context.Context, measurerCh <-chan model.Measureme
238281 errCh <- err
239282 return
240283 }
284+ p .applicationBytesSent .Add (int64 (size ))
241285
242286 // Determine whether it's time to scale the message size.
243287 if size >= spec .MaxScaledMessageSize {
244288 continue
245289 }
246290
247- _ , w := ci .ByteCounters ()
248- if uint64 (size ) > w / spec .ScalingFraction {
291+ if size > int (p .applicationBytesSent .Load ())/ spec .ScalingFraction {
249292 continue
250293 }
251294
@@ -264,11 +307,15 @@ func (p *Protocol) sender(ctx context.Context, measurerCh <-chan model.Measureme
264307func (p * Protocol ) close (ctx context.Context ) {
265308 msg := websocket .FormatCloseMessage (
266309 websocket .CloseNormalClosure , "Done sending" )
310+
267311 err := p .conn .WriteControl (websocket .CloseMessage , msg , time .Now ().Add (time .Second ))
268312 if err != nil {
269313 log .Printf ("WriteControl failed (ctx: %p, err: %v)" , ctx , err )
270314 return
271315 }
316+ // The closing message is part of the measurement and added to bytesSent.
317+ p .applicationBytesSent .Add (int64 (len (msg )))
318+
272319 log .Printf ("Close message sent (ctx: %p)" , ctx )
273320}
274321
0 commit comments