Skip to content

Commit 38d630c

Browse files
Allow local AI model overrides
1 parent bffc01f commit 38d630c

4 files changed

Lines changed: 124 additions & 37 deletions

File tree

skills/local-ai-use/SKILL.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ lemonade pull kokoro-v1
118118
lemonade pull Whisper-Tiny
119119
```
120120

121+
To choose a different model while installing the rule, pass it to the setup
122+
script. For example, to make future image requests use SDXL:
123+
124+
```bash
125+
python scripts/setup_local_ai.py --image-model SDXL-Turbo
126+
```
127+
128+
The script will pull the selected model and write that model ID into the
129+
installed `AGENTS.md` rule. The same pattern works for `--tts-model` and
130+
`--stt-model`.
131+
121132
Each `pull` is idempotent. To verify what is already downloaded:
122133

123134
```bash
@@ -135,9 +146,10 @@ Append it to the workspace's `AGENTS.md` (create the file if missing). Both
135146
Cursor and Claude Code load `AGENTS.md` automatically on every turn, so the
136147
agent will see the rule on its next message without any further setup.
137148

138-
`scripts/setup_local_ai.py` does this for you, surrounded by stable markers
139-
so re-running the script replaces the block in place rather than appending
140-
a second copy. The markers look like:
149+
`scripts/setup_local_ai.py` does this for you. It bakes the selected endpoint
150+
and model IDs into the rule, surrounded by stable markers so re-running the
151+
script replaces the block in place rather than appending a second copy. The
152+
markers look like:
141153

