Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions apps/api/api/credit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package api

import (
"context"

authv1 "buf.build/gen/go/dot/brease/protocolbuffers/go/brease/auth/v1"
"connectrpc.com/connect"
unkey "github.com/unkeyed/sdks/api/go/v2"
"github.com/unkeyed/sdks/api/go/v2/models/components"
)

func (b *BreaseHandler) UpdateCredit(ctx context.Context, c *connect.Request[authv1.UpdateCreditRequest]) (*connect.Response[authv1.UpdateCreditResponse], error) {
uk := unkey.New(
unkey.WithSecurity(c.Msg.RootKey),
)
_, err := uk.Keys.UpdateCredits(ctx, components.V2KeysUpdateCreditsRequestBody{
KeyID: c.Msg.TargetKey,
Value: &c.Msg.Value,
Operation: components.Operation(c.Msg.Operation),
})
if err != nil {
return nil, err
}

return connect.NewResponse(&authv1.UpdateCreditResponse{}), nil
}
8 changes: 8 additions & 0 deletions apps/api/auth/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"context"

"github.com/golang-jwt/jwt/v5"
)

Expand All @@ -19,6 +20,13 @@ func CtxString(c context.Context, key string) (s string) {
return
}

func CtxInt(c context.Context, key string) (i int) {
if val := c.Value(key); val != nil {
i, _ = val.(int)
}
return
}

