Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Caddyfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# MaxPending is the maximum number of pending (and failed) requests.
# Any IP block (prefix configured in prefix_cfg) with more than this number of pending requests will be blocked.
max_pending 128
# AccessPerApproval is the number of requests allowed per successful challenge. We recommend a value greater than 8 to support parallel and resumable downloads.
access_per_approval 8
# BlockTTL is the time to live for blocked IPs.
block_ttl "24h"
# PendingTTL is the time to live for pending requests when considering whether to block an IP.
pending_ttl "1h"
# ApprovalTTL is the time to live for approved requests.
approval_ttl "1h"
# MaxMemUsage is the maximum memory usage for the pending and blocklist caches.
max_mem_usage "512MiB"
# CookieName is the name of the cookie used to store signed certificate.
Expand Down
40 changes: 29 additions & 11 deletions core/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ import (
)

const (
DefaultCookieName = "cerberus-auth"
DefaultHeaderName = "X-Cerberus-Status"
DefaultDifficulty = 4
DefaultMaxPending = 128
DefaultBlockTTL = time.Hour * 24 // 1 day
DefaultPendingTTL = time.Hour // 1 hour
DefaultMaxMemUsage = 1 << 29 // 512MB
DefaultTitle = "Cerberus Challenge"
DefaultDescription = "Making sure you're not a bot!"
DefaultIPV4Prefix = 32
DefaultIPV6Prefix = 64
DefaultCookieName = "cerberus-auth"
DefaultHeaderName = "X-Cerberus-Status"
DefaultDifficulty = 4
DefaultMaxPending = 128
DefaultAccessPerApproval = 8
DefaultBlockTTL = time.Hour * 24 // 1 day
DefaultPendingTTL = time.Hour // 1 hour
DefaultApprovalTTL = time.Hour // 1 hour
DefaultMaxMemUsage = 1 << 29 // 512MB
DefaultTitle = "Cerberus Challenge"
DefaultDescription = "Making sure you're not a bot!"
DefaultIPV4Prefix = 32
DefaultIPV6Prefix = 64
)

