-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathhandler.go
More file actions
269 lines (234 loc) · 8.05 KB
/
Copy pathhandler.go
File metadata and controls
269 lines (234 loc) · 8.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
package main
import (
"context"
"crypto/rand"
"errors"
"fmt"
"log/slog"
"math/big"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
"github.com/hmchangw/chat/pkg/errcode"
"github.com/hmchangw/chat/pkg/errcode/errhttp"
pkgoidc "github.com/hmchangw/chat/pkg/oidc"
"github.com/hmchangw/chat/pkg/subject"
)
// TokenValidator validates an SSO token and returns OIDC claims.
type TokenValidator interface {
Validate(ctx context.Context, rawToken string) (pkgoidc.Claims, error)
}
type authRequest struct {
SSOToken string `json:"ssoToken" binding:"required"`
NATSPublicKey string `json:"natsPublicKey" binding:"required"`
}
type devAuthRequest struct {
Account string `json:"account" binding:"required"`
NATSPublicKey string `json:"natsPublicKey" binding:"required"`
}
type authResponse struct {
NATSJWT string `json:"natsJwt"`
UserInfo userInfoResp `json:"user"`
}
type userInfoResp struct {
Email string `json:"email"`
Account string `json:"account"`
EmployeeID string `json:"employeeId"`
EngName string `json:"engName"`
ChineseName string `json:"chineseName"`
DeptName string `json:"deptName"`
DeptID string `json:"deptId"`
}
// AuthHandler processes auth requests, validates SSO tokens via OIDC,
// and returns signed NATS user JWTs with scoped permissions.
type AuthHandler struct {
validator TokenValidator
signingKey nkeys.KeyPair
jwtExpiry time.Duration
jwtJitter float64 // fraction of jwtExpiry; 0 = fixed lifetime
randFloat func() float64 // injectable [0,1) source; defaults to crypto rand
devMode bool
}
// Option configures optional AuthHandler behavior.
type Option func(*AuthHandler)
// WithJitter sets the JWT-lifetime jitter fraction (clamped to [0, 0.9]) so a
// fleet of sessions minted together does not expire in lockstep.
func WithJitter(frac float64) Option {
return func(h *AuthHandler) {
if frac < 0 {
frac = 0
}
if frac > 0.9 {
frac = 0.9
}
h.jwtJitter = frac
}
}
// WithRandFloat overrides the randomness source (test seam).
func WithRandFloat(fn func() float64) Option {
return func(h *AuthHandler) { h.randFloat = fn }
}
// NewAuthHandler creates an AuthHandler with the given token validator,
// NATS account signing key, and JWT expiry duration.
func NewAuthHandler(validator TokenValidator, signingKey nkeys.KeyPair, jwtExpiry time.Duration, devMode bool, opts ...Option) *AuthHandler {
h := &AuthHandler{
validator: validator,
signingKey: signingKey,
jwtExpiry: jwtExpiry,
randFloat: cryptoRandFloat,
devMode: devMode,
}
for _, opt := range opts {
opt(h)
}
return h
}
// cryptoRandFloat returns a uniform float in [0,1) from crypto/rand. On the
// (practically impossible) read error it returns 0.5 — the no-skew midpoint.
func cryptoRandFloat() float64 {
const denom = 1 << 53
n, err := rand.Int(rand.Reader, big.NewInt(denom))
if err != nil {
slog.Error("crypto/rand read failed, using no-skew midpoint for JWT jitter", "error", err)
return 0.5
}
return float64(n.Int64()) / float64(denom)
}
// HandleAuth validates the SSO token, resolves permissions based on
// the user account, and returns a signed NATS JWT.
func (h *AuthHandler) HandleAuth(c *gin.Context) {
if h.devMode {
h.handleDevAuth(c)
return
}
ctx := errcode.WithLogValues(c.Request.Context(), "request_id", c.GetString("request_id"))
var req authRequest
if err := c.ShouldBindJSON(&req); err != nil {
errhttp.Write(ctx, c, errcode.BadRequest("ssoToken and natsPublicKey are required",
errcode.WithReason(errcode.AuthMissingFields)))
return
}
if !nkeys.IsValidPublicUserKey(req.NATSPublicKey) {
errhttp.Write(ctx, c, errcode.BadRequest("invalid natsPublicKey format",
errcode.WithReason(errcode.AuthInvalidNKey)))
return
}
claims, err := h.validator.Validate(ctx, req.SSOToken)
if err != nil {
if errors.Is(err, pkgoidc.ErrTokenExpired) {
errhttp.Write(ctx, c, errcode.Unauthenticated("SSO token has expired, please re-login",
errcode.WithReason(errcode.AuthTokenExpired)))
return
}
// Non-expiry failures surface as "invalid SSO token"; attach the raw
// cause so the server log carries the actual reason.
errhttp.Write(ctx, c, errcode.Unauthenticated("invalid SSO token",
errcode.WithReason(errcode.AuthInvalidToken),
errcode.WithCause(err)))
return
}
account := claims.Account()
if account == "" {
// Blank account would mint a JWT with chat.user..> permissions — refuse.
errhttp.Write(ctx, c, errcode.Unauthenticated("token missing account claim",
errcode.WithReason(errcode.AuthInvalidToken)))
return
}
if !subject.IsValidAccountToken(account) {
errhttp.Write(ctx, c, errcode.BadRequest("account must be a single NATS subject token (no '.', '*', '>' or whitespace)"))
return
}
ctx = errcode.WithLogValues(ctx, "account", account)
natsJWT, err := h.signNATSJWT(req.NATSPublicKey, account)
if err != nil {
errhttp.Write(ctx, c, fmt.Errorf("generating NATS token: %w", err))
return
}
slog.Debug("auth success", "account", account, "subject", claims.Subject)
// Parse description field: "employeeId, engName, chineseName"
employeeID, engName, chineseName := parseDescription(claims.Description)
c.JSON(http.StatusOK, authResponse{
NATSJWT: natsJWT,
UserInfo: userInfoResp{
Email: claims.Email,
Account: account,
EmployeeID: employeeID,
EngName: engName,
ChineseName: chineseName,
DeptName: claims.DeptName,
DeptID: claims.DeptID,
},
})
}
// handleDevAuth handles auth in dev mode: accepts account name directly
// without OIDC validation, for use during local development only.
func (h *AuthHandler) handleDevAuth(c *gin.Context) {
ctx := errcode.WithLogValues(c.Request.Context(), "request_id", c.GetString("request_id"))
var req devAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
errhttp.Write(ctx, c, errcode.BadRequest("account and natsPublicKey are required",
errcode.WithReason(errcode.AuthMissingFields)))
return
}
if !nkeys.IsValidPublicUserKey(req.NATSPublicKey) {
errhttp.Write(ctx, c, errcode.BadRequest("invalid natsPublicKey format",
errcode.WithReason(errcode.AuthInvalidNKey)))
return
}
if !subject.IsValidAccountToken(req.Account) {
errhttp.Write(ctx, c, errcode.BadRequest("account must be a single NATS subject token (no '.', '*', '>' or whitespace)"))
return
}
ctx = errcode.WithLogValues(ctx, "account", req.Account)
natsJWT, err := h.signNATSJWT(req.NATSPublicKey, req.Account)
if err != nil {
errhttp.Write(ctx, c, fmt.Errorf("generating NATS token: %w", err))
return
}
slog.Debug("dev auth success", "account", req.Account)
c.JSON(http.StatusOK, authResponse{
NATSJWT: natsJWT,
UserInfo: userInfoResp{
Email: req.Account + "@dev.local",
Account: req.Account,
EngName: req.Account,
},
})
}
// signNATSJWT signs a scoped NATS user JWT. Permissions and limits come
// from the account's scoped signing key template; the account tag drives
// per-user subject substitution ({{tag(account)}}).
func (h *AuthHandler) signNATSJWT(userPubKey, account string) (string, error) {
uc := jwt.NewUserClaims(userPubKey)
uc.Expires = h.jwtExpiryAt().Unix()
uc.Tags.Add("account:" + account)
uc.SetScoped(true)
return uc.Encode(h.signingKey)
}
// jwtExpiryAt returns the absolute expiry, applying ±jwtJitter around the base
// lifetime: factor = 1 + jitter*(2r-1), r in [0,1).
func (h *AuthHandler) jwtExpiryAt() time.Time {
factor := 1 + h.jwtJitter*(2*h.randFloat()-1)
return time.Now().Add(time.Duration(float64(h.jwtExpiry) * factor))
}
// parseDescription splits the description field "employeeId, engName, chineseName"
// into its three components.
func parseDescription(desc string) (employeeID, engName, chineseName string) {
parts := strings.SplitN(desc, ",", 3)
if len(parts) >= 1 {
employeeID = strings.TrimSpace(parts[0])
}
if len(parts) >= 2 {
engName = strings.TrimSpace(parts[1])
}
if len(parts) >= 3 {
chineseName = strings.TrimSpace(parts[2])
}
return
}
func (h *AuthHandler) HandleHealth(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}