Skip to content

Commit 2a63c46

Browse files
authored
Merge pull request #256 from bojand/stream_message_provider
add stream message provider api
2 parents 4c77580 + 65733b6 commit 2a63c46

File tree

4 files changed

+198
-28
lines changed

4 files changed

+198
-28
lines changed

runner/options.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ type RunConfig struct {
108108

109109
dataFunc BinaryDataFunc
110110
dataProviderFunc DataProviderFunc
111+
dataStreamFunc StreamMessageProviderFunc
111112
mdProviderFunc MetadataProviderFunc
112113

113114
funcs template.FuncMap
@@ -1006,6 +1007,30 @@ func WithMetadataProvider(fn MetadataProviderFunc) Option {
10061007
}
10071008
}
10081009

1010+
// WithStreamMessageProvider sets custom stream message provider
1011+
// WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
1012+
// protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.FormatInt(cd.RequestNumber, 10)}
1013+
// dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
1014+
// if err != nil {
1015+
// return nil, err
1016+
// }
1017+
//
1018+
// callCounter++
1019+
//
1020+
// if callCounter == 5 {
1021+
// err = ErrLastMessage
1022+
// }
1023+
//
1024+
// return dynamicMsg, err
1025+
// }),
1026+
func WithStreamMessageProvider(fn StreamMessageProviderFunc) Option {
1027+
return func(o *RunConfig) error {
1028+
o.dataStreamFunc = fn
1029+
1030+
return nil
1031+
}
1032+
}
1033+
10091034
func createClientTransportCredentials(skipVerify bool, cacertFile, clientCertFile, clientKeyFile, cname string) (credentials.TransportCredentials, error) {
10101035
var tlsConf tls.Config
10111036

runner/requester.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error {
387387
dataProvider: b.dataProvider,
388388
metadataProvider: b.metadataProvider,
389389
streamRecv: b.config.recvMsgFunc,
390+
msgProvider: b.config.dataStreamFunc,
390391
}
391392

392393
wc++ // increment worker id

runner/run_test.go

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ func TestRunUnary(t *testing.T) {
702702
}
703703
}
704704

705-
assert.Equal(t, []string{"0", "__record_metadata__||token:secret2", "2", "__record_metadata__||token:secret4", "4"}, names)
705+
assert.Equal(t, []string{"0", "__record_metadata__||token:secret1", "2", "__record_metadata__||token:secret3", "4"}, names)
706706
})
707707
}
708708

@@ -1553,6 +1553,76 @@ func TestRunClientStreaming(t *testing.T) {
15531553
}
15541554
}
15551555
})
1556+
1557+
t.Run("with stream message provider", func(t *testing.T) {
1558+
gs.ResetCounters()
1559+
1560+
callCounter := 0
1561+
1562+
report, err := Run(
1563+
"helloworld.Greeter.SayHelloCS",
1564+
internal.TestLocalhost,
1565+
WithProtoFile("../testdata/greeter.proto", []string{}),
1566+
WithTotalRequests(1),
1567+
WithConcurrency(1),
1568+
WithTimeout(time.Duration(20*time.Second)),
1569+
WithDialTimeout(time.Duration(20*time.Second)),
1570+
WithInsecure(true),
1571+
WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
1572+
protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.Itoa(callCounter)}
1573+
dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
1574+
if err != nil {
1575+
return nil, err
1576+
}
1577+
1578+
callCounter++
1579+
1580+
if callCounter == 5 {
1581+
err = ErrLastMessage
1582+
}
1583+
1584+
return dynamicMsg, err
1585+
}),
1586+
)
1587+
1588+
assert.NoError(t, err)
1589+
1590+
assert.NotNil(t, report)
1591+
1592+
assert.NotZero(t, report.Total)
1593+
assert.Equal(t, 1, int(report.Count))
1594+
assert.NotZero(t, report.Average)
1595+
assert.NotZero(t, report.Fastest)
1596+
assert.NotZero(t, report.Slowest)
1597+
assert.NotZero(t, report.Rps)
1598+
assert.Empty(t, report.Name)
1599+
assert.NotEmpty(t, report.Date)
1600+
assert.NotEmpty(t, report.Details)
1601+
assert.NotEmpty(t, report.Options)
1602+
assert.Equal(t, true, report.Options.Insecure)
1603+
assert.NotEmpty(t, report.LatencyDistribution)
1604+
assert.Equal(t, ReasonNormalEnd, report.EndReason)
1605+
assert.Empty(t, report.ErrorDist)
1606+
1607+
assert.Equal(t, report.Average, report.Slowest)
1608+
assert.Equal(t, report.Average, report.Fastest)
1609+
assert.Equal(t, report.Slowest, report.Fastest)
1610+
1611+
count := gs.GetCount(callType)
1612+
assert.Equal(t, 1, count)
1613+
1614+
connCount := gs.GetConnectionCount()
1615+
assert.Equal(t, 1, connCount)
1616+
1617+
calls := gs.GetCalls(callType)
1618+
assert.NotNil(t, calls)
1619+
assert.Len(t, calls, 1)
1620+
msgs := calls[0]
1621+
assert.Len(t, msgs, 5)
1622+
1623+
assert.Equal(t, "g0c0: 0", msgs[0].GetName())
1624+
assert.Equal(t, "g0c0: 4", msgs[4].GetName())
1625+
})
15561626
}
15571627

