Skip to content

Commit 53a04ad

Browse files
author
Shaw
committed
inference
1 parent 5c5eb75 commit 53a04ad

66 files changed

Lines changed: 2039 additions & 602 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

packages/app-core/scripts/build-llama-cpp-dflash.mjs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,14 +1078,13 @@ function ensureCheckout(cacheDir, ref) {
10781078
//
10791079
// What still doesn't fully ship at v0.4.0-milady (deferred dispatch wiring):
10801080
//
1081-
// * ggml-metal-ops.cpp / ggml-metal-device.m have NO dispatch sites for
1082-
// the milady quant types (TBQ3_0, TBQ4_0, TBQ3_TCQ, QJL1_256, Q4_POLAR).
1083-
// CUDA has them; Metal does not. After this patch the kernel symbols
1084-
// (kernel_turbo3_dot, kernel_attn_score_qjl1_256, kernel_mul_mv_q4_polar_f32,
1085-
// etc.) are present in default.metallib and `nm`/`strings` will see
1086-
// them, but the runtime cannot yet select them via GGML_TYPE_*. That
1087-
// wiring is a separate fork-internals patch and is the next agent's
1088-
// mission.
1081+
// * ggml-metal-ops.cpp / ggml-metal-device.m have a dedicated, smoke-tested
1082+
// dispatch site only for GGML_OP_ATTN_SCORE_QJL -> QJL1_256 attention
1083+
// scoring, now routed through kernel_attn_score_qjl1_256_multi to
1084+
// amortize launch overhead. TBQ3_0, TBQ4_0, TBQ3_TCQ, and Q4_POLAR still
1085+
// ship only as symbols in default.metallib. CUDA has those runtime routes;
1086+
// Metal does not yet. That wiring is a separate fork-internals patch and
1087+
// remains publish-blocking.
10891088
//
10901089
// * The EMBED_LIBRARY=ON branch (used by iOS targets) is also patched:
10911090
// it compiles ggml-metal.metal + the milady standalones as separate

packages/app-core/scripts/kernel-patches/metal-kernels.mjs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,12 @@ function patchMetalQjlAttnDeviceCpp(cacheDir, { dryRun }) {
389389
const cppPath = path.join(cacheDir, "ggml", "src", "ggml-metal", "ggml-metal-device.cpp");
390390
const original = fs.readFileSync(cppPath, "utf8");
391391
if (original.includes(SENTINEL_QJL_ATTN)) {
392-
return { changed: false, path: cppPath };
392+
const upgraded = original.replace(
393+
'const char * name = "kernel_attn_score_qjl1_256";',
394+
'const char * name = "kernel_attn_score_qjl1_256_multi";',
395+
);
396+
if (upgraded !== original && !dryRun) fs.writeFileSync(cppPath, upgraded, "utf8");
397+
return { changed: upgraded !== original && !dryRun, path: cppPath };
393398
}
394399
const anchor = `ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {`;
395400
if (!original.includes(anchor)) {
@@ -447,7 +452,37 @@ function patchMetalQjlAttnOpsCpp(cacheDir, { dryRun }) {
447452
const opsPath = path.join(cacheDir, "ggml", "src", "ggml-metal", "ggml-metal-ops.cpp");
448453
const original = fs.readFileSync(opsPath, "utf8");
449454
if (original.includes(SENTINEL_QJL_ATTN)) {
450-
return { changed: false, path: opsPath };
455+
let upgraded = original.replace(
456+
`struct milady_qjl_score_args {
457+
uint32_t n_heads;
458+
uint32_t n_kv_heads;
459+
uint32_t n_tokens;
460+
uint32_t proj_dim;
461+
};`,
462+
`struct milady_qjl_score_args {
463+
uint32_t n_heads;
464+
uint32_t n_kv_heads;
465+
uint32_t n_tokens;
466+
uint32_t proj_dim;
467+
uint32_t tokens_per_threadgroup;
468+
};`,
469+
);
470+
upgraded = upgraded.replace(
471+
` /* n_tokens = */ n_tokens,
472+
/* proj_dim = */ 256u,
473+
};`,
474+
` /* n_tokens = */ n_tokens,
475+
/* proj_dim = */ 256u,
476+
/* tokens_per_threadgroup = */ 32u,
477+
};`,
478+
);
479+
upgraded = upgraded.replace(
480+
` ggml_metal_encoder_dispatch_threadgroups(enc, (int) n_heads, (int) n_tokens, 1, 32, 1, 1);`,
481+
` const int token_groups = (int) ((n_tokens + args.tokens_per_threadgroup - 1u) / args.tokens_per_threadgroup);
482+
ggml_metal_encoder_dispatch_threadgroups(enc, (int) n_heads, token_groups, 1, 32, 1, 1);`,
483+
);
484+
if (upgraded !== original && !dryRun) fs.writeFileSync(opsPath, upgraded, "utf8");
485+
return { changed: upgraded !== original && !dryRun, path: opsPath };
451486
}
452487

453488
const funcAnchor = `static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {`;

packages/app-core/src/runtime/ensure-local-inference-handler.ts

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ import {
2626
type IAgentRuntime,
2727
logger,
2828
ModelType,
29+
type TextToSpeechParams,
2930
type TextEmbeddingParams,
31+
type TranscriptionParams,
3032
} from "@elizaos/core";
3133
import {
3234
type LocalInferenceLoader,
@@ -47,6 +49,10 @@ import { handlerRegistry } from "../services/local-inference/handler-registry";
4749
import { listInstalledModels } from "../services/local-inference/registry";
4850
import { installRouterHandler } from "../services/local-inference/router-handler";
4951
import type { AgentModelSlot } from "../services/local-inference/types";
52+
import {
53+
decodeMonoPcm16Wav,
54+
type TranscriptionAudio,
55+
} from "../services/local-inference/voice";
5056
import { getRuntimeMode } from "./mode/runtime-mode";
5157

5258
type GenerateTextHandler = (
@@ -64,13 +70,36 @@ type EmbeddingHandler = (
6470
params: TextEmbeddingParams | string | null,
6571
) => Promise<number[]>;
6672

73+
type TextToSpeechHandler = (
74+
runtime: IAgentRuntime,
75+
params: TextToSpeechParams | string,
76+
) => Promise<Uint8Array>;
77+
78+
type TranscriptionHandler = (
79+
runtime: IAgentRuntime,
80+
params: TranscriptionParams | Buffer | string | LocalTranscriptionParams,
81+
) => Promise<string>;
82+
83+
interface LocalTranscriptionParams {
84+
pcm?: Float32Array;
85+
audio?: Uint8Array | ArrayBuffer | Buffer;
86+
sampleRateHz?: number;
87+
sampleRate?: number;
88+
}
89+
90+
type LocalModelHandler =
91+
| GenerateTextHandler
92+
| EmbeddingHandler
93+
| TextToSpeechHandler
94+
| TranscriptionHandler;
95+
6796
type RuntimeWithModelRegistration = AgentRuntime & {
6897
getModel: (
6998
modelType: string | number,
70-
) => GenerateTextHandler | EmbeddingHandler | undefined;
99+
) => LocalModelHandler | undefined;
71100
registerModel: (
72101
modelType: string | number,
73-
handler: GenerateTextHandler | EmbeddingHandler,
102+
handler: LocalModelHandler,
74103
provider: string,
75104
priority?: number,
76105
) => void;
@@ -295,6 +324,85 @@ function makeEmbeddingHandler(): EmbeddingHandler {
295324
};
296325
}
297326

327+
function extractSpeechText(params: TextToSpeechParams | string): string {
328+
if (typeof params === "string") return params;
329+
if (params && typeof params.text === "string") return params.text;
330+
throw new Error(
331+
"[local-inference] TEXT_TO_SPEECH requires a string or { text } input",
332+
);
333+
}
334+
335+
function makeTextToSpeechHandler(): TextToSpeechHandler {
336+
return async (_runtime, params) => {
337+
const text = extractSpeechText(params);
338+
if (text.length === 0) {
339+
throw new Error("[local-inference] TEXT_TO_SPEECH text must be non-empty");
340+
}
341+
// Do not filter singing, emotion tags, or lyrical phrasing here. The
342+
// local voice bundle advertises its expressive capability in the
343+
// manifest; runtime safety policy lives above this model adapter.
344+
return localInferenceEngine.synthesizeSpeech(text);
345+
};
346+
}
347+
348+
function toUint8Array(value: Uint8Array | ArrayBuffer | Buffer): Uint8Array {
349+
if (value instanceof Uint8Array) {
350+
return new Uint8Array(value.buffer, value.byteOffset, value.byteLength);
351+
}
352+
return new Uint8Array(value);
353+
}
354+
355+
function extractTranscriptionAudio(
356+
params: TranscriptionParams | Buffer | string | LocalTranscriptionParams,
357+
): TranscriptionAudio {
358+
if (typeof params === "string") {
359+
throw new Error(
360+
"[local-inference] TRANSCRIPTION via the local voice runtime requires PCM/WAV bytes; URL/path strings are not fetched by this provider",
361+
);
362+
}
363+
if (params instanceof Uint8Array || params instanceof ArrayBuffer) {
364+
return decodeMonoPcm16Wav(toUint8Array(params));
365+
}
366+
if (!params || typeof params !== "object") {
367+
throw new Error(
368+
"[local-inference] TRANSCRIPTION requires PCM/WAV bytes or { pcm, sampleRateHz }",
369+
);
370+
}
371+
if ("audioUrl" in params && typeof params.audioUrl === "string") {
372+
throw new Error(
373+
"[local-inference] TRANSCRIPTION audioUrl is not fetched by the local voice runtime; pass mono PCM16 WAV bytes or { pcm, sampleRateHz }",
374+
);
375+
}
376+
if ("pcm" in params && params.pcm instanceof Float32Array) {
377+
const sampleRate =
378+
("sampleRateHz" in params ? params.sampleRateHz : undefined) ??
379+
("sampleRate" in params ? params.sampleRate : undefined);
380+
if (typeof sampleRate !== "number" || sampleRate <= 0) {
381+
throw new Error(
382+
"[local-inference] TRANSCRIPTION { pcm } requires a positive sampleRateHz",
383+
);
384+
}
385+
return { pcm: params.pcm, sampleRate };
386+
}
387+
if (
388+
"audio" in params &&
389+
(params.audio instanceof Uint8Array ||
390+
params.audio instanceof ArrayBuffer)
391+
) {
392+
return decodeMonoPcm16Wav(toUint8Array(params.audio));
393+
}
394+
throw new Error(
395+
"[local-inference] TRANSCRIPTION requires mono PCM16 WAV bytes or { pcm, sampleRateHz } for the local voice runtime",
396+
);
397+
}
398+
399+
function makeTranscriptionHandler(): TranscriptionHandler {
400+
return async (_runtime, params) => {
401+
const audio = extractTranscriptionAudio(params);
402+
return localInferenceEngine.transcribePcm(audio);
403+
};
404+
}
405+
298406
/**
299407
* Register the device-bridge loader on the runtime. Accepts load/generate
300408
* calls whether or not a mobile device is currently connected — parked
@@ -535,6 +643,29 @@ export async function ensureLocalInferenceHandler(
535643
}
536644
}
537645

646+
try {
647+
runtimeWithRegistration.registerModel(
648+
ModelType.TEXT_TO_SPEECH,
649+
makeTextToSpeechHandler(),
650+
provider,
651+
LOCAL_INFERENCE_PRIORITY,
652+
);
653+
runtimeWithRegistration.registerModel(
654+
ModelType.TRANSCRIPTION,
655+
makeTranscriptionHandler(),
656+
provider,
657+
LOCAL_INFERENCE_PRIORITY,
658+
);
659+
logger.info(
660+
`[local-inference] Registered ${provider} voice handlers for TEXT_TO_SPEECH / TRANSCRIPTION at priority ${LOCAL_INFERENCE_PRIORITY}`,
661+
);
662+
} catch (err) {
663+
logger.warn(
664+
"[local-inference] Could not register local voice handlers",
665+
err instanceof Error ? err.message : String(err),
666+
);
667+
}
668+
538669
logger.info(
539670
`[local-inference] Registered ${provider} llama.cpp handler for TEXT_SMALL / TEXT_LARGE at priority ${LOCAL_INFERENCE_PRIORITY}`,
540671
);

packages/app-core/src/services/local-inference/catalog.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ describe("local inference catalog", () => {
8686
});
8787

8888
it("sets contextLength on every Eliza-1 tier per the tier matrix", () => {
89-
// Per packages/inference/AGENTS.md §2: lite/mobile = 32k, desktop =
90-
// 64k, pro = 128k, server = 256k. The catalog records the largest
89+
// Size tiers: 0.6B/1.7B = 32k, 9B = 64k, 27B = 128k,
90+
// 27B-256k = 256k. The catalog records the largest
9191
// ctx the bundle's manifest will advertise for each tier.
9292
const expected: Record<string, number> = {
9393
"eliza-1-0_6b": 32768,

0 commit comments

Comments
 (0)