Skip to content

Commit ae7cea9

Browse files
authored
Merge pull request #3895 from skoeva/auth6
backend: auth: Extract IsTokenAboutToExpire from headlamp.go
2 parents 7507f80 + 0887d76 commit ae7cea9

File tree

4 files changed

+88
-43
lines changed

4 files changed

+88
-43
lines changed

backend/cmd/headlamp.go

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -786,30 +786,6 @@ func createHeadlampHandler(config *HeadlampConfig) http.Handler {
786786
return r
787787
}
788788

789-
func isTokenAboutToExpire(token string) bool {
790-
const tokenParts = 3
791-
792-
parts := strings.Split(token, ".")
793-
if len(parts) != tokenParts {
794-
return false
795-
}
796-
797-
payload, err := auth.DecodeBase64JSON(parts[1])
798-
if err != nil {
799-
logger.Log(logger.LevelError, nil, err, "failed to decode payload")
800-
return false
801-
}
802-
803-
expiryUnixTimeUTC, err := auth.GetExpiryUnixTimeUTC(payload)
804-
if err != nil {
805-
logger.Log(logger.LevelError, nil, err, "failed to get expiry time")
806-
return false
807-
}
808-
809-
// This time comparison is timezone aware, so it works correctly
810-
return time.Until(expiryUnixTimeUTC) <= JWTExpirationTTL
811-
}
812-
813789
// configureTLSContext configures TLS settings for the HTTP client in the context.
814790
// If skipTLSVerify is true, TLS verification will be skipped.
815791
// If caCert is provided, it will be added to the certificate pool for TLS verification.
@@ -1069,7 +1045,7 @@ func (c *HeadlampConfig) OIDCTokenRefreshMiddleware(next http.Handler) http.Hand
10691045
}
10701046

