Skip to content

Commit b2e0f1d

Browse files
authored
Add "gateway benchmark-stream"(PRIME-655) (#1138)
* Add gateway benchmark-stream endpoint to let us benchmark real completion endpoints on CG and SG instances * Add `--use-special-header` and optional request CSV output to the "benchmark gateway" command
1 parent 5812ab7 commit b2e0f1d

File tree

3 files changed

+407
-41
lines changed

3 files changed

+407
-41
lines changed

cmd/src/gateway.go

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Usage:
1717
The commands are:
1818
1919
benchmark runs benchmarks against Cody Gateway
20+
benchmark-stream runs benchmarks against Cody Gateway code completion streaming endpoints
2021
2122
Use "src gateway [command] -h" for more information about a command.
2223

cmd/src/gateway_benchmark.go

+123-41
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ type Stats struct {
2626
Total time.Duration
2727
}
2828

29+
type requestResult struct {
30+
duration time.Duration
31+
traceID string // X-Trace header value
32+
}
33+
2934
func init() {
3035
usage := `
31-
'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints.
36+
'src gateway benchmark' runs performance benchmarks against Cody Gateway and Sourcegraph test endpoints.
3237
3338
Usage:
3439
@@ -39,17 +44,20 @@ Examples:
3944
$ src gateway benchmark --sgp <token>
4045
$ src gateway benchmark --requests 50 --sgp <token>
4146
$ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp <token>
42-
$ src gateway benchmark --requests 50 --csv results.csv --sgp <token>
47+
$ src gateway benchmark --requests 50 --csv results.csv --request-csv requests.csv --sgp <token>
48+
$ src gateway benchmark --gateway https://cody-gateway.sourcegraph.com --sourcegraph https://sourcegraph.com --sgp <token> --use-special-header
4349
`
4450

4551
flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError)
4652

4753
var (
48-
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
49-
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
50-
gatewayEndpoint = flagSet.String("gateway", "https://cody-gateway.sourcegraph.com", "Cody Gateway endpoint")
51-
sgEndpoint = flagSet.String("sourcegraph", "https://sourcegraph.com", "Sourcegraph endpoint")
52-
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
54+
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
55+
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
56+
requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)")
57+
gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint")
58+
sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint")
59+
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
60+
useSpecialHeader = flagSet.Bool("use-special-header", false, "Use special header to test the gateway")
5361
)
5462

5563
handler := func(args []string) error {
@@ -61,15 +69,23 @@ Examples:
6169
return cmderrors.Usage("additional arguments not allowed")
6270
}
6371

72+
if *useSpecialHeader {
73+
fmt.Println("Using special header 'cody-core-gc-test'")
74+
}
75+
6476
var (
6577
httpClient = &http.Client{}
6678
endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s
6779
)
6880
if *gatewayEndpoint != "" {
6981
fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint)
82+
headers := http.Header{
83+
"X-Sourcegraph-Should-Trace": []string{"true"},
84+
}
7085
endpoints["ws(s): gateway"] = &webSocketClient{
71-
conn: nil,
72-
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
86+
conn: nil,
87+
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
88+
reqHeaders: headers,
7389
}
7490
endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http")
7591
} else {
@@ -80,12 +96,18 @@ Examples:
8096
return cmderrors.Usage("must specify --sgp <Sourcegraph personal access token>")
8197
}
8298
fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint)
99+
headers := http.Header{
100+
"Authorization": []string{"token " + *sgpToken},
101+
"X-Sourcegraph-Should-Trace": []string{"true"},
102+
}
103+
if *useSpecialHeader {
104+
headers.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
105+
}
106+
83107
endpoints["ws(s): sourcegraph"] = &webSocketClient{
84-
conn: nil,
85-
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
86-
headers: http.Header{
87-
"Authorization": []string{"token " + *sgpToken},
88-
},
108+
conn: nil,
109+
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
110+
reqHeaders: headers,
89111
}
90112
endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http")
91113
endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket")
@@ -95,29 +117,33 @@ Examples:
95117

96118
fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount)
97119

