Skip to content
This repository was archived by the owner on Apr 28, 2020. It is now read-only.

Commit a869bb2

Browse files
author
Brian Picciano
authored
Merge pull request mediocregopher#180 from maciej/drain-pipeline
Drain pipeline on decode error
2 parents ec262f3 + 0d766da commit a869bb2

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

action.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -441,14 +441,23 @@ func (p pipeline) Run(c Conn) error {
441441
if err := c.Encode(p); err != nil {
442442
return err
443443
}
444-
for _, cmd := range p {
444+
445+
for i, cmd := range p {
445446
if err := c.Decode(cmd); err != nil {
447+
p.drain(c, len(p)-i-1)
446448
return decodeErr(cmd, err)
447449
}
448450
}
449451
return nil
450452
}
451453

454+
func (p pipeline) drain(c Conn, n int) {
455+
rcv := resp2.Any{I: nil}
456+
for i := 0; i < n; i++ {
457+
_ = c.Decode(&rcv)
458+
}
459+
}
460+
452461
func decodeErr(cmd CmdAction, err error) error {
453462
c, ok := cmd.(*cmdAction)
454463
if ok {

action_test.go

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import (
66
"fmt"
77
. "testing"
88

9-
"github.com/mediocregopher/radix/v3/resp/resp2"
109
"github.com/stretchr/testify/assert"
1110
"github.com/stretchr/testify/require"
11+
12+
"github.com/mediocregopher/radix/v3/resp/resp2"
1213
)
1314

1415
func TestCmdAction(t *T) {
@@ -224,6 +225,37 @@ func TestPipelineAction(t *T) {
224225
assert.Equal(t, ss[i], out[i])
225226
}
226227
}
228+
229+
t.Run("drain on decodeErr", func(t *T) {
230+
// Setup
231+
k1 := randStr()
232+
k2 := randStr()
233+
kvs := map[string]string{
234+
k1: randStr(),
235+
k2: randStr(),
236+
}
237+
238+
for k, v := range kvs {
239+
require.NoError(t, c.Do(Cmd(nil, "SET", k, v)))
240+
}
241+
242+
var intRcv int
243+
var strRcv string
244+
245+
pipeline := Pipeline(
246+
Cmd(&intRcv, "GET", k1),
247+
Cmd(nil, "GET", k2),
248+
)
249+
250+
err := c.Do(pipeline)
251+
require.Error(t, err)
252+
require.Contains(t, err.Error(), "failed to decode")
253+
254+
err = c.Do(Cmd(&strRcv, "GET", k1))
255+
require.NoError(t, err)
256+
257+
assert.Equal(t, kvs[k1], strRcv)
258+
})
227259
}
228260

229261
func ExamplePipeline() {

0 commit comments

Comments
 (0)