Skip to content

Commit 25044cd

Browse files
authored
feat: web search adaptor plugin (labring#216)
* feat: web search adaptor plugin * fix: ci lint * fix: remove web_search_options field on web search plugin * docs: add web search docs
1 parent 5431d0d commit 25044cd

107 files changed

Lines changed: 2138 additions & 249 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/common/gin.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) {
4040
return
4141
}
4242

43+
func SetRequestBody(req *http.Request, body []byte) {
44+
ctx := req.Context()
45+
bufCtx := context.WithValue(ctx, RequestBodyKey{}, body)
46+
*req = *req.WithContext(bufCtx)
47+
}
48+
4349
func GetRequestBody(req *http.Request) ([]byte, error) {
4450
contentType := req.Header.Get("Content-Type")
4551
if contentType == "application/x-www-form-urlencoded" ||
@@ -78,9 +84,7 @@ func GetRequestBody(req *http.Request) ([]byte, error) {
7884
if err != nil {
7985
return nil, fmt.Errorf("request body read failed: %w", err)
8086
}
81-
ctx := req.Context()
82-
bufCtx := context.WithValue(ctx, RequestBodyKey{}, buf)
83-
*req = *req.WithContext(bufCtx)
87+
SetRequestBody(req, buf)
8488
return buf, nil
8589
}
8690

core/controller/channel-test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ import (
3333
const channelTestRequestID = "channel-test"
3434

3535
var (
36-
modelConfigCache map[string]*model.ModelConfig = make(map[string]*model.ModelConfig)
36+
modelConfigCache map[string]model.ModelConfig = make(map[string]model.ModelConfig)
3737
modelConfigCacheOnce sync.Once
3838
)
3939

40-
func guessModelConfig(model string) *model.ModelConfig {
40+
func guessModelConfig(modelName string) model.ModelConfig {
4141
modelConfigCacheOnce.Do(func() {
4242
for _, c := range adaptors.ChannelAdaptor {
4343
for _, m := range c.GetModelList() {
@@ -48,10 +48,10 @@ func guessModelConfig(model string) *model.ModelConfig {
4848
}
4949
})
5050

51-
if cachedConfig, ok := modelConfigCache[model]; ok {
51+
if cachedConfig, ok := modelConfigCache[modelName]; ok {
5252
return cachedConfig
5353
}
54-
return nil
54+
return model.ModelConfig{}
5555
}
5656

5757
// testSingleModel tests a single model in the channel
@@ -62,7 +62,7 @@ func testSingleModel(mc *model.ModelCaches, channel *model.Channel, modelName st
6262
}
6363
if modelConfig.Type == mode.Unknown {
6464
newModelConfig := guessModelConfig(modelName)
65-
if newModelConfig != nil {
65+
if newModelConfig.Type != mode.Unknown {
6666
modelConfig = newModelConfig
6767
}
6868
}

core/controller/dashboard.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func GetGroupDashboardModels(c *gin.Context) {
270270
}) {
271271
continue
272272
}
273-
newEnabledModelConfigs = append(newEnabledModelConfigs, middleware.GetGroupAdjustedModelConfig(groupCache, *mc))
273+
newEnabledModelConfigs = append(newEnabledModelConfigs, middleware.GetGroupAdjustedModelConfig(groupCache, mc))
274274
}
275275
}
276276
middleware.SuccessResponse(c, newEnabledModelConfigs)

core/controller/model.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ func (c *BuiltinModelConfig) MarshalJSON() ([]byte, error) {
5555
})
5656
}
5757

58-
func SortBuiltinModelConfigsFunc(i, j *BuiltinModelConfig) int {
59-
return model.SortModelConfigsFunc((*model.ModelConfig)(i), (*model.ModelConfig)(j))
58+
func SortBuiltinModelConfigsFunc(i, j BuiltinModelConfig) int {
59+
return model.SortModelConfigsFunc((model.ModelConfig)(i), (model.ModelConfig)(j))
6060
}
6161

6262
var (
63-
builtinModels []*BuiltinModelConfig
63+
builtinModels []BuiltinModelConfig
6464
builtinModelsMap map[string]*OpenAIModels
65-
builtinChannelType2Models map[model.ChannelType][]*BuiltinModelConfig
65+
builtinChannelType2Models map[model.ChannelType][]BuiltinModelConfig
6666
)
6767

6868
var permission = []OpenAIModelPermission{
@@ -83,12 +83,12 @@ var permission = []OpenAIModelPermission{
8383
}
8484

8585
func init() {
86-
builtinChannelType2Models = make(map[model.ChannelType][]*BuiltinModelConfig)
86+
builtinChannelType2Models = make(map[model.ChannelType][]BuiltinModelConfig)
8787
builtinModelsMap = make(map[string]*OpenAIModels)
8888
// https://platform.openai.com/docs/models/model-endpoint-compatibility
8989
for i, adaptor := range adaptors.ChannelAdaptor {
9090
modelNames := adaptor.GetModelList()
91-
builtinChannelType2Models[i] = make([]*BuiltinModelConfig, len(modelNames))
91+
builtinChannelType2Models[i] = make([]BuiltinModelConfig, len(modelNames))
9292
for idx, _model := range modelNames {
9393
if _model.Owner == "" {
9494
_model.Owner = model.ModelOwner(i.String())
@@ -103,11 +103,11 @@ func init() {
103103
Root: _model.Model,
104104
Parent: nil,
105105
}
106-
builtinModels = append(builtinModels, (*BuiltinModelConfig)(_model))
106+
builtinModels = append(builtinModels, (BuiltinModelConfig)(_model))
107107
} else if v.OwnedBy != string(_model.Owner) {
108108
log.Fatalf("model %s owner mismatch, expect %s, actual %s", _model.Model, string(_model.Owner), v.OwnedBy)
109109
}
110-
builtinChannelType2Models[i][idx] = (*BuiltinModelConfig)(_model)
110+
builtinChannelType2Models[i][idx] = (BuiltinModelConfig)(_model)
111111
}
112112
}
113113
for _, models := range builtinChannelType2Models {

core/controller/modelconfig.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func SearchModelConfigs(c *gin.Context) {
112112
type SaveModelConfigsRequest struct {
113113
CreatedAt int64 `json:"created_at"`
114114
UpdatedAt int64 `json:"updated_at"`
115-
*model.ModelConfig
115+
model.ModelConfig
116116
}
117117

118118
// SaveModelConfigs godoc
@@ -126,12 +126,12 @@ type SaveModelConfigsRequest struct {
126126
// @Success 200 {object} middleware.APIResponse
127127
// @Router /api/model_configs/ [post]
128128
func SaveModelConfigs(c *gin.Context) {
129-
var configs []*SaveModelConfigsRequest
129+
var configs []SaveModelConfigsRequest
130130
if err := c.ShouldBindJSON(&configs); err != nil {
131131
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
132132
return
133133
}
134-
modelConfigs := make([]*model.ModelConfig, len(configs))
134+
modelConfigs := make([]model.ModelConfig, len(configs))
135135
for i, config := range configs {
136136
modelConfigs[i] = config.ModelConfig
137137
}

core/controller/relay-controller.go

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ import (
3131
"github.com/labring/aiproxy/core/relay/meta"
3232
"github.com/labring/aiproxy/core/relay/mode"
3333
relaymodel "github.com/labring/aiproxy/core/relay/model"
34+
"github.com/labring/aiproxy/core/relay/plugin"
35+
websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
3436
log "github.com/sirupsen/logrus"
3537
)
3638

3739
// https://platform.openai.com/docs/api-reference/chat
3840

3941
type (
4042
RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
41-
GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
42-
GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
43+
GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
44+
GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
4345
)
4446

4547
type RelayController struct {
@@ -48,6 +50,7 @@ type RelayController struct {
4850
Handler RelayHandler
4951
}
5052

53+
// TODO: convert to plugin
5154
type wrapAdaptor struct {
5255
adaptor.Adaptor
5356
}
@@ -163,7 +166,13 @@ func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
163166
}
164167
}
165168

166-
return controller.Handle(&wrapAdaptor{adaptor}, c, meta)
169+
a := plugin.WrapperAdaptor(&wrapAdaptor{adaptor},
170+
websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
171+
return getWebSearchChannel(c, modelName)
172+
}),
173+
)
174+
175+
return controller.Handle(a, c, meta)
167176
}
168177

169178
func relayController(m mode.Mode) RelayController {
@@ -318,9 +327,17 @@ var (
318327

319328
func GetRandomChannel(mc *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, []*model.Channel, error) {
320329
channelMap := make(map[int]*model.Channel)
321-
for _, set := range availableSet {
322-
for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
323-
channelMap[channel.ID] = channel
330+
if len(availableSet) != 0 {
331+
for _, set := range availableSet {
332+
for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
333+
channelMap[channel.ID] = channel
334+
}
335+
}
336+
} else {
337+
for _, sets := range mc.EnabledModel2ChannelsBySet {
338+
for _, channel := range sets[modelName] {
339+
channelMap[channel.ID] = channel
340+
}
324341
}
325342
}
326343
migratedChannels := make([]*model.Channel, 0, len(channelMap))
@@ -403,12 +420,11 @@ func NewMetaByContext(c *gin.Context, channel *model.Channel, mode mode.Mode, op
403420
}
404421

405422
func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
406-
log := middleware.GetLogger(c)
407423
requestModel := middleware.GetRequestModel(c)
408424
mc := middleware.GetModelConfig(c)
409425

410426
// Get initial channel
411-
initialChannel, err := getInitialChannel(c, requestModel, log)
427+
initialChannel, err := getInitialChannel(c, requestModel)
412428
if err != nil || initialChannel == nil || initialChannel.channel == nil {
413429
middleware.AbortLogWithMessageWithMode(mode, c,
414430
http.StatusServiceUnavailable,
@@ -486,7 +502,7 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
486502
)
487503

488504
// Retry loop
489-
retryLoop(c, mode, retryState, relayController.Handler, log)
505+
retryLoop(c, mode, retryState, relayController.Handler)
490506
}
491507

492508
// recordResult records the consumption for the final result
@@ -572,7 +588,8 @@ type initialChannel struct {
572588
migratedChannels []*model.Channel
573589
}
574590

575-
func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initialChannel, error) {
591+
func getInitialChannel(c *gin.Context, modelName string) (*initialChannel, error) {
592+
log := middleware.GetLogger(c)
576593
if channel := middleware.GetChannel(c); channel != nil {
577594
log.Data["designated_channel"] = "true"
578595
return &initialChannel{channel: channel, designatedChannel: true}, nil
@@ -607,6 +624,29 @@ func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initi
607624
}, nil
608625
}
609626

627+
func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, error) {
628+
log := middleware.GetLogger(c)
629+
mc := middleware.GetModelCaches(c)
630+
631+
ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
632+
if err != nil {
633+
log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
634+
}
635+
log.Debugf("%s model banned channels: %+v", modelName, ids)
636+
637+
errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
638+
if err != nil {
639+
log.Errorf("get channel model error rates failed: %+v", err)
640+
}
641+
642+
channel, _, err := getChannelWithFallback(mc, nil, modelName, errorRates, ids...)
643+
if err != nil {
644+
return nil, err
645+
}
646+
647+
return channel, nil
648+
}
649+
610650
func handleRelayResult(c *gin.Context, bizErr adaptor.Error, retry bool, retryTimes int) (done bool) {
611651
if bizErr == nil {
612652
return true
@@ -645,7 +685,9 @@ func initRetryState(retryTimes int, channel *initialChannel, meta *meta.Meta, re
645685
return state
646686
}
647687

648-
func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler, log *log.Entry) {
688+
func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
689+
log := middleware.GetLogger(c)
690+
649691
// do not use for i := range state.retryTimes, because the retryTimes is constant
650692
i := 0
651693

core/middleware/distributor.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ func UpdateGroupModelTokennameTokensRequest(c *gin.Context, tpm, tps int64) {
132132
// log.Data["tps"] = strconv.FormatInt(tps, 10)
133133
}
134134

135-
func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model.ModelConfig, tokenName string) error {
135+
func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc model.ModelConfig, tokenName string) error {
136136
log := GetLogger(c)
137137

138-
adjustedModelConfig := GetGroupAdjustedModelConfig(group, *mc)
138+
adjustedModelConfig := GetGroupAdjustedModelConfig(group, mc)
139139

140140
groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(c.Request.Context(), group.ID, mc.Model, adjustedModelConfig.RPM)
141141
UpdateGroupModelRequest(c, group, groupModelCount+groupModelOverLimitCount, groupModelSecondCount)
@@ -461,8 +461,8 @@ func GetRequestMetadata(c *gin.Context) map[string]string {
461461
return c.GetStringMapString(RequestMetadata)
462462
}
463463

464-
func GetModelConfig(c *gin.Context) *model.ModelConfig {
465-
return c.MustGet(ModelConfig).(*model.ModelConfig)
464+
func GetModelConfig(c *gin.Context) model.ModelConfig {
465+
return c.MustGet(ModelConfig).(model.ModelConfig)
466466
}
467467

468468
func NewMetaByContext(c *gin.Context,

core/model/cache.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ func CacheGetPublicMCPReusingParam(mcpID, groupID string) (*PublicMCPReusingPara
672672

673673
//nolint:revive
674674
type ModelConfigCache interface {
675-
GetModelConfig(model string) (*ModelConfig, bool)
675+
GetModelConfig(model string) (ModelConfig, bool)
676676
}
677677

678678
// read-only cache
@@ -684,9 +684,9 @@ type ModelCaches struct {
684684
// map[set][]model
685685
EnabledModelsBySet map[string][]string
686686
// map[set][]modelconfig
687-
EnabledModelConfigsBySet map[string][]*ModelConfig
687+
EnabledModelConfigsBySet map[string][]ModelConfig
688688
// map[model]modelconfig
689-
EnabledModelConfigsMap map[string]*ModelConfig
689+
EnabledModelConfigsMap map[string]ModelConfig
690690

691691
// map[set]map[model][]channel
692692
EnabledModel2ChannelsBySet map[string]map[string][]*Channel
@@ -811,10 +811,10 @@ func LoadChannelByID(id int) (*Channel, error) {
811811
var _ ModelConfigCache = (*modelConfigMapCache)(nil)
812812

813813
type modelConfigMapCache struct {
814-
modelConfigMap map[string]*ModelConfig
814+
modelConfigMap map[string]ModelConfig
815815
}
816816

817-
func (m *modelConfigMapCache) GetModelConfig(model string) (*ModelConfig, bool) {
817+
func (m *modelConfigMapCache) GetModelConfig(model string) (ModelConfig, bool) {
818818
config, ok := m.modelConfigMap[model]
819819
return config, ok
820820
}
@@ -825,7 +825,7 @@ type disabledModelConfigCache struct {
825825
modelConfigs ModelConfigCache
826826
}
827827

828-
func (d *disabledModelConfigCache) GetModelConfig(model string) (*ModelConfig, bool) {
828+
func (d *disabledModelConfigCache) GetModelConfig(model string) (ModelConfig, bool) {
829829
if config, ok := d.modelConfigs.GetModelConfig(model); ok {
830830
return config, true
831831
}
@@ -837,7 +837,7 @@ func initializeModelConfigCache() (ModelConfigCache, error) {
837837
if err != nil {
838838
return nil, err
839839
}
840-
newModelConfigMap := make(map[string]*ModelConfig)
840+
newModelConfigMap := make(map[string]ModelConfig)
841841
for _, modelConfig := range modelConfigs {
842842
newModelConfigMap[modelConfig.Model] = modelConfig
843843
}
@@ -905,16 +905,16 @@ func sortChannelsByPriorityBySet(modelMapBySet map[string]map[string][]*Channel)
905905

906906
func buildEnabledModelsBySet(modelMapBySet map[string]map[string][]*Channel, modelConfigCache ModelConfigCache) (
907907
map[string][]string,
908-
map[string][]*ModelConfig,
909-
map[string]*ModelConfig,
908+
map[string][]ModelConfig,
909+
map[string]ModelConfig,
910910
) {
911911
modelsBySet := make(map[string][]string)
912-
modelConfigsBySet := make(map[string][]*ModelConfig)
913-
modelConfigsMap := make(map[string]*ModelConfig)
912+
modelConfigsBySet := make(map[string][]ModelConfig)
913+
modelConfigsMap := make(map[string]ModelConfig)
914914

915915
for set, modelMap := range modelMapBySet {
916916
models := make([]string, 0)
917-
configs := make([]*ModelConfig, 0)
917+
configs := make([]ModelConfig, 0)
918918
appended := make(map[string]struct{})
919919

920920
for model := range modelMap {
@@ -940,7 +940,7 @@ func buildEnabledModelsBySet(modelMapBySet map[string]map[string][]*Channel, mod
940940
return modelsBySet, modelConfigsBySet, modelConfigsMap
941941
}
942942

943-
func SortModelConfigsFunc(i, j *ModelConfig) int {
943+
func SortModelConfigsFunc(i, j ModelConfig) int {
944944
if i.Owner != j.Owner {
945945
if natural.Less(string(i.Owner), string(j.Owner)) {
946946
return -1

0 commit comments

Comments
 (0)