98-
var results []endpointResult
120+
var eResults []endpointResult
121+
rResults := map[string][]requestResult{}
99122
for name, clientOrURL := range endpoints {
100123
durations := make([]time.Duration, 0, *requestCount)
124+
rResults[name] = make([]requestResult, 0, *requestCount)
101125
fmt.Printf("\nTesting %s...", name)
102126

103127
for i := 0; i < *requestCount; i++ {
104128
if ws, ok := clientOrURL.(*webSocketClient); ok {
105-
duration := benchmarkEndpointWebSocket(ws)
106-
if duration > 0 {
107-
durations = append(durations, duration)
129+
result := benchmarkEndpointWebSocket(ws)
130+
if result.duration > 0 {
131+
durations = append(durations, result.duration)
132+
rResults[name] = append(rResults[name], result)
108133
}
109134
} else if url, ok := clientOrURL.(string); ok {
110-
duration := benchmarkEndpointHTTP(httpClient, url, *sgpToken)
111-
if duration > 0 {
112-
durations = append(durations, duration)
135+
result := benchmarkEndpointHTTP(httpClient, url, *sgpToken, *useSpecialHeader)
136+
if result.duration > 0 {
137+
durations = append(durations, result.duration)
138+
rResults[name] = append(rResults[name], result)
113139
}
114140
}
115141
}
116142
fmt.Println()
117143

118144
stats := calculateStats(durations)
119145

120-
results = append(results, endpointResult{
146+
eResults = append(eResults, endpointResult{
121147
name: name,
122148
avg: stats.Avg,
123149
median: stats.Median,
@@ -130,14 +156,20 @@ Examples:
130156
})
131157
}
132158

133-
printResults(results, requestCount)
159+
printResults(eResults, requestCount)
134160

135161
if *csvOutput != "" {
136-
if err := writeResultsToCSV(*csvOutput, results, requestCount); err != nil {
162+
if err := writeResultsToCSV(*csvOutput, eResults, requestCount); err != nil {
137163
return fmt.Errorf("failed to export CSV: %v", err)
138164
}
139165
fmt.Printf("\nResults exported to %s\n", *csvOutput)
140166
}
167+
if *requestLevelCsvOutput != "" {
168+
if err := writeRequestResultsToCSV(*requestLevelCsvOutput, rResults); err != nil {
169+
return fmt.Errorf("failed to export request-level CSV: %v", err)
170+
}
171+
fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput)
172+
}
141173

142174
return nil
143175
}
@@ -158,9 +190,10 @@ Examples:
158190
}
159191

160192
type webSocketClient struct {
161-
conn *websocket.Conn
162-
URL string
163-
headers http.Header
193+
conn *websocket.Conn
194+
URL string
195+
reqHeaders http.Header
196+
respHeaders http.Header
164197
}
165198

166199
func (c *webSocketClient) reconnect() error {
@@ -169,11 +202,13 @@ func (c *webSocketClient) reconnect() error {
169202
}
170203
fmt.Println("Connecting to WebSocket..", c.URL)
171204
var err error
172-
c.conn, _, err = websocket.DefaultDialer.Dial(c.URL, c.headers)
205+
var resp *http.Response
206+
c.conn, resp, err = websocket.DefaultDialer.Dial(c.URL, c.reqHeaders)
173207
if err != nil {
174208
c.conn = nil // retry again later
175209
return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err)
176210
}
211+
c.respHeaders = resp.Header
177212
fmt.Println("Connected!")
178213
return nil
179214
}
@@ -190,19 +225,23 @@ type endpointResult struct {
190225
successful int
191226
}
192227

193-
func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Duration {
228+
func benchmarkEndpointHTTP(client *http.Client, url, accessToken string, useSpecialHeader bool) requestResult {
194229
start := time.Now()
195230
req, err := http.NewRequest("POST", url, strings.NewReader("ping"))
196231
if err != nil {
197232
fmt.Printf("Error creating request: %v\n", err)
198-
return 0
233+
return requestResult{}
199234
}
200235
req.Header.Set("Content-Type", "application/json")
201236
req.Header.Set("Authorization", "token "+accessToken)
237+
req.Header.Set("X-Sourcegraph-Should-Trace", "true")
238+
if useSpecialHeader {
239+
req.Header.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
240+
}
202241
resp, err := client.Do(req)
203242
if err != nil {
204243
fmt.Printf("Error calling %s: %v\n", url, err)
205-
return 0
244+
return requestResult{}
206245
}
207246
defer func() {
208247
err := resp.Body.Close()
@@ -212,27 +251,30 @@ func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Du
212251
}()
213252
if resp.StatusCode != http.StatusOK {
214253
fmt.Printf("non-200 response: %v\n", resp.Status)
215-
return 0
254+
return requestResult{}
216255
}
217256
body, err := io.ReadAll(resp.Body)
218257
if err != nil {
219258
fmt.Printf("Error reading response body: %v\n", err)
220-
return 0
259+
return requestResult{}
221260
}
222261
if string(body) != "pong" {
223262
fmt.Printf("Expected 'pong' response, got: %q\n", string(body))
224-
return 0
263+
return requestResult{}
225264
}
226265

227-
return time.Since(start)
266+
return requestResult{
267+
duration: time.Since(start),
268+
traceID: resp.Header.Get("X-Trace"),
269+
}
228270
}
229271

230-
func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
272+
func benchmarkEndpointWebSocket(client *webSocketClient) requestResult {
231273
// Perform initial websocket connection, if needed.
232274
if client.conn == nil {
233275
if err := client.reconnect(); err != nil {
234276
fmt.Printf("Error reconnecting: %v\n", err)
235-
return 0
277+
return requestResult{}
236278
}
237279
}
238280

@@ -244,7 +286,7 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
244286
if err := client.reconnect(); err != nil {
245287
fmt.Printf("Error reconnecting: %v\n", err)
246288
}
247-
return 0
289+
return requestResult{}
248290
}
249291
_, message, err := client.conn.ReadMessage()
250292

@@ -253,16 +295,19 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
253295
if err := client.reconnect(); err != nil {
254296
fmt.Printf("Error reconnecting: %v\n", err)
255297
}
256-
return 0
298+
return requestResult{}
257299
}
258300
if string(message) != "pong" {
259301
fmt.Printf("Expected 'pong' response, got: %q\n", string(message))
260302
if err := client.reconnect(); err != nil {
261303
fmt.Printf("Error reconnecting: %v\n", err)
262304
}
263-
return 0
305+
return requestResult{}
306+
}
307+
return requestResult{
308+
duration: time.Since(start),
309+
traceID: client.respHeaders.Get("Content-Type"),
264310
}
265-
return time.Since(start)
266311
}
267312

