Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/sdk/client/api/bci-transcribe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { stream, duplex, type DuplexReadable } from "@/client/rpc/rpc-client";
import { getClientLogger } from "@/logging";
import { TranscriptionFailedError } from "@/utils/errors-client";
import { decoratePromise } from "@/utils/decorate-promise";
import { parseClientInput } from "@/client/parse-input";
import { generateClientRequestId } from "@/client/api/client-request-id";

const logger = getClientLogger();
Expand Down Expand Up @@ -70,7 +71,7 @@ export function bciTranscribe(
params: BciTranscribeClientParams,
options?: RPCOptions,
): Promise<string | TranscribeSegment[]> & { requestId: string } {
const parsed = bciTranscribeClientParamsSchema.parse(params);
const parsed = parseClientInput(bciTranscribeClientParamsSchema, params);
const requestId = generateClientRequestId();
const inner = runBciTranscribe(parsed, requestId, options);
return decoratePromise(inner, { requestId });
Expand Down
3 changes: 2 additions & 1 deletion packages/sdk/client/api/download-asset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
InvalidResponseError,
} from "@/utils/errors-client";
import { decoratePromise } from "@/utils/decorate-promise";
import { parseClientInput } from "@/client/parse-input";
import { generateClientRequestId } from "@/client/api/client-request-id";

