Skip to content

Commit 534dd3d

Browse files
keithmattixCopilotsozercan
authored
Add support for ROCm / Strix Halo (#771)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Sertac Ozercan <sozercan@gmail.com> Co-authored-by: Sertaç Özercan <852750+sozercan@users.noreply.github.com>
1 parent b497e3a commit 534dd3d

22 files changed

Lines changed: 525 additions & 27 deletions

CONTRIBUTING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ This will automatically run linting and formatting checks before each commit.
5757

5858
## Building AIKit
5959

60+
> [!TIP]
61+
> Build targets default to multi-platform (`linux/amd64,linux/arm64`). For local development, pass your host architecture to speed up builds and avoid multi-platform issues — e.g. `make build-aikit PLATFORMS=linux/amd64`. You should also use the `default` buildx builder (`docker buildx use default`) so that locally built images are available to subsequent builds via the `#syntax=` directive.
62+
6063
### Build the AIKit Binary
6164

6265
```bash

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ run-test-model:
5151
run-test-model-gpu:
5252
docker run --rm -p 8080:8080 --gpus all ${REGISTRY}${REPOSITORY}/${TEST_IMAGE_NAME}:${TAG}
5353

54+
.PHONY: run-test-model-rocm
55+
run-test-model-rocm:
56+
docker run --rm -p 8080:8080 --device /dev/kfd --device /dev/dri --group-add video --group-add $$(stat -c '%g' /dev/dri/renderD128) \
57+
${REGISTRY}${REPOSITORY}/${TEST_IMAGE_NAME}:${TAG}
58+
5459
.PHONY: run-test-model-applesilicon
5560
run-test-model-applesilicon:
5661
podman run --rm -p 8080:8080 --device /dev/dri ${REGISTRY}${REPOSITORY}/${TEST_IMAGE_NAME}:${TAG}

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ AIKit offers three main capabilities:
2929
- 🦙 Support for GGUF ([`llama`](https://github.com/ggerganov/llama.cpp)) and GGML ([`llama-ggml`](https://github.com/ggerganov/llama.cpp)) models
3030
- 🚢 [Kubernetes deployment ready](https://kaito-project.github.io/aikit/docs/kubernetes)
3131
- 📚 Supports multiple models with a single image
32-
- 🖥️ Supports [AMD64 and ARM64](https://kaito-project.github.io/aikit/docs/create-images#multi-platform-support) CPUs and [GPU-accelerated inferencing with NVIDIA GPUs](https://kaito-project.github.io/aikit/docs/gpu)
32+
- 🖥️ Supports [AMD64 and ARM64](https://kaito-project.github.io/aikit/docs/create-images#multi-platform-support) CPUs and [GPU-accelerated inferencing with NVIDIA CUDA and AMD ROCm support](https://kaito-project.github.io/aikit/docs/gpu)
3333
- 🔐 Ensure [supply chain security](https://kaito-project.github.io/aikit/docs/security) with SBOMs, Provenance attestations, and signed images
3434
- 🌈 Supports air-gapped environments with self-hosted, local, or any remote container registries to store model images for inference on the edge.
3535

@@ -107,9 +107,9 @@ If it doesn't include a specific model, you can always [create your own images](
107107
### NVIDIA CUDA
108108

109109
> [!NOTE]
110-
> To enable GPU acceleration, please see [GPU Acceleration](https://kaito-project.github.io/aikit/docs/gpu).
110+
> To enable NVIDIA GPU acceleration, please see [GPU Acceleration](https://kaito-project.github.io/aikit/docs/gpu).
111111
>
112-
> Please note that only difference between CPU and GPU section is the `--gpus all` flag in the command to enable GPU acceleration.
112+
> Published pre-made GPU images include NVIDIA CUDA libraries. For the NVIDIA CUDA commands below, the only difference from the CPU section is the `--gpus all` flag.
113113
114114
| Model | Optimization | Parameters | Command | Model Name | License |
115115
| --------------- | ------------- | ---------- | -------------------------------------------------------------------------------------- | ------------------------ | --------------------------------------------------------------------------------------------------------------------------- |
@@ -127,6 +127,14 @@ If it doesn't include a specific model, you can always [create your own images](
127127
| 🤖 GPT-OSS | | 120B | `docker run -d --rm --gpus all -p 8080:8080 ghcr.io/kaito-project/aikit/gpt-oss:120b` | `gpt-oss-120b` | [Apache 2.0](https://choosealicense.com/licenses/apache-2.0/) |
128128

129129

130+
### AMD ROCm (experimental)
131+
132+
> [!NOTE]
133+
> AMD GPU acceleration is currently available for custom `llama-cpp` images built with `runtime: rocm`. Published pre-made model images are currently CUDA-based, so for AMD GPUs please [create your own image](https://kaito-project.github.io/aikit/docs/create-images) and follow the ROCm instructions in [GPU Acceleration](https://kaito-project.github.io/aikit/docs/gpu).
134+
>
135+
> ROCm support currently applies to the `llama-cpp` backend on `linux/amd64`.
136+
137+
130138
### Apple Silicon (experimental)
131139

132140
> [!NOTE]

pkg/aikit2llb/inference/backend.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func getBackendTag(backend, runtime string, platform specs.Platform) string {
8181
baseTag := getBackendVersion(backend, runtime, platform)
8282
backendName := getEffectiveBackend(backend, runtime, platform)
8383

84-
// Handle Apple Silicon - use Vulkan llama-cpp
84+
// Handle Apple Silicon - use Vulkan llama-cpp.
8585
if runtime == utils.RuntimeAppleSilicon {
8686
return fmt.Sprintf("%s-%s", baseTag, vulkanLlamaCppBackend)
8787
}
@@ -101,6 +101,12 @@ func getBackendTag(backend, runtime string, platform specs.Platform) string {
101101
}
102102
}
103103

104+
// Handle ROCm runtime.
105+
if runtime == utils.RuntimeROCm && platform.Architecture == utils.PlatformAMD64 {
106+
return fmt.Sprintf("%s-gpu-rocm-hipblas-llama-cpp", localAIROCmBackendVersion)
107+
}
108+
109+
// Handle CPU runtime (default).
104110
return fmt.Sprintf("%s-cpu-llama-cpp", baseTag)
105111
}
106112

@@ -131,6 +137,12 @@ func getBackendName(backend, runtime string, platform specs.Platform) string {
131137
}
132138
}
133139

140+
// Handle ROCm runtime
141+
if runtime == utils.RuntimeROCm && platform.Architecture == utils.PlatformAMD64 {
142+
// Only llama-cpp backend is supported for ROCm
143+
return "hipblas-llama-cpp"
144+
}
145+
134146
// Handle CPU runtime (default)
135147
return cpuLlamaCppBackend
136148
}
@@ -220,6 +232,14 @@ func installBackends(c *config.InferenceConfig, platform specs.Platform, s llb.S
220232
cpuConfig.Runtime = "cpu" // Use CPU runtime to force CPU backend installation
221233
merge = installBackend(backend, &cpuConfig, platform, s, merge)
222234
}
235+
236+
// For llama-cpp backend with ROCm runtime, also install the CPU version for fallback
237+
if backend == utils.BackendLlamaCpp && c.Runtime == utils.RuntimeROCm && platform.Architecture == utils.PlatformAMD64 {
238+
// Create a modified config with CPU runtime to install the CPU version
239+
cpuConfig := *c
240+
cpuConfig.Runtime = "cpu" // Use CPU runtime to force CPU backend installation
241+
merge = installBackend(backend, &cpuConfig, platform, s, merge)
242+
}
223243
}
224244

225245
return merge

pkg/aikit2llb/inference/backend_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ func TestGetBackendTag(t *testing.T) {
9898
},
9999
want: fmt.Sprintf("%s-gpu-nvidia-cuda-12-llama-cpp", localAILlamaCppBackendVersion),
100100
},
101+
{
102+
name: "ROCm llama-cpp",
103+
backend: utils.BackendLlamaCpp,
104+
runtime: utils.RuntimeROCm,
105+
platform: specs.Platform{
106+
Architecture: utils.PlatformAMD64,
107+
},
108+
want: fmt.Sprintf("%s-gpu-rocm-hipblas-llama-cpp", localAIROCmBackendVersion),
109+
},
101110
{
102111
name: "Empty backend name defaults to CPU llama-cpp",
103112
backend: "",

pkg/aikit2llb/inference/convert.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@ const (
1717
localAIBinaryVersion = "v4.0.0"
1818
localAILlamaCppBackendVersion = localAIBinaryVersion
1919
localAILegacyBackendVersion = "v3.12.1"
20+
localAIROCmBackendVersion = "rocm7"
2021
localAIRepo = "ghcr.io/kaito-project/aikit/localai:"
2122
cudaVersion = "12-5"
23+
rocmVersion = "7.2"
2224
)
2325

2426
// Aikit2LLB converts an InferenceConfig to an LLB state.
2527
func Aikit2LLB(c *config.InferenceConfig, platform *specs.Platform) (llb.State, *specs.Image, error) {
2628
var merge, state llb.State
27-
if c.Runtime == utils.RuntimeAppleSilicon {
29+
switch c.Runtime {
30+
case utils.RuntimeAppleSilicon:
2831
state = llb.Image(utils.AppleSiliconBase, llb.Platform(*platform))
29-
} else {
32+
case utils.RuntimeROCm:
33+
// Use Ubuntu 24.04 for ROCm to match noble repository
34+
state = llb.Image(utils.Ubuntu24Base, llb.Platform(*platform))
35+
default:
3036
state = llb.Image(utils.UbuntuBase, llb.Platform(*platform))
3137
}
3238
base := getBaseImage(c, platform)
@@ -55,6 +61,11 @@ func Aikit2LLB(c *config.InferenceConfig, platform *specs.Platform) (llb.State,
5561
state, merge = installCuda(c, state, merge)
5662
}
5763

64+
// install rocm if runtime is rocm and architecture is amd64
65+
if c.Runtime == utils.RuntimeROCm && platform.Architecture == utils.PlatformAMD64 {
66+
state, merge = installRocm(c, state, merge)
67+
}
68+
5869
// install backend dependencies
5970
merge = installBackends(c, *platform, state, merge)
6071

@@ -67,6 +78,10 @@ func getBaseImage(c *config.InferenceConfig, platform *specs.Platform) llb.State
6778
if c.Runtime == utils.RuntimeAppleSilicon {
6879
return llb.Image(utils.AppleSiliconBase, llb.Platform(*platform))
6980
}
81+
if c.Runtime == utils.RuntimeROCm {
82+
// Use Ubuntu 24.04 for ROCm to match noble repository.
83+
return llb.Image(utils.Ubuntu24Base, llb.Platform(*platform))
84+
}
7085
if len(c.Backends) > 0 {
7186
return llb.Image(utils.UbuntuBase, llb.Platform(*platform))
7287
}
@@ -155,6 +170,37 @@ func installCuda(c *config.InferenceConfig, s llb.State, merge llb.State) (llb.S
155170
return s, llb.Merge([]llb.State{merge, diff})
156171
}
157172

173+
func installRocm(c *config.InferenceConfig, s llb.State, merge llb.State) (llb.State, llb.State) {
174+
savedState := s
175+
176+
// Set up ROCm repository
177+
s = s.Run(utils.Sh("apt-get update && apt-get install --no-install-recommends -y ca-certificates curl gnupg"), llb.IgnoreCache).Root()
178+
179+
// Add ROCm GPG key and repository
180+
s = s.Run(utils.Sh("curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm.gpg")).Root()
181+
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/rocm/apt/%s/ noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
182+
s = s.Run(utils.Shf("echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm.gpg] https://repo.radeon.com/graphics/%s/ubuntu noble main' >> /etc/apt/sources.list.d/rocm.list", rocmVersion)).Root()
183+
rocmPinning := `
184+
Package: *
185+
Pin: release o=repo.radeon.com
186+
Pin-Priority: 600
187+
`
188+
s = s.Run(utils.Shf("echo '%s' > /etc/apt/preferences.d/repo-radeon-pin-600", rocmPinning)).Root()
189+
s = s.Run(utils.Sh("apt-get update"), llb.IgnoreCache).Root()
190+
191+
// install rocm libraries and pciutils for gpu detection when using the default
192+
// llama-cpp backend or when it is configured explicitly
193+
if len(c.Backends) == 0 || slices.Contains(c.Backends, utils.BackendLlamaCpp) {
194+
s = s.Run(utils.Sh("apt-get install -y pciutils rocm && apt-get clean")).Root()
195+
}
196+
197+
// hipblaslt soname compatibility: backend may be linked against .so.0 while ROCm 7.2 ships .so.1
198+
s = s.Run(utils.Sh("set -e; cd /opt/rocm/lib; [ -e libhipblaslt.so.0 ] || ln -sf libhipblaslt.so.1 libhipblaslt.so.0")).Root()
199+
200+
diff := llb.Diff(savedState, s)
201+
return s, llb.Merge([]llb.State{merge, diff})
202+
}
203+
158204
// addLocalAI adds the LocalAI binary to the image.
159205
func addLocalAI(c *config.InferenceConfig, s llb.State, merge llb.State, platform specs.Platform) (llb.State, llb.State, error) {
160206
artifactVersion := getLocalAIArtifactVersion(c, platform)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package inference
2+
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
8+
"github.com/kaito-project/aikit/pkg/aikit/config"
9+
"github.com/kaito-project/aikit/pkg/utils"
10+
"github.com/moby/buildkit/client/llb"
11+
)
12+
13+
func TestInstallRocmInstallsPciutilsForLlamaCpp(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
backends []string
17+
}{
18+
{
19+
name: "implicit default llama-cpp backend",
20+
backends: nil,
21+
},
22+
{
23+
name: "explicit llama-cpp backend",
24+
backends: []string{utils.BackendLlamaCpp},
25+
},
26+
}
27+
28+
for _, tt := range tests {
29+
t.Run(tt.name, func(t *testing.T) {
30+
cfg := &config.InferenceConfig{
31+
Runtime: utils.RuntimeROCm,
32+
Backends: tt.backends,
33+
}
34+
35+
base := llb.Image(utils.Ubuntu24Base)
36+
_, merged := installRocm(cfg, base, base)
37+
38+
def, err := merged.Marshal(context.Background())
39+
if err != nil {
40+
t.Fatalf("marshal failed: %v", err)
41+
}
42+
43+
combined := marshalDefinitionToString(def)
44+
wantInstall := "apt-get install -y pciutils rocm && apt-get clean"
45+
if !strings.Contains(combined, wantInstall) {
46+
t.Fatalf("expected ROCm install to contain %q, got: %s", wantInstall, combined)
47+
}
48+
})
49+
}
50+
}
51+
52+
func marshalDefinitionToString(def *llb.Definition) string {
53+
if def == nil {
54+
return ""
55+
}
56+
57+
var combined strings.Builder
58+
for _, d := range def.ToPB().Def {
59+
combined.Write(d)
60+
}
61+
62+
return combined.String()
63+
}

pkg/aikit2llb/inference/image.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,14 @@ func emptyImage(c *config.InferenceConfig, platform *specs.Platform) *specs.Imag
7777
)
7878
}
7979

80+
rocmEnv := []string{
81+
"PATH=" + system.DefaultPathEnv(utils.PlatformLinux) + ":/opt/rocm/bin",
82+
"LD_LIBRARY_PATH=/opt/rocm/lib:/opt/rocm/lib64:/opt/rocm/llvm/lib",
83+
"LOCALAI_FORCE_META_BACKEND_CAPABILITY=amd",
84+
}
85+
if c.Runtime == utils.RuntimeROCm && platform.Architecture == "amd64" {
86+
img.Config.Env = append(img.Config.Env, rocmEnv...)
87+
}
88+
8089
return img
8190
}

pkg/build/build.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,14 +491,22 @@ func validateInferenceConfig(c *config.InferenceConfig) error {
491491
return errors.New("runner mode (backends without models) is not supported on apple silicon runtime")
492492
}
493493

494+
if c.Runtime == utils.RuntimeROCm && len(c.Backends) > 0 {
495+
for _, backend := range c.Backends {
496+
if backend != utils.BackendLlamaCpp {
497+
return errors.New("rocm runtime only supports llama-cpp backend")
498+
}
499+
}
500+
}
501+
494502
backends := []string{utils.BackendLlamaCpp, utils.BackendDiffusers, utils.BackendVLLM}
495503
for _, b := range c.Backends {
496504
if !slices.Contains(backends, b) {
497505
return errors.Errorf("backend %s is not supported", b)
498506
}
499507
}
500508

501-
runtimes := []string{"", utils.RuntimeNVIDIA, utils.RuntimeAppleSilicon}
509+
runtimes := []string{"", utils.RuntimeNVIDIA, utils.RuntimeROCm, utils.RuntimeAppleSilicon}
502510
if !slices.Contains(runtimes, c.Runtime) {
503511
return errors.Errorf("runtime %s is not supported", c.Runtime)
504512
}
@@ -517,6 +525,11 @@ func validateBackendPlatformCompatibility(c *config.InferenceConfig, targetPlatf
517525
}
518526
}
519527

528+
// ROCm runtime only supports amd64.
529+
if c.Runtime == utils.RuntimeROCm && hasARM64Platform {
530+
return errors.New("rocm runtime is only supported on linux/amd64 platform")
531+
}
532+
520533
// If we have ARM64 platforms, validate backend compatibility
521534
if hasARM64Platform {
522535
for _, backend := range c.Backends {

pkg/build/build_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,43 @@ func Test_validateBackendPlatformCompatibility(t *testing.T) {
241241
},
242242
wantErr: true,
243243
},
244+
{
245+
name: "rocm runtime with amd64 platform - should pass",
246+
config: &config.InferenceConfig{
247+
APIVersion: "v1alpha1",
248+
Runtime: "rocm",
249+
Backends: []string{"llama-cpp"},
250+
},
251+
targetPlatforms: []*specs.Platform{
252+
{Architecture: "amd64", OS: "linux"},
253+
},
254+
wantErr: false,
255+
},
256+
{
257+
name: "rocm runtime with arm64 platform - should fail",
258+
config: &config.InferenceConfig{
259+
APIVersion: "v1alpha1",
260+
Runtime: "rocm",
261+
Backends: []string{"llama-cpp"},
262+
},
263+
targetPlatforms: []*specs.Platform{
264+
{Architecture: "arm64", OS: "linux"},
265+
},
266+
wantErr: true,
267+
},
268+
{
269+
name: "rocm runtime with mixed platforms - should fail",
270+
config: &config.InferenceConfig{
271+
APIVersion: "v1alpha1",
272+
Runtime: "rocm",
273+
Backends: []string{"llama-cpp"},
274+
},
275+
targetPlatforms: []*specs.Platform{
276+
{Architecture: "amd64", OS: "linux"},
277+
{Architecture: "arm64", OS: "linux"},
278+
},
279+
wantErr: true,
280+
},
244281
}
245282
for _, tt := range tests {
246283
t.Run(tt.name, func(t *testing.T) {

0 commit comments

Comments
 (0)