268313
func calculateStats(durations []time.Duration) Stats {
@@ -438,3 +483,40 @@ func writeResultsToCSV(filename string, results []endpointResult, requestCount *
438483

439484
return nil
440485
}
486+
487+
func writeRequestResultsToCSV(filename string, results map[string][]requestResult) error {
488+
file, err := os.Create(filename)
489+
if err != nil {
490+
return fmt.Errorf("failed to create CSV file: %v", err)
491+
}
492+
defer func() {
493+
err := file.Close()
494+
if err != nil {
495+
return
496+
}
497+
}()
498+
499+
writer := csv.NewWriter(file)
500+
defer writer.Flush()
501+
502+
// Write header
503+
header := []string{"Endpoint", "Duration (ms)", "Trace ID"}
504+
if err := writer.Write(header); err != nil {
505+
return fmt.Errorf("failed to write CSV header: %v", err)
506+
}
507+
508+
for endpoint, requestResults := range results {
509+
for _, result := range requestResults {
510+
row := []string{
511+
endpoint,
512+
fmt.Sprintf("%.2f", float64(result.duration.Microseconds())/1000),
513+
result.traceID,
514+
}
515+
if err := writer.Write(row); err != nil {
516+
return fmt.Errorf("failed to write CSV row: %v", err)
517+
}
518+
}
519+
}
520+
521+
return nil
522+
}

0 commit comments

Comments
 (0)