Skip to content

Commit c12431f

Browse files
committed
refactor(inference): applyAndMerge helper, fix local-path dispatch, checksum validation (phase 5)
1 parent 59c1395 commit c12431f

10 files changed

Lines changed: 244 additions & 143 deletions

File tree

pkg/aikit/config/validate.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ package config
22

33
import (
44
"errors"
5+
"regexp"
56
"slices"
67

78
"github.com/kaito-project/aikit/pkg/utils"
89
pkgerrors "github.com/pkg/errors"
910
)
1011

12+
// sha256HexPattern matches a bare 64-character lowercase hex SHA256 checksum.
13+
var sha256HexPattern = regexp.MustCompile(`^[a-f0-9]{64}$`)
14+
1115
// Validate checks that the inference config is internally consistent and only
1216
// references supported backends and runtimes. Membership errors (unknown
1317
// backend / unknown runtime) are accumulated with errors.Join so that a config
@@ -71,6 +75,15 @@ func (c *InferenceConfig) Validate() error {
7175
}
7276
}
7377

78+
// Validate any provided model checksums up front so a malformed value fails
79+
// the build immediately with a clear message rather than producing a broken
80+
// digest deep in LLB construction.
81+
for _, m := range c.Models {
82+
if m.SHA256 != "" && !sha256HexPattern.MatchString(m.SHA256) {
83+
return pkgerrors.Errorf("model %q has an invalid sha256 checksum %q: expected 64 lowercase hex characters", m.Name, m.SHA256)
84+
}
85+
}
86+
7487
return nil
7588
}
7689

pkg/aikit/config/validate_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package config
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/kaito-project/aikit/pkg/utils"
8+
)
9+
10+
func TestInferenceConfigValidateChecksum(t *testing.T) {
11+
validSHA := strings.Repeat("a", 64)
12+
tests := []struct {
13+
name string
14+
sha string
15+
wantErr bool
16+
}{
17+
{name: "empty checksum is allowed", sha: "", wantErr: false},
18+
{name: "valid 64-char hex", sha: validSHA, wantErr: false},
19+
{name: "too short", sha: "abc123", wantErr: true},
20+
{name: "uppercase rejected", sha: strings.Repeat("A", 64), wantErr: true},
21+
{name: "algo-prefixed rejected", sha: "sha256:" + validSHA, wantErr: true},
22+
}
23+
for _, tt := range tests {
24+
t.Run(tt.name, func(t *testing.T) {
25+
c := &InferenceConfig{
26+
APIVersion: utils.APIv1alpha1,
27+
Backends: []string{utils.BackendLlamaCpp},
28+
Models: []Model{{Name: "m", Source: "http://x/m.gguf", SHA256: tt.sha}},
29+
}
30+
err := c.Validate()
31+
if (err != nil) != tt.wantErr {
32+
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
33+
}
34+
})
35+
}
36+
}
37+
38+
func TestInferenceConfigValidateAggregatesMembershipErrors(t *testing.T) {
39+
// Both an unknown backend and an unknown runtime should be reported together.
40+
c := &InferenceConfig{
41+
APIVersion: utils.APIv1alpha1,
42+
Runtime: "bogus-runtime",
43+
Backends: []string{"bogus-backend"},
44+
}
45+
err := c.Validate()
46+
if err == nil {
47+
t.Fatal("expected error for invalid backend and runtime")
48+
}
49+
msg := err.Error()
50+
if !strings.Contains(msg, "bogus-backend") || !strings.Contains(msg, "bogus-runtime") {
51+
t.Errorf("expected aggregated error to mention both backend and runtime, got: %s", msg)
52+
}
53+
}

