Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 128 additions & 25 deletions internal/remediation/root.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,140 @@
package remediation

// The order matters since we use slices.Max to get the max value
import (
"sync"
)

// Default weights for built-in remediations
// Allow=0, Unknown=1, then expand others to allow custom remediations to slot in between
const (
WeightAllow = 0
WeightUnknown = 1
WeightCaptcha = 10
WeightBan = 20
)

// Remediation represents a remediation type as a string
type Remediation string
Comment on lines +16 to +17
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description claims "Automatic deduplication: Shared *string pointers deduplicate map keys" and "Reduced allocations: Eliminates duplicate string headers in maps", but this implementation uses string values (not pointers) as the Remediation type. The claimed benefits about pointer-based deduplication are not realized by this design, making the documentation misleading.

Copilot uses AI. Check for mistakes.

// Built-in remediation constants
const (
Allow Remediation = iota // Allow remediation
Unknown // Unknown remediation (Unknown is used to have a value for remediation we don't support EG "MFA")
Captcha // Captcha remediation
Ban // Ban remediation
Allow Remediation = "allow"
Unknown Remediation = "unknown"
Captcha Remediation = "captcha"
Ban Remediation = "ban"
)

type Remediation uint8 // Remediation type is smallest uint to save space
// registry manages remediation weights
type registry struct {
mu sync.RWMutex
weights map[string]int // Maps remediation name to its weight
}