export type DownloadAssetOptions = BaseDownloadAssetOptions;
Expand Down Expand Up @@ -72,7 +73,7 @@ async function runDownloadAsset(
requestId: string,
rpcOptions?: RPCOptions,
): Promise<string> {
const request = downloadAssetOptionsToRequestSchema.parse({
const request = parseClientInput(downloadAssetOptionsToRequestSchema, {
...options,
requestId,
});
Expand Down
11 changes: 6 additions & 5 deletions packages/sdk/client/api/finetune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
InvalidResponseError,
StreamEndedError,
} from "@/utils/errors-client";
import { parseClientInput } from "@/client/parse-input";

export interface FinetuneHandle {
progressStream: AsyncGenerator<FinetuneProgress>;
Expand All @@ -41,14 +42,14 @@ function isFinetuneReplyParams(

function createFinetuneReplyRequest(params: FinetuneReplyParams) {
if (params.operation === "getState") {
const getStateParams = finetuneGetStateParamsSchema.parse(params);
return finetuneGetStateRequestSchema.parse({
const getStateParams = parseClientInput(finetuneGetStateParamsSchema, params);
return parseClientInput(finetuneGetStateRequestSchema, {
type: "finetune",
...getStateParams,
});
}

return finetuneStopRequestSchema.parse({
return parseClientInput(finetuneStopRequestSchema, {
type: "finetune",
modelId: params.modelId,
operation: params.operation,
Expand Down Expand Up @@ -172,7 +173,7 @@ export function finetune(
return resultPromise;
}

const runParams = finetuneRunParamsSchema.parse(params);
const runParams = parseClientInput(finetuneRunParamsSchema, params);

let resultResolver: (value: FinetuneResult) => void = () => { };
let resultRejecter: (error: unknown) => void = () => { };
Expand All @@ -191,7 +192,7 @@ export function finetune(
const processResponses = async () => {
try {
let sawTerminalResponse = false;
const request = finetuneRunRequestSchema.parse({
const request = parseClientInput(finetuneRunRequestSchema, {
type: "finetune",
...runParams,
withProgress: true,
Expand Down
11 changes: 8 additions & 3 deletions packages/sdk/client/api/load-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import {
type RPCOptions,
type ModelDescriptor,
type SdcppConfig,
loadModelOptionsToRequestSchema,
loadBuiltinToRequestSchema,
loadCustomPluginToRequestSchema,
reloadConfigOptionsToRequestSchema,
isBuiltInModelType,
isModelTypeAlias,
normalizeModelType,
inferModelTypeFromModelSrc,
Expand All @@ -23,6 +25,7 @@ import {
InvalidResponseError,
} from "@/utils/errors-client";
import { assertModelSrcMatchesModelType } from "@/utils/load-model-validation";
import { parseClientInput } from "@/client/parse-input";
import { getClientLogger } from "@/logging";
import { decoratePromise } from "@/utils/decorate-promise";
import { generateClientRequestId } from "@/client/api/client-request-id";
Expand Down Expand Up @@ -305,8 +308,10 @@ async function runLoadModel(
resolvedOptions = { ...resolvedOptions, requestId };

const request = isReloadConfig
? reloadConfigOptionsToRequestSchema.parse(resolvedOptions)
: loadModelOptionsToRequestSchema.parse(resolvedOptions);
? parseClientInput(reloadConfigOptionsToRequestSchema, resolvedOptions)
: isBuiltInModelType(resolvedOptions["modelType"])
? parseClientInput(loadBuiltinToRequestSchema, resolvedOptions)
: parseClientInput(loadCustomPluginToRequestSchema, resolvedOptions);
const modelLogger = isReloadConfig
? undefined
: (resolvedOptions["logger"] as LoadModelOptions["logger"]);
Expand Down
3 changes: 2 additions & 1 deletion packages/sdk/client/api/text-to-speech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
import { stream as streamRpc, duplex, type DuplexReadable } from "@/client/rpc/rpc-client";
import { getClientLogger } from "@/logging";
import { TextToSpeechStreamFailedError } from "@/utils/errors-client";
import { parseClientInput } from "@/client/parse-input";

const logger = getClientLogger();

Expand Down Expand Up @@ -211,7 +212,7 @@ export function textToSpeech(
params: TtsClientParamsInput,
options?: RPCOptions,
): TextToSpeechStreamResult {
const parsed: TtsClientParams = ttsClientParamsSchema.parse(params);
const parsed: TtsClientParams = parseClientInput(ttsClientParamsSchema, params);

if (parsed.sentenceStream && !parsed.stream) {
throw new TextToSpeechStreamFailedError(
Expand Down
6 changes: 2 additions & 4 deletions packages/sdk/client/config-loader/config-utils.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import { qvacConfigSchema, type QvacConfig } from "@/schemas";
import { ConfigValidationFailedError } from "@/utils/errors-client";
import { formatZodError } from "@/utils/zod-error";

export type { QvacConfig };

export function validateConfig(config: unknown): QvacConfig {
const result = qvacConfigSchema.safeParse(config);

if (!result.success) {
const errors = result.error.issues
.map((e) => `${String(e.path.join("."))}: ${e.message}`)
.join(", ");
throw new ConfigValidationFailedError(errors);
throw new ConfigValidationFailedError(formatZodError(result.error));
}

return result.data;
Expand Down
17 changes: 17 additions & 0 deletions packages/sdk/client/parse-input.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { z } from "zod";
import { formatZodError } from "@/utils/zod-error";
import { RequestValidationFailedError } from "@/utils/errors-client";

export function parseClientInput<S extends z.ZodType>(
schema: S,
value: unknown,
): z.output<S> {
try {
return schema.parse(value);
} catch (error) {
if (error instanceof z.ZodError) {
throw new RequestValidationFailedError(formatZodError(error));
}
throw error;
}
}
60 changes: 53 additions & 7 deletions packages/sdk/client/rpc/rpc-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ import {
createDuplexSession,
getWorkerLifeSignal,
} from "#rpc";
import { WorkerCrashedError } from "@/utils/errors-client";
import {
WorkerCrashedError,
RequestValidationFailedError,
} from "@/utils/errors-client";
import { formatZodError } from "@/utils/zod-error";
import { z } from "zod";
import {
nowMs,
shouldProfile,
Expand Down Expand Up @@ -50,6 +55,47 @@ function getNextCommandId() {
return commandCounter;
}

// On a failed request parse, re-validate against the single `requestSchema`
// member that owns the request's `type` so the error names the actual field
// rather than reporting a generic union failure. `loadModel` (a nested union
// with no top-level `type` literal) is already validated field-level in
// `client/api/load-model.ts`, so it falls back to the union error here.
interface RequestMemberIntrospect {
shape?: { type?: { value?: unknown } };
options?: RequestMemberIntrospect[];
}

function memberDiscriminator(option: RequestMemberIntrospect): string | undefined {
const direct = option.shape?.type?.value;
if (typeof direct === "string") return direct;
const nested = option.options?.[0]?.shape?.type?.value;
return typeof nested === "string" ? nested : undefined;
}

function pinpointRequestError(request: unknown, fallback: z.ZodError): z.ZodError {
const type = (request as { type?: unknown } | null)?.type;
if (typeof type !== "string") return fallback;
for (const option of requestSchema.options) {
if (memberDiscriminator(option as RequestMemberIntrospect) !== type) continue;
const result = (option as z.ZodType).safeParse(request);
return result.success ? fallback : result.error;
}
return fallback;
}

function parseRequest<T extends Request>(request: T): Request {
try {
return requestSchema.parse(request);
} catch (error) {
if (error instanceof z.ZodError) {
throw new RequestValidationFailedError(
formatZodError(pinpointRequestError(request, error)),
);
}
throw error;
}
}

// Race in-flight reply/stream pulls against the worker-life signal —
// bare-rpc's `_onerror` does not iterate `_outgoingRequests`, so without
// this they hang on a dead socket.
Expand Down Expand Up @@ -234,7 +280,7 @@ async function sendBase<T extends Request>(
options?: RPCOptions,
signalDisable: boolean = false,
): Promise<Response> {
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
const req = rpc.request(getNextCommandId());
logger.debug("RPC Client sending:", summarizeRequest(request));
const payloadObj = signalDisable
Expand Down Expand Up @@ -272,7 +318,7 @@ async function sendProfiled<T extends Request>(

try {
const zodStart = nowMs();
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
timings.requestZodValidationMs = nowMs() - zodStart;

const req = rpc.request(getNextCommandId());
Expand Down Expand Up @@ -352,7 +398,7 @@ async function* streamBase<T extends Request>(
options: RPCOptions = {},
signalDisable: boolean = false,
): AsyncGenerator<Response> {
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
const req = rpc.request(getNextCommandId());
logger.debug("RPC Client streaming:", summarizeRequest(request));
const payloadObj = signalDisable
Expand Down Expand Up @@ -409,7 +455,7 @@ async function* streamProfiled<T extends Request>(

try {
const zodStart = nowMs();
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
timings.requestZodValidationMs = nowMs() - zodStart;

const req = rpc.request(getNextCommandId());
Expand Down Expand Up @@ -523,7 +569,7 @@ async function duplexBase<T extends Request>(
signalDisable: boolean,
timeout?: number,
): Promise<DuplexSession> {
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
logger.debug("RPC Client duplex:", summarizeRequest(request));

const payloadObj = signalDisable
Expand Down Expand Up @@ -554,7 +600,7 @@ async function duplexProfiled<T extends Request>(

try {
const zodStart = nowMs();
const parsedRequest = requestSchema.parse(request);
const parsedRequest = parseRequest(request);
timings.requestZodValidationMs = nowMs() - zodStart;

logger.debug("RPC Client duplex:", summarizeRequest(request));
Expand Down
1 change: 1 addition & 0 deletions packages/sdk/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ export {
BareRuntimeBinaryNotFoundError,
WorkerCrashedError,
WorkerShutdownError,
RequestValidationFailedError,
} from "./utils/errors-client";

// Logging exports
Expand Down
8 changes: 7 additions & 1 deletion packages/sdk/schemas/error.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { z } from "zod";
import { QvacErrorBase } from "@qvac/error";
import { formatZodError } from "@/utils/zod-error";

/**
* Wire shape for errors thrown across the RPC boundary. The fields are
Expand Down Expand Up @@ -76,7 +77,12 @@ export function createErrorResponse(error: unknown): ErrorResponse {
return response;
}

const message = error instanceof Error ? error.message : String(error);
const message =
error instanceof z.ZodError
? formatZodError(error)
: error instanceof Error
? error.message
: String(error);
const stack = error instanceof Error ? error.stack : undefined;

return {
Expand Down
50 changes: 30 additions & 20 deletions packages/sdk/schemas/load-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ const builtInModelTypes = new Set([
...Object.values(ModelType),
...Object.keys(ModelTypeAliases),
]);

export function isBuiltInModelType(modelType: unknown): boolean {
return typeof modelType === "string" && builtInModelTypes.has(modelType);
}
import type { Logger } from "@/logging";
import { reloadConfigRequestSchema } from "./reload-config";

Expand Down Expand Up @@ -166,7 +170,7 @@ export const loadModelOptionsSchema = loadModelOptionsBaseSchema.transform(
}),
);

const loadModelOptionsToRequestBaseSchema = z.union([
export const loadBuiltinToRequestSchema = z.discriminatedUnion("modelType", [
z
.object({
...loadModelRequestCommonFields,
Expand Down Expand Up @@ -377,25 +381,31 @@ const loadModelOptionsToRequestBaseSchema = z.union([
delegate: data.delegate,
...(data.requestId !== undefined && { requestId: data.requestId }),
})),
z
.object({
...loadModelRequestCommonFields,
modelType: z.string().refine((val) => !builtInModelTypes.has(val), {
message: "Built-in model types must use their specific schema",
}),
modelConfig: z.record(z.string(), z.unknown()).optional(),
})
.transform((data) => ({
type: "loadModel" as const,
modelType: data.modelType,
modelSrc: modelInputToSrcSchema.parse(data.modelSrc),
modelName: modelInputToNameSchema.parse(data.modelSrc),
modelConfig: data.modelConfig ?? {},
seed: data.seed ?? false,
withProgress: data.withProgress ?? !!data.onProgress,
delegate: data.delegate,
...(data.requestId !== undefined && { requestId: data.requestId }),
})),
]);

export const loadCustomPluginToRequestSchema = z
.object({
...loadModelRequestCommonFields,
modelType: z.string().refine((val) => !builtInModelTypes.has(val), {
message: "Built-in model types must use their specific schema",
}),
modelConfig: z.record(z.string(), z.unknown()).optional(),
})
.transform((data) => ({
type: "loadModel" as const,
modelType: data.modelType,
modelSrc: modelInputToSrcSchema.parse(data.modelSrc),
modelName: modelInputToNameSchema.parse(data.modelSrc),
modelConfig: data.modelConfig ?? {},
seed: data.seed ?? false,
withProgress: data.withProgress ?? !!data.onProgress,
delegate: data.delegate,
...(data.requestId !== undefined && { requestId: data.requestId }),
}));

const loadModelOptionsToRequestBaseSchema = z.union([
loadBuiltinToRequestSchema,
loadCustomPluginToRequestSchema,
]);

export const loadModelOptionsToRequestSchema =
Expand Down
5 changes: 5 additions & 0 deletions packages/sdk/schemas/sdk-errors-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export const SDK_CLIENT_ERROR_CODES = {
OCR_FAILED: 50007,
MODEL_TYPE_REQUIRED: 50008,
MODEL_SRC_TYPE_MISMATCH: 50009,
REQUEST_VALIDATION_FAILED: 50010,

// RPC Communication Errors (50,200-50,399)
RPC_NO_HANDLER: 50200,
Expand Down Expand Up @@ -93,6 +94,10 @@ const clientErrorDefinitions: ErrorCodesMap = {
message: (inferred: string, resolved: string) =>
`modelSrc describes "${inferred}", but modelType resolves to "${resolved}". Omit modelType to infer it automatically, or pass a matching modelType.`,
},
[SDK_CLIENT_ERROR_CODES.REQUEST_VALIDATION_FAILED]: {
name: "REQUEST_VALIDATION_FAILED",
message: (errors: string) => `Invalid request:\n${errors}`,
},

// RPC Communication Errors (50,200-50,399)
[SDK_CLIENT_ERROR_CODES.RPC_NO_HANDLER]: {
Expand Down
Loading
Loading