diff --git a/packages/lms-communication-client/src/ClientPort.ts b/packages/lms-communication-client/src/ClientPort.ts index ec495d68..7945944e 100644 --- a/packages/lms-communication-client/src/ClientPort.ts +++ b/packages/lms-communication-client/src/ClientPort.ts @@ -14,6 +14,7 @@ import { } from "@lmstudio/lms-common"; import { Channel, + normalizeCommunicationWarningKind, deserialize, serialize, type BackendInterface, @@ -27,6 +28,7 @@ import { type ServerToClientMessage, type SignalEndpoint, type SignalEndpointsSpecBase, + type CommunicationWarningKind, type WritableSignalEndpoint, type WritableSignalEndpointsSpecBase, } from "@lmstudio/lms-communication"; @@ -76,6 +78,12 @@ function defaultErrorDeserializer( return fromSerializedError(serialized, directCause, stack); } +export interface ClientPortCommunicationWarning { + direction: "produced" | "received"; + kind: CommunicationWarningKind; + warning: string; +} + export class ClientPort< TRpcEndpoints extends RpcEndpointsSpecBase, TChannelEndpoints extends ChannelEndpointsSpecBase, @@ -104,6 +112,9 @@ export class ClientPort< stack?: string, ) => Error; private verboseErrorMessage: boolean; + private readonly onCommunicationWarning?: ( + communicationWarning: ClientPortCommunicationWarning, + ) => void; public constructor( public readonly backendInterface: BackendInterface< @@ -118,6 +129,7 @@ export class ClientPort< parentLogger, errorDeserializer, verboseErrorMessage, + onCommunicationWarning, }: { parentLogger?: LoggerInterface; errorDeserializer?: ( @@ -126,33 +138,43 @@ export class ClientPort< stack?: string, ) => Error; verboseErrorMessage?: boolean; + onCommunicationWarning?: (communicationWarning: ClientPortCommunicationWarning) => void; } = {}, ) { this.logger = new SimpleLogger("ClientPort", parentLogger); this.errorDeserializer = errorDeserializer ?? defaultErrorDeserializer; this.verboseErrorMessage = verboseErrorMessage ?? true; + this.onCommunicationWarning = onCommunicationWarning; this.transport = factory(this.receivedMessage, this.onConnected, this.errored, this.logger); } - private communicationWarning(warning: string) { + private communicationWarning(warning: string, kind: CommunicationWarningKind = "unknown") { if (this.producedCommunicationWarningsCount >= 5) { return; } - this.logger.warnText` - Produced communication warning: ${warning} - - This is usually caused by communication protocol incompatibility. Please make sure you are - using the up-to-date versions of the SDK and LM Studio. - `; + if (this.onCommunicationWarning === undefined) { + this.logger.warnText` + Produced communication warning: ${warning} + + This is usually caused by communication protocol incompatibility. Please make sure you are + using the up-to-date versions of the SDK and LM Studio. + `; + } this.safeSend( { type: "communicationWarning", warning, + kind, }, "communicationWarning", ); + this.reportCommunicationWarning({ + direction: "produced", + kind, + warning, + }); this.producedCommunicationWarningsCount++; - if (this.producedCommunicationWarningsCount >= 5) { + if (this.onCommunicationWarning === undefined && this.producedCommunicationWarningsCount >= 5) { this.logger.errorText` 5 communication warnings have been produced. Further warnings will not be printed. `; @@ -190,6 +212,7 @@ export class ClientPort< if (openChannel === undefined) { this.communicationWarning( `Received channelSend for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -201,7 +224,7 @@ export class ClientPort< ${deserializedMessage}. Zod error: ${Validator.prettyPrintZod("message", parsed.error)} - `); + `, "channelMessageTypeError"); return; } openChannel.receivedMessage(parsed.data); @@ -212,6 +235,7 @@ export class ClientPort< if (openChannel === undefined) { this.communicationWarning( `Received channelAck for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -223,6 +247,7 @@ export class ClientPort< if (openChannel === undefined) { this.communicationWarning( `Received channelClose for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -236,6 +261,7 @@ export class ClientPort< if (openChannel === undefined) { this.communicationWarning( `Received channelError for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -252,7 +278,10 @@ export class ClientPort< private receivedRpcResult(message: ServerToClientMessage & { type: "rpcResult" }) { const ongoingRpc = this.ongoingRpcs.get(message.callId); if (ongoingRpc === undefined) { - this.communicationWarning(`Received rpcResult for unknown rpc, callId = ${message.callId}`); + this.communicationWarning( + `Received rpcResult for unknown rpc, callId = ${message.callId}`, + "rpcUnknown", + ); return; } const deserializedResult = deserialize(ongoingRpc.endpoint.serialization, message.result); @@ -263,7 +292,7 @@ export class ClientPort< ${deserializedResult}. Zod error: ${Validator.prettyPrintZod("result", parsed.error)} - `); + `, "rpcResultTypeError"); return; } ongoingRpc.resolve(parsed.data); @@ -274,7 +303,10 @@ export class ClientPort< private receivedRpcError(message: ServerToClientMessage & { type: "rpcError" }) { const ongoingRpc = this.ongoingRpcs.get(message.callId); if (ongoingRpc === undefined) { - this.communicationWarning(`Received rpcError for unknown rpc, callId = ${message.callId}`); + this.communicationWarning( + `Received rpcError for unknown rpc, callId = ${message.callId}`, + "rpcUnknown", + ); return; } const error = this.errorDeserializer( @@ -313,7 +345,7 @@ export class ClientPort< patches = ${JSON.stringify(patches, null, 2)}. Error: ${String(error)} - `); + `, "signalUpdatePatchApplyError"); return; } const parseResult = openSignalSubscription.endpoint.signalData.safeParse(afterValue); @@ -330,7 +362,7 @@ export class ClientPort< Zod error: ${Validator.prettyPrintZod("value", parseResult.error)} - `); + `, "signalPatchTypeError"); return; } // Don't use the parsed value, as it loses the substructure identities @@ -342,6 +374,7 @@ export class ClientPort< if (openSignalSubscription === undefined) { this.communicationWarning( `Received signalError for unknown signal, subscribeId = ${message.subscribeId}`, + "signalUnknown", ); return; } @@ -384,7 +417,7 @@ export class ClientPort< patches = ${JSON.stringify(patches, null, 2)}. Error: ${String(error)} - `); + `, "writableSignalPatchApplyError"); } const parseResult = openSignalSubscription.endpoint.signalData.safeParse(afterValue); if (!parseResult.success) { @@ -400,7 +433,7 @@ export class ClientPort< Zod error: ${Validator.prettyPrintZod("value", parseResult.error)} - `); + `, "writableSignalPatchTypeError"); return; } // Don't use the parsed value, as it loses the substructure identities @@ -415,6 +448,7 @@ export class ClientPort< if (openSignalSubscription === undefined) { this.communicationWarning( `Received writableSignalError for unknown signal, subscribeId = ${message.subscribeId}`, + "writableSignalUnknown", ); return; } @@ -431,14 +465,33 @@ export class ClientPort< private receivedCommunicationWarning( message: ServerToClientMessage & { type: "communicationWarning" }, ) { - this.logger.warnText` - Received communication warning from the server: ${message.warning} - - This is usually caused by communication protocol incompatibility. Please make sure you are - using the up-to-date versions of the SDK and LM Studio. - - Note: This warning was received from the server and is printed on the client for convenience. - `; + const kind = normalizeCommunicationWarningKind(message.kind); + if (this.onCommunicationWarning === undefined) { + this.logger.warnText` + Received communication warning from the server (${kind}): ${message.warning} + + This is usually caused by communication protocol incompatibility. Please make sure you are + using the up-to-date versions of the SDK and LM Studio. + + Note: This warning was received from the server and is printed on the client for convenience. + `; + } + this.reportCommunicationWarning({ + direction: "received", + kind, + warning: message.warning, + }); + } + + private reportCommunicationWarning(communicationWarning: ClientPortCommunicationWarning): void { + if (this.onCommunicationWarning === undefined) { + return; + } + try { + this.onCommunicationWarning(communicationWarning); + } catch (error) { + this.logger.error("Error in onCommunicationWarning callback:", error); + } } private receivedKeepAliveAck(_message: ServerToClientMessage & { type: "keepAliveAck" }) { diff --git a/packages/lms-communication-client/src/index.ts b/packages/lms-communication-client/src/index.ts index 4c56d6cf..52ace300 100644 --- a/packages/lms-communication-client/src/index.ts +++ b/packages/lms-communication-client/src/index.ts @@ -1,5 +1,5 @@ export { AuthenticatedWsClientTransport } from "./AuthenticatedWsClientTransport.js"; -export { ClientPort, InferClientPort } from "./ClientPort.js"; +export { ClientPort, type ClientPortCommunicationWarning, InferClientPort } from "./ClientPort.js"; export { GenericClientTransport } from "./GenericClientTransport.js"; export { LMStudioHostedEnv, getHostedEnv } from "./LMStudioHostedEnv.js"; export { WsClientTransport } from "./WsClientTransport.js"; diff --git a/packages/lms-communication-server/src/ServerPort.ts b/packages/lms-communication-server/src/ServerPort.ts index 9065de06..2841a108 100644 --- a/packages/lms-communication-server/src/ServerPort.ts +++ b/packages/lms-communication-server/src/ServerPort.ts @@ -11,6 +11,7 @@ import { } from "@lmstudio/lms-common"; import { Channel, + normalizeCommunicationWarningKind, deserialize, serialize, type BackendInterface, @@ -23,6 +24,7 @@ import { type ServerTransportFactory, type SignalEndpoint, type SignalEndpointsSpecBase, + type CommunicationWarningKind, type WritableSignalEndpoint, type WritableSignalEndpointsSpecBase, } from "@lmstudio/lms-communication"; @@ -114,7 +116,7 @@ export class ServerPort< } } - private communicationWarning(warning: string) { + private communicationWarning(warning: string, kind: CommunicationWarningKind = "unknown") { if (this.producedCommunicationWarningsCount >= 5) { return; } @@ -128,6 +130,7 @@ export class ServerPort< { type: "communicationWarning", warning, + kind, }, "communicationWarning", ); @@ -144,18 +147,21 @@ export class ServerPort< if (endpoint === undefined) { this.communicationWarning( `Received channelCreate for unknown endpoint, endpoint = ${message.endpoint}`, + "channelEndpointUnknown", ); return; } if (endpoint.handler === null) { this.communicationWarning( `Received channelCreate for unhandled endpoint, endpoint = ${message.endpoint}`, + "channelEndpointUnhandled", ); return; } if (this.openChannels.has(message.channelId)) { this.communicationWarning( `Received channelCreate for already open channel, channelId = ${message.channelId}`, + "channelAlreadyOpen", ); return; } @@ -170,7 +176,7 @@ export class ServerPort< creationParameter = ${deserializedCreationParameter}. Zod error: ${Validator.prettyPrintZod("creationParameter", parseResult.error)} - `); + `, "channelCreationParameterTypeError"); return; } @@ -185,7 +191,7 @@ export class ServerPort< message = ${message as object}. Zod error: ${Validator.prettyPrintZod("message", result.error)} - `); + `, "channelMessageTypeError"); return; } const serializedMessage = serialize(endpoint.serialization, result.data); @@ -244,6 +250,7 @@ export class ServerPort< if (openChannel === undefined) { this.communicationWarning( `Received channelSend for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -255,7 +262,7 @@ export class ServerPort< message = ${message.message}. Zod error: ${Validator.prettyPrintZod("message", parsed.error)} - `); + `, "channelMessageTypeError"); return; } openChannel.receivedMessage(parsed.data); @@ -266,6 +273,7 @@ export class ServerPort< if (openChannel === undefined) { this.communicationWarning( `Received channelAck for unknown channel, channelId = ${message.channelId}`, + "channelUnknown", ); return; } @@ -277,12 +285,14 @@ export class ServerPort< if (endpoint === undefined) { this.communicationWarning( `Received rpcCall for unknown endpoint, endpoint = ${message.endpoint}`, + "rpcEndpointUnknown", ); return; } if (endpoint.handler === null) { this.communicationWarning( `Received rpcCall for unhandled endpoint, endpoint = ${message.endpoint}`, + "rpcEndpointUnhandled", ); return; } @@ -294,7 +304,7 @@ export class ServerPort< parameter = ${message.parameter}. Zod error: ${Validator.prettyPrintZod("parameter", parseResult.error)} - `); + `, "rpcParameterTypeError"); return; } const context = this.contextCreator({ @@ -355,12 +365,14 @@ export class ServerPort< if (endpoint === undefined) { this.communicationWarning( `Received signalSubscribe for unknown endpoint, endpoint = ${message.endpoint}`, + "signalEndpointUnknown", ); return; } if (endpoint.handler === null) { this.communicationWarning( `Received signalSubscribe for unhandled endpoint, endpoint = ${message.endpoint}`, + "signalEndpointUnhandled", ); return; } @@ -368,7 +380,7 @@ export class ServerPort< this.communicationWarning(text` Received signalSubscribe for already open subscription, subscribeId = ${message.subscribeId} - `); + `, "signalAlreadyOpen"); return; } const deserializedCreationParameter = deserialize( @@ -382,7 +394,7 @@ export class ServerPort< creationParameter = ${deserializedCreationParameter}. Zod error: ${Validator.prettyPrintZod("creationParameter", parseResult.error)} - `); + `, "signalCreationParameterTypeError"); return; } const context = this.contextCreator({ @@ -523,12 +535,14 @@ export class ServerPort< if (endpoint === undefined) { this.communicationWarning( `Received writableSignalSubscribe for unknown endpoint, endpoint = ${message.endpoint}`, + "writableSignalEndpointUnknown", ); return; } if (endpoint.handler === null) { this.communicationWarning( `Received writableSignalSubscribe for unhandled endpoint, endpoint = ${message.endpoint}`, + "writableSignalEndpointUnhandled", ); return; } @@ -536,7 +550,7 @@ export class ServerPort< this.communicationWarning(text` Received writableSignalSubscribe for already open subscription, subscribeId = ${message.subscribeId} - `); + `, "writableSignalAlreadyOpen"); return; } const deserializedCreationParameter = deserialize( @@ -550,7 +564,7 @@ export class ServerPort< creationParameter = ${deserializedCreationParameter}. Zod error: ${Validator.prettyPrintZod("creationParameter", parseResult.error)} - `); + `, "writableSignalCreationParameterTypeError"); return; } const context = this.contextCreator({ @@ -650,7 +664,7 @@ export class ServerPort< data = ${result}. Zod error: ${Validator.prettyPrintZod("data", parseResult.error)} - `); + `, "writableSignalDataTypeError"); return; } setter.withValueAndPatches(parseResult.data, deserializedPatches, tags); @@ -658,7 +672,7 @@ export class ServerPort< this.communicationWarning(text` Error in receivedPatches for writable signal, endpointName = ${endpoint.name}, error = ${error.message} - `); + `, "writableSignalPatchApplyError"); } }, }; @@ -746,6 +760,7 @@ export class ServerPort< if (openWritableSignalSubscription === undefined) { this.communicationWarning( `Received writableSignalUpdate for unknown subscription, subscribeId = ${message.subscribeId}`, + "writableSignalUnknown", ); return; } @@ -755,8 +770,9 @@ export class ServerPort< private receivedCommunicationWarning( message: ClientToServerMessage & { type: "communicationWarning" }, ) { + const kind = normalizeCommunicationWarningKind(message.kind); this.logger.warnText` - Received communication warning from the client: ${message.warning} + Received communication warning from the client (${kind}): ${message.warning} This is usually caused by communication protocol incompatibility. Please make sure you are using the up-to-date versions of the SDK and LM Studio. diff --git a/packages/lms-communication/src/CommunicationWarning.ts b/packages/lms-communication/src/CommunicationWarning.ts new file mode 100644 index 00000000..cf5a7b03 --- /dev/null +++ b/packages/lms-communication/src/CommunicationWarning.ts @@ -0,0 +1,43 @@ +export const communicationWarningKinds = [ + "unknown", + "channelUnknown", + "channelMessageTypeError", + "channelCreationParameterTypeError", + "channelEndpointUnknown", + "channelEndpointUnhandled", + "channelAlreadyOpen", + "rpcUnknown", + "rpcParameterTypeError", + "rpcResultTypeError", + "rpcEndpointUnknown", + "rpcEndpointUnhandled", + "signalUnknown", + "signalUpdatePatchApplyError", + "signalPatchTypeError", + "signalEndpointUnknown", + "signalEndpointUnhandled", + "signalAlreadyOpen", + "signalCreationParameterTypeError", + "writableSignalUnknown", + "writableSignalPatchApplyError", + "writableSignalPatchTypeError", + "writableSignalDataTypeError", + "writableSignalEndpointUnknown", + "writableSignalEndpointUnhandled", + "writableSignalAlreadyOpen", + "writableSignalCreationParameterTypeError", +] as const; + +export type CommunicationWarningKind = (typeof communicationWarningKinds)[number]; + +const communicationWarningKindsSet = new Set(communicationWarningKinds); + +export function normalizeCommunicationWarningKind(kind: string | undefined): CommunicationWarningKind { + if (kind === undefined) { + return "unknown"; + } + if (!communicationWarningKindsSet.has(kind)) { + return "unknown"; + } + return kind as CommunicationWarningKind; +} diff --git a/packages/lms-communication/src/Transport.ts b/packages/lms-communication/src/Transport.ts index d0341753..b25cd239 100644 --- a/packages/lms-communication/src/Transport.ts +++ b/packages/lms-communication/src/Transport.ts @@ -8,6 +8,7 @@ const clientToServerMessageSchema = z.discriminatedUnion("type", [ z.object({ type: z.literal("communicationWarning"), warning: z.string(), + kind: z.string().optional(), }), z.object({ type: z.literal("keepAlive"), @@ -78,6 +79,7 @@ const serverToClientMessageSchema = z.discriminatedUnion("type", [ z.object({ type: z.literal("communicationWarning"), warning: z.string(), + kind: z.string().optional(), }), z.object({ type: z.literal("keepAliveAck"), diff --git a/packages/lms-communication/src/index.ts b/packages/lms-communication/src/index.ts index 0010fb5e..3134e417 100644 --- a/packages/lms-communication/src/index.ts +++ b/packages/lms-communication/src/index.ts @@ -23,6 +23,11 @@ export { InferClientChannelType, InferServerChannelType, } from "./Channel.js"; +export { + communicationWarningKinds, + normalizeCommunicationWarningKind, + type CommunicationWarningKind, +} from "./CommunicationWarning.js"; export { deserialize, SerializationType,