var globalRegistry = &registry{
Comment on lines +27 to +33
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type name "registry" is too generic and unexported. Consider using a more descriptive name like "weightRegistry" to clarify its purpose and avoid potential naming conflicts.

Suggested change
// registry manages remediation weights
type registry struct {
mu sync.RWMutex
weights map[string]int // Maps remediation name to its weight
}
var globalRegistry = &registry{
// weightRegistry manages remediation weights
type weightRegistry struct {
mu sync.RWMutex
weights map[string]int // Maps remediation name to its weight
}
var globalRegistry = &weightRegistry{

Copilot uses AI. Check for mistakes.
weights: make(map[string]int),
}

//nolint:gochecknoinits // init() is required to initialize default weights
func init() {
// Initialize built-in remediations with default weights
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

globalRegistry.weights["allow"] = WeightAllow
globalRegistry.weights["unknown"] = WeightUnknown
globalRegistry.weights["captcha"] = WeightCaptcha
globalRegistry.weights["ban"] = WeightBan
}

// SetWeight sets a custom weight for a remediation (for configuration)
func SetWeight(name string, weight int) {
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

globalRegistry.weights[name] = weight
}
Comment on lines +49 to +55
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SetWeight function does not normalize the remediation name to lowercase. This creates inconsistency with the built-in remediations (initialized as lowercase in init) and can cause weight lookups to fail if the case doesn't match. The name parameter should be normalized using strings.ToLower.

Copilot uses AI. Check for mistakes.

// GetWeight returns the weight for a remediation name
func GetWeight(name string) int {
globalRegistry.mu.RLock()
defer globalRegistry.mu.RUnlock()

if weight, exists := globalRegistry.weights[name]; exists {
return weight
}
// Default to Unknown weight for unknown remediations
return WeightUnknown
}
Comment on lines +57 to +67
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GetWeight function uses RLock/RUnlock on every weight lookup, which can be a significant performance overhead for hot path operations. The PR description claims "Performance: Pointer-based lookups faster than string comparisons", but this implementation actually introduces mutex overhead that wasn't present in the integer-based comparison system. Consider using sync.Map or caching weights to reduce lock contention.

Copilot uses AI. Check for mistakes.

// LoadWeights loads weights for multiple remediations at once (for startup initialization)
func LoadWeights(weights map[string]int) {
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

for name, weight := range weights {
globalRegistry.weights[name] = weight
}
}
Comment on lines +69 to +77
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoadWeights function does not normalize remediation names to lowercase. This creates inconsistency with the built-in remediations (initialized as lowercase in init) and can cause weight lookups to fail if the case doesn't match. The name keys should be normalized using strings.ToLower.

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +77
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoadWeights function does not validate inputs. Missing validation includes: (1) No check for negative weights, which violates the documented requirement that weights must be >= 0; (2) No normalization of remediation names to lowercase, creating case-sensitivity issues; (3) No protection against extremely large weights that could cause integer overflow in Compare. These issues could lead to unexpected behavior or security problems from malicious configuration.

Copilot uses AI. Check for mistakes.

// New creates a new Remediation from a string.
func New(name string) Remediation {
return Remediation(name)
}
Comment on lines +79 to +82
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The New function does not normalize the input string (e.g., to lowercase). This creates a critical inconsistency where "Ban", "ban", and "BAN" are treated as different remediations with potentially different weights. Since weight lookups are case-sensitive, this could lead to incorrect priority comparisons and unexpected behavior. The function should normalize the input, similar to how scope is normalized with strings.ToLower in pkg/dataset/root.go line 65.

Copilot uses AI. Check for mistakes.

// String returns the remediation name
func (r Remediation) String() string {
switch r {
case Ban:
return "ban"
case Captcha:
return "captcha"
case Unknown:
return "unknown"
default:
return "allow"
if r == "" {
return "allow" // Default fallback
}
return string(r)
}

// Compare returns:
// - negative if a < b
// - zero if a == b
// - positive if a > b
Comment on lines +93 to +95
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Compare function documentation is misleading. It states "zero if a == b" which could be interpreted as comparing remediation equality, but the function actually compares weights. Two different remediations with the same weight will return zero. The documentation should clarify this, for example: "zero if a and b have the same weight".

Suggested change
// - negative if a < b
// - zero if a == b
// - positive if a > b
// - negative if a has a lower weight than b
// - zero if a and b have the same weight
// - positive if a has a higher weight than b

Copilot uses AI. Check for mistakes.
func Compare(a, b Remediation) int {
weightA := GetWeight(a.String())
weightB := GetWeight(b.String())
return weightA - weightB
}

// IsHigher returns true if a has a higher weight than b
func IsHigher(a, b Remediation) bool {
return Compare(a, b) > 0
}

// IsLower returns true if a has a lower weight than b
func IsLower(a, b Remediation) bool {
return Compare(a, b) < 0
}

// IsEqual returns true if a represents the same remediation as b.
// This compares the remediation names (strings).
func IsEqual(a, b Remediation) bool {
return a == b
}

// HasSameWeight returns true if a has the same weight as b.
// This is useful for checking if two different remediations have the same priority.
// Note: Two remediations with the same weight will be compared by name (alphabetical)
// as a tie-breaker when determining priority.
func HasSameWeight(a, b Remediation) bool {
return Compare(a, b) == 0
}

// IsWeighted returns true if r is not Allow (has weight > Allow)
// This is useful for checking if a remediation should be applied
func IsWeighted(r Remediation) bool {
return GetWeight(r.String()) > WeightAllow
}

// FromString creates a Remediation from a string (alias for New for backward compatibility)
func FromString(s string) Remediation {
switch s {
case "ban":
return Ban
case "captcha":
return Captcha
case "allow":
return Allow
default:
return Unknown
}
return New(s)
}

// IsZero returns true if the remediation is zero-valued
func (r Remediation) IsZero() bool {
return r == ""
}
32 changes: 32 additions & 0 deletions pkg/cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gopkg.in/yaml.v2"

"github.com/crowdsecurity/crowdsec-spoa/internal/geo"
"github.com/crowdsecurity/crowdsec-spoa/internal/remediation"
"github.com/crowdsecurity/crowdsec-spoa/pkg/host"
cslogging "github.com/crowdsecurity/crowdsec-spoa/pkg/logging"
"github.com/crowdsecurity/go-cs-lib/csyaml"
Expand Down Expand Up @@ -36,6 +37,32 @@ type BouncerConfig struct {
ListenUnix string `yaml:"listen_unix"`
PrometheusConfig PrometheusConfig `yaml:"prometheus"`
PprofConfig PprofConfig `yaml:"pprof"`
// RemediationWeights allows users to configure custom weights for remediations.
//
// Format:
// remediation_weights:
// <remediation_name>: <weight>
//
// Example:
// remediation_weights:
// mfa: 15 # slots between captcha (10) and ban (20)
//
// Valid weight range: integer values >= 0. Lower values are less severe; higher values are more severe.
// Recommended: Use values between 0 and 100.
//
// Built-in defaults:
// allow=0, unknown=1, captcha=10, ban=20
//
// Custom weights override or supplement built-in remediations. If a custom remediation is defined,
// its weight will be used for ordering and severity. Custom remediations can slot between built-in
// ones by choosing an appropriate weight value.
//
// Tie-breaking: If two remediations have the same weight, alphabetical order of the remediation
// name is used as a deterministic tie-breaker when determining priority.
//
// Note: Custom weights for built-in remediations (allow, unknown, captcha, ban) must be set
// before package initialization. After init(), package-level constants already have cached weights.
RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"`
}

// MergedConfig() returns the byte content of the patched configuration file (with .yaml.local).
Expand Down Expand Up @@ -67,6 +94,11 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) {
return nil, fmt.Errorf("failed to setup logging: %w", err)
}

// Load custom remediation weights if configured (loads all weights at once on startup)
if config.RemediationWeights != nil {
remediation.LoadWeights(config.RemediationWeights)
}
Comment on lines +97 to +100
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading remediation weights during config initialization creates a race condition. If the application is processing requests in goroutines while LoadWeights is being called, concurrent reads via GetWeight (which uses RLock) and writes via LoadWeights (which uses Lock) to the global registry could cause data races or inconsistent state. The weights should be loaded before any request processing starts, or the code should document that configuration must be loaded before starting the server.

Copilot uses AI. Check for mistakes.
Comment on lines +97 to +100
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RemediationWeights configuration is loaded without validation. Malicious or misconfigured values (e.g., negative weights, extremely large weights that could cause integer overflow in comparisons) are not checked before being passed to remediation.LoadWeights. This could lead to unexpected behavior or security issues.

Copilot uses AI. Check for mistakes.

if err := config.Validate(); err != nil {
return nil, err
}
Expand Down
29 changes: 15 additions & 14 deletions pkg/dataset/bart_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ import (
type BartAddOp struct {
Prefix netip.Prefix
Origin string
R remediation.Remediation
R string // Remediation name as string
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operation struct now stores remediation as a string instead of a Remediation type, then converts it back using FromString on every operation (lines 104, 149, 157, 213, 232). This introduces unnecessary string-to-Remediation conversions. Consider storing Remediation type directly in the operation struct to avoid repeated conversions.

Copilot uses AI. Check for mistakes.
IPType string
Scope string
}

// BartRemoveOp represents a single prefix removal operation for batch processing
type BartRemoveOp struct {
Prefix netip.Prefix
R remediation.Remediation
R string // Remediation name as string
Origin string
IPType string
Scope string
Expand Down Expand Up @@ -91,7 +91,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) {
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("initial load: collecting prefix operations")
}

Expand All @@ -101,7 +101,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) {
data = RemediationMap{}
}
// Add the remediation (this handles merging if prefix already seen)
data.Add(valueLog, op.R, op.Origin)
data.Add(valueLog, remediation.FromString(op.R), op.Origin)
prefixMap[prefix] = data
}

Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("adding to bart trie")
}

