Skip to content

Commit 195eed4

Browse files
wckdrueian
authored andcommitted
fix: race condition in Lua script SHA-1 loading (#108)
Replace singleflight with simpler mutex double-check pattern for loading SHA-1 from Valkey. When multiple goroutines call Exec() concurrently: 1. First goroutine acquires the write lock and loads SHA-1 2. Other goroutines wait on the lock, then see SHA-1 already set This is simpler and avoids the race where a second goroutine could re-execute the singleflight after it completes. Signed-off-by: Rueian <rueiancsie@gmail.com>
1 parent b093580 commit 195eed4

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

lua.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ func newLuaScript(script string, readonly bool, noSha1, retryable bool, opts ...
8686
type Lua struct {
8787
script string
8888
sha1 string
89-
sha1Call call
9089
maxp int
9190
sha1Mu sync.RWMutex
9291
readonly bool
@@ -111,25 +110,19 @@ func (s *Lua) Exec(ctx context.Context, c Client, keys, args []string) (resp Red
111110
scriptSha1 = s.sha1
112111
s.sha1Mu.RUnlock()
113112

114-
// If not loaded yet, use singleflight to load it.
115113
if scriptSha1 == "" {
116-
err := s.sha1Call.Do(ctx, func() error {
114+
s.sha1Mu.Lock()
115+
if s.sha1 == "" { // the double check
117116
result := c.Do(ctx, c.B().ScriptLoad().Script(s.script).Build().ToRetryable())
118117
if shaStr, err := result.ToString(); err == nil {
119-
s.sha1Mu.Lock()
120118
s.sha1 = shaStr
119+
} else {
121120
s.sha1Mu.Unlock()
122-
return nil
121+
return newErrResult(result.Error())
123122
}
124-
return result.Error()
125-
})
126-
if err != nil {
127-
return newErrResult(err)
128123
}
129-
// Reload scriptSha1 after singleflight completes.
130-
s.sha1Mu.RLock()
131124
scriptSha1 = s.sha1
132-
s.sha1Mu.RUnlock()
125+
s.sha1Mu.Unlock()
133126
}
134127
} else {
135128
scriptSha1 = s.sha1

lua_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/sha1"
66
"encoding/hex"
7+
"errors"
78
"fmt"
89
"math/rand"
910
"reflect"
@@ -527,6 +528,37 @@ func TestNewLuaScriptWithLoadSha1(t *testing.T) {
527528
}
528529
}
529530

531+
func TestNewLuaScriptWithLoadSha1Error(t *testing.T) {
532+
defer ShouldNotLeak(SetupLeakDetection())
533+
body := strconv.Itoa(rand.Int())
534+
535+
k := []string{"1", "2"}
536+
a := []string{"3", "4"}
537+
538+
expectedErr := errors.New("SCRIPT LOAD failed")
539+
540+
c := &client{
541+
BFn: func() Builder {
542+
return cmds.NewBuilder(cmds.NoSlot)
543+
},
544+
DoFn: func(ctx context.Context, cmd Completed) (resp RedisResult) {
545+
commands := cmd.Commands()
546+
if reflect.DeepEqual(commands, []string{"SCRIPT", "LOAD", body}) {
547+
return newErrResult(expectedErr)
548+
}
549+
t.Fatal("unexpected command")
550+
return newResult(strmsg('+', "unexpected"), nil)
551+
},
552+
}
553+
554+
script := NewLuaScript(body, WithLoadSHA1(true))
555+
556+
result := script.Exec(context.Background(), c, k, a)
557+
if result.Error() != expectedErr {
558+
t.Fatalf("expected error %v, got %v", expectedErr, result.Error())
559+
}
560+
}
561+
530562
func TestNewLuaScriptRetryableWithLoadSha1(t *testing.T) {
531563
defer ShouldNotLeak(SetupLeakDetection())
532564
body := strconv.Itoa(rand.Int())

0 commit comments

Comments
 (0)