Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions pkg/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type RedisClient struct {
mx sync.Mutex
mx sync.RWMutex
rdb *redis.Client
blockedIDsSetName string
allIPsMapName string
Expand Down Expand Up @@ -92,9 +92,6 @@ func (r *RedisClient) LoadBlacklist() ([]string, []string, error) {

// loads list with provided name from Redis.
func (r *RedisClient) load() error {
r.mx.Lock()
defer r.mx.Unlock()

exists, existsErr := r.rdb.Exists(r.ctx, r.blockedIDsSetName).Result()
if existsErr != nil {
log.Errorf("Failed to check existence of blocked IDs set %q: %s", r.blockedIDsSetName, existsErr)
Expand All @@ -107,13 +104,17 @@ func (r *RedisClient) load() error {
log.Errorf("Failed to SCARD blocked IDs set %q: %s", r.blockedIDsSetName, cardErr)
}

keys, _, err := r.rdb.SScan(r.ctx, r.blockedIDsSetName, 0, "", 0).Result()
keys, err := r.rdb.SMembers(r.ctx, r.blockedIDsSetName).Result()
if err != nil {
log.Errorf("Failed to SScan blocked IDs set %q: %s", r.blockedIDsSetName, err)
log.Errorf("Failed to SMembers blocked IDs set %q: %s", r.blockedIDsSetName, err)
return err
}

r.mx.Lock()
prevCount := len(r.blockedIDs)
r.blockedIDs = keys
r.mx.Unlock()

if len(keys) != prevCount {
if cardErr == nil {
log.Infof("Banned projects list updated %d -> %d (key=%q, scard=%d)", prevCount, len(keys), r.blockedIDsSetName, cardinality)
Expand All @@ -122,7 +123,7 @@ func (r *RedisClient) load() error {
}
}
if cardErr == nil && int64(len(keys)) != cardinality {
log.Warnf("SScan returned %d entries but SCARD reports %d for key %q — set may be larger than one SScan batch", len(keys), cardinality, r.blockedIDsSetName)
log.Warnf("SMembers returned %d entries but SCARD reports %d for key %q (concurrent modification?)", len(keys), cardinality, r.blockedIDsSetName)
}
if len(keys) == 0 && prevCount > 0 {
log.Warnf("Blocked projects list is now empty (was %d). Key %q may have been cleared", prevCount, r.blockedIDsSetName)
Expand All @@ -139,15 +140,11 @@ func (r *RedisClient) load() error {
log.Debugf("Loaded blocked project IDs from %q (count=%d, sample=%v)", r.blockedIDsSetName, len(keys), sample)
}

r.blockedIDs = keys
return nil
}

// updateBlacklist loads IPs blacklist and resets current period map.
func (r *RedisClient) updateBlacklist() ([]string, []string, error) {
r.mx.Lock()
defer r.mx.Unlock()

ipAddrs, err := r.rdb.HKeys(r.ctx, r.currentPeriodMapName).Result()
if err != nil {
return nil, nil, err
Expand All @@ -157,7 +154,10 @@ func (r *RedisClient) updateBlacklist() ([]string, []string, error) {
if err != nil {
return nil, nil, err
}

r.mx.Lock()
r.blacklistIPs = ips
r.mx.Unlock()

if len(ipAddrs) > 0 {
requests, err := r.rdb.HVals(r.ctx, r.currentPeriodMapName).Result()
Expand All @@ -178,6 +178,8 @@ func (r *RedisClient) updateBlacklist() ([]string, []string, error) {

// IsBlocked checks if the provided ID is blocked.
func (r *RedisClient) IsBlocked(val string) bool {
r.mx.RLock()
defer r.mx.RUnlock()
for _, id := range r.blockedIDs {
if id == val {
log.Debugf("IsBlocked: project %q matched in cache (size=%d)", val, len(r.blockedIDs))
Expand All @@ -200,6 +202,8 @@ func (r *RedisClient) IncrementIP(ip string) error {

// CheckBlacklist checks if the provided IP is in blacklist.
func (r *RedisClient) CheckBlacklist(ip string) bool {
r.mx.RLock()
defer r.mx.RUnlock()
if len(r.blacklistIPs) == 0 {
return false
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/redis/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,30 @@ func TestUpdateRateLimit(t *testing.T) {
}
}

// Regression: load() must return all entries from large blocked ID sets.
func TestLoadBlockedIDsLargeSet(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()

client.blockedIDsSetName = "DisabledProjectsSet"

const total = 100
expected := make([]string, 0, total)
for i := 0; i < total; i++ {
id := fmt.Sprintf("%024x", i)
client.rdb.SAdd(client.ctx, client.blockedIDsSetName, id)
expected = append(expected, id)
}

err := client.load()
assert.NoError(t, err)

for _, id := range expected {
assert.True(t, client.IsBlocked(id), "expected %q to be reported as blocked", id)
}
assert.False(t, client.IsBlocked("not-in-set"))
}

func TestUpdateRateLimitConcurrent(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()
Expand Down
Loading