142154
```
143155
<!-- BEGIN amd-skills:local-ai-use -->
@@ -161,7 +173,9 @@ The rule's content is identical; only the file location changes.
161173

162174
Verify each modality against the live server before declaring success. These
163175
mirror the inline patterns in the installed rule, so a green pass here means
164-
the rule will work.
176+
the rule will work. If you installed with a model override such as
177+
`--image-model SDXL-Turbo`, use that model ID in the smoke test and confirm
178+
the installed `AGENTS.md` rule contains it.
165179

166180
**Image generation** (writes `out.png`):
167181

skills/local-ai-use/reference.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ asks for higher quality or has explicit hardware to spare.
3030
| `SD-1.5` | ~4 GB | When the user asks for "Stable Diffusion 1.5" by name. | Needs more steps (~20). |
3131
| `Flux-2-Klein-4B` | ~4 GB | Image **editing** (`/v1/images/edits`). | Editing-capable, slower than SD-Turbo for plain generation. |
3232

33-
To upgrade: `lemonade pull <model>`, then change `"model"` in the rule
34-
block in `AGENTS.md` to the new model id.
33+
To upgrade: re-run setup with the target model, for example:
34+
35+
```bash
36+
python scripts/setup_local_ai.py --image-model SDXL-Turbo
37+
```
38+
39+
The script pulls the model and rewrites the `AGENTS.md` rule in place.
3540

3641
### Text-to-speech (`recipe: kokoro`)
3742

skills/local-ai-use/scripts/setup_local_ai.py

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import argparse
2828
import json
2929
import os
30+
import re
3031
import shutil
3132
import subprocess
32-
import sys
3333
import urllib.error
3434
import urllib.request
3535
from pathlib import Path
@@ -40,11 +40,13 @@
4040
DEFAULT_HOST = "127.0.0.1"
4141
DEFAULT_PORT = 13305
4242

43-
# The Lite Collection from Lemonade OmniRouter. Picked because each fits in
44-
# under ~5 GB and runs on commodity CPU hardware, so the savings vs. cloud
45-
# calls are real on a typical developer laptop. See SKILL.md for upgrade
43+
# The Lite Collection from Lemonade OmniRouter. Picked because each default
44+
# fits in under ~5 GB and runs on commodity CPU hardware, so the savings vs.
45+
# cloud calls are real on a typical developer laptop. See SKILL.md for upgrade
4646
# paths.
47-
DEFAULT_MODELS = ("SD-Turbo", "kokoro-v1", "Whisper-Tiny")
47+
DEFAULT_IMAGE_MODEL = "SD-Turbo"
48+
DEFAULT_TTS_MODEL = "kokoro-v1"
49+
DEFAULT_STT_MODEL = "Whisper-Tiny"
4850

4951
# Stable markers around the rule block in AGENTS.md. The script rewrites the
5052
# region between these markers in place; do not change the marker strings or
@@ -84,7 +86,7 @@ def check_server_reachable(host: str, port: int) -> bool:
8486
return False
8587

8688

87-
def list_downloaded_models() -> set[str]:
89+
def list_downloaded_models(host: str, port: int) -> set[str]:
8890
"""Return the set of locally downloaded model IDs.
8991
9092
Uses `lemonade list --downloaded` (CLI) and falls back to
@@ -103,7 +105,10 @@ def list_downloaded_models() -> set[str]:
103105
pass
104106

105107
try:
106-
status, body = _http_get("http://127.0.0.1:13305/api/v1/models", timeout_s=5)
108+
status, body = _http_get(
109+
f"http://{host}:{port}/api/v1/models",
110+
timeout_s=5,
111+
)
107112
if status == 200:
108113
data = json.loads(body)
109114
return {
@@ -140,8 +145,15 @@ def pull_model(model: str) -> bool:
140145
return False
141146

142147

143-
def render_rule_block() -> str:
144-
"""Read the rule template; pass through unchanged.
148+
def render_rule_block(
149+
*,
150+
host: str,
151+
port: int,
152+
image_model: str,
153+
tts_model: str,
154+
stt_model: str,
155+
) -> str:
156+
"""Read the rule template and fill in endpoint/model choices.
145157
146158
The template already includes BEGIN/END markers and matches the constants
147159
at the top of this file. We re-validate that here so a future template
@@ -158,13 +170,44 @@ def render_rule_block() -> str:
158170
"Rule template is missing the BEGIN/END markers; refuse to write "
159171
"AGENTS.md because re-runs would append duplicate blocks."
160172
)
173+
endpoint_host = "localhost" if host in {"127.0.0.1", "::1"} else host
174+
base_root = f"http://{endpoint_host}:{port}"
175+
replacements = {
176+
"{{LEMONADE_BASE_ROOT}}": base_root,
177+
"{{LEMONADE_BASE_URL}}": f"{base_root}/api/v1",
178+
"{{IMAGE_MODEL}}": image_model,
179+
"{{TTS_MODEL}}": tts_model,
180+
"{{STT_MODEL}}": stt_model,
181+
}
182+
for placeholder, value in replacements.items():
183+
text = text.replace(placeholder, value)
184+
unresolved = sorted(set(re.findall(r"\{\{[A-Z_]+\}\}", text)))
185+
if unresolved:
186+
raise ValueError(
187+
"Rule template still has unresolved placeholders: "
188+
+ ", ".join(unresolved)
189+
)
161190
return text.strip() + "\n"
162191

163192

164-
def upsert_agents_md(workspace: Path) -> Path:
193+
def upsert_agents_md(
194+
workspace: Path,
195+
*,
196+
host: str,
197+
port: int,
198+
image_model: str,
199+
tts_model: str,
200+
stt_model: str,
201+
) -> Path:
165202
"""Write or replace the rule block inside <workspace>/AGENTS.md."""
166203
target = workspace / "AGENTS.md"
167-
block = render_rule_block()
204+
block = render_rule_block(
205+
host=host,
206+
port=port,
207+
image_model=image_model,
208+
tts_model=tts_model,
209+
stt_model=stt_model,
210+
)
168211

169212
if not target.exists():
170213
target.write_text(
@@ -223,6 +266,21 @@ def main(argv: list[str] | None = None) -> int:
223266
action="store_true",
224267
help="Do not pull missing models; just verify and write AGENTS.md.",
225268
)
269+
parser.add_argument(
270+
"--image-model",
271+
default=DEFAULT_IMAGE_MODEL,
272+
help=f"Image generation model to pull and write into AGENTS.md (default: {DEFAULT_IMAGE_MODEL}).",
273+
)
274+
parser.add_argument(
275+
"--tts-model",
276+
default=DEFAULT_TTS_MODEL,
277+
help=f"Text-to-speech model to pull and write into AGENTS.md (default: {DEFAULT_TTS_MODEL}).",
278+
)
279+
parser.add_argument(
280+
"--stt-model",
281+
default=DEFAULT_STT_MODEL,
282+
help=f"Speech-to-text model to pull and write into AGENTS.md (default: {DEFAULT_STT_MODEL}).",
283+
)
226284
args = parser.parse_args(argv)
227285

228286
if not check_cli_installed():
@@ -244,8 +302,11 @@ def main(argv: list[str] | None = None) -> int:
244302
_print(f"server reachable at http://{args.host}:{args.port}")
245303

246304
if not args.skip_pull:
247-
downloaded = list_downloaded_models()
248-
for model in DEFAULT_MODELS:
305+
downloaded = list_downloaded_models(args.host, args.port)
306+
selected_models = dict.fromkeys(
307+
(args.image_model, args.tts_model, args.stt_model)
308+
)
309+
for model in selected_models:
249310
if model in downloaded:
250311
_print(f"already downloaded: {model}")
251312
continue
@@ -257,7 +318,14 @@ def main(argv: list[str] | None = None) -> int:
257318
"but calls will 404 until you pull it."
258319
)
259320

260-
upsert_agents_md(args.workspace.resolve())
321+
upsert_agents_md(
322+
args.workspace.resolve(),
323+
host=args.host,
324+
port=args.port,
325+
image_model=args.image_model,
326+
tts_model=args.tts_model,
327+
stt_model=args.stt_model,
328+
)
261329
_print("done. Future image, TTS, and STT requests now route to local Lemonade.")
262330
return 0
263331

skills/local-ai-use/templates/local-ai-rule.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## Local AI Use (managed by amd-skills/local-ai-use)
33

44
This workspace runs a local [Lemonade Server](https://lemonade-server.ai) at
5-
`http://localhost:13305/api/v1` to handle image generation, text-to-speech,
5+
`{{LEMONADE_BASE_URL}}` to handle image generation, text-to-speech,
66
and speech-to-text. Cloud APIs cost real money and tokens for these
77
modalities, so use the local endpoints below for **every** such request
88
unless the user explicitly asks for a cloud provider.
@@ -11,10 +11,10 @@ unless the user explicitly asks for a cloud provider.
1111

