Skip to content

Commit dbcdbb9

Browse files
author
Chris Stockton
committed
feat: refactor api error calls to use the apierrors package
This change will help produce consistent errors as we move code out of the api package into smaller sub packages.
1 parent 99a1eb0 commit dbcdbb9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+483
-518
lines changed

internal/api/admin.go

+32-32
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
5454

5555
userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
5656
if err != nil {
57-
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
57+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
5858
}
5959

6060
observability.LogEntrySetField(r, "user_id", userID)
6161

6262
u, err := models.FindUserByID(db, userID)
6363
if err != nil {
6464
if models.IsNotFoundError(err) {
65-
return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
65+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
6666
}
67-
return nil, internalServerError("Database error loading user").WithInternalError(err)
67+
return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err)
6868
}
6969

7070
return withUser(ctx, u), nil
@@ -77,17 +77,17 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
7777
user := getUser(ctx)
7878
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
7979
if err != nil {
80-
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
80+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
8181
}
8282

8383
observability.LogEntrySetField(r, "factor_id", factorID)
8484

8585
factor, err := user.FindOwnedFactorByID(db, factorID)
8686
if err != nil {
8787
if models.IsNotFoundError(err) {
88-
return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
88+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
8989
}
90-
return nil, internalServerError("Database error loading factor").WithInternalError(err)
90+
return nil, apierrors.NewInternalServerError("Database error loading factor").WithInternalError(err)
9191
}
9292
return withFactor(ctx, factor), nil
9393
}
@@ -109,19 +109,19 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
109109

110110
pageParams, err := paginate(r)
111111
if err != nil {
112-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
112+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
113113
}
114114

115115
sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
116116
if err != nil {
117-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
117+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
118118
}
119119

120120
filter := r.URL.Query().Get("filter")
121121

122122
users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter)
123123
if err != nil {
124-
return internalServerError("Database error finding users").WithInternalError(err)
124+
return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err)
125125
}
126126
addPaginationHeaders(w, r, pageParams)
127127

@@ -170,7 +170,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
170170
if params.BanDuration != "none" {
171171
duration, err = time.ParseDuration(params.BanDuration)
172172
if err != nil {
173-
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
173+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
174174
}
175175
}
176176
banDuration = &duration
@@ -315,7 +315,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
315315
})
316316

317317
if err != nil {
318-
return internalServerError("Error updating user").WithInternalError(err)
318+
return apierrors.NewInternalServerError("Error updating user").WithInternalError(err)
319319
}
320320

321321
return sendJSON(w, http.StatusOK, user)
@@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
339339
}
340340

341341
if params.Email == "" && params.Phone == "" {
342-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
342+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
343343
}
344344

345345
var providers []string
@@ -349,9 +349,9 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
349349
return err
350350
}
351351
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
352-
return internalServerError("Database error checking email").WithInternalError(err)
352+
return apierrors.NewInternalServerError("Database error checking email").WithInternalError(err)
353353
} else if user != nil {
354-
return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
354+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
355355
}
356356
providers = append(providers, "email")
357357
}
@@ -362,21 +362,21 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
362362
return err
363363
}
364364
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
365-
return internalServerError("Database error checking phone").WithInternalError(err)
365+
return apierrors.NewInternalServerError("Database error checking phone").WithInternalError(err)
366366
} else if exists {
367-
return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
367+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
368368
}
369369
providers = append(providers, "phone")
370370
}
371371

372372
if params.Password != nil && params.PasswordHash != "" {
373-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
373+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
374374
}
375375

376376
if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" {
377377
password, err := password.Generate(64, 10, 0, false, true)
378378
if err != nil {
379-
return internalServerError("Error generating password").WithInternalError(err)
379+
return apierrors.NewInternalServerError("Error generating password").WithInternalError(err)
380380
}
381381
params.Password = &password
382382
}
@@ -390,18 +390,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
390390

391391
if err != nil {
392392
if errors.Is(err, bcrypt.ErrPasswordTooLong) {
393-
return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
393+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
394394
}
395-
return internalServerError("Error creating user").WithInternalError(err)
395+
return apierrors.NewInternalServerError("Error creating user").WithInternalError(err)
396396
}
397397