10711047
// skip if token is not about to expire
1072-
if !isTokenAboutToExpire(token) {
1048+
if !auth.IsTokenAboutToExpire(token) {
10731049
c.telemetryHandler.RecordEvent(span, "Token not about to expire, skipping refresh")
10741050
next.ServeHTTP(w, r)
10751051
c.telemetryHandler.RecordDuration(ctx, start,

backend/cmd/headlamp_test.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"os"
3333
"path/filepath"
3434
"strconv"
35-
"strings"
3635
"testing"
3736
"time"
3837

@@ -932,23 +931,6 @@ func TestGetOidcCallbackURL(t *testing.T) {
932931
}
933932
}
934933

935-
func TestIsTokenAboutToExpire(t *testing.T) {
936-
// Token that expires in 4 minutes
937-
header := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
938-
originalPayload := "eyJleHAiOjE2MTIzNjE2MDB9"
939-
signature := ".7vl9iBWGDQdXUTbEsqFHiHoaNWxKn4UwLhO9QDhXrpM"
940-
941-
token := header + originalPayload + signature
942-
result := isTokenAboutToExpire(token)
943-
assert.True(t, result)
944-
945-
modifiedPayload := strings.Replace(originalPayload, "J", "-", 1)
946-
947-
token = header + modifiedPayload + signature
948-
result = isTokenAboutToExpire(token)
949-
assert.False(t, result, "Expected to return false when payload decoding fails due to URL-safe characters")
950-
}
951-
952934
func TestOIDCTokenRefreshMiddleware(t *testing.T) {
953935
kubeConfigStore := kubeconfig.NewContextStore()
954936
config := &HeadlampConfig{

backend/pkg/auth/auth.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ const (
3636
oidcKeyPrefix = "oidc-token-"
3737
)
3838

39+
const JWTExpirationTTL = 10 * time.Second // seconds
40+
3941
// DecodeBase64JSON decodes a base64 URL-encoded JSON string into a map.
4042
func DecodeBase64JSON(base64JSON string) (map[string]interface{}, error) {
4143
payloadBytes, err := base64.RawURLEncoding.DecodeString(base64JSON)
@@ -100,6 +102,30 @@ func GetExpiryUnixTimeUTC(tokenPayload map[string]interface{}) (time.Time, error
100102
return time.Unix(int64(exp), 0).UTC(), nil
101103
}
102104

105+
// IsTokenAboutToExpire reports whether the given token is within JWTExpirationTTL
106+
// of its expiry time.
107+
func IsTokenAboutToExpire(token string) bool {
108+
parts := strings.SplitN(token, ".", 3)
109+
if len(parts) != 3 || parts[1] == "" {
110+
return false
111+
}
112+
113+
payload, err := DecodeBase64JSON(parts[1])
114+
if err != nil {
115+
logger.Log(logger.LevelError, nil, err, "failed to decode payload")
116+
return false
117+
}
118+
119+
expiryUnixTimeUTC, err := GetExpiryUnixTimeUTC(payload)
120+
if err != nil {
121+
logger.Log(logger.LevelError, nil, err, "failed to get expiry time")
122+
return false
123+
}
124+
125+
// This time comparison is timezone aware, so it works correctly
126+
return time.Until(expiryUnixTimeUTC) <= JWTExpirationTTL
127+
}
128+
103129
// CacheRefreshedToken updates the refresh token in the cache.
104130
func CacheRefreshedToken(token *oauth2.Token, tokenType string, oldToken string,
105131
oldRefreshToken string, cache cache.Cache[interface{}],

backend/pkg/auth/auth_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package auth_test
1818

1919
import (
2020
"context"
21+
"encoding/base64"
22+
"encoding/json"
2123
"errors"
2224
"net/http"
2325
"reflect"
@@ -283,6 +285,65 @@ func TestGetExpiryUnixTimeUTC(t *testing.T) {
283285
}
284286
}
285287

288+
const headerBase64 = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0"
289+
290+
func makeJWTWithPayload(t *testing.T, payload map[string]interface{}) string {
291+
t.Helper()
292+
293+
payloadBytes, err := json.Marshal(payload)
294+
if err != nil {
295+
t.Fatalf("failed to marshal payload: %v", err)
296+
}
297+
298+
payloadBase64 := base64.RawURLEncoding.EncodeToString(payloadBytes)
299+
300+
// 3 segments: header.payload.signature (signature)
301+
return headerBase64 + "." + payloadBase64 + "."
302+
}
303+
304+
func TestIsTokenAboutToExpire_Window(t *testing.T) {
305+
now := time.Now()
306+
tests := []struct {
307+
name string
308+
exp time.Time
309+
want bool
310+
}{
311+
{"within window", now.Add(auth.JWTExpirationTTL / 2), true},
312+
{"beyond window", now.Add(auth.JWTExpirationTTL + 30*time.Second), false},
313+
{"already expired", now.Add(-5 * time.Second), true},
314+
}
315+
316+
for _, tt := range tests {
317+
t.Run(tt.name, func(t *testing.T) {
318+
token := makeJWTWithPayload(t, map[string]interface{}{"exp": float64(tt.exp.Unix())})
319+
if got := auth.IsTokenAboutToExpire(token); got != tt.want {
320+
t.Fatalf("IsTokenAboutToExpire() = %v, want %v (exp=%v, now=%v)",
321+
got, tt.want, tt.exp.UTC(), now.UTC())
322+
}
323+
})
324+
}
325+
}
326+
327+
func TestIsTokenAboutToExpire_InvalidInputs(t *testing.T) {
328+
tests := []struct {
329+
name string
330+
token string
331+
}{
332+
{"not three parts", "not-a-jwt"},
333+
{"invalid base64 payload", headerBase64 + ".%%%." + "."},
334+
{"missing exp", makeJWTWithPayload(t, map[string]interface{}{})},
335+
{"non-numeric exp", makeJWTWithPayload(t, map[string]interface{}{"exp": "1609459200"})},
336+
}
337+
338+
for _, tt := range tests {
339+
t.Run(tt.name, func(t *testing.T) {
340+
if got := auth.IsTokenAboutToExpire(tt.token); got {
341+
t.Fatalf("IsTokenAboutToExpire() = true, want false for %s", tt.name)
342+
}
343+
})
344+
}
345+
}
346+
286347
type cacheStub struct{}
287348

288349
func (cacheStub) Delete(ctx context.Context, k string) error {

0 commit comments

Comments
 (0)