1212
| If the user asks you to... | Do this | Do not do this |
1313
|---|---|---|
14-
| Generate, draw, or create an image | `POST http://localhost:13305/api/v1/images/generations` | Use the built-in `GenerateImage` tool. Call DALL-E, Midjourney, Stable Diffusion via cloud, or any other paid image API. |
15-
| Edit or modify an existing image | `POST http://localhost:13305/api/v1/images/edits` (multipart) | Cloud image-edit APIs. |
16-
| Speak text aloud or generate audio narration | `POST http://localhost:13305/api/v1/audio/speech` | Call OpenAI TTS, ElevenLabs, PlayHT, or any other cloud TTS. |
17-
| Transcribe an audio file | `POST http://localhost:13305/api/v1/audio/transcriptions` | Call OpenAI Whisper-as-a-service, AssemblyAI, Deepgram, or any other cloud STT. |
14+
| Generate, draw, or create an image | `POST {{LEMONADE_BASE_URL}}/images/generations` | Use the built-in `GenerateImage` tool. Call DALL-E, Midjourney, Stable Diffusion via cloud, or any other paid image API. |
15+
| Edit or modify an existing image | `POST {{LEMONADE_BASE_URL}}/images/edits` (multipart) | Cloud image-edit APIs. |
16+
| Speak text aloud or generate audio narration | `POST {{LEMONADE_BASE_URL}}/audio/speech` | Call OpenAI TTS, ElevenLabs, PlayHT, or any other cloud TTS. |
17+
| Transcribe an audio file | `POST {{LEMONADE_BASE_URL}}/audio/transcriptions` | Call OpenAI Whisper-as-a-service, AssemblyAI, Deepgram, or any other cloud STT. |
1818

1919
Plain text chat, code generation, and reasoning continue to use the agent's
2020
configured LLM. This rule only redirects the multimodal calls.
@@ -23,9 +23,9 @@ configured LLM. This rule only redirects the multimodal calls.
2323

