Skip to content

Commit 739d199

Browse files
committed
feat: channel adaptor spec config support
1 parent 4fdcca4 commit 739d199

31 files changed

Lines changed: 425 additions & 229 deletions

File tree

core/controller/channel-billing.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func updateAllChannelsBalance() error {
118118
func UpdateAllChannelsBalance(c *gin.Context) {
119119
err := updateAllChannelsBalance()
120120
if err != nil {
121-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
121+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
122122
return
123123
}
124124
middleware.SuccessResponse(c, nil)

core/controller/channel.go

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
package controller
22

33
import (
4+
"errors"
45
"fmt"
56
"maps"
67
"net/http"
78
"slices"
89
"strconv"
910
"strings"
1011

12+
"github.com/bytedance/sonic/ast"
1113
"github.com/gin-gonic/gin"
1214
"github.com/labring/aiproxy/core/middleware"
1315
"github.com/labring/aiproxy/core/model"
1416
"github.com/labring/aiproxy/core/monitor"
15-
"github.com/labring/aiproxy/core/relay/adaptor"
1617
"github.com/labring/aiproxy/core/relay/adaptors"
1718
log "github.com/sirupsen/logrus"
1819
)
@@ -57,7 +58,7 @@ func GetChannels(c *gin.Context) {
5758
order := c.Query("order")
5859
channels, total, err := model.GetChannels(page, perPage, id, name, key, channelType, baseURL, order)
5960
if err != nil {
60-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
61+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
6162
return
6263
}
6364
middleware.SuccessResponse(c, gin.H{
@@ -78,7 +79,7 @@ func GetChannels(c *gin.Context) {
7879
func GetAllChannels(c *gin.Context) {
7980
channels, err := model.GetAllChannels()
8081
if err != nil {
81-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
82+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
8283
return
8384
}
8485
middleware.SuccessResponse(c, channels)
@@ -99,21 +100,21 @@ func AddChannels(c *gin.Context) {
99100
channels := make([]*AddChannelRequest, 0)
100101
err := c.ShouldBindJSON(&channels)
101102
if err != nil {
102-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
103+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
103104
return
104105
}
105106
_channels := make([]*model.Channel, 0, len(channels))
106107
for _, channel := range channels {
107108
channels, err := channel.ToChannels()
108109
if err != nil {
109-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
110+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
110111
return
111112
}
112113
_channels = append(_channels, channels...)
113114
}
114115
err = model.BatchInsertChannels(_channels)
115116
if err != nil {
116-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
117+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
117118
return
118119
}
119120
middleware.SuccessResponse(c, nil)
@@ -148,7 +149,7 @@ func SearchChannels(c *gin.Context) {
148149
order := c.Query("order")
149150
channels, total, err := model.SearchChannels(keyword, page, perPage, id, name, key, channelType, baseURL, order)
150151
if err != nil {
151-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
152+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
152153
return
153154
}
154155
middleware.SuccessResponse(c, gin.H{
@@ -170,12 +171,12 @@ func SearchChannels(c *gin.Context) {
170171
func GetChannel(c *gin.Context) {
171172
id, err := strconv.Atoi(c.Param("id"))
172173
if err != nil {
173-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
174+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
174175
return
175176
}
176177
channel, err := model.GetChannelByID(id)
177178
if err != nil {
178-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
179+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
179180
return
180181
}
181182
middleware.SuccessResponse(c, channel)
@@ -200,7 +201,7 @@ func (r *AddChannelRequest) ToChannel() (*model.Channel, error) {
200201
if !ok {
201202
return nil, fmt.Errorf("invalid channel type: %d", r.Type)
202203
}
203-
if validator, ok := a.(adaptor.KeyValidator); ok {
204+
if validator := adaptors.GetKeyValidator(a); validator != nil {
204205
err := validator.ValidateKey(r.Key)
205206
if err != nil {
206207
keyHelp := validator.KeyHelp()
@@ -210,6 +211,37 @@ func (r *AddChannelRequest) ToChannel() (*model.Channel, error) {
210211
return nil, fmt.Errorf("%s [%s(%d)] invalid key: %w, %s", r.Name, r.Type.String(), r.Type, err, keyHelp)
211212
}
212213
}
214+
if r.Config != nil {
215+
for key, template := range adaptors.GetConfigTemplates(a) {
216+
v, err := r.Config.Get(key)
217+
if err != nil {
218+
if errors.Is(err, ast.ErrNotExist) {
219+
if template.Required {
220+
return nil, fmt.Errorf("config %s is required: %w", key, err)
221+
}
222+
continue
223+
}
224+
return nil, fmt.Errorf("config %s is invalid: %w", key, err)
225+
}
226+
if !v.Exists() {
227+
if template.Required {
228+
return nil, fmt.Errorf("config %s is required: %w", key, err)
229+
}
230+
continue
231+
}
232+
if template.Validator != nil {
233+
i, err := v.Interface()
234+
if err != nil {
235+
return nil, fmt.Errorf("config %s is invalid: %w", key, err)
236+
}
237+
err = template.Validator(i)
238+
if err != nil {
239+
return nil, fmt.Errorf("config %s is invalid: %w", key, err)
240+
}
241+
}
242+
}
243+
}
244+
213245
return &model.Channel{
214246
Type: r.Type,
215247
Name: r.Name,
@@ -263,17 +295,17 @@ func AddChannel(c *gin.Context) {
263295
channel := AddChannelRequest{}
264296
err := c.ShouldBindJSON(&channel)
265297
if err != nil {
266-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
298+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
267299
return
268300
}
269301
channels, err := channel.ToChannels()
270302
if err != nil {
271-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
303+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
272304
return
273305
}
274306
err = model.BatchInsertChannels(channels)
275307
if err != nil {
276-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
308+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
277309
return
278310
}
279311
middleware.SuccessResponse(c, nil)
@@ -293,7 +325,7 @@ func DeleteChannel(c *gin.Context) {
293325
id, _ := strconv.Atoi(c.Param("id"))
294326
err := model.DeleteChannelByID(id)
295327
if err != nil {
296-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
328+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
297329
return
298330
}
299331
middleware.SuccessResponse(c, nil)
@@ -314,12 +346,12 @@ func DeleteChannels(c *gin.Context) {
314346
ids := []int{}
315347
err := c.ShouldBindJSON(&ids)
316348
if err != nil {
317-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
349+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
318350
return
319351
}
320352
err = model.DeleteChannelsByIDs(ids)
321353
if err != nil {
322-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
354+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
323355
return
324356
}
325357
middleware.SuccessResponse(c, nil)
@@ -340,29 +372,29 @@ func DeleteChannels(c *gin.Context) {
340372
func UpdateChannel(c *gin.Context) {
341373
idStr := c.Param("id")
342374
if idStr == "" {
343-
middleware.ErrorResponse(c, http.StatusOK, "id is required")
375+
middleware.ErrorResponse(c, http.StatusBadRequest, "id is required")
344376
return
345377
}
346378
id, err := strconv.Atoi(idStr)
347379
if err != nil {
348-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
380+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
349381
return
350382
}
351383
channel := AddChannelRequest{}
352384
err = c.ShouldBindJSON(&channel)
353385
if err != nil {
354-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
386+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
355387
return
356388
}
357389
ch, err := channel.ToChannel()
358390
if err != nil {
359-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
391+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
360392
return
361393
}
362394
ch.ID = id
363395
err = model.UpdateChannel(ch)
364396
if err != nil {
365-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
397+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
366398
return
367399
}
368400
err = monitor.ClearChannelAllModelErrors(c.Request.Context(), id)
@@ -394,12 +426,12 @@ func UpdateChannelStatus(c *gin.Context) {
394426
status := UpdateChannelStatusRequest{}
395427
err := c.ShouldBindJSON(&status)
396428
if err != nil {
397-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
429+
middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
398430
return
399431
}
400432
err = model.UpdateChannelStatusByID(id, status.Status)
401433
if err != nil {
402-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
434+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
403435
return
404436
}
405437
err = monitor.ClearChannelAllModelErrors(c.Request.Context(), id)

core/controller/dashboard.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func GetDashboard(c *gin.Context) {
169169

170170
dashboards, err := model.GetDashboardData(start, end, modelName, channelID, timeSpan, timezoneLocation)
171171
if err != nil {
172-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
172+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
173173
return
174174
}
175175

@@ -206,7 +206,7 @@ func GetDashboard(c *gin.Context) {
206206
func GetGroupDashboard(c *gin.Context) {
207207
group := c.Param("group")
208208
if group == "" || group == "*" {
209-
middleware.ErrorResponse(c, http.StatusOK, "invalid group parameter")
209+
middleware.ErrorResponse(c, http.StatusBadRequest, "invalid group parameter")
210210
return
211211
}
212212

@@ -220,7 +220,7 @@ func GetGroupDashboard(c *gin.Context) {
220220

221221
dashboards, err := model.GetGroupDashboardData(group, start, end, tokenName, modelName, timeSpan, timezoneLocation)
222222
if err != nil {
223-
middleware.ErrorResponse(c, http.StatusOK, "failed to get statistics")
223+
middleware.ErrorResponse(c, http.StatusInternalServerError, "failed to get statistics")
224224
return
225225
}
226226

@@ -247,15 +247,15 @@ func GetGroupDashboard(c *gin.Context) {
247247
func GetGroupDashboardModels(c *gin.Context) {
248248
group := c.Param("group")
249249
if group == "" || group == "*" {
250-
middleware.ErrorResponse(c, http.StatusOK, "invalid group parameter")
250+
middleware.ErrorResponse(c, http.StatusBadRequest, "invalid group parameter")
251251
return
252252
}
253253
groupCache, err := model.CacheGetGroup(group)
254254
if err != nil {
255255
if errors.Is(err, gorm.ErrRecordNotFound) {
256256
middleware.SuccessResponse(c, model.LoadModelCaches().EnabledModelConfigsBySet[model.ChannelDefaultSet])
257257
} else {
258-
middleware.ErrorResponse(c, http.StatusOK, fmt.Sprintf("failed to get group: %v", err))
258+
middleware.ErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("failed to get group: %v", err))
259259
}
260260
return
261261
}
@@ -295,7 +295,7 @@ func GetModelCostRank(c *gin.Context) {
295295
startTime, endTime := parseTimeRange(c)
296296
models, err := model.GetModelCostRank(group, channelID, startTime, endTime)
297297
if err != nil {
298-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
298+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
299299
return
300300
}
301301
middleware.SuccessResponse(c, models)
@@ -316,13 +316,13 @@ func GetModelCostRank(c *gin.Context) {
316316
func GetGroupModelCostRank(c *gin.Context) {
317317
group := c.Param("group")
318318
if group == "" || group == "*" {
319-
middleware.ErrorResponse(c, http.StatusOK, "invalid group parameter")
319+
middleware.ErrorResponse(c, http.StatusBadRequest, "invalid group parameter")
320320
return
321321
}
322322
startTime, endTime := parseTimeRange(c)
323323
models, err := model.GetModelCostRank(group, 0, startTime, endTime)
324324
if err != nil {
325-
middleware.ErrorResponse(c, http.StatusOK, err.Error())
325+
middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
326326
return
327327
}
328328
middleware.SuccessResponse(c, models)

0 commit comments

Comments
 (0)