Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions core/common/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,134 @@ package config
import (
"os"
"strings"
"sync/atomic"

"github.com/labring/aiproxy/core/common/env"
)

type tokenVariants struct {
raw string
bearer string
sk string
bearerSK string
}

var (
DebugEnabled bool
DebugSQLEnabled bool
DisableAutoMigrateDB bool
AdminKey string
WebPath string
DisableWeb bool
DisableWebRoot bool
FfmpegEnabled bool
InternalToken string
DisableModelConfig bool
Redis string
RedisKeyPrefix string
ConfigFilePath string
DebugEnabled bool
DebugSQLEnabled bool
DisableAutoMigrateDB bool
AdminKey string
DynamicRemoteAdminKey string
WebPath string
DisableWeb bool
DisableWebRoot bool
FfmpegEnabled bool
InternalToken string
DisableModelConfig bool
Redis string
RedisKeyPrefix string
ConfigFilePath string

// OnCall Lark configuration for urgent alerts
OnCallLarkAppID string
OnCallLarkAppSecret string
OnCallLarkOpenIDs []string // comma-separated open IDs

adminKeyState atomic.Value
dynamicRemoteAdminKeyState atomic.Value
internalTokenState atomic.Value
)

func buildTokenVariants(token string) tokenVariants {
if token == "" {
return tokenVariants{}
}

sk := "sk-" + token
return tokenVariants{
raw: token,
bearer: "Bearer " + token,
sk: sk,
bearerSK: "Bearer " + sk,
}
}

func SetAdminKey(key string) {
v := buildTokenVariants(key)
AdminKey = v.raw
adminKeyState.Store(v)
}

func SetDynamicRemoteAdminKey(key string) {
v := buildTokenVariants(key)
DynamicRemoteAdminKey = v.raw
dynamicRemoteAdminKeyState.Store(v)
}

func SetInternalToken(token string) {
v := buildTokenVariants(token)
InternalToken = v.raw
internalTokenState.Store(v)
}

func GetAdminKey() string {
v, _ := adminKeyState.Load().(tokenVariants)
return v.raw
}

func GetDynamicRemoteAdminKey() string {
v, _ := dynamicRemoteAdminKeyState.Load().(tokenVariants)
return v.raw
}

func GetEffectiveAdminKey() string {
if key := GetDynamicRemoteAdminKey(); key != "" {
return key
}

return GetAdminKey()
}

func GetInternalToken() string {
v, _ := internalTokenState.Load().(tokenVariants)
return v.raw
}

func MatchAdminKey(raw string) bool {
return matchTokenVariants(raw, adminKeyState)
}

func MatchEffectiveAdminKey(raw string) bool {
if GetDynamicRemoteAdminKey() != "" {
return matchTokenVariants(raw, dynamicRemoteAdminKeyState)
}

return matchTokenVariants(raw, adminKeyState)
}

func matchTokenVariants(raw string, state atomic.Value) bool {
v, _ := state.Load().(tokenVariants)
Comment thread
Iweisc marked this conversation as resolved.
return raw != "" && (raw == v.raw ||
raw == v.bearer ||
raw == v.sk ||
raw == v.bearerSK)
}

func MatchInternalToken(raw string) bool {
return matchTokenVariants(raw, internalTokenState)
}