398398
if params.Id != "" {
399399
customId, err := uuid.FromString(params.Id)
400400
if err != nil {
401-
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
401+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
402402
}
403403
if customId == uuid.Nil {
404-
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
404+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
405405
}
406406
user.ID = customId
407407
}
@@ -419,7 +419,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
419419
if params.BanDuration != "none" {
420420
duration, err = time.ParseDuration(params.BanDuration)
421421
if err != nil {
422-
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
422+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
423423
}
424424
}
425425
banDuration = &duration
@@ -501,7 +501,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
501501
})
502502

503503
if err != nil {
504-
return internalServerError("Database error creating new user").WithInternalError(err)
504+
return apierrors.NewInternalServerError("Database error creating new user").WithInternalError(err)
505505
}
506506

507507
return sendJSON(w, http.StatusOK, user)
@@ -529,7 +529,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
529529
"user_email": user.Email,
530530
"user_phone": user.Phone,
531531
}); terr != nil {
532-
return internalServerError("Error recording audit log entry").WithInternalError(terr)
532+
return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr)
533533
}
534534

535535
if params.ShouldSoftDelete {
@@ -538,24 +538,24 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
538538
return nil
539539
}
540540
if terr := user.SoftDeleteUser(tx); terr != nil {
541-
return internalServerError("Error soft deleting user").WithInternalError(terr)
541+
return apierrors.NewInternalServerError("Error soft deleting user").WithInternalError(terr)
542542
}
543543

544544
if terr := user.SoftDeleteUserIdentities(tx); terr != nil {
545-
return internalServerError("Error soft deleting user identities").WithInternalError(terr)
545+
return apierrors.NewInternalServerError("Error soft deleting user identities").WithInternalError(terr)
546546
}
547547

548548
// hard delete all associated factors
549549
if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil {
550-
return internalServerError("Error deleting user's factors").WithInternalError(terr)
550+
return apierrors.NewInternalServerError("Error deleting user's factors").WithInternalError(terr)
551551
}
552552
// hard delete all associated sessions
553553
if terr := models.Logout(tx, user.ID); terr != nil {
554-
return internalServerError("Error deleting user's sessions").WithInternalError(terr)
554+
return apierrors.NewInternalServerError("Error deleting user's sessions").WithInternalError(terr)
555555
}
556556
} else {
557557
if terr := tx.Destroy(user); terr != nil {
558-
return internalServerError("Database error deleting user").WithInternalError(terr)
558+
return apierrors.NewInternalServerError("Database error deleting user").WithInternalError(terr)
559559
}
560560
}
561561

@@ -581,7 +581,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
581581
return terr
582582
}
583583
if terr := tx.Destroy(factor); terr != nil {
584-
return internalServerError("Database error deleting factor").WithInternalError(terr)
584+
return apierrors.NewInternalServerError("Database error deleting factor").WithInternalError(terr)
585585
}
586586
return nil
587587
})
@@ -619,7 +619,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
619619
if params.Phone != "" && factor.IsPhoneFactor() {
620620
phone, err := validatePhone(params.Phone)
621621
if err != nil {
622-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
622+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
623623
}
624624
if terr := factor.UpdatePhone(tx, phone); terr != nil {
625625
return terr

internal/api/anonymous.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
1616
aud := a.requestAud(ctx, r)
1717

1818
if config.DisableSignup {
19-
return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
19+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
2020
}
2121

2222
params := &SignupParams{}
@@ -48,7 +48,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
4848
return nil
4949
})
5050
if err != nil {
51-
return internalServerError("Database error creating anonymous user").WithInternalError(err)
51+
return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err)
5252
}
5353

5454
metering.RecordLogin("anonymous", newUser.ID)

internal/api/api.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
157157
}
158158
if params.Email == "" && params.Phone == "" {
159159
if !api.config.External.AnonymousUsers.Enabled {
160-
return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
160+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
161161
}
162162
if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil {
163163
return err

internal/api/audit.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
2121
// aud := a.requestAud(ctx, r)
2222
pageParams, err := paginate(r)
2323
if err != nil {
24-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
24+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
2525
}
2626

2727
var col []string
@@ -32,14 +32,14 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
3232
qparts := strings.SplitN(q, ":", 2)
3333
col, exists = filterColumnMap[qparts[0]]
3434
if !exists || len(qparts) < 2 {
35-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q)
35+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q)
3636
}
3737
qval = qparts[1]
3838
}
3939

