Skip to content

Commit b3d3d56

Browse files
committed
Fix: Address security alerts and bump version to v0.0.9
1 parent a179255 commit b3d3d56

File tree

7 files changed

+52
-37
lines changed

7 files changed

+52
-37
lines changed

caddywaf.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ var (
4848
)
4949

5050
// Add or update the version constant as needed
51-
const wafVersion = "v0.0.8" // update this value to the new release version when tagging
51+
const wafVersion = "v0.0.9" // update this value to the new release version when tagging
5252

5353
// ==================== Initialization and Setup ====================
5454

handler.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http
182182

183183
for _, rule := range rules {
184184
if rule.regex.MatchString(body) {
185-
if m.processRuleMatch(recorder, r, &rule, body, state) {
185+
if m.processRuleMatch(recorder, r, &rule, "RESPONSE_BODY", body, state) { // Pass RESPONSE_BODY as target
186186
return
187187
}
188188
}
@@ -453,29 +453,31 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
453453
continue
454454
}
455455

456+
redactedValue := m.requestValueExtractor.RedactValueIfSensitive(target, value)
457+
456458
m.logger.Debug("Extracted value",
457459
zap.String("rule_id", rule.ID),
458460
zap.String("target", target),
459-
zap.String("value", value),
461+
zap.String("value", redactedValue),
460462
)
461463

462464
if rule.regex.MatchString(value) {
463465
m.logger.Debug("Rule matched",
464466
zap.String("rule_id", rule.ID),
465467
zap.String("target", target),
466-
zap.String("value", value),
468+
zap.String("value", redactedValue),
467469
)
468470

469471
// FIXED: Correctly interpret processRuleMatch return value
470472
var shouldContinue bool
471473
if phase == 3 || phase == 4 {
472474
if recorder, ok := w.(*responseRecorder); ok {
473-
shouldContinue = m.processRuleMatch(recorder, r, &rule, value, state)
475+
shouldContinue = m.processRuleMatch(recorder, r, &rule, target, value, state)
474476
} else {
475-
shouldContinue = m.processRuleMatch(w, r, &rule, value, state)
477+
shouldContinue = m.processRuleMatch(w, r, &rule, target, value, state)
476478
}
477479
} else {
478-
shouldContinue = m.processRuleMatch(w, r, &rule, value, state)
480+
shouldContinue = m.processRuleMatch(w, r, &rule, target, value, state)
479481
}
480482

481483
// If processRuleMatch returned false or state is now blocked, stop processing
@@ -496,7 +498,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
496498
m.logger.Debug("Rule did not match",
497499
zap.String("rule_id", rule.ID),
498500
zap.String("target", target),
499-
zap.String("value", value),
501+
zap.String("value", redactedValue),
500502
)
501503
}
502504
}

handler_test.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"mime/multipart"
7+
"os"
78
"net/http"
89
"net/http/httptest"
910
"net/netip"
@@ -26,8 +27,9 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
2627
dnsBlacklist: map[string]struct{}{
2728
"malicious.domain": {},
2829
},
29-
ipBlacklist: iptrie.NewTrie(),
30-
CustomResponses: customResponse,
30+
ipBlacklist: iptrie.NewTrie(),
31+
CustomResponses: customResponse,
32+
requestValueExtractor: NewRequestValueExtractor(logger, false),
3133
}
3234

3335
w := httptest.NewRecorder()
@@ -60,6 +62,9 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
6062
}
6163

6264
func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
65+
if _, err := os.Stat(geoIPdata); os.IsNotExist(err) {
66+
t.Skip("GeoIP database not found, skipping test")
67+
}
6368
logger, err := zap.NewDevelopment()
6469
assert.NoError(t, err)
6570

@@ -77,7 +82,8 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
7782
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
7883
geoIP: geoIPBlock,
7984
},
80-
CustomResponses: customResponse,
85+
CustomResponses: customResponse,
86+
requestValueExtractor: NewRequestValueExtractor(logger, false),
8187
}
8288

8389
wlMiddleware := &Middleware{
@@ -90,7 +96,8 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
9096
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
9197
geoIP: geoIPBlock,
9298
},
93-
CustomResponses: customResponse,
99+
CustomResponses: customResponse,
100+
requestValueExtractor: NewRequestValueExtractor(logger, false),
94101
}
95102

96103
blackWhiteMw := &Middleware{
@@ -109,7 +116,8 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
109116
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
110117
geoIP: geoIPBlock,
111118
},
112-
CustomResponses: customResponse,
119+
CustomResponses: customResponse,
120+
requestValueExtractor: NewRequestValueExtractor(logger, false),
113121
}
114122

115123
req := httptest.NewRequest("GET", testURL, nil)
@@ -205,9 +213,10 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
205213

