Skip to content

Commit 3c6f7a5

Browse files
committed
chore: adopt canonical model type names in SDK docs, examples, and tests
Signed-off-by: Arun Mani J <j.arunmani@proton.me>
1 parent 9a46100 commit 3c6f7a5

12 files changed

Lines changed: 32 additions & 31 deletions

packages/sdk/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ try {
4646
// Load a model into memory
4747
const modelId = await loadModel({
4848
modelSrc: LLAMA_3_2_1B_INST_Q4_0,
49-
modelType: "llm",
5049
onProgress: (progress) => {
5150
console.log(progress);
5251
},

packages/sdk/client/api/load-model.ts

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ export function loadModel<S extends ModelDescriptor>(
7575
* @overloadLabel "Load new model"
7676
* @param options - An object that defines all configuration parameters required for loading the model, including:
7777
* - modelSrc: The location from which the model weights are fetched (local path, remote URL, or Hyperdrive URL)
78-
* - modelType: The type of model ("llm", "whisper", "embeddings", "nmt", or "tts")
78+
* - modelType: The canonical type of model ("llamacpp-completion",
79+
* "whispercpp-transcription", "llamacpp-embedding", "nmtcpp-translation",
80+
* "tts-ggml", ...). May be omitted when `modelSrc` is a registry descriptor
81+
* that already carries the engine.
7982
* - modelConfig: Model-specific configuration options (companion sources, model parameters, etc.)
8083
* - onProgress: Callback for download progress updates
8184
* - logger: Logger instance for model operation logs
@@ -92,27 +95,27 @@ export function loadModel<S extends ModelDescriptor>(
9295
* // Local file path - absolute path
9396
* const localModelId = await loadModel({
9497
* modelSrc: "/home/user/models/llama-7b.gguf",
95-
* modelType: "llm",
98+
* modelType: "llamacpp-completion",
9699
* modelConfig: { ctx_size: 2048 }
97100
* });
98101
*
99102
* // Local file path - relative path
100103
* const relativeModelId = await loadModel({
101104
* modelSrc: "./models/whisper-base.gguf",
102-
* modelType: "whisper"
105+
* modelType: "whispercpp-transcription"
103106
* });
104107
*
105108
* // Hyperdrive URL with key and path
106109
* const hyperdriveId = await loadModel({
107110
* modelSrc: "pear://<hyperdrive-key>/llama-7b.gguf",
108-
* modelType: "llm",
111+
* modelType: "llamacpp-completion",
109112
* modelConfig: { ctx_size: 2048 }
110113
* });
111114
*
112115
* // Remote HTTP/HTTPS URL with progress tracking
113116
* const remoteId = await loadModel({
114117
* modelSrc: "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
115-
* modelType: "llm",
118+
* modelType: "llamacpp-completion",
116119
* onProgress: (progress) => {
117120
* console.log(`Downloaded: ${progress.percentage}%`);
118121
* }
@@ -121,7 +124,7 @@ export function loadModel<S extends ModelDescriptor>(
121124
* // Multimodal model with projection
122125
* const multimodalId = await loadModel({
123126
* modelSrc: "https://huggingface.co/.../main-model.gguf",
124-
* modelType: "llm",
127+
* modelType: "llamacpp-completion",
125128
* modelConfig: {
126129
* ctx_size: 512,
127130
* projectionModelSrc: "https://huggingface.co/.../projection-model.gguf"
@@ -134,7 +137,7 @@ export function loadModel<S extends ModelDescriptor>(
134137
* // Whisper with VAD model
135138
* const whisperId = await loadModel({
136139
* modelSrc: "https://huggingface.co/.../whisper-model.gguf",
137-
* modelType: "whisper",
140+
* modelType: "whispercpp-transcription",
138141
* modelConfig: {
139142
* mode: "caption",
140143
* output_format: "plaintext",
@@ -150,7 +153,7 @@ export function loadModel<S extends ModelDescriptor>(
150153
*
151154
* const modelId = await loadModel({
152155
* modelSrc: "/path/to/model.gguf",
153-
* modelType: "llm",
156+
* modelType: "llamacpp-completion",
154157
* logger // Pass logger in options
155158
* });
156159
* ```

packages/sdk/examples/diffusion-img2vid.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ try {
3838
console.log("Loading Wan 2.1 I2V model (diffusion + UMT5-XXL + VAE + CLIP vision)...");
3939
const modelId = await loadModel({
4040
modelSrc: diffusionModelSrc,
41-
modelType: "diffusion",
41+
modelType: "sdcpp-generation",
4242
modelConfig: {
4343
mode: "video",
4444
device: "gpu",

packages/sdk/schemas/model-types.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,11 @@ const canonicalValuesSet = new Set<string>(Object.values(ModelType));
7272
*
7373
* @example
7474
* ```typescript
75-
* // Using alias (backward compatible)
76-
* loadModel({ modelSrc: "...", modelType: MODEL_TYPES.nmt });
77-
* // MODEL_TYPES.nmt resolves to "nmtcpp-translation"
78-
*
79-
* // Using canonical name directly
75+
* // Preferred: canonical name
8076
* loadModel({ modelSrc: "...", modelType: MODEL_TYPES.nmtcppTranslation });
77+
*
78+
* // Deprecated: alias (still resolves to "nmtcpp-translation")
79+
* loadModel({ modelSrc: "...", modelType: MODEL_TYPES.nmt });
8180
* ```
8281
*/
8382
export const PUBLIC_MODEL_TYPES = {

packages/sdk/test/bare/tts-resolve-config.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ test(
2222
await ttsPlugin.resolveConfig!(legacyConfig, {
2323
resolveModelPath: async () => "",
2424
modelSrc: "s3:///legacy/model",
25-
modelType: "tts",
25+
modelType: "tts-ggml",
2626
});
2727
t.ok(false, "expected LegacyTtsModelDeprecatedError");
2828
} catch (err) {

packages/sdk/test/mocks/pr-body-bc-valid.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ const model = await loadModel("model-path");
1717
**AFTER:**
1818

1919
```typescript
20-
const modelId = await loadModel("model-path", { modelType: "llm" });
20+
const modelId = await loadModel("model-path", { modelType: "llamacpp-completion" });
2121
```
2222

packages/sdk/test/unit/bci-schemas.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ test("loadModelOptionsToRequestSchema: resolves the 'bci' alias to the canonical
311311

312312
test("loadModelOptionsToRequestSchema: rejects unknown modelConfig keys for BCI (strict)", (t) => {
313313
const result = loadModelOptionsToRequestSchema.safeParse({
314-
modelType: "bci",
314+
modelType: "bci-whispercpp-transcription",
315315
modelSrc: "ggml-bci-windowed.bin",
316316
modelConfig: { notABciField: true },
317317
});

packages/sdk/test/unit/classification-schemas.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,15 @@ test("loadModelOptionsBaseSchema: accepts classification alias", (t) => {
202202
test("loadModelOptionsBaseSchema: accepts classification with custom modelSrc", (t) => {
203203
const result = loadModelOptionsBaseSchema.safeParse({
204204
modelSrc: "/abs/path/to/my-classifier.gguf",
205-
modelType: "classification",
205+
modelType: "ggml-classification",
206206
modelConfig: { topK: 3 },
207207
});
208208
t.is(result.success, true);
209209
});
210210

211211
test("loadModelOptionsBaseSchema: rejects classification config with unknown key (strict)", (t) => {
212212
const result = loadModelOptionsBaseSchema.safeParse({
213-
modelType: "classification",
213+
modelType: "ggml-classification",
214214
modelConfig: { topK: 3, unknownKey: true },
215215
});
216216
t.is(result.success, false);

packages/sdk/test/unit/inference-handler-migrations.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ test("translateRequestSchema (NMT): accepts an optional requestId", (t) => {
5454
modelId: "m1",
5555
text: "hello",
5656
stream: true,
57-
modelType: "nmt",
57+
modelType: "nmtcpp-translation",
5858
requestId: "req-nmt",
5959
});
6060
t.is(result.success, true);
@@ -66,7 +66,7 @@ test("translateRequestSchema (LLM): accepts an optional requestId", (t) => {
6666
modelId: "m1",
6767
text: "hello",
6868
stream: true,
69-
modelType: "llm",
69+
modelType: "llamacpp-completion",
7070
from: "en",
7171
to: "fr",
7272
requestId: "req-llm",
@@ -80,7 +80,7 @@ test("translateRequestSchema: rejects empty-string requestId", (t) => {
8080
modelId: "m1",
8181
text: "hello",
8282
stream: true,
83-
modelType: "nmt",
83+
modelType: "nmtcpp-translation",
8484
requestId: "",
8585
});
8686
t.is(result.success, false);

packages/sdk/test/unit/profiler-operation-transport.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ test("operation metrics: loadModel extracts gauges and tags", (t) => {
2121
"profile-1",
2222
100,
2323
500,
24-
{ modelType: "llm" },
24+
{ modelType: "llamacpp-completion" },
2525
{
2626
__profilingMeta: {
2727
sourceType: "registry",
@@ -37,7 +37,7 @@ test("operation metrics: loadModel extracts gauges and tags", (t) => {
3737
);
3838

3939
t.ok(event, "event is built");
40-
t.alike(event!.tags, { modelType: "llm", sourceType: "registry" });
40+
t.alike(event!.tags, { modelType: "llamacpp-completion", sourceType: "registry" });
4141
t.is(event!.gauges?.downloadTime, 220);
4242
t.is(event!.gauges?.totalBytesDownloaded, 4096);
4343
t.is(event!.gauges?.downloadSpeedBps, 18618);
@@ -51,7 +51,7 @@ test("operation metrics: omits unavailable gauges (no fabrication)", (t) => {
5151
"profile-2",
5252
100,
5353
90,
54-
{ modelType: "llm" },
54+
{ modelType: "llamacpp-completion" },
5555
{
5656
__profilingMeta: {
5757
sourceType: "filesystem",
@@ -83,7 +83,7 @@ test("transport: operation event survives injection/extraction round-trip", (t)
8383
ms: 500,
8484
profileId: "round-trip-test",
8585
gauges: { totalLoadTime: 500, downloadTime: 200 },
86-
tags: { modelType: "llm", sourceType: "registry", cacheHit: "true" },
86+
tags: { modelType: "llamacpp-completion", sourceType: "registry", cacheHit: "true" },
8787
};
8888

8989
const baseJson = '{"type":"loadModel","success":true}';
@@ -99,7 +99,7 @@ test("transport: operation event survives injection/extraction round-trip", (t)
9999
t.is(extracted!.operation!.profileId, "round-trip-test");
100100
t.alike(extracted!.operation!.gauges, { totalLoadTime: 500, downloadTime: 200 });
101101
t.alike(extracted!.operation!.tags, {
102-
modelType: "llm",
102+
modelType: "llamacpp-completion",
103103
sourceType: "registry",
104104
cacheHit: "true",
105105
});

0 commit comments

Comments
 (0)