type Config struct {
Expand All @@ -30,10 +32,14 @@ type Config struct {
// MaxPending is the maximum number of pending (and failed) requests.
// Any IP block (prefix configured in prefix_cfg) with more than this number of pending requests will be blocked.
MaxPending int32 `json:"max_pending,omitempty"`
// AccessPerApproval is the number of requests allowed per successful challenge.
AccessPerApproval int32 `json:"access_per_approval,omitempty"`
// BlockTTL is the time to live for blocked IPs.
BlockTTL time.Duration `json:"block_ttl,omitempty"`
// PendingTTL is the time to live for pending requests when considering whether to block an IP.
PendingTTL time.Duration `json:"pending_ttl,omitempty"`
// ApprovalTTL is the time to live for approved requests.
ApprovalTTL time.Duration `json:"approval_ttl,omitempty"`
// MaxMemUsage is the maximum memory usage for the pending and blocklist caches.
MaxMemUsage int64 `json:"max_mem_usage,omitempty"`
// CookieName is the name of the cookie used to store signed certificate.
Expand All @@ -55,12 +61,18 @@ func (c *Config) Provision() {
if c.MaxPending == 0 {
c.MaxPending = DefaultMaxPending
}
if c.AccessPerApproval == 0 {
c.AccessPerApproval = DefaultAccessPerApproval
}
if c.BlockTTL == time.Duration(0) {
c.BlockTTL = DefaultBlockTTL
}
if c.PendingTTL == time.Duration(0) {
c.PendingTTL = DefaultPendingTTL
}
if c.ApprovalTTL == time.Duration(0) {
c.ApprovalTTL = DefaultApprovalTTL
}
if c.MaxMemUsage == 0 {
c.MaxMemUsage = DefaultMaxMemUsage
}
Expand Down Expand Up @@ -88,12 +100,18 @@ func (c *Config) Validate() error {
if c.MaxPending < 1 {
return errors.New("max_pending must be at least 1")
}
if c.AccessPerApproval < 1 {
return errors.New("access_per_approval must be at least 1")
}
if c.BlockTTL < 0 {
return errors.New("block_ttl must be a positive duration")
}
if c.PendingTTL < 0 {
return errors.New("pending_ttl must be a positive duration")
}
if c.ApprovalTTL < 0 {
return errors.New("approval_ttl must be a positive duration")
}
if c.MaxMemUsage < 1 {
return errors.New("max_mem_usage must be at least 1")
}
Expand Down
9 changes: 6 additions & 3 deletions core/const.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package core

import "time"

const (
AppName = "cerberus"
VarName = "cerberus-block"
Version = "v0.2.1"
AppName = "cerberus"
VarName = "cerberus-block"
Version = "v0.2.1"
NonceTTL = 2 * time.Minute
)
3 changes: 2 additions & 1 deletion core/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (i *Instance) UpdateWithConfig(c Config, logger *zap.Logger) error {
} else {
// We need to reset the state.
logger.Info("existing cerberus instance with incompatible config found, resetting state")
state, pendingElems, blocklistElems, err := NewInstanceState(c.MaxMemUsage, c.MaxMemUsage, c.PendingTTL, c.BlockTTL)
state, pendingElems, blocklistElems, approvalElems, err := NewInstanceState(c.MaxMemUsage, c.MaxMemUsage, c.MaxMemUsage, c.PendingTTL, c.BlockTTL, c.ApprovalTTL)
if err != nil {
return err
}
Expand All @@ -32,6 +32,7 @@ func (i *Instance) UpdateWithConfig(c Config, logger *zap.Logger) error {
logger.Info("cerberus state initialized",
zap.Int64("pending_elems", pendingElems),
zap.Int64("blocklist_elems", blocklistElems),
zap.Int64("approval_elems", approvalElems),
)
}
return nil
Expand Down
3 changes: 2 additions & 1 deletion core/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ func GetInstance(config Config, logger *zap.Logger) (*Instance, error) {

if instance == nil {
// Initialize a new instance.
state, pendingElems, blocklistElems, err := NewInstanceState(config.MaxMemUsage, config.MaxMemUsage, config.PendingTTL, config.BlockTTL)
state, pendingElems, blocklistElems, approvalElems, err := NewInstanceState(config.MaxMemUsage, config.MaxMemUsage, config.MaxMemUsage, config.PendingTTL, config.BlockTTL, config.ApprovalTTL)
if err != nil {
return nil, err
}

logger.Info("cerberus state initialized",
zap.Int64("pending_elems", pendingElems),
zap.Int64("blocklist_elems", blocklistElems),
zap.Int64("approval_elems", approvalElems),
)
instance = &Instance{
Config: config,
Expand Down
148 changes: 122 additions & 26 deletions core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,37 @@ import (
"unsafe"

"crypto/sha256"
"encoding/binary"
"encoding/hex"

"github.com/elastic/go-freelru"
"github.com/google/uuid"
"github.com/sjtug/cerberus/internal/expiremap"
"github.com/sjtug/cerberus/internal/ipblock"
"github.com/zeebo/xxh3"
"golang.org/x/crypto/ed25519"
)

const (
FreeLRUInternalCost = 20
PendingItemCost = 4 + int64(unsafe.Sizeof(&atomic.Int32{})) + FreeLRUInternalCost
PendingItemCost = FreeLRUInternalCost + int64(unsafe.Sizeof(ipblock.IPBlock{})) + int64(unsafe.Sizeof(&atomic.Int32{})) + int64(unsafe.Sizeof(atomic.Int32{}))
BlocklistItemCost = FreeLRUInternalCost + int64(unsafe.Sizeof(ipblock.IPBlock{}))
ApprovalItemCost = FreeLRUInternalCost + int64(unsafe.Sizeof(uuid.UUID{})) + int64(unsafe.Sizeof(&atomic.Int32{})) + int64(unsafe.Sizeof(atomic.Int32{}))
)

// hashIPBlock computes a hash value for an IPBlock to be used in sharded LRU cache.
// It uses the internal uint64 data and mixes it for better distribution.
func hashIPBlock(ip ipblock.IPBlock) uint32 {
data := ip.ToUint64()
// Mix the bits using multiplication by a prime and XOR
hash := uint32(data) ^ uint32(data>>32) // #nosec G115 we explicitly want to truncate the uint64 to uint32
hash = hash * 0x9e3779b1 // Golden ratio
return hash

var buf [8]byte
binary.BigEndian.PutUint64(buf[:], data)

hash := xxh3.Hash(buf[:])
return uint32(hash) // #nosec G115 -- expected truncation
}

func hashUUID(id uuid.UUID) uint32 {
hash := xxh3.Hash(id[:])
return uint32(hash) // #nosec G115 -- expected truncation
}

type InstanceState struct {
Expand All @@ -35,51 +45,103 @@ type InstanceState struct {
fp string
pending freelru.Cache[ipblock.IPBlock, *atomic.Int32]
blocklist freelru.Cache[ipblock.IPBlock, struct{}]
approval freelru.Cache[uuid.UUID, *atomic.Int32]
usedNonce *expiremap.ExpireMap[uint32, struct{}]
stop chan struct{}
}

func NewInstanceState(pendingMaxMemUsage int64, blocklistMaxMemUsage int64, pendingTTL time.Duration, blocklistTTL time.Duration) (*InstanceState, int64, int64, error) {
stop := make(chan struct{})

pendingElems := pendingMaxMemUsage / BlocklistItemCost
pending, err := freelru.NewSharded[ipblock.IPBlock, *atomic.Int32](uint32(pendingElems), hashIPBlock) // #nosec G115 we trust config input
// initLRU creates and initializes an LRU cache with the given parameters
func initLRU[K comparable, V any](
elems uint32,
hashFunc func(K) uint32,
ttl time.Duration,
stop chan struct{},
purgeInterval time.Duration,
) (freelru.Cache[K, V], error) {
cache, err := freelru.NewSharded[K, V](elems, hashFunc)
if err != nil {
return nil, 0, 0, err
return nil, err
}
pending.SetLifetime(pendingTTL)
cache.SetLifetime(ttl)

go func() {
for {
select {
case <-stop:
return
case <-time.After(37 * time.Second):
pending.PurgeExpired()
case <-time.After(purgeInterval):
cache.PurgeExpired()
}
}
}()

blocklistElems := blocklistMaxMemUsage / BlocklistItemCost
blocklist, err := freelru.NewSharded[ipblock.IPBlock, struct{}](uint32(blocklistElems), hashIPBlock) // #nosec G115 we trust config input
if err != nil {
return nil, 0, 0, err
}
blocklist.SetLifetime(blocklistTTL)
return cache, nil
}

// initUsedNonce creates and initializes an ExpireMap for tracking used nonces
func initUsedNonce(stop chan struct{}, purgeInterval time.Duration) *expiremap.ExpireMap[uint32, struct{}] {
usedNonce := expiremap.NewExpireMap[uint32, struct{}](func(x uint32) uint32 {
return x
})
go func() {
for {
select {
case <-stop:
return
case <-time.After(61 * time.Second):
blocklist.PurgeExpired()
case <-time.After(purgeInterval):
usedNonce.PurgeExpired()
}
}
}()
return usedNonce
}

func NewInstanceState(pendingMaxMemUsage int64, blocklistMaxMemUsage int64, approvedMaxMemUsage int64, pendingTTL time.Duration, blocklistTTL time.Duration, approvalTTL time.Duration) (*InstanceState, int64, int64, int64, error) {
uuid.EnableRandPool()

stop := make(chan struct{})

pendingElems := uint32(pendingMaxMemUsage / PendingItemCost) // #nosec G115 we trust config input
pending, err := initLRU[ipblock.IPBlock, *atomic.Int32](
pendingElems,
hashIPBlock,
pendingTTL,
stop,
37*time.Second,
)
if err != nil {
return nil, 0, 0, 0, err
}

blocklistElems := uint32(blocklistMaxMemUsage / BlocklistItemCost) // #nosec G115 we trust config input
blocklist, err := initLRU[ipblock.IPBlock, struct{}](
blocklistElems,
hashIPBlock,
blocklistTTL,
stop,
61*time.Second,
)
if err != nil {
return nil, 0, 0, 0, err
}

approvalElems := uint32(approvedMaxMemUsage / ApprovalItemCost) // #nosec G115 we trust config input
approval, err := initLRU[uuid.UUID, *atomic.Int32](
approvalElems,
hashUUID,
approvalTTL,
stop,
43*time.Second,
)
if err != nil {
return nil, 0, 0, 0, err
}

usedNonce := initUsedNonce(stop, 41*time.Second)

pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
return nil, 0, 0, err
return nil, 0, 0, 0, err
}

fp := sha256.Sum256(priv.Seed())
Expand All @@ -90,8 +152,10 @@ func NewInstanceState(pendingMaxMemUsage int64, blocklistMaxMemUsage int64, pend
fp: hex.EncodeToString(fp[:]),
pending: pending,
blocklist: blocklist,
approval: approval,
usedNonce: usedNonce,
stop: stop,
}, pendingElems, blocklistElems, nil
}, int64(pendingElems), int64(blocklistElems), int64(approvalElems), nil
}

func (s *InstanceState) GetPublicKey() ed25519.PublicKey {
Expand Down Expand Up @@ -124,6 +188,7 @@ func (s *InstanceState) DecPending(ip ipblock.IPBlock) int32 {
count := counter.Add(-1)
if count <= 0 {
s.pending.Remove(ip)
return 0
}
return count
}
Expand All @@ -144,6 +209,37 @@ func (s *InstanceState) ContainsBlocklist(ip ipblock.IPBlock) bool {
return ok
}

// IssueApproval issues a new approval ID and returns it
func (s *InstanceState) IssueApproval(n int32) uuid.UUID {
id := uuid.New()

var counter atomic.Int32
counter.Store(n)

s.approval.Add(id, &counter)
return id
}

// DecApproval decrements the counter of the approval ID and returns whether the ID is still valid
func (s *InstanceState) DecApproval(id uuid.UUID) bool {
counter, ok := s.approval.Get(id)
if ok {
count := counter.Add(-1)
if count < 0 {
s.approval.Remove(id)
return false
}
return true
}
return false
}

// InsertUsedNonce inserts a nonce into the usedNonce map.
// Returns true if the nonce was inserted, false if it was already present.
func (s *InstanceState) InsertUsedNonce(nonce uint32) bool {
return s.usedNonce.SetIfAbsent(nonce, struct{}{}, NonceTTL)
}

func (s *InstanceState) Close() {
close(s.stop)
}
4 changes: 3 additions & 1 deletion core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
)

func newTestState(t *testing.T) *InstanceState {
state, _, _, err := NewInstanceState(
state, _, _, _, err := NewInstanceState(
1<<20, // 1MB for pending
1<<20, // 1MB for blocklist
1<<20, // 1MB for approved
time.Hour, // 1 hour TTL for pending
time.Hour, // 1 hour TTL for blocklist
time.Hour, // 1 hour TTL for approved
)
if err != nil {
t.Fatalf("failed to create instance state: %v", err)
Expand Down
Loading
Loading