func CtxJWTToken(c context.Context, key string) (token *jwt.Token) {
if val := c.Value(key); val != nil {
token, _ = val.(*jwt.Token)
Expand Down
88 changes: 65 additions & 23 deletions apps/api/auth/middleware.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package auth

import (
"connectrpc.com/connect"
"context"
"errors"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"

"connectrpc.com/connect"
"github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/unkeyed/unkey/sdks/golang/models/components"
components2 "github.com/unkeyed/sdks/api/go/v2/models/components"
"github.com/unkeyed/unkey/sdks/golang/models/sdkerrors"
"go.dot.industries/brease/trace"
"go.dot.industries/brease/worker"
"google.golang.org/grpc/metadata"
"net/http"
"regexp"
"strings"

"github.com/golang-jwt/jwt/v5"
errors2 "github.com/juju/errors"
Expand All @@ -31,6 +33,7 @@ const (
ContextUserIDKey = "userId"
ContextOrgKey = "orgId"
ContextPermissionsKey = "permissions"
ContextCreditsKey = "credits"
PermissionReadRule = "context.rule.read"
PermissionCreateRule = "context.rule.create"
PermissionEvaluate = "context.evaluate"
Expand Down Expand Up @@ -70,6 +73,7 @@ type validateAuthTokenResult struct {
orgID string
authenticator string
permissions []string
credits *int
}

func NewAuthInterceptor(logger *zap.Logger) connect.UnaryInterceptorFunc {
Expand All @@ -83,7 +87,7 @@ func NewAuthInterceptor(logger *zap.Logger) connect.UnaryInterceptorFunc {
// TODO: client side auth interceptor
// Send a token with client requests.
// req.Header().Set(tokenHeader, "sample")
} else if !strings.Contains(req.Spec().Procedure, "RefreshToken") {
} else if !strings.Contains(req.Spec().Procedure, "RefreshToken") && !strings.Contains(req.Spec().Procedure, "UpdateCredit") {
// server only
var err error
ctx, err = authenticate(ctx, req.Header(), logger)
Expand All @@ -92,7 +96,22 @@ func NewAuthInterceptor(logger *zap.Logger) connect.UnaryInterceptorFunc {
}
}

return next(ctx, req)
resp, err := next(ctx, req)
if err != nil {
return nil, err
}
if resp != nil {
if userID := CtxString(ctx, ContextUserIDKey); userID != "" {
resp.Header().Set("X-User-Id", userID)
}
if orgID := CtxString(ctx, ContextOrgKey); orgID != "" {
resp.Header().Set("X-Org-Id", orgID)
}
if credits := CtxInt(ctx, ContextCreditsKey); credits != 0 {
resp.Header().Set("X-Credits", strconv.Itoa(credits))
}
}
return resp, nil
}
}
return interceptor
Expand Down Expand Up @@ -211,6 +230,9 @@ func authenticate(ctx context.Context, headers http.Header, logger *zap.Logger)
if authed.permissions != nil {
ctx = context.WithValue(ctx, ContextPermissionsKey, authed.permissions)
}
if authed.credits != nil {
ctx = context.WithValue(ctx, ContextCreditsKey, *authed.credits)
}

return ctx, nil
}
Expand Down Expand Up @@ -284,10 +306,8 @@ func validateUnkey(ctx context.Context, args interface{}) (interface{}, error) {
},
}, nil
}
apiID := env.Getenv("UNKEY_API_ID", "")
resp, err := unkeyClient.Keys.VerifyKey(ctx, components.V1KeysVerifyKeyRequest{
APIID: &apiID,
Key: key,
resp, err := unkeyClient.Keys.VerifyKey(ctx, components2.V2KeysVerifyKeyRequestBody{
Key: key,
})
if err != nil {
var errBadRequest *sdkerrors.ErrBadRequest
Expand Down Expand Up @@ -338,34 +358,56 @@ func validateUnkey(ctx context.Context, args interface{}) (interface{}, error) {
}
}

if !resp.Valid {
switch resp.Code {
case components.CodeUsageExceeded:
if resp.V2KeysVerifyKeyResponseBody == nil {
return validateAuthTokenResult{
error: &validationErr{
Status: http.StatusUnauthorized,
Error: errors2.Unauthorizedf("invalid API response: %v", resp.GetV2KeysVerifyKeyResponseBody()),
},
}, nil
}
if resp.V2KeysVerifyKeyResponseBody.Data.Valid != true {
r := resp.V2KeysVerifyKeyResponseBody.Data
switch r.Code {
case components2.CodeUsageExceeded:
return validateAuthTokenResult{
error: &validationErr{
Status: http.StatusTooManyRequests,
Error: errors2.NewQuotaLimitExceeded(nil, "usage exceeded credits"),
},
}, nil
case components.CodeRateLimited:
case components2.CodeRateLimited:
str := ""
for _, rl := range r.Ratelimits {
if str != "" {
str += "\n"
}
str += fmt.Sprintf(
"rate limit: %d remaining: %d reset: %v, duration: %d",
rl.Limit,
rl.Remaining,
rl.Reset,
rl.Duration,
)
}
return validateAuthTokenResult{
error: &validationErr{
Status: http.StatusTooManyRequests,
Error: errors2.NewNotYetAvailable(nil, fmt.Sprintf("rate limit: %.2f remaining: %.2f reset: %v", resp.Ratelimit.Limit, resp.Ratelimit.Remaining, resp.Ratelimit.Reset)),
Error: errors2.NewNotYetAvailable(nil, str),
},
}, nil
default:
return validateAuthTokenResult{
error: &validationErr{
Status: http.StatusUnauthorized,
Error: errors2.Unauthorizedf("invalid API key: %s", resp.Code),
Error: errors2.Unauthorizedf("invalid API key: %s", r.Code),
},
}, nil
}
}

userID := ""
if uid, ok := resp.Meta[ContextUserIDKey]; ok && uid != nil {
if uid, ok := resp.V2KeysVerifyKeyResponseBody.Data.Meta[ContextUserIDKey]; ok && uid != nil {
userID, ok = uid.(string)
if !ok || userID == "" {
return validateAuthTokenResult{
Expand All @@ -377,8 +419,8 @@ func validateUnkey(ctx context.Context, args interface{}) (interface{}, error) {
}
}

orgID := resp.OwnerID
if orgID == nil {
orgID := resp.V2KeysVerifyKeyResponseBody.Data.Identity.ExternalID
if orgID == "" {
return validateAuthTokenResult{
error: &validationErr{
Status: http.StatusUnauthorized,
Expand All @@ -387,9 +429,9 @@ func validateUnkey(ctx context.Context, args interface{}) (interface{}, error) {
}, nil
}

permissions := resp.Permissions

return validateAuthTokenResult{authed: true, userID: userID, orgID: *orgID, permissions: permissions}, nil
permissions := resp.V2KeysVerifyKeyResponseBody.Data.Permissions
credits := resp.V2KeysVerifyKeyResponseBody.Data.Credits
return validateAuthTokenResult{authed: true, userID: userID, orgID: orgID, permissions: permissions, credits: credits}, nil
}

func validateRootAPIKey(ctx context.Context, args interface{}) (interface{}, error) {
Expand Down
2 changes: 1 addition & 1 deletion apps/api/auth/unkey.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package auth

import (
unkey "github.com/unkeyed/unkey/sdks/golang"
unkey "github.com/unkeyed/sdks/api/go/v2"
"go.dot.industries/brease/env"
)

Expand Down
Loading
Loading