4040
logs, err := models.FindAuditLogEntries(db, col, qval, pageParams)
4141
if err != nil {
42-
return internalServerError("Error searching for audit logs").WithInternalError(err)
42+
return apierrors.NewInternalServerError("Error searching for audit logs").WithInternalError(err)
4343
}
4444

4545
addPaginationHeaders(w, r, pageParams)

internal/api/auth.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte
3737
ctx := r.Context()
3838
claims := getClaims(ctx)
3939
if claims.IsAnonymous {
40-
return nil, forbiddenError(apierrors.ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions")
40+
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions")
4141
}
4242
return ctx, nil
4343
}
@@ -46,7 +46,7 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) {
4646
// Find the administrative user
4747
claims := getClaims(ctx)
4848
if claims == nil {
49-
return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "Invalid token")
49+
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "Invalid token")
5050
}
5151

5252
adminRoles := a.config.JWT.AdminRoles
@@ -56,14 +56,14 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) {
5656
return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil
5757
}
5858

59-
return nil, forbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", ")))
59+
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", ")))
6060
}
6161

6262
func (a *API) extractBearerToken(r *http.Request) (string, error) {
6363
authHeader := r.Header.Get("Authorization")
6464
matches := bearerRegexp.FindStringSubmatch(authHeader)
6565
if len(matches) != 2 {
66-
return "", httpError(http.StatusUnauthorized, apierrors.ErrorCodeNoAuthorization, "This endpoint requires a Bearer token")
66+
return "", apierrors.NewHTTPError(http.StatusUnauthorized, apierrors.ErrorCodeNoAuthorization, "This endpoint requires a Bearer token")
6767
}
6868

6969
return matches[1], nil
@@ -89,7 +89,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
8989
return nil, fmt.Errorf("missing kid")
9090
})
9191
if err != nil {
92-
return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
92+
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
9393
}
9494

9595
return withToken(ctx, token), nil
@@ -100,23 +100,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
100100
claims := getClaims(ctx)
101101

102102
if claims == nil {
103-
return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid token: missing claims")
103+
return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid token: missing claims")
104104
}
105105

106106
if claims.Subject == "" {
107-
return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim")
107+
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim")
108108
}
109109

110110
var user *models.User
111111
if claims.Subject != "" {
112112
userId, err := uuid.FromString(claims.Subject)
113113
if err != nil {
114-
return ctx, badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err)
114+
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err)
115115
}
116116
user, err = models.FindUserByID(db, userId)
117117
if err != nil {
118118
if models.IsNotFoundError(err) {
119-
return ctx, forbiddenError(apierrors.ErrorCodeUserNotFound, "User from sub claim in JWT does not exist")
119+
return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeUserNotFound, "User from sub claim in JWT does not exist")
120120
}
121121
return ctx, err
122122
}
@@ -127,12 +127,12 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
127127
if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() {
128128
sessionId, err := uuid.FromString(claims.SessionId)
129129
if err != nil {
130-
return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err)
130+
return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err)
131131
}
132132
session, err = models.FindSessionByID(db, sessionId, false)
133133
if err != nil {
134134
if models.IsNotFoundError(err) {
135-
return ctx, forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId))
135+
return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId))
136136
}
137137
return ctx, err
138138
}

internal/api/auth_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
185185
},
186186
Role: "authenticated",
187187
},
188-
ExpectedError: forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim"),
188+
ExpectedError: apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim"),
189189
ExpectedUser: nil,
190190
},
191191
{
@@ -207,7 +207,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
207207
},
208208
Role: "authenticated",
209209
},
210-
ExpectedError: badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"),
210+
ExpectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"),
211211
ExpectedUser: nil,
212212
},
213213
{
@@ -256,7 +256,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
256256
Role: "authenticated",
257257
SessionId: "73bf9ee0-9e8c-453b-b484-09cb93e2f341",
258258
},
259-
ExpectedError: forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"),
259+
ExpectedError: apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"),
260260
ExpectedUser: u,
261261
ExpectedSession: nil,
262262
},

0 commit comments

Comments
 (0)