Skip to content

Commit 298b6f8

Browse files
committed
test: Enhance testing capabilities with new pytest configurations and fixtures
Signed-off-by: Eden Reich <[email protected]>
1 parent 90171af commit 298b6f8

File tree

6 files changed

+370
-48
lines changed

6 files changed

+370
-48
lines changed

.devcontainer/Dockerfile

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ ENV ZSH_CUSTOM=/home/vscode/.oh-my-zsh/custom \
66
PYLINT_VERSION=3.3.3 \
77
BUILD_VERSION=1.2.2.post1 \
88
TWINE_VERSION=6.0.1 \
9-
TASK_VERSION=v3.41.0
9+
TASK_VERSION=v3.41.0 \
10+
PYTEST_VERSION=8.3.4 \
11+
PYTEST_WATCH_VERSION=4.2.0
1012

1113
RUN apt-get update && \
1214
# Install nodejs and npm
@@ -16,12 +18,17 @@ RUN apt-get update && \
1618
curl -s https://taskfile.dev/install.sh | sh -s -- -b /usr/local/bin ${TASK_VERSION} && \
1719
# Install pipx
1820
python -m pip install --upgrade pip && \
19-
# Install black, isort, pylint using pip
20-
pip install black==${BLACK_VERSION} && \
21-
pip install isort==${ISORT_VERSION} && \
22-
pip install pylint==${PYLINT_VERSION} && \
23-
pip install build==${BUILD_VERSION} && \
24-
pip install twine==${TWINE_VERSION} && \
21+
# Install development tools using pip
22+
pip install black==${BLACK_VERSION} \
23+
isort==${ISORT_VERSION} \
24+
pylint==${PYLINT_VERSION} \
25+
build==${BUILD_VERSION} \
26+
twine==${TWINE_VERSION} \
27+
pytest==${PYTEST_VERSION} \
28+
pytest-watch==${PYTEST_WATCH_VERSION} \
29+
pytest-cov \
30+
pytest-xdist \
31+
debugpy && \
2532
# Clean up
2633
apt-get clean && \
2734
rm -rf /var/lib/apt/lists/*

.devcontainer/devcontainer.json

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
"dev.containers.copyGitConfig": true,
3535
"githubPullRequests.experimental.chat": true,
3636
"githubPullRequests.experimental.notificationsView": true,
37-
"files.insertFinalNewline": true
37+
"files.insertFinalNewline": true,
38+
"python.testing.pytestEnabled": true,
39+
"python.testing.unittestEnabled": false,
40+
"python.testing.nosetestsEnabled": false,
41+
"python.testing.pytestArgs": [
42+
"tests"
43+
]
3844
}
3945
}
4046
},

.devcontainer/launch.json

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,30 @@
22
"version": "0.2.0",
33
"configurations": [
44
{
5-
"name": "Python: Current File",
5+
"name": "Python: Debug Tests",
66
"type": "debugpy",
77
"request": "launch",
8-
"program": "${file}",
9-
"console": "integratedTerminal"
8+
"program": "/usr/local/bin/python",
9+
"args": [
10+
"-v",
11+
"--no-cov",
12+
"tests/"
13+
],
14+
"console": "integratedTerminal",
15+
"justMyCode": false
16+
},
17+
{
18+
"name": "Python: Debug Current Test",
19+
"type": "debugpy",
20+
"request": "launch",
21+
"program": "/usr/local/bin/python",
22+
"args": [
23+
"-v",
24+
"--no-cov",
25+
"${file}"
26+
],
27+
"console": "integratedTerminal",
28+
"justMyCode": false
1029
}
1130
]
1231
}

Taskfile.yml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,22 @@ tasks:
1515
test:
1616
desc: Run tests
1717
cmds:
18-
- pytest tests/
18+
- pytest tests/ -v
19+
20+
test:watch:
21+
desc: Run tests in watch mode
22+
cmds:
23+
- ptw tests/ -- -v
24+
25+
test:coverage:
26+
desc: Run tests with coverage report
27+
cmds:
28+
- pytest tests/ -v --cov=inference_gateway --cov-report=term-missing
29+
30+
test:debug:
31+
desc: Run tests with debugger enabled
32+
cmds:
33+
- pytest tests/ -v --pdb
1934

2035
clean:
2136
desc: Clean up

inference_gateway/client.py

Lines changed: 131 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Generator, Optional
1+
from typing import Generator, Optional, Union, List, Dict, Optional
22
import json
33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import List, Dict, Optional
65
import requests
76

87

@@ -37,36 +36,67 @@ def to_dict(self) -> Dict[str, str]:
3736
@dataclass
3837
class Model:
3938
"""Represents an LLM model"""
39+
4040
name: str
4141

4242

4343
@dataclass
4444
class ProviderModels:
4545
"""Groups models by provider"""
46+
4647
provider: Provider
4748
models: List[Model]
4849

4950

5051
@dataclass
5152
class ResponseTokens:
5253
"""Response tokens structure as defined in the API spec"""
54+
5355
role: str
5456
model: str
5557
content: str
5658

59+
@classmethod
60+
def from_dict(cls, data: dict) -> "ResponseTokens":
61+
"""Create ResponseTokens from dictionary data
62+
63+
Args:
64+
data: Dictionary containing response data
65+
66+
Returns:
67+
ResponseTokens instance
68+
69+
Raises:
70+
TypeError: If data is not a dictionary
71+
ValueError: If required fields are missing
72+
"""
73+
if not isinstance(data, dict):
74+
raise TypeError(f"Expected dict, got {type(data)}")
75+
76+
required = ["role", "model", "content"]
77+
missing = [field for field in required if field not in data]
78+
79+
if missing:
80+
raise ValueError(
81+
f"Missing required arguments: {
82+
', '.join(missing)}"
83+
)
84+
85+
return cls(role=data["role"], model=data["model"], content=data["content"])
86+
5787

5888
@dataclass
5989
class GenerateResponse:
6090
"""Response structure for token generation"""
91+
6192
provider: str
6293
response: ResponseTokens
6394

6495
@classmethod
65-
def from_dict(cls, data: dict) -> 'GenerateResponse':
96+
def from_dict(cls, data: dict) -> "GenerateResponse":
6697
"""Create GenerateResponse from dictionary data"""
6798
return cls(
68-
provider=data.get('provider', ''),
69-
response=ResponseTokens(**data.get('response', {}))
99+
provider=data.get("provider", ""), response=ResponseTokens(**data.get("response", {}))
70100
)
71101

72102

@@ -86,9 +116,79 @@ def list_models(self) -> List[ProviderModels]:
86116
response.raise_for_status()
87117
return response.json()
88118

119+
def _parse_sse_chunk(self, chunk: bytes) -> dict:
120+
"""Parse an SSE message chunk into structured event data
121+
122+
Args:
123+
chunk: Raw SSE message chunk in bytes format
124+
125+
Returns:
126+
dict: Parsed SSE message with event type and data fields
127+
128+
Raises:
129+
json.JSONDecodeError: If chunk format or content is invalid
130+
"""
131+
if not isinstance(chunk, bytes):
132+
raise TypeError(f"Expected bytes, got {type(chunk)}")
133+
134+
try:
135+
decoded = chunk.decode("utf-8")
136+
message = {}
137+
138+
for line in (l.strip() for l in decoded.split("\n") if l.strip()):
139+
if line.startswith("event: "):
140+
message["event"] = line.removeprefix("event: ")
141+
elif line.startswith("data: "):
142+
try:
143+
json_str = line.removeprefix("data: ")
144+
data = json.loads(json_str)
145+
if not isinstance(data, dict):
146+
raise json.JSONDecodeError(
147+
f"Invalid SSE data format - expected object, got: {
148+
json_str}",
149+
json_str,
150+
0,
151+
)
152+
message["data"] = data
153+
except json.JSONDecodeError as e:
154+
raise json.JSONDecodeError(f"Invalid SSE JSON: {json_str}", e.doc, e.pos)
155+
156+
if not message.get("data"):
157+
raise json.JSONDecodeError(
158+
f"Missing or invalid data field in SSE message: {
159+
decoded}",
160+
decoded,
161+
0,
162+
)
163+
164+
return message
165+
166+
except UnicodeDecodeError as e:
167+
raise json.JSONDecodeError(
168+
f"Invalid UTF-8 encoding in SSE chunk: {
169+
chunk!r}",
170+
str(chunk),
171+
0,
172+
)
173+
174+
def _parse_json_line(self, line: bytes) -> ResponseTokens:
175+
"""Parse a single JSON line into GenerateResponse"""
176+
try:
177+
decoded_line = line.decode("utf-8")
178+
data = json.loads(decoded_line)
179+
return ResponseTokens.from_dict(data)
180+
except UnicodeDecodeError as e:
181+
raise json.JSONDecodeError(f"Invalid UTF-8 encoding: {line}", str(line), 0)
182+
except json.JSONDecodeError as e:
183+
raise json.JSONDecodeError(
184+
f"Invalid JSON response: {
185+
decoded_line}",
186+
e.doc,
187+
e.pos,
188+
)
189+
89190
def generate_content(self, provider: Provider, model: str, messages: List[Message]) -> Dict:
90-
payload = {"model": model, "messages": [
91-
msg.to_dict() for msg in messages]}
191+
payload = {"model": model, "messages": [msg.to_dict() for msg in messages]}
92192

93193
response = self.session.post(
94194
f"{self.base_url}/llms/{provider.value}/generate", json=payload
@@ -97,12 +197,8 @@ def generate_content(self, provider: Provider, model: str, messages: List[Messag
97197
return response.json()
98198

99199
def generate_content_stream(
100-
self,
101-
provider: Provider,
102-
model: str,
103-
messages: List[Message],
104-
use_sse: bool = False
105-
) -> Generator[Union[GenerateResponse, dict], None, None]:
200+
self, provider: Provider, model: str, messages: List[Message], use_sse: bool = False
201+
) -> Generator[Union[ResponseTokens, dict], None, None]:
106202
"""Stream content generation from the model
107203
108204
Args:
@@ -112,33 +208,37 @@ def generate_content_stream(
112208
use_sse: Whether to use Server-Sent Events format
113209
114210
Yields:
115-
Either GenerateResponse objects (for raw JSON) or dicts (for SSE)
211+
Either ResponseTokens objects (for raw JSON) or dicts (for SSE)
116212
"""
117213
payload = {
118214
"model": model,
119215
"messages": [msg.to_dict() for msg in messages],
120216
"stream": True,
121-
"ssevents": use_sse
217+
"ssevents": use_sse,
122218
}
123219

124-
with self.session.post(
125-
f"{self.base_url}/llms/{provider.value}/generate",
126-
json=payload,
127-
stream=True
128-
) as response:
129-
response.raise_for_status()
220+
response = self.session.post(
221+
f"{self.base_url}/llms/{provider.value}/generate", json=payload, stream=True
222+
)
223+
response.raise_for_status()
224+
225+
if use_sse:
226+
buffer = []
130227

131228
for line in response.iter_lines():
132-
if line:
133-
if use_sse and line.startswith(b'data: '):
134-
# Handle SSE format
135-
data = json.loads(line.decode(
136-
'utf-8').replace('data: ', ''))
137-
yield data
138-
else:
139-
# Handle raw JSON format
140-
data = json.loads(line)
141-
yield GenerateResponse.from_dict(data)
229+
if not line:
230+
if buffer:
231+
chunk = b"\n".join(buffer)
232+
yield self._parse_sse_chunk(chunk)
233+
buffer = []
234+
continue
235+
236+
buffer.append(line)
237+
else:
238+
for line in response.iter_lines():
239+
if not line:
240+
continue
241+
yield self._parse_json_line(line)
142242

143243
def health_check(self) -> bool:
144244
"""Check if the API is healthy"""

0 commit comments

Comments
 (0)