15581628
func TestRunClientStreamingBinary(t *testing.T) {
@@ -2274,6 +2344,76 @@ func TestRunBidi(t *testing.T) {
22742344
msgs := calls[0]
22752345
assert.Len(t, msgs, 6)
22762346
})
2347+
2348+
t.Run("with stream message provider", func(t *testing.T) {
2349+
gs.ResetCounters()
2350+
2351+
callCounter := 0
2352+
2353+
report, err := Run(
2354+
"helloworld.Greeter.SayHelloBidi",
2355+
internal.TestLocalhost,
2356+
WithProtoFile("../testdata/greeter.proto", []string{}),
2357+
WithTotalRequests(1),
2358+
WithConcurrency(1),
2359+
WithTimeout(time.Duration(20*time.Second)),
2360+
WithDialTimeout(time.Duration(20*time.Second)),
2361+
WithInsecure(true),
2362+
WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
2363+
protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.Itoa(callCounter)}
2364+
dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
2365+
if err != nil {
2366+
return nil, err
2367+
}
2368+
2369+
callCounter++
2370+
2371+
if callCounter == 7 {
2372+
err = ErrLastMessage
2373+
}
2374+
2375+
return dynamicMsg, err
2376+
}),
2377+
)
2378+
2379+
assert.NoError(t, err)
2380+
2381+
assert.NotNil(t, report)
2382+
2383+
assert.NotZero(t, report.Total)
2384+
assert.Equal(t, 1, int(report.Count))
2385+
assert.NotZero(t, report.Average)
2386+
assert.NotZero(t, report.Fastest)
2387+
assert.NotZero(t, report.Slowest)
2388+
assert.NotZero(t, report.Rps)
2389+
assert.Empty(t, report.Name)
2390+
assert.NotEmpty(t, report.Date)
2391+
assert.NotEmpty(t, report.Details)
2392+
assert.NotEmpty(t, report.Options)
2393+
assert.NotEmpty(t, report.LatencyDistribution)
2394+
assert.Equal(t, ReasonNormalEnd, report.EndReason)
2395+
assert.Equal(t, true, report.Options.Insecure)
2396+
assert.Empty(t, report.ErrorDist)
2397+
2398+
assert.Equal(t, report.Average, report.Slowest)
2399+
assert.Equal(t, report.Average, report.Fastest)
2400+
assert.Equal(t, report.Slowest, report.Fastest)
2401+
2402+
count := gs.GetCount(callType)
2403+
assert.Equal(t, 1, count)
2404+
2405+
connCount := gs.GetConnectionCount()
2406+
assert.Equal(t, 1, connCount)
2407+
2408+
calls := gs.GetCalls(callType)
2409+
assert.NotNil(t, calls)
2410+
assert.Len(t, calls, 1)
2411+
msgs := calls[0]
2412+
assert.Len(t, msgs, 7)
2413+
2414+
assert.Equal(t, "g0c0: 0", msgs[0].GetName())
2415+
assert.Equal(t, "g0c0: 6", msgs[6].GetName())
2416+
})
22772417
}
22782418

