Skip to content

Commit 78a6c79

Browse files
committed
Proxy AI/LLM requests thorugh the balancer to hide the actual AI token
from the individual instances
1 parent 749c2c6 commit 78a6c79

19 files changed

Lines changed: 889 additions & 20 deletions

File tree

balancer/main.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"log"
66
"net/http"
7+
"os"
78

89
"github.com/juice-shop/multi-juicer/balancer/pkg/bundle"
10+
"github.com/juice-shop/multi-juicer/balancer/pkg/llmgateway"
911
"github.com/juice-shop/multi-juicer/balancer/pkg/notification"
1012
"github.com/juice-shop/multi-juicer/balancer/pkg/scoring"
1113
"github.com/juice-shop/multi-juicer/balancer/routes"
@@ -29,6 +31,20 @@ func main() {
2931
scoringService.CalculateAndCacheScoreBoard(ctx)
3032
go scoringService.StartingScoringWorker(ctx)
3133
go notificationService.StartNotificationWatcher(ctx)
34+
35+
if b.Config.JuiceShopConfig.LLM.Enabled {
36+
llmAPIKey := os.Getenv("LLM_API_KEY")
37+
llmAPIURL := os.Getenv("LLM_API_URL")
38+
39+
usage := llmgateway.NewUsageTracker()
40+
gateway, err := llmgateway.NewGateway(b.Config.CookieConfig.SigningKey, llmAPIURL, llmAPIKey, usage, b.Log)
41+
if err != nil {
42+
log.Fatalf("Failed to create LLM gateway: %v", err)
43+
}
44+
go usage.StartFlusher(ctx, b.ClientSet, b.RuntimeEnvironment.Namespace, b.Log)
45+
go StartLLMGatewayServer(gateway, b.Log)
46+
}
47+
3248
StartBalancerServer(b)
3349
}
3450

@@ -47,6 +63,19 @@ func StartBalancerServer(b *bundle.Bundle) {
4763
}
4864
}
4965

