Skip to content

Commit 6ac5d81

Browse files
authored
feat(startup): fetch model definition remotely (#1654)
1 parent f928899 commit 6ac5d81

File tree

6 files changed

+68
-6
lines changed

6 files changed

+68
-6
lines changed

api/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
3737
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
3838
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
3939

40-
startup.PreloadModelsConfigurations(options.Loader.ModelPath, options.ModelsURL...)
40+
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
4141

4242
cl := config.NewConfigLoader()
4343
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {

api/options/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ type Option struct {
2828
ApiKeys []string
2929
Metrics *metrics.Metrics
3030

31+
ModelLibraryURL string
32+
3133
Galleries []gallery.Gallery
3234

3335
BackendAssets embed.FS
@@ -78,6 +80,12 @@ func WithCors(b bool) AppOption {
7880
}
7981
}
8082

83+
func WithModelLibraryURL(url string) AppOption {
84+
return func(o *Option) {
85+
o.ModelLibraryURL = url
86+
}
87+
}
88+
8189
var EnableWatchDog = func(o *Option) {
8290
o.WatchDog = true
8391
}

embedded/embedded.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"slices"
77
"strings"
88

9+
"github.com/go-skynet/LocalAI/pkg/downloader"
10+
911
"github.com/go-skynet/LocalAI/pkg/assets"
1012
"gopkg.in/yaml.v3"
1113
)
@@ -30,6 +32,19 @@ func init() {
3032
yaml.Unmarshal(modelLibrary, &modelShorteners)
3133
}
3234

35+
func GetRemoteLibraryShorteners(url string) (map[string]string, error) {
36+
remoteLibrary := map[string]string{}
37+
38+
err := downloader.GetURI(url, func(_ string, i []byte) error {
39+
return yaml.Unmarshal(i, &remoteLibrary)
40+
})
41+
if err != nil {
42+
return nil, fmt.Errorf("error downloading remote library: %s", err.Error())
43+
}
44+
45+
return remoteLibrary, err
46+
}
47+
3348
// ExistsInModelsLibrary checks if a model exists in the embedded models library
3449
func ExistsInModelsLibrary(s string) bool {
3550
f := fmt.Sprintf("%s.yaml", s)

main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ import (
2626
"github.com/urfave/cli/v2"
2727
)
2828

29+
const (
30+
remoteLibraryURL = "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
31+
)
32+
2933
func main() {
3034
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
3135
// clean up process
@@ -94,6 +98,12 @@ func main() {
9498
Usage: "JSON list of galleries",
9599
EnvVars: []string{"GALLERIES"},
96100
},
101+
&cli.StringFlag{
102+
Name: "remote-library",
103+
Usage: "A LocalAI remote library URL",
104+
EnvVars: []string{"REMOTE_LIBRARY"},
105+
Value: remoteLibraryURL,
106+
},
97107
&cli.StringFlag{
98108
Name: "preload-models",
99109
Usage: "A List of models to apply in JSON at start",
@@ -219,6 +229,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
219229
options.WithAudioDir(ctx.String("audio-path")),
220230
options.WithF16(ctx.Bool("f16")),
221231
options.WithStringGalleries(ctx.String("galleries")),
232+
options.WithModelLibraryURL(ctx.String("remote-library")),
222233
options.WithDisableMessage(false),
223234
options.WithCors(ctx.Bool("cors")),
224235
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),

pkg/startup/model_preload.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,22 @@ import (
1414
// PreloadModelsConfigurations will preload models from the given list of URLs
1515
// It will download the model if it is not already present in the model path
1616
// It will also try to resolve if the model is an embedded model YAML configuration
17-
func PreloadModelsConfigurations(modelPath string, models ...string) {
17+
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
1818
for _, url := range models {
19-
url = embedded.ModelShortURL(url)
2019

20+
// As a best effort, try to resolve the model from the remote library
21+
// if it's not resolved we try with the other method below
22+
if modelLibraryURL != "" {
23+
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
24+
if err == nil {
25+
if lib[url] != "" {
26+
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
27+
url = lib[url]
28+
}
29+
}
30+
}
31+
32+
url = embedded.ModelShortURL(url)
2133
switch {
2234
case embedded.ExistsInModelsLibrary(url):
2335
modelYAML, err := embedded.ResolveContent(url)

pkg/startup/model_preload_test.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,29 @@ import (
1515
var _ = Describe("Preload test", func() {
1616

1717
Context("Preloading from strings", func() {
18+
It("loads from remote url", func() {
19+
tmpdir, err := os.MkdirTemp("", "")
20+
Expect(err).ToNot(HaveOccurred())
21+
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
22+
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
23+
24+
PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")
25+
26+
resultFile := filepath.Join(tmpdir, fileName)
27+
28+
content, err := os.ReadFile(resultFile)
29+
Expect(err).ToNot(HaveOccurred())
30+
31+
Expect(string(content)).To(ContainSubstring("name: phi-2"))
32+
})
33+
1834
It("loads from embedded full-urls", func() {
1935
tmpdir, err := os.MkdirTemp("", "")
2036
Expect(err).ToNot(HaveOccurred())
2137
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
2238
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
2339

24-
PreloadModelsConfigurations(tmpdir, url)
40+
PreloadModelsConfigurations("", tmpdir, url)
2541

2642
resultFile := filepath.Join(tmpdir, fileName)
2743

@@ -35,7 +51,7 @@ var _ = Describe("Preload test", func() {
3551
Expect(err).ToNot(HaveOccurred())
3652
url := "phi-2"
3753

38-
PreloadModelsConfigurations(tmpdir, url)
54+
PreloadModelsConfigurations("", tmpdir, url)
3955

4056
entry, err := os.ReadDir(tmpdir)
4157
Expect(err).ToNot(HaveOccurred())
@@ -53,7 +69,7 @@ var _ = Describe("Preload test", func() {
5369
url := "mistral-openorca"
5470
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
5571

56-
PreloadModelsConfigurations(tmpdir, url)
72+
PreloadModelsConfigurations("", tmpdir, url)
5773

5874
resultFile := filepath.Join(tmpdir, fileName)
5975

0 commit comments

Comments
 (0)