Expand All @@ -146,15 +146,15 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [
valueLog.Trace("exact prefix exists, merging remediations")
}
// bart already cloned via our Cloner interface, modify directly
existingData.Add(valueLog, op.R, op.Origin)
existingData.Add(valueLog, remediation.FromString(op.R), op.Origin)
return existingData, false // false = don't delete
}
if valueLog != nil {
valueLog.Trace("creating new entry")
}
// Create new data
newData := make(RemediationMap)
newData.Add(valueLog, op.R, op.Origin)
newData.Add(valueLog, remediation.FromString(op.R), op.Origin)
return newData, false // false = don't delete
})
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("removing from bart trie")
}

Expand All @@ -210,11 +210,12 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {

// Check if the remediation exists with the matching origin before removing
// This prevents removing decisions when the origin has been overwritten (e.g., by CAPI)
if !existingData.HasRemediationWithOrigin(op.R, op.Origin) {
if !existingData.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) {
// Origin doesn't match - this decision was likely overwritten by another origin
// Don't remove it, as it's not the decision we're trying to delete
if valueLog != nil {
storedOrigin, exists := existingData[op.R]
r := remediation.FromString(op.R)
storedOrigin, exists := existingData[r]
if exists {
valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin)
} else {
Expand All @@ -228,7 +229,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {
// bart already cloned via our Cloner interface, modify directly
// Remove returns an error if remediation doesn't exist (duplicate delete)
// We already checked origin above, so this should succeed
err := existingData.Remove(valueLog, op.R)
err := existingData.Remove(valueLog, remediation.FromString(op.R))
if errors.Is(err, ErrRemediationNotFound) {
// This shouldn't happen since we checked above, but handle it gracefully
if valueLog != nil {
Expand Down Expand Up @@ -288,11 +289,11 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string)
return remediation.Allow, ""
}

remediationResult, origin := data.GetRemediationAndOrigin()
r, origin := data.GetRemediationAndOrigin()
if valueLog != nil {
valueLog.Tracef("bart result: %s (data: %+v)", remediationResult.String(), data)
valueLog.Tracef("bart result: %s (data: %+v)", r.String(), data)
}
return remediationResult, origin
return r, origin
}

// HasRemediation checks if an exact prefix has a specific remediation with a specific origin.
Expand Down
Loading