Skip to content

Commit ea7ba8e

Browse files
committed
Refactor and unit test auth message creation
1 parent 020c7cd commit ea7ba8e

File tree

2 files changed

+113
-9
lines changed

2 files changed

+113
-9
lines changed

src/lmstudio/json_api.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,9 @@ def from_details(message: str, details: DictObject) -> "LMStudioServerError":
409409
if display_data:
410410
specific_error: LMStudioServerError | None = None
411411
match display_data:
412-
case {"code": "generic.noModelMatchingQuery"}:
412+
case {"code": "generic.noModelMatchingQuery"} | {
413+
"code": "generic.pathNotFound"
414+
}:
413415
specific_error = LMStudioModelNotFoundError(str(default_error))
414416
case {"code": "generic.presetNotFound"}:
415417
specific_error = LMStudioPresetNotFoundError(str(default_error))
@@ -2116,11 +2118,16 @@ def _format_auth_message(
21162118
"clientPasskey": client_passkey,
21172119
}
21182120

2119-
def _create_auth_message(self, api_token: str | None = None) -> DictObject:
2120-
"""Create an LM Studio websocket authentication message."""
2121+
@classmethod
2122+
def _create_auth_from_token(cls, api_token: str | None) -> DictObject:
2123+
"""Create an LM Studio websocket auth message from an API token.
2124+
2125+
If no token is given, and none is set in the environment,
2126+
falls back to generating a client scoped guest identifier
2127+
"""
21212128
if api_token is None:
2122-
api_token = os.getenv(_ENV_API_TOKEN, None)
2123-
if api_token is not None:
2129+
api_token = os.environ.get(_ENV_API_TOKEN, None)
2130+
if api_token: # Accept empty string as equivalent to None
21242131
match = _LMS_API_TOKEN_REGEX.match(api_token.strip())
21252132
if match is None:
21262133
raise LMStudioValueError(
@@ -2135,9 +2142,14 @@ def _create_auth_message(self, api_token: str | None = None) -> DictObject:
21352142
raise LMStudioValueError(
21362143
"Unexpected error parsing api_token: required token fields were not detected."
21372144
)
2138-
return self._format_auth_message(client_identifier, client_passkey)
2145+
return cls._format_auth_message(client_identifier, client_passkey)
2146+
2147+
return cls._format_auth_message()
21392148

2140-
return self._format_auth_message()
2149+
def _create_auth_message(self, api_token: str | None = None) -> DictObject:
2150+
"""Create an LM Studio websocket authentication message."""
2151+
# This is an instance method purely so subclasses may override it
2152+
return self._create_auth_from_token(api_token)
21412153

21422154

21432155
TClient = TypeVar("TClient", bound=ClientBase)

tests/test_sessions.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
"""Test common client session behaviour."""
22

33
import logging
4+
import os
5+
46
from typing import Generator
7+
from unittest import mock
58

69
import pytest
710
from pytest import LogCaptureFixture as LogCap
811

912
from lmstudio import (
1013
AsyncClient,
1114
Client,
15+
LMStudioValueError,
1216
LMStudioWebsocketError,
1317
)
1418
from lmstudio.async_api import (
1519
_AsyncLMStudioWebsocket,
1620
_AsyncSession,
1721
_AsyncSessionSystem,
1822
)
23+
from lmstudio.json_api import ClientBase
1924
from lmstudio.sync_api import (
2025
SyncLMStudioWebsocket,
2126
_SyncSession,
@@ -24,6 +29,93 @@
2429
from lmstudio._ws_impl import AsyncTaskManager
2530
from lmstudio._ws_thread import AsyncWebsocketThread
2631

32+
# This API token is structurally valid
33+
_VALID_API_TOKEN = "sk-lm-abcDEF78:abcDEF7890abcDEF7890"
34+
35+
36+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
37+
def test_auth_message_default(client_cls: ClientBase) -> None:
38+
with mock.patch.dict(os.environ) as env:
39+
env.pop("LMSTUDIO_API_TOKEN", None)
40+
auth_message = client_cls._create_auth_from_token(None)
41+
assert auth_message["authVersion"] == 1
42+
assert auth_message["clientIdentifier"].startswith("guest:")
43+
client_key = auth_message["clientPasskey"]
44+
assert client_key != ""
45+
assert isinstance(client_key, str)
46+
47+
48+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
49+
def test_auth_message_empty_token(client_cls: ClientBase) -> None:
50+
with mock.patch.dict(os.environ) as env:
51+
# Set a valid token in the env to ensure it is ignored
52+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
53+
auth_message = client_cls._create_auth_from_token("")
54+
assert auth_message["authVersion"] == 1
55+
assert auth_message["clientIdentifier"].startswith("guest:")
56+
client_key = auth_message["clientPasskey"]
57+
assert client_key != ""
58+
assert isinstance(client_key, str)
59+
60+
61+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
62+
def test_auth_message_empty_token_from_env(client_cls: ClientBase) -> None:
63+
with mock.patch.dict(os.environ) as env:
64+
env["LMSTUDIO_API_TOKEN"] = ""
65+
auth_message = client_cls._create_auth_from_token(None)
66+
assert auth_message["authVersion"] == 1
67+
assert auth_message["clientIdentifier"].startswith("guest:")
68+
client_key = auth_message["clientPasskey"]
69+
assert client_key != ""
70+
assert isinstance(client_key, str)
71+
72+
73+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
74+
def test_auth_message_valid_token(client_cls: ClientBase) -> None:
75+
auth_message = client_cls._create_auth_from_token(_VALID_API_TOKEN)
76+
assert auth_message["authVersion"] == 1
77+
assert auth_message["clientIdentifier"] == "abcDEF78"
78+
assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890"
79+
80+
81+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
82+
def test_auth_message_valid_token_from_env(client_cls: ClientBase) -> None:
83+
with mock.patch.dict(os.environ) as env:
84+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
85+
auth_message = client_cls._create_auth_from_token(None)
86+
assert auth_message["authVersion"] == 1
87+
assert auth_message["clientIdentifier"] == "abcDEF78"
88+
assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890"
89+
90+
91+
_INVALID_TOKENS = [
92+
"missing-token-prefix",
93+
"sk-lm-missing-id-and-key-separator",
94+
"sk-lm-invalid_id:invalid_key",
95+
"sk-lm-idtoolong:abcDEF7890abcDEF7890",
96+
"sk-lm-abcDEF78:keytooshort",
97+
]
98+
99+
100+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
101+
@pytest.mark.parametrize("api_token", _INVALID_TOKENS)
102+
def test_auth_message_invalid_token(client_cls: ClientBase, api_token: str) -> None:
103+
with mock.patch.dict(os.environ) as env:
104+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
105+
with pytest.raises(LMStudioValueError):
106+
client_cls._create_auth_from_token(api_token)
107+
108+
109+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
110+
@pytest.mark.parametrize("api_token", _INVALID_TOKENS)
111+
def test_auth_message_invalid_token_from_env(
112+
client_cls: ClientBase, api_token: str
113+
) -> None:
114+
with mock.patch.dict(os.environ) as env:
115+
env["LMSTUDIO_API_TOKEN"] = api_token
116+
with pytest.raises(LMStudioValueError):
117+
client_cls._create_auth_from_token(None)
118+
27119

28120
async def check_connected_async_session(session: _AsyncSession) -> None:
29121
assert session.connected
@@ -160,7 +252,7 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None:
160252
async def test_websocket_cm_async(caplog: LogCap) -> None:
161253
caplog.set_level(logging.DEBUG)
162254
api_host = await AsyncClient.find_default_local_api_host()
163-
auth_details = AsyncClient._format_auth_message()
255+
auth_details = AsyncClient._create_auth_from_token(None)
164256
tm = AsyncTaskManager(on_activation=None)
165257
lmsws = _AsyncLMStudioWebsocket(tm, f"http://{api_host}/system", auth_details)
166258
# SDK client websockets start out disconnected
@@ -200,7 +292,7 @@ def ws_thread() -> Generator[AsyncWebsocketThread, None, None]:
200292
def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None:
201293
caplog.set_level(logging.DEBUG)
202294
api_host = Client.find_default_local_api_host()
203-
auth_details = Client._format_auth_message()
295+
auth_details = Client._create_auth_from_token(None)
204296
lmsws = SyncLMStudioWebsocket(ws_thread, f"http://{api_host}/system", auth_details)
205297
# SDK client websockets start out disconnected
206298
assert not lmsws.connected

0 commit comments

Comments
 (0)