Skip to content

Commit b0d0b68

Browse files
Add minimal, multi stream support to minimal-download client (#38)
* Add minimal multistream option * Add all conn logic to download This change consolidates websocket logic into the download() method so that connection start and shutdown can happen concurrently across multiple streams. As such, we checkpoint firstStart firstClose and lastStart and lastClose times as well as byte counts at significant events. With these variables, we can calculate various avg rates or a peak rates.
1 parent fa019b5 commit b0d0b68

File tree

1 file changed

+125
-44
lines changed

1 file changed

+125
-44
lines changed

cmd/minimal-download/main.go

Lines changed: 125 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"net/url"
1515
"path"
1616
"runtime"
17+
"sync"
18+
"sync/atomic"
1719
"time"
1820

1921
"github.com/google/uuid"
@@ -27,14 +29,16 @@ const (
2729
)
2830

2931
var (
30-
flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use")
31-
flagDuration = flag.Duration("duration", 5*time.Second, "Length of the last stream")
32-
flagByteLimit = flag.Int("bytes", 0, "Byte limit to request to the server")
33-
flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification")
34-
flagServerURL = flag.String("server.url", "", "URL to directly target")
35-
flagMID = flag.String("mid", uuid.NewString(), "Measurement ID to use")
36-
flagScheme = flag.String("scheme", "wss", "Websocket scheme (wss or ws)")
37-
flagLocateURL = flag.String("locate.url", locateURL, "The base url for the Locate API")
32+
flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use")
33+
flagDuration = flag.Duration("duration", 5*time.Second, "Length of the last stream")
34+
flagMaxDuration = flag.Duration("max-duration", 15*time.Second, "Maximum length of all connections")
35+
flagByteLimit = flag.Int("bytes", 0, "Byte limit to request to the server")
36+
flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification")
37+
flagServerURL = flag.String("server.url", "", "URL to directly target")
38+
flagMID = flag.String("server.mid", uuid.NewString(), "Measurement ID to use")
39+
flagScheme = flag.String("locate.scheme", "wss", "Websocket scheme (wss or ws)")
40+
flagLocateURL = flag.String("locate.url", locateURL, "The base url for the Locate API")
41+
flagStreams = flag.Int("streams", 1, "The number of concurrent streams to create")
3842
)
3943

4044
// WireMeasurement is a wrapper for Measurement structs that contains
@@ -110,9 +114,9 @@ func init() {
110114
}
111115

112116
// connect to the given msak server URL, returning a *websocket.Conn.
113-
func connect(ctx context.Context, s *url.URL) (*websocket.Conn, error) {
117+
func prepareHeaders(ctx context.Context, s *url.URL) (string, http.Header) {
114118
q := s.Query()
115-
q.Set("streams", fmt.Sprintf("%d", 1))
119+
q.Set("streams", fmt.Sprintf("%d", *flagStreams))
116120
q.Set("cc", *flagCC)
117121
q.Set("bytes", fmt.Sprintf("%d", *flagByteLimit))
118122
q.Set("duration", fmt.Sprintf("%d", (*flagDuration).Milliseconds()))
@@ -126,8 +130,7 @@ func connect(ctx context.Context, s *url.URL) (*websocket.Conn, error) {
126130
headers := http.Header{}
127131
headers.Add("Sec-WebSocket-Protocol", "net.measurementlab.throughput.v1")
128132
headers.Add("User-Agent", clientName+"/"+clientVersion)
129-
conn, _, err := localDialer.DialContext(ctx, s.String(), headers)
130-
return conn, err
133+
return s.String(), headers
131134
}
132135

133136
// formatMessage reports a WireMeasurement in a human readable format.
@@ -204,38 +207,65 @@ func getDownloadServer(ctx context.Context) (*url.URL, error) {
204207
return nil, errors.New("no server")
205208
}
206209

207-
// getConn connects to a download server, returning the *websocket.Conn.
208-
func getConn(ctx context.Context) (*websocket.Conn, error) {
209-
srv, err := getDownloadServer(ctx)
210-
if err != nil {
211-
return nil, err
212-
}
213-
// Connect to server.
214-
return connect(ctx, srv)
210+
type sharedResults struct {
211+
bytesTotal atomic.Int64 // total bytes seen over the life of all connections.
212+
bytesAtLastStart atomic.Int64 // total bytes seen when the last connection starts.
213+
bytesAtFirstStop atomic.Int64 // total bytes seen when the first connection stops/closes.
214+
minRTT atomic.Int64 // minimum of all MinRTT values from all connections.
215+
mu sync.Mutex
216+
started atomic.Bool // set true after first connection opens.
217+
firstStartTime time.Time
218+
lastStartTime time.Time
219+
stopped atomic.Bool // set true after first connection closes (may be different than start conn).
220+
firstStopTime time.Time
221+
lastStopTime time.Time
215222
}
216223

217-
func main() {
218-
flag.Parse()
219-
220-
ctx, cancel := context.WithTimeout(context.Background(), *flagDuration*2)
221-
defer cancel()
222-
223-
conn, err := getConn(ctx)
224+
func (s *sharedResults) download(ctx context.Context, u string, headers http.Header, wg *sync.WaitGroup, streamCount int, stream int) {
225+
// Connect to server.
226+
conn, _, err := localDialer.DialContext(ctx, u, headers)
224227
if err != nil {
225-
log.Fatal(err)
228+
log.Println("skipping one stream; fialed to connect:", err)
229+
return
230+
}
231+
defer func(conn *websocket.Conn) {
232+
// Close on return.
233+
conn.Close()
234+
// On return, record first and last stop times.
235+
s.mu.Lock() // protect stopTime.
236+
now := time.Now()
237+
if !s.stopped.Load() {
238+
// Stop after first connect close.
239+
s.stopped.Store(true)
240+
s.firstStopTime = now
241+
s.bytesAtFirstStop.Store(s.bytesTotal.Load())
242+
}
243+
// This will update for every closed stream, but the last stream to close will be the correct "lastStopTime".
244+
s.lastStopTime = now
245+
s.mu.Unlock()
246+
wg.Done()
247+
}(conn)
248+
249+
// Record first and last start times.
250+
s.mu.Lock()
251+
now := time.Now()
252+
if !s.started.Load() {
253+
s.started.Store(true)
254+
// record start time as first open connection.
255+
s.firstStartTime = now
226256
}
227-
defer conn.Close()
257+
// This will update for every stream, but the last stream to update will be the correct "lastStartTime".
258+
s.lastStartTime = now
259+
s.bytesAtLastStart.Store(s.bytesTotal.Load())
260+
s.mu.Unlock()
228261

229-
// Max runtime.
230-
deadline := time.Now().Add(*flagDuration * 2)
262+
// Set absolute deadline for connections.
263+
deadline := time.Now().Add(*flagMaxDuration)
231264
conn.SetWriteDeadline(deadline)
232265
conn.SetReadDeadline(deadline)
233266

234-
// receive from text & binary messages from conn until the context expires or conn closes.
235-
var applicationBytesReceived int64
236-
var minRTT int64
237-
start := time.Now()
238267
outer:
268+
// Receive text & binary messages from conn until the context expires or conn closes.
239269
for {
240270
select {
241271
case <-ctx.Done():
@@ -256,28 +286,79 @@ outer:
256286
log.Println("error", err)
257287
return
258288
}
259-
applicationBytesReceived += size
289+
s.bytesTotal.Add(size)
260290
case websocket.TextMessage:
261291
data, err := io.ReadAll(reader)
262292
if err != nil {
263293
log.Println("error", err)
264294
return
265295
}
266-
applicationBytesReceived += int64(len(data))
296+
s.bytesTotal.Add(int64(len(data)))
267297

268298
var m WireMeasurement
269299
if err := json.Unmarshal(data, &m); err != nil {
270300
log.Println("error", err)
271301
return
272302
}
273-
formatMessage("Download server", 1, m)
274-
minRTT = m.TCPInfo["MinRTT"]
303+
if m.TCPInfo["MinRTT"] < s.minRTT.Load() || s.minRTT.Load() == 0 {
304+
// NOTE: this will be the minimum of MinRTT across all streams.
305+
s.minRTT.Store(m.TCPInfo["MinRTT"])
306+
}
307+
308+
switch {
309+
case streamCount == 1:
310+
// Use server metrics for single stream tests.
311+
formatMessage("Download server", 1, m)
312+
case streamCount > 1 && stream == 0:
313+
// Only do this for one stream.
314+
elapsed := time.Since(s.firstStartTime)
315+
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
316+
8*float64(s.bytesTotal.Load())/1e6/elapsed.Seconds(), // as mbps.
317+
float64(s.minRTT.Load())/1000.0, // as ms.
318+
elapsed.Seconds(), 0, s.bytesTotal.Load())
319+
}
275320
}
276321
}
277322
}
278-
since := time.Since(start)
279-
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
280-
8*float64(applicationBytesReceived)/1e6/since.Seconds(), // as mbps.
281-
float64(minRTT)/1000.0, // as ms.
282-
since.Seconds(), 0, applicationBytesReceived)
323+
}
324+
325+
func main() {
326+
flag.Parse()
327+
328+
ctx, cancel := context.WithTimeout(context.Background(), *flagMaxDuration)
329+
defer cancel()
330+
331+
srv, err := getDownloadServer(ctx)
332+
if err != nil {
333+
log.Fatal(err)
334+
}
335+
// Get common URL and headers.
336+
u, headers := prepareHeaders(ctx, srv)
337+
log.Printf("Connecting: %s://%s%s?...", srv.Scheme, srv.Host, srv.Path)
338+
339+
s := &sharedResults{}
340+
wg := &sync.WaitGroup{}
341+
for i := 0; i < *flagStreams; i++ {
342+
wg.Add(1)
343+
go s.download(ctx, u, headers, wg, *flagStreams, i)
344+
}
345+
wg.Wait()
346+
347+
log.Println("------")
348+
elapsedAvg := s.firstStopTime.Sub(s.firstStartTime)
349+
bytesAvg := s.bytesAtFirstStop.Load() // like msak-client, bytes during first-start to first-stop.
350+
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
351+
8*float64(bytesAvg)/1e6/elapsedAvg.Seconds(), // as mbps.
352+
float64(s.minRTT.Load())/1000.0, // as ms.
353+
elapsedAvg.Seconds(), 0, bytesAvg)
354+
355+
// TODO: we assume connections all overlap during peak periods.
356+
elapsedPeak := s.firstStopTime.Sub(s.lastStartTime)
357+
bytesPeak := s.bytesAtFirstStop.Load() - s.bytesAtLastStart.Load() // bytes during of peak period.
358+
if *flagStreams > 1 && bytesPeak > 0 && elapsedPeak > 0 {
359+
log.Printf("Download client #1 - Peak %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
360+
8*float64(bytesPeak)/1e6/elapsedPeak.Seconds(), // as mbps.
361+
float64(s.minRTT.Load())/1000.0, // as ms.
362+
elapsedPeak.Seconds(), 0, bytesPeak)
363+
}
283364
}

0 commit comments

Comments
 (0)