Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Added Queries Rate Limiter which uses Token Bucket algorithm to the Session struct.
Added RateLimiterConfig to the ClusterConfig struct. (#1731)

### Changed

Expand Down
3 changes: 3 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ type ClusterConfig struct {

// internal config for testing
disableControlConn bool

// If Session has RateLimiterConfig then queries will be limited using RateLimiter
RateLimiterConfig *RateLimiterConfig
}

type Dialer interface {
Expand Down
76 changes: 76 additions & 0 deletions rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package gocql

import (
"sync"
"time"
)

// RateLimiterConfig holds the configuration parameters for the rate limiter, which uses Token Bucket approach.
//
// Fields:
//
// - rate: Allowed requests per second
// - Burst: Maximum number of burst requests
//
// Example:
// RateLimiterConfig{
// rate: 300000,
// burst: 150,
// }
type RateLimiterConfig struct {
rate float64
burst int
}

type tokenBucket struct {
rate float64
burst int
tokens int
lastRefilled time.Time
mu sync.Mutex
}

func (tb *tokenBucket) refill() {
tb.mu.Lock()
defer tb.mu.Unlock()
now := time.Now()
tokensToAdd := int(tb.rate * now.Sub(tb.lastRefilled).Seconds())
tb.tokens = min(tb.tokens+tokensToAdd, tb.burst)
tb.lastRefilled = now
}

func (tb *tokenBucket) Allow() bool {
tb.refill()
tb.mu.Lock()
defer tb.mu.Unlock()
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

type ConfigurableRateLimiter struct {
tb tokenBucket
}

func NewConfigurableRateLimiter(rate float64, burst int) *ConfigurableRateLimiter {
tb := tokenBucket{
rate: rate,
burst: burst,
tokens: burst,
lastRefilled: time.Now(),
}
return &ConfigurableRateLimiter{tb}
}

func (rl *ConfigurableRateLimiter) Allow() bool {
return rl.tb.Allow()
}
78 changes: 78 additions & 0 deletions rate_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package gocql

import (
"fmt"
"sync"
"testing"
)

const queries = 100

const skipRateLimiterTestMsg = "Skipping rate limiter test, due to limit of simultaneously alive goroutines. Should be tested locally"

func TestRateLimiter50k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 50_000 workers")
RunRateLimiterTest(t, 50_000)
}

func TestRateLimiter100k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 100_000 workers")
RunRateLimiterTest(t, 100_000)
}

func TestRateLimiter200k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 200_000 workers")
RunRateLimiterTest(t, 200_000)
}

func RunRateLimiterTest(t *testing.T, workerCount int) {
cluster := createCluster()
cluster.RateLimiterConfig = &RateLimiterConfig{
rate: 300000,
burst: 100,
}

session := createSessionFromCluster(cluster, t)
defer session.Close()

execRelease(session.Query("drop keyspace if exists pargettest"))
execRelease(session.Query("create keyspace pargettest with replication = {'class' : 'SimpleStrategy', 'replication_factor' : 1}"))
execRelease(session.Query("drop table if exists pargettest.test"))
execRelease(session.Query("create table pargettest.test (a text, b int, primary key(a))"))
execRelease(session.Query("insert into pargettest.test (a, b) values ( 'a', 1)"))

var wg sync.WaitGroup

for i := 1; i <= workerCount; i++ {
wg.Add(1)

go func() {
defer wg.Done()
for j := 0; j < queries; j++ {
iterRelease(session.Query("select * from pargettest.test where a='a'"))
}
}()
}

wg.Wait()
}

func iterRelease(query *Query) {
_, err := query.Iter().SliceMap()
if err != nil {
println(err.Error())
panic(err)
}
query.Release()
}

func execRelease(query *Query) {
if err := query.Exec(); err != nil {
println(err.Error())
panic(err)
}
query.Release()
}
12 changes: 12 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ type Session struct {
isInitialized bool

logger StdLogger

rateLimiter *ConfigurableRateLimiter
}

var queryPool = &sync.Pool{
Expand Down Expand Up @@ -188,6 +190,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
s.frameObserver = cfg.FrameHeaderObserver
s.streamObserver = cfg.StreamObserver

if cfg.RateLimiterConfig != nil {
s.rateLimiter = NewConfigurableRateLimiter(cfg.RateLimiterConfig.rate, cfg.RateLimiterConfig.burst)
}

//Check the TLS Config before trying to connect to anything external
connCfg, err := connConfig(&s.cfg)
if err != nil {
Expand Down Expand Up @@ -452,6 +458,12 @@ func (s *Session) SetTrace(trace Tracer) {
// value before the query is executed. Query is automatically prepared
// if it has not previously been executed.
func (s *Session) Query(stmt string, values ...interface{}) *Query {
if s.rateLimiter != nil {
for !s.rateLimiter.Allow() {
time.Sleep(time.Millisecond * 50)
}
}

qry := queryPool.Get().(*Query)
qry.session = s
qry.stmt = stmt
Expand Down