pkg/aikit2llb/inference/backend.go

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -163,51 +163,51 @@ func installBackend(backend string, c *config.InferenceConfig, platform specs.Pl
163163
ociImage := fmt.Sprintf("%s:%s", utils.BackendOCIRegistry, tag)
164164

165165
// Create the backends directory
166-
savedState := s
167166
backendName := getBackendName(backend, c.Runtime, platform)
168167
backendDir := fmt.Sprintf("/backends/%s", backendName)
169168

170169
// Download the backend from OCI registry and extract to specific backend directory
171170
backendState := llb.Image(ociImage, llb.Platform(platform))
172171

173-
// Copy the backend files to the specific backend directory
174-
s = s.File(
175-
llb.Copy(backendState, "/", backendDir+"/", &llb.CopyInfo{
176-
CreateDestPath: true,
177-
AllowWildcard: true,
178-
}),
179-
llb.WithCustomName(fmt.Sprintf("Installing backend %s from %s", backend, ociImage)),
180-
)
181-
182-
// Ensure the directory exists and create metadata.json for the backend
183-
backendAlias := getBackendAlias(backend)
184-
metadataContent := fmt.Sprintf(`{
172+
_, merge = applyAndMerge(s, merge, func(s llb.State) llb.State {
173+
// Copy the backend files to the specific backend directory
174+
s = s.File(
175+
llb.Copy(backendState, "/", backendDir+"/", &llb.CopyInfo{
176+
CreateDestPath: true,
177+
AllowWildcard: true,
178+
}),
179+
llb.WithCustomName(fmt.Sprintf("Installing backend %s from %s", backend, ociImage)),
180+
)
181+
182+
// Ensure the directory exists and create metadata.json for the backend
183+
backendAlias := getBackendAlias(backend)
184+
metadataContent := fmt.Sprintf(`{
185185
"alias": "%s",
186186
"name": "%s",
187187
"gallery_url": "github:mudler/LocalAI/backend/index.yaml@master",
188188
"installed_at": "%s"
189189
}`, backendAlias, backendName, time.Now().UTC().Format(time.RFC3339))
190190

191-
s = s.File(
192-
llb.Mkfile(fmt.Sprintf("%s/metadata.json", backendDir), 0o644, []byte(metadataContent)),
193-
llb.WithCustomName(fmt.Sprintf("Creating metadata.json for backend %s", backendName)),
194-
)
195-
196-
// Apply workarounds for the pre-built vLLM backend image.
197-
if backend == utils.BackendVLLM {
198-
// Remove broken flash_attn package (PyTorch ABI incompatibility).
199-
// Patch backend.py to use the current vLLM AsyncLLM API
200-
// (get_model_config() was replaced by the model_config property).
201-
s = s.Run(utils.Shf(
202-
"rm -rf %[1]s/venv/lib/python*/site-packages/flash_attn* && "+
203-
"sed -i 's/await self.llm.get_model_config()/self.llm.model_config/' %[1]s/backend.py",
204-
backendDir),
205-
llb.WithCustomNamef("Patching vLLM backend %s for compatibility", backendName),
206-
).Root()
207-
}
208-
209-
diff := llb.Diff(savedState, s)
210-
return llb.Merge([]llb.State{merge, diff})
191+
s = s.File(
192+
llb.Mkfile(fmt.Sprintf("%s/metadata.json", backendDir), 0o644, []byte(metadataContent)),
193+
llb.WithCustomName(fmt.Sprintf("Creating metadata.json for backend %s", backendName)),
194+
)
195+
196+
// Apply workarounds for the pre-built vLLM backend image.
197+
if backend == utils.BackendVLLM {
198+
// Remove broken flash_attn package (PyTorch ABI incompatibility).
199+
// Patch backend.py to use the current vLLM AsyncLLM API
200+
// (get_model_config() was replaced by the model_config property).
201+
s = s.Run(utils.Shf(
202+
"rm -rf %[1]s/venv/lib/python*/site-packages/flash_attn* && "+
203+
"sed -i 's/await self.llm.get_model_config()/self.llm.model_config/' %[1]s/backend.py",
204+
backendDir),
205+
llb.WithCustomNamef("Patching vLLM backend %s for compatibility", backendName),
206+
).Root()
207+
}
208+
return s
209+
})
210+
return merge
211211
}
212212

213213
// getDefaultBackends returns the default backends based on runtime if no backends are specified.

pkg/aikit2llb/inference/convert.go

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package inference
22

33
import (
44
"fmt"
5-
"net/url"
65
"slices"
76
"strings"
87

@@ -90,39 +89,37 @@ func getBaseImage(c *config.InferenceConfig, platform *specs.Platform) llb.State
9089

9190
// writeConfig writes the /config.yaml file to the image when c.Config is set.
9291
func writeConfig(c *config.InferenceConfig, base llb.State, s llb.State, platform specs.Platform) (llb.State, llb.State) {
93-
savedState := s
94-
if c.Config != "" {
95-
s = s.File(
96-
llb.Mkfile("/config.yaml", 0o644, []byte(c.Config)),
97-
llb.WithCustomName(fmt.Sprintf("Creating config for platform %s/%s", platform.OS, platform.Architecture)),
98-
)
99-
}
100-
diff := llb.Diff(savedState, s)
101-
merge := llb.Merge([]llb.State{base, diff})
102-
return s, merge
92+
return applyAndMerge(s, base, func(s llb.State) llb.State {
93+
if c.Config != "" {
94+
s = s.File(
95+
llb.Mkfile("/config.yaml", 0o644, []byte(c.Config)),
96+
llb.WithCustomName(fmt.Sprintf("Creating config for platform %s/%s", platform.OS, platform.Architecture)),
97+
)
98+
}
99+
return s
100+
})
103101
}
104102

105103
// copyModels copies models to the image and writes the config.
106104
func copyModels(c *config.InferenceConfig, base llb.State, s llb.State, platform specs.Platform) (llb.State, llb.State, error) {
107105
savedState := s
108106
for _, model := range c.Models {
109-
// Check if the model source is a URL
110-
if _, err := url.ParseRequestURI(model.Source); err == nil {
111-
switch {
112-
case strings.HasPrefix(model.Source, "oci://"):
113-
s = handleOCI(model.Source, s, platform)
114-
case strings.HasPrefix(model.Source, "http://"), strings.HasPrefix(model.Source, "https://"):
115-
s = handleHTTP(model.Source, model.Name, model.SHA256, s)
116-
case strings.HasPrefix(model.Source, "huggingface://"):
117-
s, err = handleHuggingFace(model.Source, s)
118-
if err != nil {
119-
return llb.State{}, llb.State{}, err
120-
}
121-
default:
122-
return llb.State{}, llb.State{}, fmt.Errorf("unsupported URL scheme: %s", model.Source)
107+
// Dispatch on the source's URI scheme. Anything without a recognized
108+
// scheme (including absolute local paths like /models/foo.gguf) is treated
109+
// as a local file. The previous url.ParseRequestURI guard incorrectly
110+
// rejected absolute local paths, which parse as URIs with an empty scheme.
111+
var err error
112+
switch {
113+
case strings.HasPrefix(model.Source, "oci://"):
114+
s = handleOCI(model.Source, s, platform)
115+
case strings.HasPrefix(model.Source, "http://"), strings.HasPrefix(model.Source, "https://"):
116+
s = handleHTTP(model.Source, model.Name, model.SHA256, s)
117+
case strings.HasPrefix(model.Source, "huggingface://"):
118+
s, err = handleHuggingFace(model.Source, s)
119+
if err != nil {
120+
return llb.State{}, llb.State{}, err
123121
}
124-
} else {
125-
// Handle local paths
122+
default:
126123
s = handleLocal(model.Source, s)
127124
}
128125

@@ -155,50 +152,47 @@ func installCuda(c *config.InferenceConfig, s llb.State, merge llb.State) (llb.S
155152
)
156153
s = s.Run(utils.Sh("dpkg -i cuda-keyring_1.1-1_all.deb && rm cuda-keyring_1.1-1_all.deb")).Root()
157154

158-
savedState := s
159-
// running apt-get update twice due to nvidia repo
160-
s = s.Run(utils.Sh("apt-get update && apt-get install --no-install-recommends -y ca-certificates && apt-get update"), llb.IgnoreCache).Root()
161-
162-
// install cuda libraries for llama-cpp (default) and vllm backends
163-
if len(c.Backends) == 0 || slices.Contains(c.Backends, utils.BackendLlamaCpp) || slices.Contains(c.Backends, utils.BackendVLLM) {
164-
// install cuda libraries and pciutils for gpu detection
165-
s = s.Run(utils.Shf("apt-get install -y --no-install-recommends pciutils libcublas-%[1]s cuda-cudart-%[1]s && apt-get clean", cudaVersion)).Root()
166-
// TODO: clean up /var/lib/dpkg/status
167-
}
155+
return applyAndMerge(s, merge, func(s llb.State) llb.State {
156+
// running apt-get update twice due to nvidia repo
157+
s = s.Run(utils.Sh("apt-get update && apt-get install --no-install-recommends -y ca-certificates && apt-get update"), llb.IgnoreCache).Root()
168158

169-
diff := llb.Diff(savedState, s)
170-
return s, llb.Merge([]llb.State{merge, diff})
159+
// install cuda libraries for llama-cpp (default) and vllm backends
160+
if len(c.Backends) == 0 || slices.Contains(c.Backends, utils.BackendLlamaCpp) || slices.Contains(c.Backends, utils.BackendVLLM) {
161+
// install cuda libraries and pciutils for gpu detection
162+
s = s.Run(utils.Shf("apt-get install -y --no-install-recommends pciutils libcublas-%[1]s cuda-cudart-%[1]s && apt-get clean", cudaVersion)).Root()
163+
// TODO: clean up /var/lib/dpkg/status
164+
}
165+
return s
166+
})
171167
}
172168

173169
func installRocm(c *config.InferenceConfig, s llb.State, merge llb.State) (llb.State, llb.State) {
174-
savedState := s
175-
176-
// Set up ROCm repository
177-
s = s.Run(utils.Sh("apt-get update && apt-get install --no-install-recommends -y ca-certificates curl gnupg"), llb.IgnoreCache).Root()
178-
179-
// Add ROCm GPG key and repository
180-
s = s.Run(utils.Sh("curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm.gpg")).Root()
181-
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/rocm/apt/%s/ noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
182-
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/graphics/%s/ubuntu noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
183-
rocmPinning := `
170+
return applyAndMerge(s, merge, func(s llb.State) llb.State {
171+
// Set up ROCm repository
172+
s = s.Run(utils.Sh("apt-get update && apt-get install --no-install-recommends -y ca-certificates curl gnupg"), llb.IgnoreCache).Root()
173+
174+
// Add ROCm GPG key and repository
175+
s = s.Run(utils.Sh("curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm.gpg")).Root()
176+
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/rocm/apt/%s/ noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
177+
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/graphics/%s/ubuntu noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
178+
rocmPinning := `
184179
Package: *
185180
Pin: release o=repo.radeon.com
186181
Pin-Priority: 600
187182
`
188-
s = s.Run(utils.Shf("echo '%s' > /etc/apt/preferences.d/repo-radeon-pin-600", rocmPinning)).Root()
189-
s = s.Run(utils.Sh("apt-get update"), llb.IgnoreCache).Root()
190-
191-
// install rocm libraries and pciutils for gpu detection when using the default
192-
// llama-cpp backend or when it is configured explicitly
193-
if len(c.Backends) == 0 || slices.Contains(c.Backends, utils.BackendLlamaCpp) {
194-
s = s.Run(utils.Sh("apt-get install -y pciutils rocm && apt-get clean")).Root()
195-
}
183+
s = s.Run(utils.Shf("echo '%s' > /etc/apt/preferences.d/repo-radeon-pin-600", rocmPinning)).Root()
184+
s = s.Run(utils.Sh("apt-get update"), llb.IgnoreCache).Root()
196185

197-
// hipblaslt soname compatibility: backend may be linked against .so.0 while ROCm 7.2 ships .so.1
198-
s = s.Run(utils.Sh("set -e; cd /opt/rocm/lib; [ -e libhipblaslt.so.0 ] || ln -sf libhipblaslt.so.1 libhipblaslt.so.0")).Root()
186+
// install rocm libraries and pciutils for gpu detection when using the default
187+
// llama-cpp backend or when it is configured explicitly
188+
if len(c.Backends) == 0 || slices.Contains(c.Backends, utils.BackendLlamaCpp) {
189+
s = s.Run(utils.Sh("apt-get install -y pciutils rocm && apt-get clean")).Root()
190+
}
199191

200-
diff := llb.Diff(savedState, s)
201-
return s, llb.Merge([]llb.State{merge, diff})
192+
// hipblaslt soname compatibility: backend may be linked against .so.0 while ROCm 7.2 ships .so.1
193+
s = s.Run(utils.Sh("set -e; cd /opt/rocm/lib; [ -e libhipblaslt.so.0 ] || ln -sf libhipblaslt.so.1 libhipblaslt.so.0")).Root()
194+
return s
195+
})
202196
}
203197

204198
// addLocalAI adds the LocalAI binary to the image.
@@ -218,20 +212,18 @@ func addLocalAI(c *config.InferenceConfig, s llb.State, merge llb.State, platfor
218212
return s, merge, fmt.Errorf("unsupported architecture %s", platform.Architecture)
219213
}
220214

221-
savedState := s
222-
223215
// Use the oras CLI image to pull the artifact containing the LocalAI binary
224216
tooling := llb.Image(orasImage, llb.Platform(platform)).Run(
225217
utils.Shf("set -e\noras pull %[1]s\nchmod +x local-ai\nchmod 755 local-ai", art.Ref),
226218
llb.WithCustomName("Pulling LocalAI from OCI artifact "+art.Ref),
227219
).Root()
228220

229221
// Copy the prepared binary into /usr/bin/local-ai
230-
s = s.File(
231-
llb.Copy(tooling, "local-ai", "/usr/bin/local-ai"),
232-
llb.WithCustomName("Copying local-ai from OCI artifact to /usr/bin"),
233-
)
234-
235-
diff := llb.Diff(savedState, s)
236-
return s, llb.Merge([]llb.State{merge, diff}), nil
222+
s, merge = applyAndMerge(s, merge, func(s llb.State) llb.State {
223+
return s.File(
224+
llb.Copy(tooling, "local-ai", "/usr/bin/local-ai"),
225+
llb.WithCustomName("Copying local-ai from OCI artifact to /usr/bin"),
226+
)
227+
})
228+
return s, merge, nil
237229
}

