Skip to content

Commit e408905

Browse files
cstocktonChris Stockton
and
Chris Stockton
authored
chore: move error codes to apierrors package (#1973)
This change will allow moving code out of the api into smaller packages without creating cyclic dependencies. --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent 6b842f6 commit e408905

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+472
-408
lines changed

internal/api/admin.go

+17-16
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/gofrs/uuid"
1111
"github.com/pkg/errors"
1212
"github.com/sethvargo/go-password/password"
13+
"github.com/supabase/auth/internal/api/apierrors"
1314
"github.com/supabase/auth/internal/api/provider"
1415
"github.com/supabase/auth/internal/models"
1516
"github.com/supabase/auth/internal/observability"
@@ -53,15 +54,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
5354

5455
userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
5556
if err != nil {
56-
return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID")
57+
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
5758
}
5859

5960
observability.LogEntrySetField(r, "user_id", userID)
6061

6162
u, err := models.FindUserByID(db, userID)
6263
if err != nil {
6364
if models.IsNotFoundError(err) {
64-
return nil, notFoundError(ErrorCodeUserNotFound, "User not found")
65+
return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
6566
}
6667
return nil, internalServerError("Database error loading user").WithInternalError(err)
6768
}
@@ -76,15 +77,15 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
7677
user := getUser(ctx)
7778
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
7879
if err != nil {
79-
return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID")
80+
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
8081
}
8182

8283
observability.LogEntrySetField(r, "factor_id", factorID)
8384

8485
factor, err := user.FindOwnedFactorByID(db, factorID)
8586
if err != nil {
8687
if models.IsNotFoundError(err) {
87-
return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")
88+
return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
8889
}
8990
return nil, internalServerError("Database error loading factor").WithInternalError(err)
9091
}
@@ -108,12 +109,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
108109

109110
pageParams, err := paginate(r)
110111
if err != nil {
111-
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
112+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
112113
}
113114

114115
sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
115116
if err != nil {
116-
return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
117+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
117118
}
118119

119120
filter := r.URL.Query().Get("filter")
@@ -169,7 +170,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
169170
if params.BanDuration != "none" {
170171
duration, err = time.ParseDuration(params.BanDuration)
171172
if err != nil {
172-
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
173+
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
173174
}
174175
}
175176
banDuration = &duration
@@ -338,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
338339
}
339340

340341
if params.Email == "" && params.Phone == "" {
341-
return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
342+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
342343
}
343344

344345
var providers []string
@@ -350,7 +351,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
350351
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
351352
return internalServerError("Database error checking email").WithInternalError(err)
352353
} else if user != nil {
353-
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
354+
return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
354355
}
355356
providers = append(providers, "email")
356357
}
@@ -363,13 +364,13 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
363364
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
364365
return internalServerError("Database error checking phone").WithInternalError(err)
365366
} else if exists {
366-
return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user")
367+
return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
367368
}
368369
providers = append(providers, "phone")
369370
}
370371

