Skip to content

Commit 1c32e92

Browse files
committed
refactor: enhance IP blacklist handling and add panic recovery in middleware
1 parent f45e833 commit 1c32e92

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

blacklist.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,25 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
6565
}
6666

6767
func (m *Middleware) isIPBlacklisted(ip string) bool {
68-
if m.ipBlacklist.Contains(netip.MustParseAddr(ip)) {
69-
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
70-
m.IPBlacklistBlockCount++ // Increment the counter
71-
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter
72-
m.logger.Debug("IP blacklist hit", zap.String("ip", ip)) // Keep existing debug log
73-
return true // Indicate that the IP is blacklisted
68+
// Extract IP address without port
69+
cleanIP := extractIP(ip, m.logger)
70+
71+
addr, err := netip.ParseAddr(cleanIP)
72+
if err != nil {
73+
m.logger.Warn("Failed to parse IP address for blacklist check",
74+
zap.String("ip", ip),
75+
zap.String("clean_ip", cleanIP),
76+
zap.Error(err),
77+
)
78+
return false
79+
}
80+
81+
if m.ipBlacklist.Contains(addr) {
82+
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
83+
m.IPBlacklistBlockCount++ // Increment the counter
84+
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter
85+
m.logger.Debug("IP blacklist hit", zap.String("ip", cleanIP)) // Keep existing debug log
86+
return true // Indicate that the IP is blacklisted
7487
}
7588
return false // Indicate that the IP is NOT blacklisted
7689
}

caddywaf.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
8282

8383
func (m *Middleware) Provision(ctx caddy.Context) error {
8484
m.logger = ctx.Logger(m)
85-
m.ruleCache = NewRuleCache() // Initialize RuleCache
85+
m.ruleCache = NewRuleCache() // Initialize RuleCache
86+
m.Rules = make(map[int][]Rule) // Initialize Rules map to prevent nil pointer panic
8687

8788
// Set default log severity if not provided
8889
if m.LogSeverity == "" {
@@ -466,7 +467,16 @@ func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) erro
466467

467468
// Convert the map to CIDRTrie
468469
for ip := range blacklist {
469-
blacklistMap.Insert(netip.MustParsePrefix(ip), nil)
470+
// Add /32 suffix if the IP doesn't have CIDR notation
471+
if !strings.Contains(ip, "/") {
472+
ip = ip + "/32"
473+
}
474+
prefix, err := netip.ParsePrefix(ip)
475+
if err != nil {
476+
m.logger.Warn("Skipping invalid IP in blacklist", zap.String("ip", ip), zap.Error(err))
477+
continue
478+
}
479+
blacklistMap.Insert(prefix, nil)
470480
}
471481
return nil
472482
}

handler.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ type (
2222
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
2323
logID := uuid.New().String()
2424

25+
// Add panic recovery to catch and log panics
26+
defer func() {
27+
if rec := recover(); rec != nil {
28+
m.logger.Error("PANIC in ServeHTTP",
29+
zap.String("log_id", logID),
30+
zap.Any("panic", rec),
31+
zap.Stack("stack"),
32+
)
33+
// Return 500 error to client
34+
w.WriteHeader(http.StatusInternalServerError)
35+
w.Write([]byte("Internal Server Error"))
36+
}
37+
}()
38+
2539
m.logRequestStart(r, logID)
2640

2741
// Propagate log ID within the request context for logging
@@ -152,7 +166,14 @@ func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http
152166
}
153167
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID))
154168

155-
for _, rule := range m.Rules[4] {
169+
// Check if rules exist for Phase 4 before iterating
170+
rules, ok := m.Rules[4]
171+
if !ok || len(rules) == 0 {
172+
m.logger.Debug("No rules found for Phase 4")
173+
return
174+
}
175+
176+
for _, rule := range rules {
156177
if rule.regex.MatchString(body) {
157178
if m.processRuleMatch(recorder, r, &rule, body, state) {
158179
return

0 commit comments

Comments
 (0)