Skip to content

Commit 6883fe2

Browse files
committed
chore: typecheck (mypy) fixes
Signed-off-by: Amit Oren <amoren@redhat.com>
1 parent ea85415 commit 6883fe2

10 files changed

Lines changed: 71 additions & 43 deletions

File tree

src/neuralnav/api/routes/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,10 @@ async def list_all_deployments():
443443

444444
deployments = []
445445
for deployment_id in deployment_ids:
446-
status = manager.get_inferenceservice_status(deployment_id)
446+
svc_status = manager.get_inferenceservice_status(deployment_id)
447447
pods = manager.get_deployment_pods(deployment_id)
448448

449-
deployments.append({"deployment_id": deployment_id, "status": status, "pods": pods})
449+
deployments.append({"deployment_id": deployment_id, "status": svc_status, "pods": pods})
450450

451451
return {
452452
"success": True,

src/neuralnav/api/routes/recommendation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Recommendation endpoints."""
22

33
import logging
4+
from typing import Literal
45

56
from fastapi import APIRouter, HTTPException, status
67
from pydantic import BaseModel
@@ -49,7 +50,7 @@ class RankedRecommendationFromSpecRequest(BaseModel):
4950
ttft_target_ms: int
5051
itl_target_ms: int
5152
e2e_target_ms: int
52-
percentile: str = "p95" # "mean", "p90", "p95", "p99"
53+
percentile: Literal["mean", "p90", "p95", "p99"] = "p95"
5354

5455
# Ranking options
5556
min_accuracy: int | None = None
@@ -92,7 +93,7 @@ async def simple_recommend(request: SimpleRecommendationRequest):
9293
recommendation=recommendation, namespace="default"
9394
)
9495
deployment_id = yaml_result["deployment_id"]
95-
yaml_files = yaml_result["files"]
96+
yaml_files: dict = yaml_result["files"]
9697
logger.info(
9798
f"Auto-generated YAML files for {deployment_id}: {list(yaml_files.keys())}"
9899
)
@@ -272,7 +273,9 @@ async def test_endpoint(message: str = "I need a chatbot for 1000 users"):
272273
return {
273274
"success": True,
274275
"model": recommendation.model_name,
275-
"gpu_config": f"{recommendation.gpu_config.gpu_count}x {recommendation.gpu_config.gpu_type}",
276+
"gpu_config": f"{recommendation.gpu_config.gpu_count}x {recommendation.gpu_config.gpu_type}"
277+
if recommendation.gpu_config
278+
else "N/A",
276279
"cost_per_month": f"${recommendation.cost_per_month_usd:.2f}",
277280
"meets_slo": recommendation.meets_slo,
278281
"reasoning": recommendation.reasoning,

src/neuralnav/configuration/generator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def generate_deployment_id(self, recommendation: DeploymentRecommendation) -> st
7777
use_case = recommendation.intent.use_case.replace("_", "-")
7878

7979
# Clean model name: remove special chars, keep alphanumeric and hyphens
80-
model_name = recommendation.model_id.split("/")[-1].lower()
80+
model_name = (recommendation.model_id or "unknown").split("/")[-1].lower()
8181
model_name = re.sub(r"[^a-z0-9-]", "-", model_name)
8282
# Remove consecutive hyphens
8383
model_name = re.sub(r"-+", "-", model_name).strip("-")
@@ -120,6 +120,8 @@ def _prepare_template_context(
120120
traffic = recommendation.traffic_profile
121121
slo = recommendation.slo_targets
122122

123+
assert gpu_config is not None, "gpu_config is required for template context"
124+
123125
# Calculate GPU hourly rate from ModelCatalog
124126
gpu_info = self._catalog.get_gpu_type(gpu_config.gpu_type)
125127
gpu_hourly_rate = gpu_info.cost_per_hour_usd if gpu_info else 1.0
@@ -166,7 +168,8 @@ def _prepare_template_context(
166168
# Calculate max_num_seqs based on expected QPS and latency
167169
# Rule of thumb: concurrent requests = QPS × avg_latency_seconds
168170
avg_latency_sec = slo.e2e_p95_target_ms / 1000.0
169-
max_num_seqs = max(32, int(traffic.expected_qps * avg_latency_sec * 1.5))
171+
expected_qps = traffic.expected_qps or 0.0
172+
max_num_seqs = max(32, int(expected_qps * avg_latency_sec * 1.5))
170173

171174
# Max batched tokens (vLLM parameter)
172175
max_num_batched_tokens = max_num_seqs * (traffic.prompt_tokens + traffic.output_tokens)
@@ -228,7 +231,7 @@ def _prepare_template_context(
228231

229232
def generate_all(
230233
self, recommendation: DeploymentRecommendation, namespace: str = "default"
231-
) -> dict[str, str]:
234+
) -> dict[str, Any]:
232235
"""
233236
Generate all deployment YAML files.
234237
@@ -237,7 +240,7 @@ def generate_all(
237240
namespace: Kubernetes namespace
238241
239242
Returns:
240-
Dictionary mapping config type to file path
243+
Dictionary with deployment_id, namespace, files, and metadata
241244
"""
242245
deployment_id = self.generate_deployment_id(recommendation)
243246
context = self._prepare_template_context(recommendation, deployment_id, namespace)

src/neuralnav/knowledge_base/model_catalog.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def __init__(self, data: dict):
5757
self.memory_gb = data["memory_gb"]
5858
self.compute_capability = data["compute_capability"]
5959
self.typical_use_cases = data["typical_use_cases"]
60-
self.cost_per_hour_usd = data["cost_per_hour_usd"] # Base/minimum price
60+
self.cost_per_hour_usd: float = data["cost_per_hour_usd"] # Base/minimum price
6161
# Cloud provider-specific pricing (optional)
62-
self.cost_per_hour_aws = data.get("cost_per_hour_aws")
63-
self.cost_per_hour_gcp = data.get("cost_per_hour_gcp")
64-
self.cost_per_hour_azure = data.get("cost_per_hour_azure")
62+
self.cost_per_hour_aws: float | None = data.get("cost_per_hour_aws")
63+
self.cost_per_hour_gcp: float | None = data.get("cost_per_hour_gcp")
64+
self.cost_per_hour_azure: float | None = data.get("cost_per_hour_azure")
6565
self.availability = data["availability"]
6666
self.notes = data.get("notes", "")
6767

@@ -75,11 +75,11 @@ def get_cost_for_provider(self, provider: str | None = None) -> float:
7575
Returns:
7676
Cost per hour in USD
7777
"""
78-
if provider == "aws" and self.cost_per_hour_aws:
78+
if provider == "aws" and self.cost_per_hour_aws is not None:
7979
return self.cost_per_hour_aws
80-
elif provider == "gcp" and self.cost_per_hour_gcp:
80+
elif provider == "gcp" and self.cost_per_hour_gcp is not None:
8181
return self.cost_per_hour_gcp
82-
elif provider == "azure" and self.cost_per_hour_azure:
82+
elif provider == "azure" and self.cost_per_hour_azure is not None:
8383
return self.cost_per_hour_azure
8484
return self.cost_per_hour_usd
8585

@@ -313,13 +313,25 @@ def get_cost_breakdown(
313313
"hourly_rate_azure": gpu.cost_per_hour_azure,
314314
"cost_per_hour_total": gpu.cost_per_hour_usd * total_gpus,
315315
"cost_per_month_base": gpu.cost_per_hour_usd * total_gpus * hours_per_month,
316-
"cost_per_month_aws": (gpu.cost_per_hour_aws or gpu.cost_per_hour_usd)
316+
"cost_per_month_aws": (
317+
gpu.cost_per_hour_aws
318+
if gpu.cost_per_hour_aws is not None
319+
else gpu.cost_per_hour_usd
320+
)
317321
* total_gpus
318322
* hours_per_month,
319-
"cost_per_month_gcp": (gpu.cost_per_hour_gcp or gpu.cost_per_hour_usd)
323+
"cost_per_month_gcp": (
324+
gpu.cost_per_hour_gcp
325+
if gpu.cost_per_hour_gcp is not None
326+
else gpu.cost_per_hour_usd
327+
)
320328
* total_gpus
321329
* hours_per_month,
322-
"cost_per_month_azure": (gpu.cost_per_hour_azure or gpu.cost_per_hour_usd)
330+
"cost_per_month_azure": (
331+
gpu.cost_per_hour_azure
332+
if gpu.cost_per_hour_azure is not None
333+
else gpu.cost_per_hour_usd
334+
)
323335
* total_gpus
324336
* hours_per_month,
325337
}

src/neuralnav/llm/ollama_client.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Ollama client wrapper for LLM interactions."""
22

3+
from __future__ import annotations
4+
35
import json
46
import logging
57
import os
6-
from typing import Any
8+
from typing import Any, Literal
79

810
try:
911
import ollama
@@ -30,14 +32,15 @@ def __init__(self, model: str | None = None, host: str | None = None):
3032
host: Optional Ollama host URL. Falls back to OLLAMA_HOST env var,
3133
then localhost:11434.
3234
"""
33-
self.model = model or os.getenv("OLLAMA_MODEL", "qwen2.5:7b")
35+
default_model = os.getenv("OLLAMA_MODEL", "qwen2.5:7b")
36+
self.model: str = model if model else default_model
3437
self.host = host or os.getenv("OLLAMA_HOST")
3538

39+
self._client: ollama.Client | None = None
3640
if OLLAMA_AVAILABLE:
3741
client_kwargs = {"host": self.host} if self.host else {}
3842
self._client = ollama.Client(**client_kwargs)
3943
else:
40-
self._client = None
4144
logger.error("Ollama library not installed. Install with: pip install ollama")
4245

4346
def chat(
@@ -57,7 +60,7 @@ def chat(
5760
Returns:
5861
Response dict with 'message' containing 'content'
5962
"""
60-
if not OLLAMA_AVAILABLE:
63+
if not OLLAMA_AVAILABLE or not self._client:
6164
raise RuntimeError("Ollama library not available")
6265

6366
try:
@@ -71,16 +74,13 @@ def chat(
7174
f"[LLM PROMPT] {last_msg.get('content', '')[:500]}..."
7275
) # Log first 500 chars at debug level
7376

74-
kwargs = {
75-
"model": self.model,
76-
"messages": messages,
77-
"options": {"temperature": temperature},
78-
}
79-
80-
if format_json:
81-
kwargs["format"] = "json"
82-
83-
response = self._client.chat(**kwargs)
77+
fmt: Literal["", "json"] = "json" if format_json else ""
78+
response = self._client.chat( # type: ignore[call-overload]
79+
model=self.model,
80+
messages=messages,
81+
format=fmt,
82+
options={"temperature": temperature},
83+
)
8484

8585
# Log the full response
8686
response_content = response.get("message", {}).get("content", "")
@@ -93,7 +93,7 @@ def chat(
9393
logger.info("[LLM RESPONSE CONTENT - END]")
9494
logger.info("=" * 80)
9595

96-
return response
96+
return dict(response)
9797

9898
except Exception as e:
9999
logger.error(f"Error calling Ollama: {e}")
@@ -122,7 +122,7 @@ def generate_completion(
122122

123123
messages = [{"role": "user", "content": prompt}]
124124
response = self.chat(messages, format_json=format_json, temperature=temperature)
125-
return response["message"]["content"]
125+
return str(response["message"]["content"])
126126

127127
def extract_structured_data(
128128
self,
@@ -152,7 +152,8 @@ def extract_structured_data(
152152
)
153153

154154
try:
155-
return json.loads(response_text)
155+
result: dict[str, Any] = json.loads(response_text)
156+
return result
156157
except json.JSONDecodeError as e:
157158
logger.error(f"Failed to parse JSON response: {response_text}")
158159
logger.error(f"JSON error: {e}")

src/neuralnav/llm/prompts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
"""
3636

3737

38-
def build_intent_extraction_prompt(user_message: str, conversation_history: list = None) -> str:
38+
def build_intent_extraction_prompt(
39+
user_message: str, conversation_history: list | None = None
40+
) -> str:
3941
"""
4042
Build prompt for extracting deployment intent from user conversation.
4143

src/neuralnav/llm/prompts_experimental.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
def build_conversational_prompt(
29-
user_message: str, current_understanding: dict, conversation_history: list = None
29+
user_message: str, current_understanding: dict, conversation_history: list | None = None
3030
) -> str:
3131
"""
3232
Build prompt for conversational AI responses.

src/neuralnav/orchestration/workflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,13 @@ def generate_recommendation_from_specs(self, specifications: dict) -> Deployment
205205
all_configs.sort(key=lambda x: x.scores.balanced_score if x.scores else 0, reverse=True)
206206
best_recommendation = all_configs[0]
207207

208+
gpu_cfg = best_recommendation.gpu_config
208209
logger.info(
209210
f"Selected: {best_recommendation.model_name} on "
210-
f"{best_recommendation.gpu_config.gpu_count}x {best_recommendation.gpu_config.gpu_type} "
211+
f"{gpu_cfg.gpu_count}x {gpu_cfg.gpu_type} "
211212
f"(balanced score: {best_recommendation.scores.balanced_score if best_recommendation.scores else 0:.1f})"
213+
if gpu_cfg
214+
else f"Selected: {best_recommendation.model_name} (no GPU config)"
212215
)
213216

214217
# Add top 3 alternatives

src/neuralnav/recommendation/scorer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import re
1919
from pathlib import Path
20+
from typing import Literal
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -97,7 +98,7 @@ def _load_slo_ranges(self) -> dict:
9798
with open(config_path) as f:
9899
data = json.load(f)
99100
logger.debug(f"Loaded SLO ranges from {config_path}")
100-
return data.get("use_case_slo_workload", {})
101+
return dict(data.get("use_case_slo_workload", {}))
101102
except (FileNotFoundError, json.JSONDecodeError) as e:
102103
logger.warning(f"Could not load SLO ranges from {config_path}: {e}")
103104
return {}
@@ -244,9 +245,9 @@ def score_latency(
244245
target_ttft_ms: int,
245246
target_itl_ms: int,
246247
target_e2e_ms: int,
247-
use_case: str = None,
248+
use_case: str | None = None,
248249
near_miss_tolerance: float = 0.0,
249-
) -> tuple[int, str]:
250+
) -> tuple[int, Literal["compliant", "near_miss", "exceeds"]]:
250251
"""
251252
Score latency using CAPPED RANGE SCORING.
252253
@@ -294,6 +295,7 @@ def score_latency(
294295
worst_ratio = max(ratios)
295296

296297
# Determine SLO status using the tolerance passed from config_finder
298+
slo_status: Literal["compliant", "near_miss", "exceeds"]
297299
if worst_ratio <= 1.0:
298300
slo_status = "compliant"
299301
elif worst_ratio <= (1.0 + near_miss_tolerance):

src/neuralnav/shared/schemas/specification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Specification-related schemas for traffic profiles and SLO targets."""
22

3+
from typing import Literal
4+
35
from pydantic import BaseModel, Field
46

57
from .intent import DeploymentIntent
@@ -19,7 +21,7 @@ class SLOTargets(BaseModel):
1921
ttft_p95_target_ms: int = Field(..., description="Time to First Token target (ms)")
2022
itl_p95_target_ms: int = Field(..., description="Inter-Token Latency target (ms/token)")
2123
e2e_p95_target_ms: int = Field(..., description="End-to-end latency target (ms)")
22-
percentile: str = Field(
24+
percentile: Literal["mean", "p90", "p95", "p99"] = Field(
2325
default="p95", description="Percentile for SLO comparison (mean, p90, p95, p99)"
2426
)
2527

0 commit comments

Comments
 (0)