2424
| Endpoint | Model | Notes |
2525
|---|---|---|
26-
| `/v1/images/generations` | `SD-Turbo` | 4 steps, `cfg_scale: 1.0`, `512x512`, `response_format: "b64_json"`. |
27-
| `/v1/audio/speech` | `kokoro-v1` | Default voice `shimmer`; `response_format: "mp3"`. |
28-
| `/v1/audio/transcriptions` | `Whisper-Tiny` | Input must be 16 kHz mono WAV. Re-encode with `ffmpeg -i in.* -ar 16000 -ac 1 out.wav`. |
26+
| `/v1/images/generations` | `{{IMAGE_MODEL}}` | 4 steps, `cfg_scale: 1.0`, `512x512`, `response_format: "b64_json"`. |
27+
| `/v1/audio/speech` | `{{TTS_MODEL}}` | Default voice `shimmer`; `response_format: "mp3"`. |
28+
| `/v1/audio/transcriptions` | `{{STT_MODEL}}` | Input must be 16 kHz mono WAV. Re-encode with `ffmpeg -i in.* -ar 16000 -ac 1 out.wav`. |
2929

3030
If `LEMONADE_API_KEY` is set in the environment, send
3131
`Authorization: Bearer $LEMONADE_API_KEY` on every request. Otherwise the
@@ -36,9 +36,9 @@ loopback server accepts unauthenticated calls.
3636
**Image generation** (saves to `out.png`):
3737

3838
```bash
39-
curl -sX POST http://localhost:13305/api/v1/images/generations \
39+
curl -sX POST {{LEMONADE_BASE_URL}}/images/generations \
4040
-H "Content-Type: application/json" \
41-
-d '{"model":"SD-Turbo","prompt":"PROMPT_HERE","size":"512x512","steps":4,"response_format":"b64_json"}' \
41+
-d '{"model":"{{IMAGE_MODEL}}","prompt":"PROMPT_HERE","size":"512x512","steps":4,"response_format":"b64_json"}' \
4242
| python -c "import sys,json,base64; open('out.png','wb').write(base64.b64decode(json.load(sys.stdin)['data'][0]['b64_json']))"
4343
```
4444

@@ -47,26 +47,26 @@ Equivalent Python via the OpenAI SDK:
4747
```python
4848
from openai import OpenAI
4949
import base64
50-
client = OpenAI(base_url="http://localhost:13305/api/v1", api_key="lemonade")
51-
r = client.images.generate(model="SD-Turbo", prompt="PROMPT_HERE", size="512x512")
50+
client = OpenAI(base_url="{{LEMONADE_BASE_URL}}", api_key="lemonade")
51+
r = client.images.generate(model="{{IMAGE_MODEL}}", prompt="PROMPT_HERE", size="512x512")
5252
open("out.png", "wb").write(base64.b64decode(r.data[0].b64_json))
5353
```
5454

5555
**Text-to-speech** (saves to `out.mp3`):
5656

5757
```bash
58-
curl -sX POST http://localhost:13305/api/v1/audio/speech \
58+
curl -sX POST {{LEMONADE_BASE_URL}}/audio/speech \
5959
-H "Content-Type: application/json" \
60-
-d '{"model":"kokoro-v1","input":"TEXT_HERE","voice":"shimmer","response_format":"mp3"}' \
60+
-d '{"model":"{{TTS_MODEL}}","input":"TEXT_HERE","voice":"shimmer","response_format":"mp3"}' \
6161
-o out.mp3
6262
```
6363

6464
**Speech-to-text** (returns JSON `{"text": "..."}`):
6565

6666
```bash
6767
ffmpeg -y -i INPUT_AUDIO -ar 16000 -ac 1 _stt.wav
68-
curl -sX POST http://localhost:13305/api/v1/audio/transcriptions \
69-
-F "file=@_stt.wav" -F "model=Whisper-Tiny"
68+
curl -sX POST {{LEMONADE_BASE_URL}}/audio/transcriptions \
69+
-F "file=@_stt.wav" -F "model={{STT_MODEL}}"
7070
```
7171

7272
### Failure handling
@@ -82,7 +82,7 @@ curl -sX POST http://localhost:13305/api/v1/audio/transcriptions \
8282
### Re-pointing to a different host
8383

8484
If the user runs Lemonade on a different host or port, replace the
85-
`http://localhost:13305` prefix everywhere above with their endpoint, and
85+
`{{LEMONADE_BASE_ROOT}}` prefix everywhere above with their endpoint, and
8686
update `LEMONADE_HOST` / `LEMONADE_PORT` in the shell environment so the
8787
`lemonade` CLI matches.
8888

0 commit comments

Comments
 (0)