22792419
func TestRunUnarySecure(t *testing.T) {

runner/worker.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ type Worker struct {
3737

3838
dataProvider DataProviderFunc
3939
metadataProvider MetadataProviderFunc
40-
streamRecv StreamRecvMsgInterceptFunc
40+
msgProvider StreamMessageProviderFunc
41+
42+
streamRecv StreamRecvMsgInterceptFunc
4143
}
4244

4345
func (w *Worker) runWorker() error {
@@ -80,14 +82,6 @@ func (w *Worker) makeRequest(tv TickValue) error {
8082

8183
ctd := newCallData(w.mtd, w.config.funcs, w.workerID, reqNum)
8284

83-
inputs, err := w.dataProvider(ctd)
84-
if err != nil {
85-
return err
86-
}
87-
if len(inputs) == 0 {
88-
return fmt.Errorf("no data provided for request")
89-
}
90-
9185
reqMD, err := w.metadataProvider(ctd)
9286
if err != nil {
9387
return err
@@ -112,25 +106,15 @@ func (w *Worker) makeRequest(tv TickValue) error {
112106
ctx = metadata.NewOutgoingContext(ctx, *reqMD)
113107
}
114108

115-
var callType string
116-
if w.config.hasLog {
117-
callType = "unary"
118-
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
119-
callType = "bidi"
120-
} else if w.mtd.IsServerStreaming() {
121-
callType = "server-streaming"
122-
} else if w.mtd.IsClientStreaming() {
123-
callType = "client-streaming"
124-
}
125-
126-
w.config.log.Debugw("Making request", "workerID", w.workerID,
127-
"call type", callType, "call", w.mtd.GetFullyQualifiedName(),
128-
"input", inputs, "metadata", reqMD)
109+
inputs, err := w.dataProvider(ctd)
110+
if err != nil {
111+
return err
129112
}
130113

131-
unaryInput := inputs[0]
132114
var msgProvider StreamMessageProviderFunc
133-
if w.mtd.IsClientStreaming() {
115+
if w.msgProvider != nil {
116+
msgProvider = w.msgProvider
117+
} else if w.mtd.IsClientStreaming() {
134118
if w.config.streamDynamicMessages {
135119
mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount)
136120
if err != nil {
@@ -148,15 +132,35 @@ func (w *Worker) makeRequest(tv TickValue) error {
148132
}
149133
}
150134

135+
if len(inputs) == 0 && msgProvider == nil {
136+
return fmt.Errorf("no data provided for request")
137+
}
138+
139+
var callType string
140+
if w.config.hasLog {
141+
callType = "unary"
142+
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
143+
callType = "bidi"
144+
} else if w.mtd.IsServerStreaming() {
145+
callType = "server-streaming"
146+
} else if w.mtd.IsClientStreaming() {
147+
callType = "client-streaming"
148+
}
149+
150+
w.config.log.Debugw("Making request", "workerID", w.workerID,
151+
"call type", callType, "call", w.mtd.GetFullyQualifiedName(),
152+
"input", inputs, "metadata", reqMD)
153+
}
154+
151155
// RPC errors are handled via stats handler
152156
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
153157
_ = w.makeBidiRequest(&ctx, ctd, msgProvider)
154158
} else if w.mtd.IsClientStreaming() {
155159
_ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider)
156160
} else if w.mtd.IsServerStreaming() {
157-
_ = w.makeServerStreamingRequest(&ctx, unaryInput)
161+
_ = w.makeServerStreamingRequest(&ctx, inputs[0])
158162
} else {
159-
_ = w.makeUnaryRequest(&ctx, reqMD, unaryInput)
163+
_ = w.makeUnaryRequest(&ctx, reqMD, inputs[0])
160164
}
161165

162166
return err

0 commit comments

Comments
 (0)