Skip to content

Commit 6893092

Browse files
committed
feat: support async usage
1 parent c458c29 commit 6893092

24 files changed

Lines changed: 2051 additions & 71 deletions

core/common/consume/consume.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ func AsyncConsume(
3535
downstreamResult bool,
3636
metadata map[string]string,
3737
upstreamID string,
38+
asyncUsageStatus model.AsyncUsageStatus,
3839
) {
3940
if !checkNeedRecordConsume(code, meta) {
4041
return
@@ -65,6 +66,7 @@ func AsyncConsume(
6566
downstreamResult,
6667
metadata,
6768
upstreamID,
69+
asyncUsageStatus,
6870
)
6971
}
7072

@@ -84,6 +86,7 @@ func Consume(
8486
downstreamResult bool,
8587
metadata map[string]string,
8688
upstreamID string,
89+
asyncUsageStatus model.AsyncUsageStatus,
8790
) {
8891
if !checkNeedRecordConsume(code, meta) {
8992
return
@@ -119,6 +122,7 @@ func Consume(
119122
downstreamResult,
120123
metadata,
121124
upstreamID,
125+
asyncUsageStatus,
122126
)
123127
if err != nil {
124128
log.Error("error batch record consume: " + err.Error())

core/common/consume/record.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ func recordConsume(
2222
downstreamResult bool,
2323
metadata map[string]string,
2424
upstreamID string,
25+
asyncUsageStatus model.AsyncUsageStatus,
2526
) error {
2627
summaryServiceTier := meta.RequestServiceTier
2728
if !meta.ModelConfig.ShouldSummaryServiceTier() {
@@ -58,6 +59,7 @@ func recordConsume(
5859
meta.PromptCacheKey,
5960
upstreamID,
6061
meta.RequestServiceTier,
62+
asyncUsageStatus,
6163
summaryServiceTier,
6264
summaryClaudeLongContext,
6365
)

core/controller/relay-controller.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/labring/aiproxy/core/relay/plugin/thinksplit"
3838
"github.com/labring/aiproxy/core/relay/plugin/timeout"
3939
websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
40+
log "github.com/sirupsen/logrus"
4041
)
4142

4243
// https://platform.openai.com/docs/api-reference/chat
@@ -379,6 +380,11 @@ func recordResult(
379380
log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
380381
}
381382

383+
asyncUsageStatus := model.AsyncUsageStatusNone
384+
if downstreamResult && result.Error == nil && result.AsyncUsage {
385+
asyncUsageStatus = model.AsyncUsageStatusPending
386+
}
387+
382388
consume.AsyncConsume(
383389
gbc.Consumer,
384390
code,
@@ -393,7 +399,41 @@ func recordResult(
393399
downstreamResult,
394400
metadata,
395401
result.UpstreamID,
402+
asyncUsageStatus,
396403
)
404+
405+
if asyncUsageStatus == model.AsyncUsageStatusPending {
406+
saveAsyncUsageInfo(meta, price, result)
407+
}
408+
}
409+
410+
func saveAsyncUsageInfo(
411+
meta *meta.Meta,
412+
price model.Price,
413+
result *controller.HandleResult,
414+
) {
415+
if result.UpstreamID == "" {
416+
log.Warnf("skip async usage without upstream id, request_id: %s", meta.RequestID)
417+
return
418+
}
419+
420+
if err := model.CreateAsyncUsageInfo(&model.AsyncUsageInfo{
421+
RequestID: meta.RequestID,
422+
RequestAt: meta.RequestAt,
423+
Mode: int(meta.Mode),
424+
Model: meta.OriginModel,
425+
ChannelID: meta.Channel.ID,
426+
BaseURL: meta.Channel.BaseURL,
427+
GroupID: meta.Group.ID,
428+
TokenID: meta.Token.ID,
429+
TokenName: meta.Token.Name,
430+
Price: price,
431+
ServiceTier: meta.RequestServiceTier,
432+
UpstreamID: result.UpstreamID,
433+
DownstreamDone: true,
434+
}); err != nil {
435+
log.Errorf("failed to save async usage info: %v", err)
436+
}
397437
}
398438

399439
func effectiveDetailBodyMaxSize(modelLimit, globalLimit int64) int64 {

core/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ func main() {
8484

8585
go task.UsageAlertTask(ctx)
8686

87+
log.Info("async usage poll task started")
88+
89+
go task.AsyncUsagePollTask(ctx)
90+
8791
if common.RedisEnabled {
8892
log.Info("redis health check task started")
8993

core/model/async_usage.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package model
2+
3+
import (
4+
"time"
5+
6+
"gorm.io/gorm"
7+
)
8+
9+
type AsyncUsageStatus int
10+
11+
const (
12+
AsyncUsageStatusNone AsyncUsageStatus = iota
13+
AsyncUsageStatusPending
14+
AsyncUsageStatusCompleted
15+
AsyncUsageStatusFailed
16+
)
17+
18+
const (
19+
AsyncUsageDefaultPollDelay = 10 * time.Second
20+
AsyncUsageMaxPollDelay = 3 * time.Minute
21+
)
22+
23+
type AsyncUsageInfo struct {
24+
ID int `gorm:"primaryKey" json:"id"`
25+
RequestID string `gorm:"type:char(16);index" json:"request_id"`
26+
RequestAt time.Time ` json:"request_at"`
27+
Mode int `gorm:"index" json:"mode"`
28+
Model string `gorm:"size:128" json:"model"`
29+
ChannelID int `gorm:"index" json:"channel_id"`
30+
BaseURL string `gorm:"type:text" json:"base_url,omitempty"`
31+
GroupID string `gorm:"size:64;index" json:"group_id"`
32+
TokenID int `gorm:"index" json:"token_id"`
33+
TokenName string `gorm:"size:128" json:"token_name,omitempty"`
34+
Price Price `gorm:"serializer:fastjson;type:text" json:"price"`
35+
ServiceTier string `gorm:"size:16" json:"service_tier,omitempty"`
36+
UpstreamID string `gorm:"type:varchar(256);index" json:"upstream_id"`
37+
Status AsyncUsageStatus `gorm:"index;default:1" json:"status"`
38+
Usage Usage `gorm:"serializer:fastjson;type:text" json:"usage"`
39+
Amount Amount `gorm:"embedded" json:"amount,omitempty"`
40+
Error string `gorm:"type:text" json:"error,omitempty"`
41+
RetryCount int ` json:"retry_count"`
42+
DownstreamDone bool ` json:"downstream_done"`
43+
BalanceConsumed bool ` json:"balance_consumed"`
44+
NextPollAt time.Time `gorm:"index" json:"next_poll_at"`
45+
CreatedAt time.Time ` json:"created_at"`
46+
UpdatedAt time.Time ` json:"updated_at"`
47+
}
48+
49+
func CreateAsyncUsageInfo(info *AsyncUsageInfo) error {
50+
info.Status = AsyncUsageStatusPending
51+
info.CreatedAt = time.Now()
52+
53+
info.UpdatedAt = info.CreatedAt
54+
if info.NextPollAt.IsZero() {
55+
info.NextPollAt = info.CreatedAt.Add(AsyncUsageDefaultPollDelay)
56+
}
57+
58+
return LogDB.Create(info).Error
59+
}
60+
61+
func GetPendingAsyncUsages(limit int) ([]*AsyncUsageInfo, error) {
62+
return GetPendingAsyncUsagesDue(limit, time.Now())
63+
}
64+
65+
func GetPendingAsyncUsagesDue(
66+
limit int,
67+
now time.Time,
68+
) ([]*AsyncUsageInfo, error) {
69+
var infos []*AsyncUsageInfo
70+
71+
err := LogDB.
72+
Where("status = ?", int(AsyncUsageStatusPending)).
73+
Where(
74+
LogDB.
75+
Where("next_poll_at <= ?", now).
76+
Or("next_poll_at IS NULL"),
77+
).
78+
Order("next_poll_at ASC, updated_at ASC, created_at ASC").
79+
Limit(limit).
80+
Find(&infos).Error
81+
82+
return infos, err
83+
}
84+
85+
func AsyncUsageBackoffDelay(
86+
retryCount int,
87+
) time.Duration {
88+
if retryCount <= 1 {
89+
return AsyncUsageDefaultPollDelay
90+
}
91+
92+
delay := AsyncUsageDefaultPollDelay
93+
for range retryCount - 1 {
94+
delay *= 2
95+
if delay >= AsyncUsageMaxPollDelay {
96+
return AsyncUsageMaxPollDelay
97+
}
98+
}
99+
100+
return delay
101+
}
102+
103+
func UpdateAsyncUsageInfo(info *AsyncUsageInfo) error {
104+
info.UpdatedAt = time.Now()
105+
return LogDB.Save(info).Error
106+
}
107+
108+
func UpdateLogUsageByRequestID(
109+
requestID string,
110+
usage Usage,
111+
amount Amount,
112+
) error {
113+
var logEntry Log
114+
if err := LogDB.Where("request_id = ?", requestID).First(&logEntry).Error; err != nil {
115+
return err
116+
}
117+
118+
logEntry.Usage = usage
119+
logEntry.Amount.Add(amount)
120+
logEntry.AsyncUsageStatus = AsyncUsageStatusCompleted
121+
122+
return LogDB.Save(&logEntry).Error
123+
}
124+
125+
func UpdateLogAsyncUsageStatusByRequestID(
126+
requestID string,
127+
status AsyncUsageStatus,
128+
) error {
129+
if requestID == "" {
130+
return nil
131+
}
132+
133+
tx := LogDB.
134+
Model(&Log{}).
135+
Where("request_id = ?", requestID).
136+
Update("async_usage_status", status)
137+
if tx.Error != nil {
138+
return tx.Error
139+
}
140+
141+
if tx.RowsAffected == 0 {
142+
return NotFoundError("log")
143+
}
144+
145+
return nil
146+
}
147+
148+
func CleanupFinishedAsyncUsages(olderThan time.Duration, batchSize int) error {
149+
if batchSize <= 0 {
150+
batchSize = defaultCleanLogBatchSize
151+
}
152+
153+
cutoff := time.Now().Add(-olderThan)
154+
155+
subQuery := LogDB.
156+
Model(&AsyncUsageInfo{}).
157+
Where(
158+
"status IN (?) AND updated_at < ?",
159+
[]AsyncUsageStatus{AsyncUsageStatusCompleted, AsyncUsageStatusFailed},
160+
cutoff,
161+
).
162+
Limit(batchSize).
163+
Select("id")
164+
165+
return LogDB.
166+
Session(&gorm.Session{SkipDefaultTransaction: true}).
167+
Where("id IN (?)", subQuery).
168+
Delete(&AsyncUsageInfo{}).Error
169+
}

0 commit comments

Comments
 (0)