Skip to content

Commit 5331eba

Browse files
committed
feat: add cache plugin
1 parent 25044cd commit 5331eba

9 files changed

Lines changed: 365 additions & 65 deletions

File tree

core/controller/relay-controller.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"github.com/labring/aiproxy/core/relay/mode"
3333
relaymodel "github.com/labring/aiproxy/core/relay/model"
3434
"github.com/labring/aiproxy/core/relay/plugin"
35+
"github.com/labring/aiproxy/core/relay/plugin/cache"
3536
websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
3637
log "github.com/sirupsen/logrus"
3738
)
@@ -167,6 +168,7 @@ func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
167168
}
168169

169170
a := plugin.WrapperAdaptor(&wrapAdaptor{adaptor},
171+
cache.NewCachePlugin(),
170172
websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
171173
return getWebSearchChannel(c, modelName)
172174
}),

core/model/groupmodel.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package model
22

33
import (
4+
"errors"
5+
46
log "github.com/sirupsen/logrus"
57
"gorm.io/gorm"
68
)
@@ -26,6 +28,13 @@ type GroupModelConfig struct {
2628
RetryTimes int64 `json:"retry_times"`
2729
}
2830

31+
func (g *GroupModelConfig) BeforeSave(_ *gorm.DB) (err error) {
32+
if g.Model == "" {
33+
return errors.New("model is required")
34+
}
35+
return nil
36+
}
37+
2938
func SaveGroupModelConfig(groupModelConfig GroupModelConfig) (err error) {
3039
defer func() {
3140
if err == nil {

core/model/modelconfig.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package model
22

33
import (
44
"encoding/json"
5+
"errors"
56
"fmt"
67
"strings"
78
"time"
@@ -39,6 +40,13 @@ type ModelConfig struct {
3940
MaxErrorRate float64 `json:"max_error_rate,omitempty"`
4041
}
4142

43+
func (c *ModelConfig) BeforeSave(_ *gorm.DB) (err error) {
44+
if c.Model == "" {
45+
return errors.New("model is required")
46+
}
47+
return nil
48+
}
49+
4250
func NewDefaultModelConfig(model string) ModelConfig {
4351
return ModelConfig{
4452
Model: model,

core/relay/controller/dohelper.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (rw *responseWriter) Write(b []byte) (int, error) {
3737
if rw.firstByteAt.IsZero() {
3838
rw.firstByteAt = time.Now()
3939
}
40-
if total := rw.body.Len() + len(b); total <= maxBufferSize {
40+
if rw.body.Len()+len(b) <= maxBufferSize {
4141
rw.body.Write(b)
4242
} else {
4343
rw.body.Write(b[:maxBufferSize-rw.body.Len()])
@@ -49,7 +49,7 @@ func (rw *responseWriter) WriteString(s string) (int, error) {
4949
if rw.firstByteAt.IsZero() {
5050
rw.firstByteAt = time.Now()
5151
}
52-
if total := rw.body.Len() + len(s); total <= maxBufferSize {
52+
if rw.body.Len()+len(s) <= maxBufferSize {
5353
rw.body.WriteString(s)
5454
} else {
5555
rw.body.WriteString(s[:maxBufferSize-rw.body.Len()])
@@ -111,7 +111,9 @@ func DoHelper(
111111
return model.Usage{}, &detail, relayErr
112112
}
113113

114-
defer resp.Body.Close()
114+
if resp.Body != nil {
115+
defer resp.Body.Close()
116+
}
115117

116118
// 4. Handle success response
117119
usage, relayErr := handleResponse(a, c, meta, resp, &detail)
@@ -236,8 +238,6 @@ func handleResponse(a adaptor.Adaptor, c *gin.Context, meta *meta.Meta, resp *ht
236238
}()
237239
c.Writer = rw
238240

239-
c.Header("Content-Type", resp.Header.Get("Content-Type"))
240-
241241
usage, relayErr := a.DoResponse(meta, c, resp)
242242
if relayErr != nil {
243243
respBody, _ := relayErr.MarshalJSON()

core/relay/plugin/cache/cache.go

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
package cache
2+
3+
import (
4+
"bytes"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"fmt"
8+
"net/http"
9+
"strconv"
10+
"sync"
11+
"time"
12+
13+
"github.com/gin-gonic/gin"
14+
"github.com/labring/aiproxy/core/common"
15+
"github.com/labring/aiproxy/core/model"
16+
"github.com/labring/aiproxy/core/relay/adaptor"
17+
"github.com/labring/aiproxy/core/relay/meta"
18+
"github.com/labring/aiproxy/core/relay/plugin"
19+
"github.com/labring/aiproxy/core/relay/plugin/noop"
20+
gcache "github.com/patrickmn/go-cache"
21+
)
22+
23+
// Constants for cache metadata keys
24+
const (
25+
cacheKey = "cache_key"
26+
cacheHit = "cache_hit"
27+
cacheValue = "cache_value"
28+
)
29+
30+
// Constants for plugin configuration
31+
const (
32+
pluginConfigCacheKey = "cache-config"
33+
cacheHeader = "X-Aiproxy-Cache"
34+
)
35+
36+
// Buffer size constants
37+
const (
38+
defaultBufferSize = 512 * 1024
39+
maxBufferSize = 4 * defaultBufferSize
40+
)
41+
42+
// Item represents a cached response
43+
type Item struct {
44+
Body []byte
45+
Header http.Header
46+
Usage *model.Usage
47+
}
48+
49+
// Cache implements caching functionality for AI requests
50+
type Cache struct {
51+
noop.Noop
52+
}
53+
54+
var (
55+
_ plugin.Plugin = (*Cache)(nil)
56+
// Global cache instance with 5 minute default TTL and 10 minute cleanup interval
57+
cache = gcache.New(5*time.Minute, 10*time.Minute)
58+
// Buffer pool for response writers
59+
bufferPool = sync.Pool{
60+
New: func() any {
61+
return bytes.NewBuffer(make([]byte, 0, defaultBufferSize))
62+
},
63+
}
64+
)
65+
66+
// NewCachePlugin creates a new cache plugin
67+
func NewCachePlugin() plugin.Plugin {
68+
return &Cache{}
69+
}
70+
71+
// Cache metadata helpers
72+
func getCacheKey(meta *meta.Meta) string {
73+
return meta.GetString(cacheKey)
74+
}
75+
76+
func setCacheKey(meta *meta.Meta, key string) {
77+
meta.Set(cacheKey, key)
78+
}
79+
80+
func isCacheHit(meta *meta.Meta) bool {
81+
return meta.GetBool(cacheHit)
82+
}
83+
84+
func getCacheItem(meta *meta.Meta) *Item {
85+
v, ok := meta.Get(cacheValue)
86+
if !ok {
87+
return nil
88+
}
89+
item, ok := v.(*Item)
90+
if !ok {
91+
panic(fmt.Sprintf("cache item type not match: %T", v))
92+
}
93+
return item
94+
}
95+
96+
func setCacheHit(meta *meta.Meta, item *Item) {
97+
meta.Set(cacheHit, true)
98+
meta.Set(cacheValue, item)
99+
}
100+
101+
// Buffer pool helpers
102+
func getBuffer() *bytes.Buffer {
103+
return bufferPool.Get().(*bytes.Buffer)
104+
}
105+
106+
func putBuffer(buf *bytes.Buffer) {
107+
buf.Reset()
108+
if buf.Cap() > maxBufferSize {
109+
return
110+
}
111+
bufferPool.Put(buf)
112+
}
113+
114+
// getPluginConfig retrieves the plugin configuration from metadata
115+
func getPluginConfig(meta *meta.Meta) (config *Config, err error) {
116+
v, ok := meta.Get(pluginConfigCacheKey)
117+
if ok {
118+
config, ok := v.(*Config)
119+
if !ok {
120+
panic(fmt.Sprintf("cache config type not match: %T", v))
121+
}
122+
return config, nil
123+
}
124+
125+
pluginConfig := Config{}
126+
if err := meta.ModelConfig.LoadPluginConfig("cache", &pluginConfig); err != nil {
127+
return nil, err
128+
}
129+
meta.Set(pluginConfigCacheKey, &pluginConfig)
130+
return &pluginConfig, nil
131+
}
132+
133+
// ConvertRequest handles the request conversion phase
134+
func (c *Cache) ConvertRequest(meta *meta.Meta, req *http.Request, do adaptor.ConvertRequest) (*adaptor.ConvertRequestResult, error) {
135+
pluginConfig, err := getPluginConfig(meta)
136+
if err != nil {
137+
return do.ConvertRequest(meta, req)
138+
}
139+
if !pluginConfig.EnablePlugin {
140+
return do.ConvertRequest(meta, req)
141+
}
142+
143+
body, err := common.GetRequestBody(req)
144+
if err != nil {
145+
return nil, err
146+
}
147+
148+
// Generate hash as cache key
149+
hash := sha256.Sum256(body)
150+
cacheKey := fmt.Sprintf("%d:%s", meta.Mode, hex.EncodeToString(hash[:]))
151+
setCacheKey(meta, cacheKey)
152+
153+
item, ok := cache.Get(cacheKey)
154+
if ok {
155+
cacheItem, ok := item.(Item)
156+
if !ok {
157+
panic(fmt.Sprintf("cache item type not match: %T", item))
158+
}
159+
setCacheHit(meta, &cacheItem)
160+
return &adaptor.ConvertRequestResult{}, nil
161+
}
162+
163+
return do.ConvertRequest(meta, req)
164+
}
165+
166+
// DoRequest handles the request execution phase
167+
func (c *Cache) DoRequest(meta *meta.Meta, ctx *gin.Context, req *http.Request, do adaptor.DoRequest) (*http.Response, error) {
168+
if isCacheHit(meta) {
169+
return &http.Response{}, nil
170+
}
171+
172+
return do.DoRequest(meta, ctx, req)
173+
}
174+
175+
// Custom response writer to capture response for caching
176+
type responseWriter struct {
177+
gin.ResponseWriter
178+
cacheBody *bytes.Buffer
179+
maxSize int
180+
overflow bool
181+
}
182+
183+
func (rw *responseWriter) Write(b []byte) (int, error) {
184+
if rw.overflow {
185+
return rw.ResponseWriter.Write(b)
186+
}
187+
if rw.maxSize > 0 && rw.cacheBody.Len()+len(b) > rw.maxSize {
188+
rw.overflow = true
189+
rw.cacheBody.Reset()
190+
return rw.ResponseWriter.Write(b)
191+
}
192+
rw.cacheBody.Write(b)
193+
return rw.ResponseWriter.Write(b)
194+
}
195+
196+
func (rw *responseWriter) WriteString(s string) (int, error) {
197+
if rw.overflow {
198+
return rw.ResponseWriter.WriteString(s)
199+
}
200+
if rw.maxSize > 0 && rw.cacheBody.Len()+len(s) > rw.maxSize {
201+
rw.overflow = true
202+
rw.cacheBody.Reset()
203+
return rw.ResponseWriter.WriteString(s)
204+
}
205+
rw.cacheBody.WriteString(s)
206+
return rw.ResponseWriter.WriteString(s)
207+
}
208+
209+
// DoResponse handles the response processing phase
210+
func (c *Cache) DoResponse(meta *meta.Meta, ctx *gin.Context, resp *http.Response, do adaptor.DoResponse) (usage *model.Usage, adapterErr adaptor.Error) {
211+
pluginConfig, err := getPluginConfig(meta)
212+
if err != nil {
213+
return do.DoResponse(meta, ctx, resp)
214+
}
215+
216+
// Handle cache hit
217+
if isCacheHit(meta) {
218+
item := getCacheItem(meta)
219+
if item == nil {
220+
return do.DoResponse(meta, ctx, resp)
221+
}
222+
223+
ctx.Header("Content-Type", item.Header.Get("Content-Type"))
224+
ctx.Header("Content-Length", strconv.Itoa(len(item.Body)))
225+
if pluginConfig.AddCacheHitHeader {
226+
header := pluginConfig.CacheHitHeader
227+
if header == "" {
228+
header = cacheHeader
229+
}
230+
ctx.Header(header, "hit")
231+
}
232+
ctx.Status(http.StatusOK)
233+
_, _ = ctx.Writer.Write(item.Body)
234+
return item.Usage, nil
235+
}
236+
237+
if !pluginConfig.EnablePlugin {
238+
return do.DoResponse(meta, ctx, resp)
239+
}
240+
241+
// Set up response capture for caching
242+
buf := getBuffer()
243+
defer putBuffer(buf)
244+
245+
rw := &responseWriter{
246+
ResponseWriter: ctx.Writer,
247+
maxSize: pluginConfig.MaxSize,
248+
cacheBody: buf,
249+
}
250+
ctx.Writer = rw
251+
defer func() {
252+
ctx.Writer = rw.ResponseWriter
253+
if adapterErr != nil || rw.overflow {
254+
return
255+
}
256+
respBody := rw.cacheBody.Bytes()
257+
respHeader := rw.Header()
258+
cache.Set(getCacheKey(meta), Item{
259+
Body: bytes.Clone(respBody),
260+
Header: respHeader,
261+
Usage: usage,
262+
}, time.Duration(pluginConfig.TTL)*time.Second)
263+
}()
264+
265+
return do.DoResponse(meta, ctx, resp)
266+
}

core/relay/plugin/cache/config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package cache
2+
3+
type Config struct {
4+
EnablePlugin bool `json:"enable_plugin"`
5+
TTL int `json:"ttl"`
6+
MaxSize int `json:"max_size"`
7+
MaxItems int `json:"max_items"`
8+
AddCacheHitHeader bool `json:"add_cache_hit_header"`
9+
CacheHitHeader string `json:"cache_hit_header"`
10+
}

0 commit comments

Comments
 (0)