@@ -31,15 +31,17 @@ import (
3131 "github.com/labring/aiproxy/core/relay/meta"
3232 "github.com/labring/aiproxy/core/relay/mode"
3333 relaymodel "github.com/labring/aiproxy/core/relay/model"
34+ "github.com/labring/aiproxy/core/relay/plugin"
35+ websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
3436 log "github.com/sirupsen/logrus"
3537)
3638
3739// https://platform.openai.com/docs/api-reference/chat
3840
3941type (
4042 RelayHandler func (* gin.Context , * meta.Meta ) * controller.HandleResult
41- GetRequestUsage func (* gin.Context , * model.ModelConfig ) (model.Usage , error )
42- GetRequestPrice func (* gin.Context , * model.ModelConfig ) (model.Price , error )
43+ GetRequestUsage func (* gin.Context , model.ModelConfig ) (model.Usage , error )
44+ GetRequestPrice func (* gin.Context , model.ModelConfig ) (model.Price , error )
4345)
4446
4547type RelayController struct {
@@ -48,6 +50,7 @@ type RelayController struct {
4850 Handler RelayHandler
4951}
5052
53+ // TODO: convert to plugin
5154type wrapAdaptor struct {
5255 adaptor.Adaptor
5356}
@@ -163,7 +166,13 @@ func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
163166 }
164167 }
165168
166- return controller .Handle (& wrapAdaptor {adaptor }, c , meta )
169+ a := plugin .WrapperAdaptor (& wrapAdaptor {adaptor },
170+ websearch .NewWebSearchPlugin (func (modelName string ) (* model.Channel , error ) {
171+ return getWebSearchChannel (c , modelName )
172+ }),
173+ )
174+
175+ return controller .Handle (a , c , meta )
167176}
168177
169178func relayController (m mode.Mode ) RelayController {
@@ -318,9 +327,17 @@ var (
318327
319328func GetRandomChannel (mc * model.ModelCaches , availableSet []string , modelName string , errorRates map [int64 ]float64 , ignoreChannel ... int64 ) (* model.Channel , []* model.Channel , error ) {
320329 channelMap := make (map [int ]* model.Channel )
321- for _ , set := range availableSet {
322- for _ , channel := range mc.EnabledModel2ChannelsBySet [set ][modelName ] {
323- channelMap [channel .ID ] = channel
330+ if len (availableSet ) != 0 {
331+ for _ , set := range availableSet {
332+ for _ , channel := range mc.EnabledModel2ChannelsBySet [set ][modelName ] {
333+ channelMap [channel .ID ] = channel
334+ }
335+ }
336+ } else {
337+ for _ , sets := range mc .EnabledModel2ChannelsBySet {
338+ for _ , channel := range sets [modelName ] {
339+ channelMap [channel .ID ] = channel
340+ }
324341 }
325342 }
326343 migratedChannels := make ([]* model.Channel , 0 , len (channelMap ))
@@ -403,12 +420,11 @@ func NewMetaByContext(c *gin.Context, channel *model.Channel, mode mode.Mode, op
403420}
404421
405422func relay (c * gin.Context , mode mode.Mode , relayController RelayController ) {
406- log := middleware .GetLogger (c )
407423 requestModel := middleware .GetRequestModel (c )
408424 mc := middleware .GetModelConfig (c )
409425
410426 // Get initial channel
411- initialChannel , err := getInitialChannel (c , requestModel , log )
427+ initialChannel , err := getInitialChannel (c , requestModel )
412428 if err != nil || initialChannel == nil || initialChannel .channel == nil {
413429 middleware .AbortLogWithMessageWithMode (mode , c ,
414430 http .StatusServiceUnavailable ,
@@ -486,7 +502,7 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
486502 )
487503
488504 // Retry loop
489- retryLoop (c , mode , retryState , relayController .Handler , log )
505+ retryLoop (c , mode , retryState , relayController .Handler )
490506}
491507
492508// recordResult records the consumption for the final result
@@ -572,7 +588,8 @@ type initialChannel struct {
572588 migratedChannels []* model.Channel
573589}
574590
575- func getInitialChannel (c * gin.Context , modelName string , log * log.Entry ) (* initialChannel , error ) {
591+ func getInitialChannel (c * gin.Context , modelName string ) (* initialChannel , error ) {
592+ log := middleware .GetLogger (c )
576593 if channel := middleware .GetChannel (c ); channel != nil {
577594 log .Data ["designated_channel" ] = "true"
578595 return & initialChannel {channel : channel , designatedChannel : true }, nil
@@ -607,6 +624,29 @@ func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initi
607624 }, nil
608625}
609626
627+ func getWebSearchChannel (c * gin.Context , modelName string ) (* model.Channel , error ) {
628+ log := middleware .GetLogger (c )
629+ mc := middleware .GetModelCaches (c )
630+
631+ ids , err := monitor .GetBannedChannelsWithModel (c .Request .Context (), modelName )
632+ if err != nil {
633+ log .Errorf ("get %s auto banned channels failed: %+v" , modelName , err )
634+ }
635+ log .Debugf ("%s model banned channels: %+v" , modelName , ids )
636+
637+ errorRates , err := monitor .GetModelChannelErrorRate (c .Request .Context (), modelName )
638+ if err != nil {
639+ log .Errorf ("get channel model error rates failed: %+v" , err )
640+ }
641+
642+ channel , _ , err := getChannelWithFallback (mc , nil , modelName , errorRates , ids ... )
643+ if err != nil {
644+ return nil , err
645+ }
646+
647+ return channel , nil
648+ }
649+
610650func handleRelayResult (c * gin.Context , bizErr adaptor.Error , retry bool , retryTimes int ) (done bool ) {
611651 if bizErr == nil {
612652 return true
@@ -645,7 +685,9 @@ func initRetryState(retryTimes int, channel *initialChannel, meta *meta.Meta, re
645685 return state
646686}
647687
648- func retryLoop (c * gin.Context , mode mode.Mode , state * retryState , relayController RelayHandler , log * log.Entry ) {
688+ func retryLoop (c * gin.Context , mode mode.Mode , state * retryState , relayController RelayHandler ) {
689+ log := middleware .GetLogger (c )
690+
649691 // do not use for i := range state.retryTimes, because the retryTimes is constant
650692 i := 0
651693
0 commit comments