Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions pkg/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ func (r *RedisClient) CheckAvailability() bool {
return pong == "PONG"
}

// UpdateRateLimit checks and updates the rate limit for a project using a Lua script
func (r *RedisClient) UpdateRateLimit(projectID string, eventsLimit int64, eventsPeriod int64) (bool, error) {
// CheckRateLimit checks and updates the rate limit for a project using a Lua script
func (r *RedisClient) CheckRateLimit(projectID string, eventsLimit int64, eventsPeriod int64) (bool, error) {
// If eventsLimit is 0, we don't need to update the rate limit
if eventsLimit == 0 {
return true, nil
Expand All @@ -191,32 +191,28 @@ func (r *RedisClient) UpdateRateLimit(projectID string, eventsLimit int64, event
local limit = tonumber(ARGV[3])
local period = tonumber(ARGV[4])

-- Read current window value
local current = redis.call('HGET', key, field)
if not current then
-- No existing record, create new window
redis.call('HSET', key, field, now .. ':1')
-- No existing record, event count is within limit
return 1
end

local timestamp, count = string.match(current, '(%d+):(%d+)')
timestamp = tonumber(timestamp)
count = tonumber(count)

-- Check if we're in a new time window
-- If we're in a new time window - event count is within limit
if now - timestamp >= period then
-- Reset for new window
redis.call('HSET', key, field, now .. ':1')
return 1
end

-- Check if incrementing would exceed limit
if count + 1 > limit then
return 0
-- Still in current window: check if event count is within limit
if count < limit then
return 1
end

-- Increment counter
redis.call('HSET', key, field, timestamp .. ':' .. (count + 1))
return 1
return 0
`

// Run the script
Expand Down
41 changes: 22 additions & 19 deletions pkg/redis/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func setupTestRedis(t *testing.T) (*RedisClient, *miniredis.Miniredis) {
return client, mr
}

func TestUpdateRateLimit(t *testing.T) {
func TestCheckRateLimit(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()

Expand Down Expand Up @@ -79,7 +79,7 @@ func TestUpdateRateLimit(t *testing.T) {
wantErr: false,
},
{
name: "should reset count after period expires",
name: "should ignore existing counter (treat as allowed) after period expires",
projectID: "project4",
eventsLimit: 5,
eventsPeriod: 60,
Expand All @@ -101,12 +101,16 @@ func TestUpdateRateLimit(t *testing.T) {
wantErr: false,
},
{
name: "should handle multiple calls up to limit",
name: "should fail if limit is already reached",
projectID: "project6",
eventsLimit: 3,
eventsPeriod: 60,
calls: 4,
wantAllowed: false, // Last call should be denied
setup: func() {
client.rdb.HSet(client.ctx, "rate_limits", "project6",
fmt.Sprintf("%d:%d", time.Now().Unix(), 3))
},
calls: 4,
wantAllowed: false,
wantErr: false,
},
}
Expand All @@ -123,7 +127,7 @@ func TestUpdateRateLimit(t *testing.T) {

// Make the specified number of calls
for i := 0; i < tt.calls; i++ {
lastAllowed, lastErr = client.UpdateRateLimit(tt.projectID, tt.eventsLimit, tt.eventsPeriod)
lastAllowed, lastErr = client.CheckRateLimit(tt.projectID, tt.eventsLimit, tt.eventsPeriod)
}

if tt.wantErr {
Expand All @@ -136,7 +140,7 @@ func TestUpdateRateLimit(t *testing.T) {
}
}

func TestUpdateRateLimitConcurrent(t *testing.T) {
func TestCheckRateLimitConcurrent(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()

Expand All @@ -148,15 +152,20 @@ func TestUpdateRateLimitConcurrent(t *testing.T) {
callsPerRoutine = 20
)

var rejectedCount int = 0
initialValue := fmt.Sprintf("%d:%d", time.Now().Unix(), eventsLimit)
if err := client.rdb.HSet(client.ctx, "rate_limits", projectID, initialValue).Err(); err != nil {
t.Fatalf("failed to seed rate limit: %v", err)
}

var rejectedCount int64 = 0

done := make(chan bool)

// Launch multiple goroutines to test concurrent access
for i := 0; i < goroutines; i++ {
go func() {
for j := 0; j < callsPerRoutine; j++ {
allowed, err := client.UpdateRateLimit(projectID, eventsLimit, eventsPeriod)
allowed, err := client.CheckRateLimit(projectID, eventsLimit, eventsPeriod)
assert.NoError(t, err)
if !allowed {
rejectedCount++
Expand All @@ -171,17 +180,11 @@ func TestUpdateRateLimitConcurrent(t *testing.T) {
<-done
}

// Verify the total number of successful updates doesn't exceed the limit
// Verify the stored value remains unchanged and all checks are denied
val, err := client.rdb.HGet(client.ctx, "rate_limits", projectID).Result()
assert.NoError(t, err)
assert.NotEmpty(t, val)

// The total count should not exceed the events limit
count := 0
_, err = fmt.Sscanf(val, "%d:%d", &count, &count)
assert.NoError(t, err)
assert.Equal(t, count, eventsLimit)
assert.Equal(t, rejectedCount, goroutines*callsPerRoutine-eventsLimit)
t.Logf("count: %d", count)
assert.Equal(t, initialValue, val)
assert.Equal(t, int64(goroutines*callsPerRoutine), rejectedCount)
t.Logf("stored value: %s", val)
t.Logf("rejectedCount: %d", rejectedCount)
}
4 changes: 2 additions & 2 deletions pkg/server/errorshandler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (handler *Handler) process(body []byte) ResponseMessage {
return ResponseMessage{402, true, "Project has exceeded the events limit"}
}

rateWithinLimit, err := handler.RedisClient.UpdateRateLimit(projectId, projectLimits.EventsLimit, projectLimits.EventsPeriod)
rateWithinLimit, err := handler.RedisClient.CheckRateLimit(projectId, projectLimits.EventsLimit, projectLimits.EventsPeriod)
if err != nil {
log.Errorf("Failed to update rate limit: %s", err)
return ResponseMessage{402, true, "Failed to update rate limit"}
Expand Down Expand Up @@ -134,7 +134,7 @@ func GetQueueCache(nonDefaultQueues []string) map[string]bool {

// getTimeSeriesKey generates a Redis TimeSeries key for project metrics
func getTimeSeriesKey(projectId, metricType, granularity string) string {
return fmt.Sprintf("ts:project-%s:%s:%s", metricType, projectId, granularity)
return fmt.Sprintf("ts:collected-project-%s:%s:%s", metricType, projectId, granularity)
}

// recordProjectMetrics records project metrics to Redis TimeSeries
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/errorshandler/handler_sentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (handler *Handler) HandleSentry(ctx *fasthttp.RequestCtx) {
return
}

rateWithinLimit, err := handler.RedisClient.UpdateRateLimit(projectId, projectLimits.EventsLimit, projectLimits.EventsPeriod)
rateWithinLimit, err := handler.RedisClient.CheckRateLimit(projectId, projectLimits.EventsLimit, projectLimits.EventsPeriod)
if err != nil {
log.Errorf("Failed to update rate limit: %s", err)
sendAnswerHTTP(ctx, ResponseMessage{402, true, "Failed to update rate limit"})
Expand Down
Loading