diff --git a/runner/redirect.go b/runner/redirect.go new file mode 100644 index 00000000..a96a027a --- /dev/null +++ b/runner/redirect.go @@ -0,0 +1,127 @@ +// Copyright 2024 OWASP CRS Project +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "fmt" + "net/url" + "strconv" + + "github.com/coreruleset/go-ftw/v2/ftwhttp" + "github.com/coreruleset/go-ftw/v2/test" + "github.com/rs/zerolog/log" +) + +// RedirectLocation represents a parsed redirect location +type RedirectLocation struct { + Protocol string + Host string + Port int + URI string +} + +// extractRedirectLocation parses the Location header from a redirect response +// and returns the parsed components (protocol, host, port, URI). +// It handles both absolute and relative URLs. +func extractRedirectLocation(response *ftwhttp.Response, baseInput *test.Input) (*RedirectLocation, error) { + if response == nil { + return nil, fmt.Errorf("no previous response available for redirect") + } + + // Check if status code is a redirect + statusCode := response.Parsed.StatusCode + switch statusCode { + case 300, 301, 302, 303, 307, 308: + // valid redirect status codes + default: + return nil, fmt.Errorf("previous response status code %d is not a redirect", statusCode) + } + + // Get Location header + location := response.Parsed.Header.Get("Location") + if location == "" { + return nil, fmt.Errorf("previous response is a redirect but has no Location header") + } + + log.Debug().Msgf("Following redirect to: %s", location) + + // Parse the location URL + locationURL, err := url.Parse(location) + if err != nil { + return nil, fmt.Errorf("failed to parse Location header '%s': %w", location, err) + } + + result := &RedirectLocation{} + + // Build base URL from the previous request for resolving relative redirects + baseURL := &url.URL{ + Scheme: baseInput.GetProtocol(), + Host: baseInput.GetDestAddr(), + Path: baseInput.GetURI(), + } + + // Add port to host if it's not a default port + port := baseInput.GetPort() + isDefaultPort := (baseURL.Scheme == "https" && port == 443) || + (baseURL.Scheme == "http" && port == 80) + if !isDefaultPort { + baseURL.Host = fmt.Sprintf("%s:%d", baseURL.Host, port) + } + + // Resolve the location URL against the base URL + resolvedURL := baseURL.ResolveReference(locationURL) + + // Extract components from resolved URL + result.Protocol = resolvedURL.Scheme + result.Host = resolvedURL.Hostname() + + // Extract port + portStr := resolvedURL.Port() + if portStr != "" { + parsedPort, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port in Location header: %s", portStr) + } + result.Port = parsedPort + } else { + // Use default port based on scheme + if result.Protocol == "https" { + result.Port = 443 + } else { + result.Port = 80 + } + } + + // Construct URI (path + query); fragments are not included in RequestURI + result.URI = resolvedURL.RequestURI() + + log.Debug().Msgf("Parsed redirect: protocol=%s, host=%s, port=%d, uri=%s", + result.Protocol, result.Host, result.Port, result.URI) + + return result, nil +} + +// applyRedirectToInput modifies the test input to follow a redirect +func applyRedirectToInput(input *test.Input, redirect *RedirectLocation) { + // Override destination with redirect location + input.Protocol = &redirect.Protocol + input.DestAddr = &redirect.Host + input.Port = &redirect.Port + input.URI = &redirect.URI + + // Update Host header to match the new destination, including non-default ports + headers := input.GetHeaders() + + hostHeader := redirect.Host + isDefaultPort := (redirect.Protocol == "https" && redirect.Port == 443) || + (redirect.Protocol == "http" && redirect.Port == 80) + if !isDefaultPort { + hostHeader = fmt.Sprintf("%s:%d", redirect.Host, redirect.Port) + } + + headers.Set("Host", hostHeader) + + log.Debug().Msgf("Applied redirect to input: %s://%s:%d%s", + redirect.Protocol, redirect.Host, redirect.Port, redirect.URI) +} diff --git a/runner/redirect_test.go b/runner/redirect_test.go new file mode 100644 index 00000000..52560b6b --- /dev/null +++ b/runner/redirect_test.go @@ -0,0 +1,354 @@ +// Copyright 2024 OWASP CRS Project +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "net/http" + "testing" + + schema "github.com/coreruleset/ftw-tests-schema/v2/types" + "github.com/coreruleset/go-ftw/v2/ftwhttp" + "github.com/coreruleset/go-ftw/v2/test" + "github.com/rs/zerolog" + "github.com/stretchr/testify/suite" +) + +type redirectTestSuite struct { + suite.Suite +} + +func (s *redirectTestSuite) SetupSuite() { + zerolog.SetGlobalLevel(zerolog.Disabled) +} + +func TestRedirectTestSuite(t *testing.T) { + suite.Run(t, new(redirectTestSuite)) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_AbsoluteURL() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 302, + Header: http.Header{ + "Location": []string{"https://newdomain.com:8443/newpath?query=value"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err) + s.NotNil(result) + s.Equal("https", result.Protocol) + s.Equal("newdomain.com", result.Host) + s.Equal(8443, result.Port) + s.Equal("/newpath?query=value", result.URI) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_AbsoluteURLWithDefaultPort() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 301, + Header: http.Header{ + "Location": []string{"https://newdomain.com/newpath"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err) + s.NotNil(result) + s.Equal("https", result.Protocol) + s.Equal("newdomain.com", result.Host) + s.Equal(443, result.Port) + s.Equal("/newpath", result.URI) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_RelativeURLAbsolutePath() { + protocol := "http" + destAddr := "example.com" + port := 8080 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 302, + Header: http.Header{ + "Location": []string{"/newpath"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err) + s.NotNil(result) + s.Equal("http", result.Protocol) + s.Equal("example.com", result.Host) + s.Equal(8080, result.Port) + s.Equal("/newpath", result.URI) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_RelativeURLRelativePath() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/path/to/resource" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 302, + Header: http.Header{ + "Location": []string{"newresource"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err) + s.NotNil(result) + s.Equal("http", result.Protocol) + s.Equal("example.com", result.Host) + s.Equal(80, result.Port) + s.Equal("/path/to/newresource", result.URI) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_NoResponse() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + result, err := extractRedirectLocation(nil, baseInput) + s.Error(err) + s.Nil(result) + s.Contains(err.Error(), "no previous response available") +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_NotRedirectStatus() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 200, + Header: http.Header{ + "Location": []string{"/newpath"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.Error(err) + s.Nil(result) + s.Contains(err.Error(), "not a redirect") +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_NoLocationHeader() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 302, + Header: http.Header{}, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.Error(err) + s.Nil(result) + s.Contains(err.Error(), "no Location header") +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_HTTPToHTTPS() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: 301, + Header: http.Header{ + "Location": []string{"https://example.com/secure"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err) + s.NotNil(result) + s.Equal("https", result.Protocol) + s.Equal("example.com", result.Host) + s.Equal(443, result.Port) + s.Equal("/secure", result.URI) +} + +func (s *redirectTestSuite) TestApplyRedirectToInput() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + input := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + redirect := &RedirectLocation{ + Protocol: "https", + Host: "newdomain.com", + Port: 8443, + URI: "/newpath", + } + + applyRedirectToInput(input, redirect) + + s.Equal("https", input.GetProtocol()) + s.Equal("newdomain.com", input.GetDestAddr()) + s.Equal(8443, input.GetPort()) + s.Equal("/newpath", input.GetURI()) + + // Check Host header was updated (should include port for non-default ports) + headers := input.GetHeaders() + hostHeaders := headers.GetAll("Host") + s.Len(hostHeaders, 1) + s.Equal("newdomain.com:8443", hostHeaders[0].Value) +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_Various3xxCodes() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + redirectCodes := []int{300, 301, 302, 303, 307, 308} + + for _, code := range redirectCodes { + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: code, + Header: http.Header{ + "Location": []string{"/redirect"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.NoError(err, "Failed for status code %d", code) + s.NotNil(result, "Result is nil for status code %d", code) + s.Equal("/redirect", result.URI) + } +} + +func (s *redirectTestSuite) TestExtractRedirectLocation_NonRedirect3xxCodes() { + protocol := "http" + destAddr := "example.com" + port := 80 + uri := "/original" + + baseInput := test.NewInput(&schema.Input{ + Protocol: &protocol, + DestAddr: &destAddr, + Port: &port, + URI: &uri, + }) + + // Test non-redirect 3xx codes that should be rejected + nonRedirectCodes := []int{304, 305, 306} + + for _, code := range nonRedirectCodes { + response := &ftwhttp.Response{ + Parsed: http.Response{ + StatusCode: code, + Header: http.Header{ + "Location": []string{"/somewhere"}, + }, + }, + } + + result, err := extractRedirectLocation(response, baseInput) + s.Error(err, "Should reject status code %d", code) + s.Nil(result, "Result should be nil for status code %d", code) + s.Contains(err.Error(), "not a redirect", "Error message should indicate it's not a redirect for code %d", code) + } +} diff --git a/runner/run.go b/runner/run.go index c940ab4d..397aa7d7 100644 --- a/runner/run.go +++ b/runner/run.go @@ -76,6 +76,10 @@ func RunTest(runContext *TestRunContext, ftwTest *test.FTWTest) error { continue } runContext.StartTest() + // Clear previous response and input when starting a new test + // (follow_redirect should only work within the same test case) + runContext.LastStageResponse = nil + runContext.LastStageInput = nil test.ApplyPlatformOverrides(runContext.RunnerConfig, &testCase) // this is just for printing once the next test @@ -144,6 +148,15 @@ func RunStage(runContext *TestRunContext, ftwCheck *FTWCheck, testCase schema.Te return nil } + // Handle follow_redirect if enabled + if stage.Input.FollowRedirect != nil && *stage.Input.FollowRedirect { + redirectLocation, err := extractRedirectLocation(runContext.LastStageResponse, runContext.LastStageInput) + if err != nil { + return fmt.Errorf("follow_redirect enabled but failed to extract redirect location: %w", err) + } + applyRedirectToInput(testInput, redirectLocation) + } + // Destination is needed for a request dest := &ftwhttp.Destination{ DestAddr: testInput.GetDestAddr(), @@ -202,6 +215,10 @@ func RunStage(runContext *TestRunContext, ftwCheck *FTWCheck, testCase schema.Te runContext.EndStage(&testCase, testResult, ftwCheck.GetTriggeredRules()) + // Store the response and input for potential use by follow_redirect in next stage + runContext.LastStageResponse = response + runContext.LastStageInput = testInput + // show the result unless quiet was passed in the command line displayResult(&testCase, runContext, testResult, roundTripTime) diff --git a/runner/run_test.go b/runner/run_test.go index 9696ff03..5c5e6ee9 100644 --- a/runner/run_test.go +++ b/runner/run_test.go @@ -12,6 +12,7 @@ import ( "os" "regexp" "strconv" + "sync" "testing" "text/template" @@ -339,6 +340,73 @@ func (s *runTestSuite) TestOverrideRun() { s.LessOrEqual(0, res.Stats.TotalFailed(), "Oops, test run failed!") } +func (s *runTestSuite) TestFollowRedirect() { + // Track which URIs were requested to validate redirect behavior + var requestedURIs []string + var requestedHosts []string + var mu sync.Mutex + + // Custom handler that returns a redirect on first request + handler := func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestedURIs = append(requestedURIs, r.RequestURI) + requestedHosts = append(requestedHosts, r.Host) + mu.Unlock() + + // Don't track marker requests + if r.Header.Get(s.cfg.LogMarkerHeaderName) != "" { + s.writeMarkerOrMessageToTestServerLog(logText, r) + w.WriteHeader(http.StatusOK) + return + } + + if r.RequestURI == "/redirect-me" { + // Stage 1: Return relative redirect to /redirected + w.Header().Set("Location", "/redirected") + w.WriteHeader(http.StatusFound) + _, _ = w.Write([]byte("Redirecting...")) + } else if r.RequestURI == "/redirected" { + // Stage 2: Return success + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Success after redirect")) + } else { + // Unexpected URI + w.WriteHeader(http.StatusNotFound) + } + + s.writeMarkerOrMessageToTestServerLog(logText, r) + } + + s.ts.Config.Handler = http.HandlerFunc(handler) + + s.runnerConfig.Output = output.Quiet + res, err := Run(s.runnerConfig, s.ftwTests, s.out) + s.Require().NoError(err) + s.Equal(0, res.Stats.TotalFailed(), "Follow redirect test should pass") + + // Verify that both URIs were requested (excluding marker requests) + mu.Lock() + actualRequests := []string{} + actualHosts := []string{} + for i, uri := range requestedURIs { + if uri != "/status/200" { // Skip marker requests + actualRequests = append(actualRequests, uri) + actualHosts = append(actualHosts, requestedHosts[i]) + } + } + mu.Unlock() + + s.Require().Len(actualRequests, 2, "Should have made 2 non-marker requests") + s.Equal("/redirect-me", actualRequests[0], "First request should be to /redirect-me") + s.Equal("/redirected", actualRequests[1], "Second request should be to /redirected (redirect target)") + + // Verify Host headers + s.Require().Len(actualHosts, 2, "Should have captured 2 Host headers") + // First request uses Host from test yaml (just IP, no port) + // Second request (after redirect) should include port for non-default port + s.Contains(actualHosts[1], ":", "Host header should include port after redirect for non-default port") +} + func (s *runTestSuite) TestBrokenOverrideRun() { // the test should succeed, despite the unknown override property res, err := Run(s.runnerConfig, s.ftwTests, s.out) diff --git a/runner/testdata/TestFollowRedirect.yaml b/runner/testdata/TestFollowRedirect.yaml new file mode 100644 index 00000000..16e3601f --- /dev/null +++ b/runner/testdata/TestFollowRedirect.yaml @@ -0,0 +1,27 @@ +--- +meta: + author: "tester" + description: "Test follow_redirect functionality" +tests: + - test_id: 1 + description: "Test redirect following between stages" + stages: + # Stage 1: Initial request that returns a redirect + - input: + dest_addr: "{{ .TestAddr }}" + port: {{ .TestPort }} + uri: "/redirect-me" + headers: + User-Agent: "go-ftw test" + Host: "{{ .TestAddr }}" + output: + expect_error: False + status: 302 + # Stage 2: Follow the redirect from stage 1 + - input: + follow_redirect: true + headers: + User-Agent: "go-ftw test" + output: + expect_error: False + status: 200 diff --git a/runner/types.go b/runner/types.go index 4a65365c..43f72446 100644 --- a/runner/types.go +++ b/runner/types.go @@ -11,6 +11,7 @@ import ( "github.com/coreruleset/go-ftw/v2/config" "github.com/coreruleset/go-ftw/v2/ftwhttp" "github.com/coreruleset/go-ftw/v2/output" + "github.com/coreruleset/go-ftw/v2/test" "github.com/coreruleset/go-ftw/v2/waflog" ) @@ -32,6 +33,12 @@ type TestRunContext struct { LogLines *waflog.FTWLogLines CurrentStageDuration time.Duration currentStageStartTime time.Time + // LastStageResponse stores the response from the previous stage, + // used for follow_redirect functionality + LastStageResponse *ftwhttp.Response + // LastStageInput stores the input from the previous stage, + // used as base for resolving relative redirects + LastStageInput *test.Input } func (t *TestRunContext) StartTest() {