66+
func StartLLMGatewayServer(gateway *llmgateway.Gateway, logger *log.Logger) {
67+
router := http.NewServeMux()
68+
router.Handle("/", gateway)
69+
server := &http.Server{
70+
Addr: ":8082",
71+
Handler: router,
72+
}
73+
logger.Println("Starting LLM gateway on :8082")
74+
if err := server.ListenAndServe(); err != nil {
75+
log.Fatalf("Failed to start LLM gateway server: %v", err)
76+
}
77+
}
78+
5079
func StartMetricsServer() {
5180
metricsRouter := http.NewServeMux()
5281
metricsRouter.Handle("GET /balancer/metrics", promhttp.Handler())

balancer/pkg/bundle/bundle.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ type CookieConfig struct {
6868
Secure bool `json:"secure"`
6969
}
7070

71+
type LLMConfig struct {
72+
Enabled bool `json:"enabled"`
73+
Model string `json:"model"`
74+
ApiUrl string `json:"apiUrl"`
75+
}
76+
7177
type JuiceShopConfig struct {
7278
Image string `json:"image"`
7379
Tag string `json:"tag"`
@@ -87,6 +93,8 @@ type JuiceShopConfig struct {
8793
VolumeMounts []corev1.VolumeMount `json:"volumeMounts"`
8894
RuntimeClassName *string `json:"runtimeClassName"`
8995

96+
LLM LLMConfig `json:"llm"`
97+
9098
JuiceShopPodConfig JuiceShopPodConfig `json:"pod"`
9199
}
92100

balancer/pkg/llmgateway/gateway.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package llmgateway
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"log"
8+
"net/http"
9+
"net/http/httputil"
10+
"net/url"
11+
"strings"
12+
13+
"github.com/juice-shop/multi-juicer/balancer/pkg/signutil"
14+
)
15+
16+
// openAIResponse is a minimal representation of an OpenAI chat completion response for usage extraction.
17+
type openAIResponse struct {
18+
Usage *openAIUsage `json:"usage,omitempty"`
19+
}
20+
21+
type openAIUsage struct {
22+
InputTokens int64 `json:"prompt_tokens"`
23+
OutputTokens int64 `json:"completion_tokens"`
24+
}
25+
26+
// Gateway proxies LLM requests from JuiceShop instances to an upstream LLM API.
27+
type Gateway struct {
28+
signingKey string
29+
upstreamURL *url.URL
30+
apiKey string
31+
usage *UsageTracker
32+
logger *log.Logger
33+
}
34+
35+
// NewGateway creates a new LLM gateway.
36+
func NewGateway(signingKey string, upstreamURL string, apiKey string, usage *UsageTracker, logger *log.Logger) (*Gateway, error) {
37+
u, err := url.Parse(upstreamURL)
38+
if err != nil {
39+
return nil, err
40+
}
41+
return &Gateway{
42+
signingKey: signingKey,
43+
upstreamURL: u,
44+
apiKey: apiKey,
45+
usage: usage,
46+
logger: logger,
47+
}, nil
48+
}
49+
50+
func (g *Gateway) ServeHTTP(w http.ResponseWriter, r *http.Request) {
51+
// Extract bearer token
52+
authHeader := r.Header.Get("Authorization")
53+
if !strings.HasPrefix(authHeader, "Bearer ") {
54+
http.Error(w, `{"error":"missing or invalid Authorization header"}`, http.StatusUnauthorized)
55+
return
56+
}
57+
teamToken := strings.TrimPrefix(authHeader, "Bearer ")
58+
59+
// Validate token by verifying the HMAC signature and extracting the team name
60+
team, err := signutil.Unsign(teamToken, g.signingKey)
61+
if err != nil {
62+
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
63+
return
64+
}
65+
66+
// Check if this is a chat completions request (for usage tracking)
67+
isChatCompletion := strings.Contains(r.URL.Path, "/chat/completions")
68+
g.logger.Printf("LLM gateway: request from team '%s': %s %s (isChatCompletion=%v)", team, r.Method, r.URL.Path, isChatCompletion)
69+
70+
// Create reverse proxy
71+
proxy := &httputil.ReverseProxy{
72+
Rewrite: func(pr *httputil.ProxyRequest) {
73+
pr.SetURL(g.upstreamURL)
74+
pr.Out.Host = g.upstreamURL.Host
75+
// Replace the authorization header with the real API key
76+
pr.Out.Header.Set("Authorization", "Bearer "+g.apiKey)
77+
},
78+
}
79+
80+
if isChatCompletion {
81+
proxy.ModifyResponse = func(resp *http.Response) error {
82+
return g.extractUsage(resp, team)
83+
}
84+
}
85+
86+
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
87+
g.logger.Printf("LLM gateway proxy error for team '%s': %v", team, err)
88+
http.Error(w, `{"error":"upstream LLM API error"}`, http.StatusBadGateway)
89+
}
90+
91+
proxy.ServeHTTP(w, r)
92+
}
93+
94+
func (g *Gateway) extractUsage(resp *http.Response, team string) error {
95+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
96+
return nil
97+
}
98+
99+
contentType := resp.Header.Get("Content-Type")
100+
isSSE := strings.Contains(contentType, "text/event-stream")
101+
102+
body, err := io.ReadAll(resp.Body)
103+
resp.Body.Close()
104+
if err != nil {
105+
g.logger.Printf("LLM gateway: failed to read response body for team '%s': %v", team, err)
106+
resp.Body = io.NopCloser(bytes.NewReader(body))
107+
return nil
108+
}
109+
110+
// Restore the body for the client
111+
resp.Body = io.NopCloser(bytes.NewReader(body))
112+
113+
if isSSE {
114+
g.extractUsageFromSSE(body, team)
115+
} else {
116+
g.extractUsageFromJSON(body, team)
117+
}
118+
return nil
119+
}
120+
121+
func (g *Gateway) extractUsageFromJSON(body []byte, team string) {
122+
var result openAIResponse
123+
if err := json.Unmarshal(body, &result); err != nil {
124+
return
125+
}
126+
if result.Usage != nil {
127+
g.logger.Printf("LLM gateway: usage for team '%s': input_tokens=%d, output_tokens=%d", team, result.Usage.InputTokens, result.Usage.OutputTokens)
128+
g.usage.Add(team, result.Usage.InputTokens, result.Usage.OutputTokens)
129+
}
130+
}
131+
132+
// extractUsageFromSSE scans SSE events for usage data, which typically appears in the last chunk.
133+
func (g *Gateway) extractUsageFromSSE(body []byte, team string) {
134+
// SSE format: lines starting with "data: " contain JSON payloads
135+
// Scan all data lines for usage (it's usually in the last real chunk before "data: [DONE]")
136+
lines := strings.Split(string(body), "\n")
137+
for _, line := range lines {
138+
line = strings.TrimSpace(line)
139+
if !strings.HasPrefix(line, "data: ") {
140+
continue
141+
}
142+
payload := strings.TrimPrefix(line, "data: ")
143+
if payload == "[DONE]" {
144+
continue
145+
}
146+
var chunk openAIResponse
147+
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
148+
continue
149+
}
150+
if chunk.Usage != nil {
151+
g.logger.Printf("LLM gateway: SSE usage for team '%s': input_tokens=%d, output_tokens=%d", team, chunk.Usage.InputTokens, chunk.Usage.OutputTokens)
152+
g.usage.Add(team, chunk.Usage.InputTokens, chunk.Usage.OutputTokens)
153+
return
154+
}
155+
}
156+
g.logger.Printf("LLM gateway: no usage data found in SSE stream for team '%s'", team)
157+
}

0 commit comments

Comments
 (0)