func ReloadEnv() {
DebugEnabled = env.Bool("DEBUG", false)
DebugSQLEnabled = env.Bool("DEBUG_SQL", false)
DisableAutoMigrateDB = env.Bool("DISABLE_AUTO_MIGRATE_DB", false)
AdminKey = os.Getenv("ADMIN_KEY")
SetAdminKey(os.Getenv("ADMIN_KEY"))
SetDynamicRemoteAdminKey("")
WebPath = os.Getenv("WEB_PATH")
DisableWeb = env.Bool("DISABLE_WEB", false)
DisableWebRoot = env.Bool("DISABLE_WEB_ROOT", false)
FfmpegEnabled = env.Bool("FFMPEG_ENABLED", false)
InternalToken = os.Getenv("INTERNAL_TOKEN")
SetInternalToken(os.Getenv("INTERNAL_TOKEN"))
DisableModelConfig = env.Bool("DISABLE_MODEL_CONFIG", false)
Redis = env.String("REDIS", os.Getenv("REDIS_CONN_STRING"))
RedisKeyPrefix = os.Getenv("REDIS_KEY_PREFIX")
Expand Down
9 changes: 9 additions & 0 deletions core/common/trylock/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package trylock

func InjectMemLockValueForTest(key string, value any) {
memRecord.Store(key, value)
}

func ResetMemLockValueForTest(key string) {
memRecord.Delete(key)
}
10 changes: 7 additions & 3 deletions core/common/trylock/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package trylock
import (
"context"
"errors"
"fmt"
"sync"
"time"

Expand All @@ -24,7 +25,11 @@ func cleanMemLock() {
for now := range ticker.C {
memRecord.Range(func(key, value any) bool {
exp, ok := value.(time.Time)
if !ok || now.After(exp) {
if !ok {
panic(fmt.Sprintf("mem lock type mismatch: %T", value))
}

if now.After(exp) {
memRecord.CompareAndDelete(key, value)
}

Expand All @@ -45,8 +50,7 @@ func MemLock(key string, expiration time.Duration) bool {

oldExpiration, ok := actual.(time.Time)
if !ok {
memRecord.CompareAndDelete(key, actual)
continue
panic(fmt.Sprintf("mem lock type mismatch: %T", actual))
}

if now.After(oldExpiration) {
Expand Down
15 changes: 15 additions & 0 deletions core/common/trylock/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ func TestMemLock(t *testing.T) {
t.Error("Expected true, Got false")
}
}

func TestMemLockPanicsOnTypeMismatch(t *testing.T) {
trylock.InjectMemLockValueForTest("panic-key", "bad")
t.Cleanup(func() {
trylock.ResetMemLockValueForTest("panic-key")
})

defer func() {
if recover() == nil {
t.Fatal("expected panic on mem lock type mismatch")
}
}()

trylock.MemLock("panic-key", time.Second)
}
2 changes: 2 additions & 0 deletions core/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.26

require (
cloud.google.com/go/iam v1.9.0
github.com/alicebob/miniredis/v2 v2.37.0
github.com/aws/aws-sdk-go-v2 v1.41.6
github.com/aws/aws-sdk-go-v2/credentials v1.19.15
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.5
Expand Down Expand Up @@ -165,6 +166,7 @@ require (
github.com/ugorji/go/codec v1.3.1 // indirect
github.com/woodsbury/decimal128 v1.4.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.1 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
Expand Down
4 changes: 4 additions & 0 deletions core/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA
github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk=
github.com/PuerkitoBio/goquery v1.12.0 h1:pAcL4g3WRXekcB9AU/y1mbKez2dbY2AajVhtkO8RIBo=
github.com/PuerkitoBio/goquery v1.12.0/go.mod h1:802ej+gV2y7bbIhOIoPY5sT183ZW0YFofScC4q/hIpQ=
github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68=
github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
Expand Down Expand Up @@ -382,6 +384,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
Expand Down
10 changes: 6 additions & 4 deletions core/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ func main() {

config.ReloadEnv()

if err := ensureAdminKey(); err != nil {
log.Warn("failed to ensure AdminKey: " + err.Error())
}

common.InitLog(log.StandardLogger(), config.DebugEnabled)

printLoadedEnvFiles()
Expand Down Expand Up @@ -94,6 +90,12 @@ func main() {
go task.RedisHealthCheckTask(ctx)
}

if task.AdminKeyCacheEnabled() {
log.Info("admin key cache task started")

go task.AdminKeyCacheTask(ctx)
}

log.Info("update channels balance task started")

go controller.UpdateChannelsBalance(time.Minute * 10)
Expand Down
50 changes: 27 additions & 23 deletions core/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"maps"
"net/http"
"slices"
"strings"

"github.com/gin-gonic/gin"
"github.com/labring/aiproxy/core/common"
Expand Down Expand Up @@ -39,7 +38,7 @@ func ErrorResponse(c *gin.Context, code int, message string) {
}

func AdminAuth(c *gin.Context) {
if config.AdminKey == "" {
if config.GetEffectiveAdminKey() == "" {
ErrorResponse(c, http.StatusUnauthorized, "unauthorized, admin key is not set")
c.Abort()
return
Expand All @@ -50,17 +49,14 @@ func AdminAuth(c *gin.Context) {
accessToken = c.Query("key")
}

accessToken = strings.TrimPrefix(accessToken, "Bearer ")
accessToken = strings.TrimPrefix(accessToken, "sk-")

if accessToken != config.AdminKey {
if !config.MatchEffectiveAdminKey(accessToken) {
ErrorResponse(c, http.StatusUnauthorized, "unauthorized, no access token provided")
c.Abort()
return
}

c.Set(Token, &model.TokenCache{
Key: config.AdminKey,
Key: config.GetEffectiveAdminKey(),
})

group := c.Param("group")
Expand All @@ -75,32 +71,20 @@ func AdminAuth(c *gin.Context) {
func TokenAuth(c *gin.Context) {
log := common.GetLogger(c)

key := c.Request.Header.Get("Authorization")
if key == "" {
key = c.Request.Header.Get("X-Api-Key")
}

if key == "" {
key = c.Request.Header.Get("X-Goog-Api-Key")
}

key = strings.TrimPrefix(
strings.TrimPrefix(key, "Bearer "),
"sk-",
)
key := requestToken(c.Request.Header)

var (
token model.TokenCache
useInternalToken bool
)

if config.AdminKey != "" && config.AdminKey == key ||
config.InternalToken != "" && config.InternalToken == key {
if config.MatchEffectiveAdminKey(key) || config.MatchInternalToken(key) {
token = model.TokenCache{
Key: key,
Key: normalizeTokenKey(key),
}
useInternalToken = true
} else {
key = normalizeTokenKey(key)
tokenCache, err := model.GetAndValidateToken(key)
if err != nil {
oncall.AlertDBError("TokenAuth", err)
Expand Down Expand Up @@ -307,3 +291,23 @@ func maskTokenKey(key string) string {
}
return key[:4] + "*****" + key[len(key)-4:]
}

func requestToken(headers http.Header) string {
if key := headers.Get("Authorization"); key != "" {
return key
}
if key := headers.Get("X-Api-Key"); key != "" {
return key
}
return headers.Get("X-Goog-Api-Key")
}

func normalizeTokenKey(key string) string {
if len(key) >= 7 && key[:7] == "Bearer " {
key = key[7:]
}
if len(key) >= 3 && key[:3] == "sk-" {
key = key[3:]
}
return key
}
Loading