diff --git a/core/common/config/env.go b/core/common/config/env.go index d2ecf84a..6219832e 100644 --- a/core/common/config/env.go +++ b/core/common/config/env.go @@ -3,41 +3,134 @@ package config import ( "os" "strings" + "sync/atomic" "github.com/labring/aiproxy/core/common/env" ) +type tokenVariants struct { + raw string + bearer string + sk string + bearerSK string +} + var ( - DebugEnabled bool - DebugSQLEnabled bool - DisableAutoMigrateDB bool - AdminKey string - WebPath string - DisableWeb bool - DisableWebRoot bool - FfmpegEnabled bool - InternalToken string - DisableModelConfig bool - Redis string - RedisKeyPrefix string - ConfigFilePath string + DebugEnabled bool + DebugSQLEnabled bool + DisableAutoMigrateDB bool + AdminKey string + DynamicRemoteAdminKey string + WebPath string + DisableWeb bool + DisableWebRoot bool + FfmpegEnabled bool + InternalToken string + DisableModelConfig bool + Redis string + RedisKeyPrefix string + ConfigFilePath string // OnCall Lark configuration for urgent alerts OnCallLarkAppID string OnCallLarkAppSecret string OnCallLarkOpenIDs []string // comma-separated open IDs + + adminKeyState atomic.Value + dynamicRemoteAdminKeyState atomic.Value + internalTokenState atomic.Value ) +func buildTokenVariants(token string) tokenVariants { + if token == "" { + return tokenVariants{} + } + + sk := "sk-" + token + return tokenVariants{ + raw: token, + bearer: "Bearer " + token, + sk: sk, + bearerSK: "Bearer " + sk, + } +} + +func SetAdminKey(key string) { + v := buildTokenVariants(key) + AdminKey = v.raw + adminKeyState.Store(v) +} + +func SetDynamicRemoteAdminKey(key string) { + v := buildTokenVariants(key) + DynamicRemoteAdminKey = v.raw + dynamicRemoteAdminKeyState.Store(v) +} + +func SetInternalToken(token string) { + v := buildTokenVariants(token) + InternalToken = v.raw + internalTokenState.Store(v) +} + +func GetAdminKey() string { + v, _ := adminKeyState.Load().(tokenVariants) + return v.raw +} + +func GetDynamicRemoteAdminKey() string { + v, _ := dynamicRemoteAdminKeyState.Load().(tokenVariants) + return v.raw +} + +func GetEffectiveAdminKey() string { + if key := GetDynamicRemoteAdminKey(); key != "" { + return key + } + + return GetAdminKey() +} + +func GetInternalToken() string { + v, _ := internalTokenState.Load().(tokenVariants) + return v.raw +} + +func MatchAdminKey(raw string) bool { + return matchTokenVariants(raw, adminKeyState) +} + +func MatchEffectiveAdminKey(raw string) bool { + if GetDynamicRemoteAdminKey() != "" { + return matchTokenVariants(raw, dynamicRemoteAdminKeyState) + } + + return matchTokenVariants(raw, adminKeyState) +} + +func matchTokenVariants(raw string, state atomic.Value) bool { + v, _ := state.Load().(tokenVariants) + return raw != "" && (raw == v.raw || + raw == v.bearer || + raw == v.sk || + raw == v.bearerSK) +} + +func MatchInternalToken(raw string) bool { + return matchTokenVariants(raw, internalTokenState) +} + func ReloadEnv() { DebugEnabled = env.Bool("DEBUG", false) DebugSQLEnabled = env.Bool("DEBUG_SQL", false) DisableAutoMigrateDB = env.Bool("DISABLE_AUTO_MIGRATE_DB", false) - AdminKey = os.Getenv("ADMIN_KEY") + SetAdminKey(os.Getenv("ADMIN_KEY")) + SetDynamicRemoteAdminKey("") WebPath = os.Getenv("WEB_PATH") DisableWeb = env.Bool("DISABLE_WEB", false) DisableWebRoot = env.Bool("DISABLE_WEB_ROOT", false) FfmpegEnabled = env.Bool("FFMPEG_ENABLED", false) - InternalToken = os.Getenv("INTERNAL_TOKEN") + SetInternalToken(os.Getenv("INTERNAL_TOKEN")) DisableModelConfig = env.Bool("DISABLE_MODEL_CONFIG", false) Redis = env.String("REDIS", os.Getenv("REDIS_CONN_STRING")) RedisKeyPrefix = os.Getenv("REDIS_KEY_PREFIX") diff --git a/core/common/trylock/export_test.go b/core/common/trylock/export_test.go new file mode 100644 index 00000000..281fb3c9 --- /dev/null +++ b/core/common/trylock/export_test.go @@ -0,0 +1,9 @@ +package trylock + +func InjectMemLockValueForTest(key string, value any) { + memRecord.Store(key, value) +} + +func ResetMemLockValueForTest(key string) { + memRecord.Delete(key) +} diff --git a/core/common/trylock/lock.go b/core/common/trylock/lock.go index e0706032..3d0a5d3a 100644 --- a/core/common/trylock/lock.go +++ b/core/common/trylock/lock.go @@ -3,6 +3,7 @@ package trylock import ( "context" "errors" + "fmt" "sync" "time" @@ -24,7 +25,11 @@ func cleanMemLock() { for now := range ticker.C { memRecord.Range(func(key, value any) bool { exp, ok := value.(time.Time) - if !ok || now.After(exp) { + if !ok { + panic(fmt.Sprintf("mem lock type mismatch: %T", value)) + } + + if now.After(exp) { memRecord.CompareAndDelete(key, value) } @@ -45,8 +50,7 @@ func MemLock(key string, expiration time.Duration) bool { oldExpiration, ok := actual.(time.Time) if !ok { - memRecord.CompareAndDelete(key, actual) - continue + panic(fmt.Sprintf("mem lock type mismatch: %T", actual)) } if now.After(oldExpiration) { diff --git a/core/common/trylock/lock_test.go b/core/common/trylock/lock_test.go index 29db65d4..03b0f619 100644 --- a/core/common/trylock/lock_test.go +++ b/core/common/trylock/lock_test.go @@ -26,3 +26,18 @@ func TestMemLock(t *testing.T) { t.Error("Expected true, Got false") } } + +func TestMemLockPanicsOnTypeMismatch(t *testing.T) { + trylock.InjectMemLockValueForTest("panic-key", "bad") + t.Cleanup(func() { + trylock.ResetMemLockValueForTest("panic-key") + }) + + defer func() { + if recover() == nil { + t.Fatal("expected panic on mem lock type mismatch") + } + }() + + trylock.MemLock("panic-key", time.Second) +} diff --git a/core/go.mod b/core/go.mod index 1c43b42b..f84fd2e1 100644 --- a/core/go.mod +++ b/core/go.mod @@ -4,6 +4,7 @@ go 1.26 require ( cloud.google.com/go/iam v1.9.0 + github.com/alicebob/miniredis/v2 v2.37.0 github.com/aws/aws-sdk-go-v2 v1.41.6 github.com/aws/aws-sdk-go-v2/credentials v1.19.15 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.5 @@ -165,6 +166,7 @@ require ( github.com/ugorji/go/codec v1.3.1 // indirect github.com/woodsbury/decimal128 v1.4.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.mongodb.org/mongo-driver/v2 v2.5.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect diff --git a/core/go.sum b/core/go.sum index d6fa5d29..2715175e 100644 --- a/core/go.sum +++ b/core/go.sum @@ -23,6 +23,8 @@ github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/PuerkitoBio/goquery v1.12.0 h1:pAcL4g3WRXekcB9AU/y1mbKez2dbY2AajVhtkO8RIBo= github.com/PuerkitoBio/goquery v1.12.0/go.mod h1:802ej+gV2y7bbIhOIoPY5sT183ZW0YFofScC4q/hIpQ= +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= @@ -382,6 +384,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= diff --git a/core/main.go b/core/main.go index 0749d7f0..bd38bd7e 100644 --- a/core/main.go +++ b/core/main.go @@ -42,10 +42,6 @@ func main() { config.ReloadEnv() - if err := ensureAdminKey(); err != nil { - log.Warn("failed to ensure AdminKey: " + err.Error()) - } - common.InitLog(log.StandardLogger(), config.DebugEnabled) printLoadedEnvFiles() @@ -94,6 +90,12 @@ func main() { go task.RedisHealthCheckTask(ctx) } + if task.AdminKeyCacheEnabled() { + log.Info("admin key cache task started") + + go task.AdminKeyCacheTask(ctx) + } + log.Info("update channels balance task started") go controller.UpdateChannelsBalance(time.Minute * 10) diff --git a/core/middleware/auth.go b/core/middleware/auth.go index a0b326db..e1bcb3ae 100644 --- a/core/middleware/auth.go +++ b/core/middleware/auth.go @@ -5,7 +5,6 @@ import ( "maps" "net/http" "slices" - "strings" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" @@ -39,7 +38,7 @@ func ErrorResponse(c *gin.Context, code int, message string) { } func AdminAuth(c *gin.Context) { - if config.AdminKey == "" { + if config.GetEffectiveAdminKey() == "" { ErrorResponse(c, http.StatusUnauthorized, "unauthorized, admin key is not set") c.Abort() return @@ -50,17 +49,14 @@ func AdminAuth(c *gin.Context) { accessToken = c.Query("key") } - accessToken = strings.TrimPrefix(accessToken, "Bearer ") - accessToken = strings.TrimPrefix(accessToken, "sk-") - - if accessToken != config.AdminKey { + if !config.MatchEffectiveAdminKey(accessToken) { ErrorResponse(c, http.StatusUnauthorized, "unauthorized, no access token provided") c.Abort() return } c.Set(Token, &model.TokenCache{ - Key: config.AdminKey, + Key: config.GetEffectiveAdminKey(), }) group := c.Param("group") @@ -75,32 +71,20 @@ func AdminAuth(c *gin.Context) { func TokenAuth(c *gin.Context) { log := common.GetLogger(c) - key := c.Request.Header.Get("Authorization") - if key == "" { - key = c.Request.Header.Get("X-Api-Key") - } - - if key == "" { - key = c.Request.Header.Get("X-Goog-Api-Key") - } - - key = strings.TrimPrefix( - strings.TrimPrefix(key, "Bearer "), - "sk-", - ) + key := requestToken(c.Request.Header) var ( token model.TokenCache useInternalToken bool ) - if config.AdminKey != "" && config.AdminKey == key || - config.InternalToken != "" && config.InternalToken == key { + if config.MatchEffectiveAdminKey(key) || config.MatchInternalToken(key) { token = model.TokenCache{ - Key: key, + Key: normalizeTokenKey(key), } useInternalToken = true } else { + key = normalizeTokenKey(key) tokenCache, err := model.GetAndValidateToken(key) if err != nil { oncall.AlertDBError("TokenAuth", err) @@ -307,3 +291,23 @@ func maskTokenKey(key string) string { } return key[:4] + "*****" + key[len(key)-4:] } + +func requestToken(headers http.Header) string { + if key := headers.Get("Authorization"); key != "" { + return key + } + if key := headers.Get("X-Api-Key"); key != "" { + return key + } + return headers.Get("X-Goog-Api-Key") +} + +func normalizeTokenKey(key string) string { + if len(key) >= 7 && key[:7] == "Bearer " { + key = key[7:] + } + if len(key) >= 3 && key[:3] == "sk-" { + key = key[3:] + } + return key +} diff --git a/core/middleware/auth_test.go b/core/middleware/auth_test.go new file mode 100644 index 00000000..52a4aa3b --- /dev/null +++ b/core/middleware/auth_test.go @@ -0,0 +1,122 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/labring/aiproxy/core/common/config" +) + +func TestRequestToken(t *testing.T) { + tests := []struct { + name string + set func(http.Header) + want string + }{ + { + name: "authorization takes precedence", + set: func(headers http.Header) { + headers.Set("Authorization", "Bearer auth-token") + headers.Set("X-Api-Key", "api-token") + headers.Set("X-Goog-Api-Key", "goog-token") + }, + want: "Bearer auth-token", + }, + { + name: "x api key fallback", + set: func(headers http.Header) { + headers.Set("X-Api-Key", "api-token") + headers.Set("X-Goog-Api-Key", "goog-token") + }, + want: "api-token", + }, + { + name: "x goog api key fallback", + set: func(headers http.Header) { + headers.Set("X-Goog-Api-Key", "goog-token") + }, + want: "goog-token", + }, + { + name: "empty when missing", + set: func(http.Header) {}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := make(http.Header) + tt.set(headers) + + if got := requestToken(headers); got != tt.want { + t.Fatalf("requestToken() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNormalizeTokenKey(t *testing.T) { + tests := []struct { + name string + key string + want string + }{ + { + name: "bearer token", + key: "Bearer token-value", + want: "token-value", + }, + { + name: "sk token", + key: "sk-token-value", + want: "token-value", + }, + { + name: "bearer sk token", + key: "Bearer sk-token-value", + want: "token-value", + }, + { + name: "plain token", + key: "token-value", + want: "token-value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeTokenKey(tt.key); got != tt.want { + t.Fatalf("normalizeTokenKey() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestEffectiveAdminKeyUsesDynamicRemoteKey(t *testing.T) { + oldAdminKey := config.GetAdminKey() + oldDynamicRemoteAdminKey := config.GetDynamicRemoteAdminKey() + t.Cleanup(func() { + config.SetAdminKey(oldAdminKey) + config.SetDynamicRemoteAdminKey(oldDynamicRemoteAdminKey) + }) + + config.SetAdminKey("static-admin") + config.SetDynamicRemoteAdminKey("dynamic-admin") + + if config.GetAdminKey() != "static-admin" { + t.Fatalf("GetAdminKey() = %q, want static-admin", config.GetAdminKey()) + } + if config.GetEffectiveAdminKey() != "dynamic-admin" { + t.Fatalf("GetEffectiveAdminKey() = %q, want dynamic-admin", config.GetEffectiveAdminKey()) + } + if !config.MatchEffectiveAdminKey("dynamic-admin") { + t.Fatal("dynamic admin key should match effective admin key") + } + if config.MatchEffectiveAdminKey("static-admin") { + t.Fatal("static admin key should not match while dynamic admin key is active") + } + if config.MatchAdminKey("dynamic-admin") { + t.Fatal("dynamic admin key should not mutate static admin key matching") + } +} diff --git a/core/middleware/mcp.go b/core/middleware/mcp.go index afee9e75..a5355358 100644 --- a/core/middleware/mcp.go +++ b/core/middleware/mcp.go @@ -3,7 +3,6 @@ package middleware import ( "fmt" "net/http" - "strings" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" @@ -20,23 +19,18 @@ func MCPAuth(c *gin.Context) { key, _ = c.GetQuery("key") } - key = strings.TrimPrefix( - strings.TrimPrefix(key, "Bearer "), - "sk-", - ) - var ( token model.TokenCache useInternalToken bool ) - if config.AdminKey != "" && config.AdminKey == key || - config.InternalToken != "" && config.InternalToken == key { + if config.MatchEffectiveAdminKey(key) || config.MatchInternalToken(key) { token = model.TokenCache{ - Key: key, + Key: normalizeTokenKey(key), } useInternalToken = true } else { + key = normalizeTokenKey(key) tokenCache, err := model.GetAndValidateToken(key) if err != nil { AbortLogWithMessage(c, http.StatusUnauthorized, err.Error()) diff --git a/core/startup.go b/core/startup.go index 13051a2a..cb29a22d 100644 --- a/core/startup.go +++ b/core/startup.go @@ -24,6 +24,7 @@ import ( "github.com/labring/aiproxy/core/middleware" "github.com/labring/aiproxy/core/model" "github.com/labring/aiproxy/core/router" + "github.com/labring/aiproxy/core/task" log "github.com/sirupsen/logrus" ) @@ -35,6 +36,10 @@ func initializeServices(pprofPort int) error { return err } + if err := initializeAdminKey(); err != nil { + return err + } + // Initialize oncall after Redis so it can use Redis for state synchronization oncall.Init() @@ -53,6 +58,22 @@ func initializeServices(pprofPort int) error { return model.InitLogDB(int(config.GetCleanLogBatchSize())) } +func initializeAdminKey() error { + if err := task.InitAdminKeyCache(context.Background()); err != nil { + log.Warn("failed to initialize AdminKey cache: " + err.Error()) + } + + if err := ensureAdminKey(); err != nil { + log.Warn("failed to ensure AdminKey: " + err.Error()) + } + + if err := task.InitAdminKeyCache(context.Background()); err != nil { + log.Warn("failed to refresh AdminKey cache: " + err.Error()) + } + + return nil +} + func initializePprof(pprofPort int) { go func() { err := pprof.RunPprofServer(pprofPort) @@ -219,14 +240,14 @@ func writeToEnvFile(envFile, key, value string) error { } func ensureAdminKey() error { - if config.AdminKey != "" { + if config.GetAdminKey() != "" { log.Info("AdminKey is already set") return nil } log.Info("AdminKey is not set, generating new AdminKey...") - config.AdminKey = generateAdminKey() + config.SetAdminKey(generateAdminKey()) envFile := ".env.aiproxy.local" @@ -235,7 +256,7 @@ func ensureAdminKey() error { envFile = absEnvFile } - if err := writeToEnvFile(envFile, "ADMIN_KEY", config.AdminKey); err != nil { + if err := writeToEnvFile(envFile, "ADMIN_KEY", config.GetAdminKey()); err != nil { return fmt.Errorf("failed to write AdminKey to %s: %w", envFile, err) } diff --git a/core/task/admin_key_cache.go b/core/task/admin_key_cache.go new file mode 100644 index 00000000..faf5c212 --- /dev/null +++ b/core/task/admin_key_cache.go @@ -0,0 +1,113 @@ +package task + +import ( + "context" + "time" + + "github.com/labring/aiproxy/core/common" + "github.com/labring/aiproxy/core/common/config" + "github.com/redis/go-redis/v9" + log "github.com/sirupsen/logrus" +) + +const ( + dynamicRemoteAdminKeyRedisKey = "dynamic-remote-admin-key" + adminKeySyncInterval = 500 * time.Millisecond + adminKeyCacheInitWait = 5 * time.Second + adminKeyCacheOpWait = 2 * time.Second +) + +func AdminKeyCacheEnabled() bool { + return common.RedisEnabled && common.RDB != nil +} + +func InitAdminKeyCache(ctx context.Context) error { + if !AdminKeyCacheEnabled() { + return nil + } + + opCtx, cancel := context.WithTimeout(ctx, adminKeyCacheInitWait) + defer cancel() + + return syncAdminKeyCacheOnce(opCtx) +} + +func AdminKeyCacheTask(ctx context.Context) { + if !AdminKeyCacheEnabled() { + return + } + + ticker := time.NewTicker(adminKeySyncInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + opCtx, cancel := context.WithTimeout(ctx, adminKeyCacheOpWait) + err := syncAdminKeyCacheOnce(opCtx) + cancel() + if err != nil { + log.Debugf("admin key cache sync failed: %v", err) + } + } + } +} + +func syncAdminKeyCacheOnce(ctx context.Context) error { + cachedAdminKey, err := loadCachedAdminKey(ctx) + if err != nil { + return err + } + + localAdminKey := config.GetEffectiveAdminKey() + if cachedAdminKey != "" { + if cachedAdminKey != localAdminKey { + config.SetDynamicRemoteAdminKey(cachedAdminKey) + log.Info("admin key loaded from redis") + } + + return nil + } + + if localAdminKey == "" { + return nil + } + + created, err := common.RDB.SetNX(ctx, getAdminKeyCacheKey(), localAdminKey, 0).Result() + if err != nil { + return err + } + if created { + log.Info("admin key synced to redis") + return nil + } + + cachedAdminKey, err = loadCachedAdminKey(ctx) + if err != nil { + return err + } + if cachedAdminKey != "" && cachedAdminKey != localAdminKey { + config.SetDynamicRemoteAdminKey(cachedAdminKey) + log.Info("admin key loaded from redis") + } + + return nil +} + +func loadCachedAdminKey(ctx context.Context) (string, error) { + adminKey, err := common.RDB.Get(ctx, getAdminKeyCacheKey()).Result() + if err == nil { + return adminKey, nil + } + if err != redis.Nil { + return "", err + } + + return "", nil +} + +func getAdminKeyCacheKey() string { + return common.RedisKey(dynamicRemoteAdminKeyRedisKey) +} diff --git a/core/task/admin_key_cache_test.go b/core/task/admin_key_cache_test.go new file mode 100644 index 00000000..db1cbbef --- /dev/null +++ b/core/task/admin_key_cache_test.go @@ -0,0 +1,176 @@ +package task + +import ( + "context" + "os" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/labring/aiproxy/core/common" + "github.com/labring/aiproxy/core/common/config" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestInitAdminKeyCacheLoadsCachedKeyIntoDynamicState(t *testing.T) { + ctx := context.Background() + client, cleanup := setupRedisForAdminKeyCacheTest(t, ctx) + defer cleanup() + + configureAdminKeyCacheTest(t, client) + config.SetAdminKey("local-key") + + require.NoError(t, client.Set(ctx, getAdminKeyCacheKey(), "redis-key", 0).Err()) + require.NoError(t, InitAdminKeyCache(ctx)) + require.Equal(t, "local-key", config.GetAdminKey()) + require.Equal(t, "redis-key", config.GetDynamicRemoteAdminKey()) + require.Equal(t, "redis-key", config.GetEffectiveAdminKey()) +} + +func TestInitAdminKeyCacheBootstrapsLocalKey(t *testing.T) { + ctx := context.Background() + client, cleanup := setupRedisForAdminKeyCacheTest(t, ctx) + defer cleanup() + + configureAdminKeyCacheTest(t, client) + config.SetAdminKey("local-key") + + require.NoError(t, InitAdminKeyCache(ctx)) + + cachedKey, err := client.Get(ctx, getAdminKeyCacheKey()).Result() + require.NoError(t, err) + require.Equal(t, "local-key", cachedKey) + require.Equal(t, "", config.GetDynamicRemoteAdminKey()) +} + +func TestInitAdminKeyCacheNoopsWithoutRedis(t *testing.T) { + oldRDB := common.RDB + oldRedisEnabled := common.RedisEnabled + oldRedisKeyPrefix := config.RedisKeyPrefix + oldAdminKey := config.GetAdminKey() + oldDynamicRemoteAdminKey := config.GetDynamicRemoteAdminKey() + oldInternalToken := config.GetInternalToken() + + common.RDB = nil + common.RedisEnabled = true + config.RedisKeyPrefix = "" + config.SetAdminKey("local-key") + config.SetDynamicRemoteAdminKey("") + config.SetInternalToken("") + + t.Cleanup(func() { + common.RDB = oldRDB + common.RedisEnabled = oldRedisEnabled + config.RedisKeyPrefix = oldRedisKeyPrefix + config.SetAdminKey(oldAdminKey) + config.SetDynamicRemoteAdminKey(oldDynamicRemoteAdminKey) + config.SetInternalToken(oldInternalToken) + }) + + require.False(t, AdminKeyCacheEnabled()) + require.NoError(t, InitAdminKeyCache(context.Background())) + require.Equal(t, "local-key", config.GetAdminKey()) +} + +func TestInitAdminKeyCacheReturnsRedisError(t *testing.T) { + client := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 10 * time.Millisecond, + ReadTimeout: 10 * time.Millisecond, + WriteTimeout: 10 * time.Millisecond, + }) + defer client.Close() + + configureAdminKeyCacheTest(t, client) + config.SetAdminKey("local-key") + + err := InitAdminKeyCache(context.Background()) + require.Error(t, err) +} + +func TestAdminKeyCacheTaskUpdatesDynamicKey(t *testing.T) { + ctx := context.Background() + client, cleanup := setupRedisForAdminKeyCacheTest(t, ctx) + defer cleanup() + + configureAdminKeyCacheTest(t, client) + config.SetAdminKey("initial-key") + + require.NoError(t, InitAdminKeyCache(ctx)) + + taskCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go AdminKeyCacheTask(taskCtx) + + require.NoError(t, client.Set(ctx, getAdminKeyCacheKey(), "rotated-key", 0).Err()) + require.Eventually(t, func() bool { + return config.GetDynamicRemoteAdminKey() == "rotated-key" && + config.GetAdminKey() == "initial-key" + }, 3*time.Second, 50*time.Millisecond) +} + +func TestAdminKeyCacheUsesStableRedisKey(t *testing.T) { + oldRedisKeyPrefix := config.RedisKeyPrefix + config.RedisKeyPrefix = "review-scope" + t.Cleanup(func() { + config.RedisKeyPrefix = oldRedisKeyPrefix + }) + + require.Equal(t, "review-scope:dynamic-remote-admin-key", getAdminKeyCacheKey()) +} + +func configureAdminKeyCacheTest(t *testing.T, client *redis.Client) { + t.Helper() + + oldRDB := common.RDB + oldRedisEnabled := common.RedisEnabled + oldRedisKeyPrefix := config.RedisKeyPrefix + oldAdminKey := config.GetAdminKey() + oldDynamicRemoteAdminKey := config.GetDynamicRemoteAdminKey() + oldInternalToken := config.GetInternalToken() + oldSealosJWTKey, hadSealosJWTKey := os.LookupEnv("SEALOS_JWT_KEY") + + common.RDB = client + common.RedisEnabled = client != nil + config.RedisKeyPrefix = "" + config.SetAdminKey("") + config.SetDynamicRemoteAdminKey("") + config.SetInternalToken("admin-key-cache-test-scope") + require.NoError(t, os.Unsetenv("SEALOS_JWT_KEY")) + + t.Cleanup(func() { + common.RDB = oldRDB + common.RedisEnabled = oldRedisEnabled + config.RedisKeyPrefix = oldRedisKeyPrefix + config.SetAdminKey(oldAdminKey) + config.SetDynamicRemoteAdminKey(oldDynamicRemoteAdminKey) + config.SetInternalToken(oldInternalToken) + if hadSealosJWTKey { + require.NoError(t, os.Setenv("SEALOS_JWT_KEY", oldSealosJWTKey)) + } else { + require.NoError(t, os.Unsetenv("SEALOS_JWT_KEY")) + } + }) +} + +func setupRedisForAdminKeyCacheTest(t *testing.T, ctx context.Context) (*redis.Client, func()) { + t.Helper() + + server, err := miniredis.Run() + require.NoError(t, err) + + client := redis.NewClient(&redis.Options{ + Addr: server.Addr(), + DB: 0, + }) + require.NoError(t, client.Ping(ctx).Err()) + + cleanup := func() { + _ = client.Close() + server.Close() + } + + return client, cleanup +}