diff --git a/.gitignore b/.gitignore index 1a56e84..630d1c9 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ _book/ dist/ coverage.* .bin +.claude/ \ No newline at end of file diff --git a/json.go b/json.go index 87f060e..7fea051 100644 --- a/json.go +++ b/json.go @@ -17,6 +17,10 @@ type ErrorContainer struct { Error *DefaultError `json:"error"` } +func (e *ErrorContainer) ID() string { + return e.Error.ID() +} + type ErrorReporter interface { ReportError(r *http.Request, code int, err error, args ...interface{}) } @@ -159,13 +163,15 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) // Enhancing must happen after logging or context will be lost. var payload interface{} = err if h.ErrorEnhancer != nil { payload = h.ErrorEnhancer(r, err) } + if id, ok := payload.(interface{ ID() string }); ok { + w.Header().Set("Ory-Error-Id", id.ID()) + } if de, ok := payload.(*DefaultError); ok && !h.EnableDebug { de2 := *de de2.DebugField = "" @@ -179,6 +185,8 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code payload = ec2 } + w.WriteHeader(code) + if err := json.NewEncoder(w).Encode(payload); err != nil { // There was an error, but there's actually not a lot we can do except log that this happened. h.Reporter.ReportError(r, code, errors.WithStack(err), "Could not write ErrorContainer to response writer") diff --git a/json_test.go b/json_test.go index 579405d..6e6ce96 100644 --- a/json_test.go +++ b/json_test.go @@ -422,3 +422,51 @@ func TestCanceledJSON(t *testing.T) { assert.Contains(t, string(body), "some unrelated error") assert.Equal(t, 499, resp.StatusCode) } + +func TestOryErrorIDHeader(t *testing.T) { + for k, tc := range []struct { + name string + err error + expectedHeader string + }{ + { + name: "sets ID in header", + err: &ErrMisconfiguration, + expectedHeader: "invalid_configuration", + }, + { + name: "sets empty header without ID", + err: &ErrNotFound, + expectedHeader: "", + }, + { + name: "custom error with ID sets header", + err: &DefaultError{ + IDField: "custom_error_id", + CodeField: http.StatusBadRequest, + StatusField: http.StatusText(http.StatusBadRequest), + ErrorField: "custom error", + }, + expectedHeader: "custom_error_id", + }, + { + name: "upstream error sets header", + err: &ErrUpstreamError, + expectedHeader: "upstream_error", + }, + } { + t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) { + h := NewJSONWriter(nil) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.WriteError(w, r, tc.err) + })) + t.Cleanup(ts.Close) + + resp, err := http.Get(ts.URL + "/do") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id")) + }) + } +} diff --git a/plain.go b/plain.go index 14c8e7f..941eeaa 100644 --- a/plain.go +++ b/plain.go @@ -87,6 +87,9 @@ func (h *TextWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code // All errors land here, so it's a really good idea to do the logging here as well! h.Reporter.ReportError(r, code, err, "An error occurred while handling a request") + if id, ok := err.(interface{ ID() string }); ok { + w.Header().Set("Ory-Error-Id", id.ID()) + } w.Header().Set("Content-Type", h.contentType) w.WriteHeader(code) fmt.Fprintf(w, "%s", err) diff --git a/plain_test.go b/plain_test.go new file mode 100644 index 0000000..fea2846 --- /dev/null +++ b/plain_test.go @@ -0,0 +1,62 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package herodot + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTextWriterOryErrorIDHeader(t *testing.T) { + for k, tc := range []struct { + name string + err error + expectedHeader string + }{ + { + name: "error with ID sets header", + err: &ErrMisconfiguration, + expectedHeader: "invalid_configuration", + }, + { + name: "error without ID does not set header", + err: &ErrNotFound, + expectedHeader: "", + }, + { + name: "custom error with ID sets header", + err: &DefaultError{ + IDField: "custom_text_error_id", + CodeField: http.StatusBadRequest, + StatusField: http.StatusText(http.StatusBadRequest), + ErrorField: "custom error", + }, + expectedHeader: "custom_text_error_id", + }, + { + name: "upstream error sets header", + err: &ErrUpstreamError, + expectedHeader: "upstream_error", + }, + } { + t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) { + h := NewTextWriter(&stdReporter{}, "plain") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.WriteError(w, r, tc.err) + })) + t.Cleanup(ts.Close) + + resp, err := http.Get(ts.URL + "/do") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id")) + }) + } +}