diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dfb22828..8100193b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Added Queries Rate Limiter which uses Token Bucket algorithm to the Session struct. + Added RateLimiterConfig to the ClusterConfig struct. (#1731) ### Changed diff --git a/cluster.go b/cluster.go index 13e62f3b0..e58775af5 100644 --- a/cluster.go +++ b/cluster.go @@ -259,6 +259,9 @@ type ClusterConfig struct { // internal config for testing disableControlConn bool + + // If Session has RateLimiterConfig then queries will be limited using RateLimiter + RateLimiterConfig *RateLimiterConfig } type Dialer interface { diff --git a/rate_limiter.go b/rate_limiter.go new file mode 100644 index 000000000..53d073f3a --- /dev/null +++ b/rate_limiter.go @@ -0,0 +1,76 @@ +package gocql + +import ( + "sync" + "time" +) + +// RateLimiterConfig holds the configuration parameters for the rate limiter, which uses Token Bucket approach. +// +// Fields: +// +// - rate: Allowed requests per second +// - Burst: Maximum number of burst requests +// +// Example: +// RateLimiterConfig{ +// rate: 300000, +// burst: 150, +// } +type RateLimiterConfig struct { + rate float64 + burst int +} + +type tokenBucket struct { + rate float64 + burst int + tokens int + lastRefilled time.Time + mu sync.Mutex +} + +func (tb *tokenBucket) refill() { + tb.mu.Lock() + defer tb.mu.Unlock() + now := time.Now() + tokensToAdd := int(tb.rate * now.Sub(tb.lastRefilled).Seconds()) + tb.tokens = min(tb.tokens+tokensToAdd, tb.burst) + tb.lastRefilled = now +} + +func (tb *tokenBucket) Allow() bool { + tb.refill() + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.tokens > 0 { + tb.tokens-- + return true + } + return false +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +type ConfigurableRateLimiter struct { + tb tokenBucket +} + +func NewConfigurableRateLimiter(rate float64, burst int) *ConfigurableRateLimiter { + tb := tokenBucket{ + rate: rate, + burst: burst, + tokens: burst, + lastRefilled: time.Now(), + } + return &ConfigurableRateLimiter{tb} +} + +func (rl *ConfigurableRateLimiter) Allow() bool { + return rl.tb.Allow() +} diff --git a/rate_limiter_test.go b/rate_limiter_test.go new file mode 100644 index 000000000..03335c27e --- /dev/null +++ b/rate_limiter_test.go @@ -0,0 +1,78 @@ +package gocql + +import ( + "fmt" + "sync" + "testing" +) + +const queries = 100 + +const skipRateLimiterTestMsg = "Skipping rate limiter test, due to limit of simultaneously alive goroutines. Should be tested locally" + +func TestRateLimiter50k(t *testing.T) { + t.Skip(skipRateLimiterTestMsg) + fmt.Println("Running rate limiter test with 50_000 workers") + RunRateLimiterTest(t, 50_000) +} + +func TestRateLimiter100k(t *testing.T) { + t.Skip(skipRateLimiterTestMsg) + fmt.Println("Running rate limiter test with 100_000 workers") + RunRateLimiterTest(t, 100_000) +} + +func TestRateLimiter200k(t *testing.T) { + t.Skip(skipRateLimiterTestMsg) + fmt.Println("Running rate limiter test with 200_000 workers") + RunRateLimiterTest(t, 200_000) +} + +func RunRateLimiterTest(t *testing.T, workerCount int) { + cluster := createCluster() + cluster.RateLimiterConfig = &RateLimiterConfig{ + rate: 300000, + burst: 100, + } + + session := createSessionFromCluster(cluster, t) + defer session.Close() + + execRelease(session.Query("drop keyspace if exists pargettest")) + execRelease(session.Query("create keyspace pargettest with replication = {'class' : 'SimpleStrategy', 'replication_factor' : 1}")) + execRelease(session.Query("drop table if exists pargettest.test")) + execRelease(session.Query("create table pargettest.test (a text, b int, primary key(a))")) + execRelease(session.Query("insert into pargettest.test (a, b) values ( 'a', 1)")) + + var wg sync.WaitGroup + + for i := 1; i <= workerCount; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + for j := 0; j < queries; j++ { + iterRelease(session.Query("select * from pargettest.test where a='a'")) + } + }() + } + + wg.Wait() +} + +func iterRelease(query *Query) { + _, err := query.Iter().SliceMap() + if err != nil { + println(err.Error()) + panic(err) + } + query.Release() +} + +func execRelease(query *Query) { + if err := query.Exec(); err != nil { + println(err.Error()) + panic(err) + } + query.Release() +} diff --git a/session.go b/session.go index a600b95f3..afbe30317 100644 --- a/session.go +++ b/session.go @@ -103,6 +103,8 @@ type Session struct { isInitialized bool logger StdLogger + + rateLimiter *ConfigurableRateLimiter } var queryPool = &sync.Pool{ @@ -188,6 +190,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.frameObserver = cfg.FrameHeaderObserver s.streamObserver = cfg.StreamObserver + if cfg.RateLimiterConfig != nil { + s.rateLimiter = NewConfigurableRateLimiter(cfg.RateLimiterConfig.rate, cfg.RateLimiterConfig.burst) + } + //Check the TLS Config before trying to connect to anything external connCfg, err := connConfig(&s.cfg) if err != nil { @@ -452,6 +458,12 @@ func (s *Session) SetTrace(trace Tracer) { // value before the query is executed. Query is automatically prepared // if it has not previously been executed. func (s *Session) Query(stmt string, values ...interface{}) *Query { + if s.rateLimiter != nil { + for !s.rateLimiter.Allow() { + time.Sleep(time.Millisecond * 50) + } + } + qry := queryPool.Get().(*Query) qry.session = s qry.stmt = stmt