Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ _book/
dist/
coverage.*
.bin
.claude/
10 changes: 9 additions & 1 deletion json.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type ErrorContainer struct {
Error *DefaultError `json:"error"`
}

func (e *ErrorContainer) ID() string {
return e.Error.ID()
}

type ErrorReporter interface {
ReportError(r *http.Request, code int, err error, args ...interface{})
}
Expand Down Expand Up @@ -159,13 +163,15 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)

// Enhancing must happen after logging or context will be lost.
var payload interface{} = err
if h.ErrorEnhancer != nil {
payload = h.ErrorEnhancer(r, err)
}
if id, ok := payload.(interface{ ID() string }); ok {
w.Header().Set("Ory-Error-Id", id.ID())
}
if de, ok := payload.(*DefaultError); ok && !h.EnableDebug {
de2 := *de
de2.DebugField = ""
Expand All @@ -179,6 +185,8 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
payload = ec2
}

w.WriteHeader(code)

if err := json.NewEncoder(w).Encode(payload); err != nil {
// There was an error, but there's actually not a lot we can do except log that this happened.
h.Reporter.ReportError(r, code, errors.WithStack(err), "Could not write ErrorContainer to response writer")
Expand Down
48 changes: 48 additions & 0 deletions json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,51 @@ func TestCanceledJSON(t *testing.T) {
assert.Contains(t, string(body), "some unrelated error")
assert.Equal(t, 499, resp.StatusCode)
}

func TestOryErrorIDHeader(t *testing.T) {
for k, tc := range []struct {
name string
err error
expectedHeader string
}{
{
name: "sets ID in header",
err: &ErrMisconfiguration,
expectedHeader: "invalid_configuration",
},
{
name: "sets empty header without ID",
err: &ErrNotFound,
expectedHeader: "",
},
{
name: "custom error with ID sets header",
err: &DefaultError{
IDField: "custom_error_id",
CodeField: http.StatusBadRequest,
StatusField: http.StatusText(http.StatusBadRequest),
ErrorField: "custom error",
},
expectedHeader: "custom_error_id",
},
{
name: "upstream error sets header",
err: &ErrUpstreamError,
expectedHeader: "upstream_error",
},
} {
t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) {
h := NewJSONWriter(nil)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, tc.err)
}))
t.Cleanup(ts.Close)

resp, err := http.Get(ts.URL + "/do")
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id"))
})
}
}
3 changes: 3 additions & 0 deletions plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ func (h *TextWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
// All errors land here, so it's a really good idea to do the logging here as well!
h.Reporter.ReportError(r, code, err, "An error occurred while handling a request")

if id, ok := err.(interface{ ID() string }); ok {
w.Header().Set("Ory-Error-Id", id.ID())
}
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(code)
fmt.Fprintf(w, "%s", err)
Expand Down
62 changes: 62 additions & 0 deletions plain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package herodot

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTextWriterOryErrorIDHeader(t *testing.T) {
for k, tc := range []struct {
name string
err error
expectedHeader string
}{
{
name: "error with ID sets header",
err: &ErrMisconfiguration,
expectedHeader: "invalid_configuration",
},
{
name: "error without ID does not set header",
err: &ErrNotFound,
expectedHeader: "",
},
{
name: "custom error with ID sets header",
err: &DefaultError{
IDField: "custom_text_error_id",
CodeField: http.StatusBadRequest,
StatusField: http.StatusText(http.StatusBadRequest),
ErrorField: "custom error",
},
expectedHeader: "custom_text_error_id",
},
{
name: "upstream error sets header",
err: &ErrUpstreamError,
expectedHeader: "upstream_error",
},
} {
t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) {
h := NewTextWriter(&stdReporter{}, "plain")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, tc.err)
}))
t.Cleanup(ts.Close)

resp, err := http.Get(ts.URL + "/do")
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id"))
})
}
}
Loading