206214
t.Run("Allow unblocked CIDR", func(t *testing.T) {
207215
middleware := &Middleware{
208-
logger: logger,
209-
ipBlacklist: blackList,
210-
CustomResponses: customResponse,
216+
logger: logger,
217+
ipBlacklist: blackList,
218+
CustomResponses: customResponse,
219+
requestValueExtractor: NewRequestValueExtractor(logger, false),
211220
}
212221

213222
req := httptest.NewRequest("GET", testURL, nil)
@@ -222,9 +231,10 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
222231

223232
t.Run("Blocks blacklisted CIDR", func(t *testing.T) {
224233
middleware := &Middleware{
225-
logger: logger,
226-
ipBlacklist: blackList,
227-
CustomResponses: customResponse,
234+
logger: logger,
235+
ipBlacklist: blackList,
236+
CustomResponses: customResponse,
237+
requestValueExtractor: NewRequestValueExtractor(logger, false),
228238
}
229239

230240
req := httptest.NewRequest("GET", testURL, nil)

request.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
166166
}
167167

168168
// Redact sensitive fields before returning the value (as before)
169-
value := rve.redactValueIfSensitive(target, unredactedValue)
169+
value := rve.RedactValueIfSensitive(target, unredactedValue)
170170

171171
// Log the extracted value (redacted if necessary)
172172
rve.logger.Debug("Extracted value",
@@ -351,7 +351,7 @@ func (rve *RequestValueExtractor) extractValueForJSONPath(r *http.Request, jsonP
351351
}
352352

353353
// Helper function to redact value if target is sensitive
354-
func (rve *RequestValueExtractor) redactValueIfSensitive(target string, value string) string {
354+
func (rve *RequestValueExtractor) RedactValueIfSensitive(target string, value string) string {
355355
if rve.redactSensitiveData {
356356
for _, sensitive := range sensitiveTargets {
357357
if strings.Contains(strings.ToLower(target), sensitive) {

request_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func TestRedactValueIfSensitive(t *testing.T) {
145145
for _, tt := range tests {
146146
t.Run(tt.name, func(t *testing.T) {
147147
rve := NewRequestValueExtractor(logger, tt.redactSensitive)
148-
result := rve.redactValueIfSensitive(tt.target, tt.value)
148+
result := rve.RedactValueIfSensitive(tt.target, tt.value)
149149

150150
if tt.expectedRedacted && result != "REDACTED" {
151151
t.Errorf("Expected REDACTED but got %q", result)
@@ -364,10 +364,11 @@ func newMockLogger() *MockLogger {
364364
func TestProcessRuleMatch_HighScore(t *testing.T) {
365365
logger := newMockLogger()
366366
middleware := &Middleware{
367-
logger: logger.Logger,
368-
AnomalyThreshold: 100, // High threshold
369-
ruleHits: sync.Map{},
370-
muMetrics: sync.RWMutex{},
367+
logger: logger.Logger,
368+
AnomalyThreshold: 100, // High threshold
369+
ruleHits: sync.Map{},
370+
muMetrics: sync.RWMutex{},
371+
requestValueExtractor: NewRequestValueExtractor(logger.Logger, false), // Initialize
371372
}
372373

373374
rule := &Rule{
@@ -394,7 +395,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
394395
w := httptest.NewRecorder()
395396

396397
// Test blocking rule with high score
397-
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
398+
shouldContinue := middleware.processRuleMatch(w, req, rule, "header", "value", state)
398399
assert.False(t, shouldContinue)
399400
assert.Equal(t, http.StatusForbidden, w.Code)
400401
assert.True(t, state.Blocked)

rules.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
// rules.go
22
package caddywaf
33

4+
45
import (
56
"encoding/json"
67
"fmt"
78
"net/http"
89
"os"
910
"regexp"
1011
"sort"
11-
"strings"
12-
1312
"go.uber.org/zap"
1413
"go.uber.org/zap/zapcore"
1514
)
1615

17-
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, value string, state *WAFState) bool {
16+
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
1817
logID := r.Context().Value(ContextKeyLogId("logID")).(string)
1918

19+
redactedValue := m.requestValueExtractor.RedactValueIfSensitive(target, value)
20+
2021
m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message
2122
zap.String("rule_id", rule.ID),
22-
zap.String("target", strings.Join(rule.Targets, ",")),
23-
zap.String("value", value),
23+
zap.String("target", target), // Log the specific target that matched
24+
zap.String("value", redactedValue),
2425
zap.String("description", rule.Description),
2526
zap.Int("score", rule.Score),
2627
zap.Int("anomaly_threshold_config", m.AnomalyThreshold), // ADDED: Log configured anomaly threshold

rules_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ func TestProcessRuleMatch(t *testing.T) {
146146
for _, tt := range tests {
147147
t.Run(tt.name, func(t *testing.T) {
148148
m := &Middleware{
149-
logger: logger,
150-
AnomalyThreshold: tt.anomalyThreshold,
151-
ruleHits: sync.Map{},
152-
muMetrics: sync.RWMutex{},
149+
logger: logger,
150+
AnomalyThreshold: tt.anomalyThreshold,
151+
ruleHits: sync.Map{},
152+
muMetrics: sync.RWMutex{},
153+
requestValueExtractor: NewRequestValueExtractor(logger, false),
153154
}
154155

155156
w := httptest.NewRecorder()
@@ -162,7 +163,7 @@ func TestProcessRuleMatch(t *testing.T) {
162163
ResponseWritten: tt.responseWritten,
163164
}
164165

165-
result := m.processRuleMatch(w, r, &tt.rule, "test-value", state)
166+
result := m.processRuleMatch(w, r, &tt.rule, "ARGS", "test-value", state)
166167
if result == tt.wantBlock {
167168
t.Errorf("processRuleMatch() returned %v, want %v", result, !tt.wantBlock)
168169
}

0 commit comments

Comments
 (0)