371372
if params.Password != nil && params.PasswordHash != "" {
372-
return badRequestError(ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
373+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
373374
}
374375

375376
if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" {
@@ -389,18 +390,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
389390

390391
if err != nil {
391392
if errors.Is(err, bcrypt.ErrPasswordTooLong) {
392-
return badRequestError(ErrorCodeValidationFailed, err.Error())
393+
return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
393394
}
394395
return internalServerError("Error creating user").WithInternalError(err)
395396
}
396397

397398
if params.Id != "" {
398399
customId, err := uuid.FromString(params.Id)
399400
if err != nil {
400-
return badRequestError(ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
401+
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
401402
}
402403
if customId == uuid.Nil {
403-
return badRequestError(ErrorCodeValidationFailed, "ID cannot be a nil uuid")
404+
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
404405
}
405406
user.ID = customId
406407
}
@@ -418,7 +419,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
418419
if params.BanDuration != "none" {
419420
duration, err = time.ParseDuration(params.BanDuration)
420421
if err != nil {
421-
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
422+
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
422423
}
423424
}
424425
banDuration = &duration
@@ -618,7 +619,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
618619
if params.Phone != "" && factor.IsPhoneFactor() {
619620
phone, err := validatePhone(params.Phone)
620621
if err != nil {
621-
return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
622+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
622623
}
623624
if terr := factor.UpdatePhone(tx, phone); terr != nil {
624625
return terr

internal/api/admin_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
1717
"github.com/stretchr/testify/suite"
18+
"github.com/supabase/auth/internal/api/apierrors"
1819
"github.com/supabase/auth/internal/conf"
1920
"github.com/supabase/auth/internal/models"
2021
)
@@ -908,7 +909,7 @@ func (ts *AdminTestSuite) TestAdminUserCreateValidationErrors() {
908909

909910
data := map[string]interface{}{}
910911
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
911-
require.Equal(ts.T(), data["error_code"], ErrorCodeValidationFailed)
912+
require.Equal(ts.T(), data["error_code"], apierrors.ErrorCodeValidationFailed)
912913
})
913914

914915
}

internal/api/anonymous.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package api
33
import (
44
"net/http"
55

6+
"github.com/supabase/auth/internal/api/apierrors"
67
"github.com/supabase/auth/internal/metering"
78
"github.com/supabase/auth/internal/models"
89
"github.com/supabase/auth/internal/storage"
@@ -15,7 +16,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
1516
aud := a.requestAud(ctx, r)
1617

1718
if config.DisableSignup {
18-
return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance")
19+
return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
1920
}
2021

2122
params := &SignupParams{}

internal/api/api.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/rs/cors"
99
"github.com/sebest/xff"
1010
"github.com/sirupsen/logrus"
11+
"github.com/supabase/auth/internal/api/apierrors"
1112
"github.com/supabase/auth/internal/conf"
1213
"github.com/supabase/auth/internal/mailer"
1314
"github.com/supabase/auth/internal/models"
@@ -155,7 +156,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
155156
}
156157
if params.Email == "" && params.Phone == "" {
157158
if !api.config.External.AnonymousUsers.Enabled {
158-
return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
159+
return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
159160
}
160161
if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil {
161162
return err

internal/api/apierrors/apierrors.go

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package apierrors
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
// OAuthError is the JSON handler for OAuth2 error responses
8+
type OAuthError struct {
9+
Err string `json:"error"`
10+
Description string `json:"error_description,omitempty"`
11+
InternalError error `json:"-"`
12+
InternalMessage string `json:"-"`
13+
}
14+
15+
func NewOAuthError(err string, description string) *OAuthError {
16+
return &OAuthError{Err: err, Description: description}
17+
}
18+
19+
func (e *OAuthError) Error() string {
20+
if e.InternalMessage != "" {
21+
return e.InternalMessage
22+
}
23+
return fmt.Sprintf("%s: %s", e.Err, e.Description)
24+
}
25+
26+
// WithInternalError adds internal error information to the error
27+
func (e *OAuthError) WithInternalError(err error) *OAuthError {
28+
e.InternalError = err
29+
return e
30+
}
31+
32+
// WithInternalMessage adds internal message information to the error
33+
func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError {
34+
e.InternalMessage = fmt.Sprintf(fmtString, args...)
35+
return e
36+
}
37+
38+
// Cause returns the root cause error
39+
func (e *OAuthError) Cause() error {
40+
if e.InternalError != nil {
41+
return e.InternalError
42+
}
43+
return e
44+
}
45+
46+
// HTTPError is an error with a message and an HTTP status code.
47+
type HTTPError struct {
48+
HTTPStatus int `json:"code"` // do not rename the JSON tags!
49+
ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags!
50+
Message string `json:"msg"` // do not rename the JSON tags!
51+
InternalError error `json:"-"`
52+
InternalMessage string `json:"-"`
53+
ErrorID string `json:"error_id,omitempty"`
54+
}
55+
56+
func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError {
57+
return &HTTPError{
58+
HTTPStatus: httpStatus,
59+
ErrorCode: errorCode,
60+
Message: fmt.Sprintf(fmtString, args...),
61+
}
62+
}
63+
64+
func (e *HTTPError) Error() string {
65+
if e.InternalMessage != "" {
66+
return e.InternalMessage
67+
}
68+
return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message)
69+
}
70+
71+
func (e *HTTPError) Is(target error) bool {
72+
return e.Error() == target.Error()
73+
}
74+
75+
// Cause returns the root cause error
76+
func (e *HTTPError) Cause() error {
77+
if e.InternalError != nil {
78+
return e.InternalError
79+
}
80+
return e
81+
}
82+
83+
// WithInternalError adds internal error information to the error
84+
func (e *HTTPError) WithInternalError(err error) *HTTPError {
85+
e.InternalError = err
86+
return e
87+
}
88+
89+
// WithInternalMessage adds internal message information to the error
90+
func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError {
91+
e.InternalMessage = fmt.Sprintf(fmtString, args...)
92+
return e
93+
}

internal/api/errorcodes.go renamed to internal/api/apierrors/errorcode.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package api
1+
package apierrors
22

33
type ErrorCode = string
44

internal/api/audit.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http"
55
"strings"
66

7+
"github.com/supabase/auth/internal/api/apierrors"
78
"github.com/supabase/auth/internal/models"
89
)
910

@@ -20,7 +21,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
2021
// aud := a.requestAud(ctx, r)
2122
pageParams, err := paginate(r)
2223
if err != nil {
23-
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
24+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
2425
}
2526

2627
var col []string
@@ -31,7 +32,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
3132
qparts := strings.SplitN(q, ":", 2)
3233
col, exists = filterColumnMap[qparts[0]]
3334
if !exists || len(qparts) < 2 {
34-
return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q)
35+
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q)
3536
}
3637
qval = qparts[1]
3738
}

0 commit comments

Comments
 (0)