Skip to content

Commit d5d82ba

Browse files
authored
feat(grpc): backend SPI pluggable in embedding mode (#1621)
* run server * grpc backend embedded support * backend providable
1 parent efe2883 commit d5d82ba

File tree

8 files changed

+196
-20
lines changed

8 files changed

+196
-20
lines changed

api/backend/embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
4141

4242
var fn func() ([]float32, error)
4343
switch model := inferenceModel.(type) {
44-
case *grpc.Client:
44+
case grpc.Backend:
4545
fn = func() ([]float32, error) {
4646
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
4747
if len(tokens) > 0 {

api/backend/llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
3131

3232
grpcOpts := gRPCModelOpts(c)
3333

34-
var inferenceModel *grpc.Client
34+
var inferenceModel grpc.Backend
3535
var err error
3636

3737
opts := modelOpts(c, o, []model.Option{

pkg/grpc/backend.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
"github.com/go-skynet/LocalAI/api/schema"
6+
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
7+
"google.golang.org/grpc"
8+
)
9+
10+
var embeds = map[string]*embedBackend{}
11+
12+
func Provide(addr string, llm LLM) {
13+
embeds[addr] = &embedBackend{s: &server{llm: llm}}
14+
}
15+
16+
func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
17+
if bc, ok := embeds[address]; ok {
18+
return bc
19+
}
20+
return NewGrpcClient(address, parallel, wd, enableWatchDog)
21+
}
22+
23+
func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
24+
if !enableWatchDog {
25+
wd = nil
26+
}
27+
return &Client{
28+
address: address,
29+
parallel: parallel,
30+
wd: wd,
31+
}
32+
}
33+
34+
type Backend interface {
35+
IsBusy() bool
36+
HealthCheck(ctx context.Context) (bool, error)
37+
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
38+
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
39+
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
40+
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
41+
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
42+
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
43+
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
44+
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
45+
Status(ctx context.Context) (*pb.StatusResponse, error)
46+
}

pkg/grpc/client.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,6 @@ type WatchDog interface {
2727
UnMark(address string)
2828
}
2929

30-
func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
31-
if !enableWatchDog {
32-
wd = nil
33-
}
34-
return &Client{
35-
address: address,
36-
parallel: parallel,
37-
wd: wd,
38-
}
39-
}
40-
4130
func (c *Client) IsBusy() bool {
4231
c.Lock()
4332
defer c.Unlock()

pkg/grpc/embed.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
"github.com/go-skynet/LocalAI/api/schema"
6+
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
7+
"google.golang.org/grpc"
8+
"google.golang.org/grpc/metadata"
9+
"time"
10+
)
11+
12+
var _ Backend = new(embedBackend)
13+
var _ pb.Backend_PredictStreamServer = new(embedBackendServerStream)
14+
15+
type embedBackend struct {
16+
s *server
17+
}
18+
19+
func (e *embedBackend) IsBusy() bool {
20+
return e.s.llm.Busy()
21+
}
22+
23+
func (e *embedBackend) HealthCheck(ctx context.Context) (bool, error) {
24+
return true, nil
25+
}
26+
27+
func (e *embedBackend) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
28+
return e.s.Embedding(ctx, in)
29+
}
30+
31+
func (e *embedBackend) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
32+
return e.s.Predict(ctx, in)
33+
}
34+
35+
func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
36+
return e.s.LoadModel(ctx, in)
37+
}
38+
39+
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
40+
bs := &embedBackendServerStream{
41+
ctx: ctx,
42+
fn: f,
43+
}
44+
return e.s.PredictStream(in, bs)
45+
}
46+
47+
func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
48+
return e.s.GenerateImage(ctx, in)
49+
}
50+
51+
func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
52+
return e.s.TTS(ctx, in)
53+
}
54+
55+
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
56+
r, err := e.s.AudioTranscription(ctx, in)
57+
if err != nil {
58+
return nil, err
59+
}
60+
tr := &schema.Result{}
61+
for _, s := range r.Segments {
62+
var tks []int
63+
for _, t := range s.Tokens {
64+
tks = append(tks, int(t))
65+
}
66+
tr.Segments = append(tr.Segments,
67+
schema.Segment{
68+
Text: s.Text,
69+
Id: int(s.Id),
70+
Start: time.Duration(s.Start),
71+
End: time.Duration(s.End),
72+
Tokens: tks,
73+
})
74+
}
75+
tr.Text = r.Text
76+
return tr, err
77+
}
78+
79+
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
80+
return e.s.TokenizeString(ctx, in)
81+
}
82+
83+
func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) {
84+
return e.s.Status(ctx, &pb.HealthMessage{})
85+
}
86+
87+
type embedBackendServerStream struct {
88+
ctx context.Context
89+
fn func(s []byte)
90+
}
91+
92+
func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
93+
e.fn(reply.GetMessage())
94+
return nil
95+
}
96+
97+
func (e *embedBackendServerStream) SetHeader(md metadata.MD) error {
98+
return nil
99+
}
100+
101+
func (e *embedBackendServerStream) SendHeader(md metadata.MD) error {
102+
return nil
103+
}
104+
105+
func (e *embedBackendServerStream) SetTrailer(md metadata.MD) {
106+
}
107+
108+
func (e *embedBackendServerStream) Context() context.Context {
109+
return e.ctx
110+
}
111+
112+
func (e *embedBackendServerStream) SendMsg(m any) error {
113+
if x, ok := m.(*pb.Reply); ok {
114+
return e.Send(x)
115+
}
116+
return nil
117+
}
118+
119+
func (e *embedBackendServerStream) RecvMsg(m any) error {
120+
return nil
121+
}