pkg/aikit2llb/inference/convert_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/kaito-project/aikit/pkg/aikit/config"
99
"github.com/kaito-project/aikit/pkg/utils"
1010
"github.com/moby/buildkit/client/llb"
11+
specs "github.com/opencontainers/image-spec/specs-go/v1"
1112
)
1213

1314
func TestInstallRocmInstallsPciutilsForLlamaCpp(t *testing.T) {
@@ -61,3 +62,28 @@ func marshalDefinitionToString(def *llb.Definition) string {
6162

6263
return combined.String()
6364
}
65+
66+
// TestCopyModelsAbsoluteLocalPath guards the scheme-dispatch fix: an absolute
67+
// local model path (no URI scheme) must be treated as a local file, not
68+
// rejected. The previous url.ParseRequestURI guard caused absolute paths to
69+
// fall through to a hard "unsupported URL scheme" error.
70+
func TestCopyModelsAbsoluteLocalPath(t *testing.T) {
71+
cfg := &config.InferenceConfig{
72+
Runtime: "",
73+
Models: []config.Model{
74+
{Name: "local", Source: "/models/local.gguf"},
75+
},
76+
}
77+
78+
platform := specs.Platform{OS: utils.PlatformLinux, Architecture: utils.PlatformAMD64}
79+
base := llb.Image(utils.UbuntuBase)
80+
state, merged, err := copyModels(cfg, base, base, platform)
81+
if err != nil {
82+
t.Fatalf("copyModels returned error for absolute local path: %v", err)
83+
}
84+
85+
if _, err := merged.Marshal(context.Background()); err != nil {
86+
t.Fatalf("marshal failed: %v", err)
87+
}
88+
_ = state
89+
}

0 commit comments

Comments
 (0)