pkg/grpc/server.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,23 @@ func StartServer(address string, model LLM) error {
181181

182182
return nil
183183
}
184+
185+
func RunServer(address string, model LLM) (func() error, error) {
186+
lis, err := net.Listen("tcp", address)
187+
if err != nil {
188+
return nil, err
189+
}
190+
s := grpc.NewServer()
191+
pb.RegisterBackendServer(s, &server{llm: model})
192+
log.Printf("gRPC Server listening at %v", lis.Addr())
193+
if err = s.Serve(lis); err != nil {
194+
return func() error {
195+
return lis.Close()
196+
}, err
197+
}
198+
199+
return func() error {
200+
s.GracefulStop()
201+
return nil
202+
}, nil
203+
}

pkg/model/initializers.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
166166
}
167167
}
168168

169-
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
169+
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
170170
if parallel {
171171
return addr.GRPC(parallel, ml.wd), nil
172172
}
@@ -177,7 +177,7 @@ func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.C
177177
return ml.grpcClients[string(addr)], nil
178178
}
179179

180-
func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err error) {
180+
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
181181
o := NewOptions(opts...)
182182

183183
if o.model != "" {
@@ -220,7 +220,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err e
220220
return ml.resolveAddress(addr, o.parallelRequests)
221221
}
222222

223-
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
223+
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
224224
o := NewOptions(opts...)
225225

226226
ml.mu.Lock()

pkg/model/loader.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type ModelLoader struct {
5959
ModelPath string
6060
mu sync.Mutex
6161
// TODO: this needs generics
62-
grpcClients map[string]*grpc.Client
62+
grpcClients map[string]grpc.Backend
6363
models map[string]ModelAddress
6464
grpcProcesses map[string]*process.Process
6565
templates map[TemplateType]map[string]*template.Template
@@ -68,7 +68,7 @@ type ModelLoader struct {
6868

6969
type ModelAddress string
7070

71-
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
71+
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
7272
enableWD := false
7373
if wd != nil {
7474
enableWD = true
@@ -79,7 +79,7 @@ func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
7979
func NewModelLoader(modelPath string) *ModelLoader {
8080
nml := &ModelLoader{
8181
ModelPath: modelPath,
82-
grpcClients: make(map[string]*grpc.Client),
82+
grpcClients: make(map[string]grpc.Backend),
8383
models: make(map[string]ModelAddress),
8484
templates: make(map[TemplateType]map[string]*template.Template),
8585
grpcProcesses: make(map[string]*process.Process),
@@ -163,7 +163,7 @@ func (ml *ModelLoader) StopModel(modelName string) error {
163163
}
164164

165165
func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
166-
var client *grpc.Client
166+
var client grpc.Backend
167167
if m, ok := ml.models[s]; ok {
168168
log.Debug().Msgf("Model already loaded in memory: %s", s)
169169
if c, ok := ml.grpcClients[s]; ok {

0 commit comments

Comments
 (0)