Skip to content

Commit 9cabb45

Browse files
committed
fix(neuralnav): tighten DeploymentIntent types and skip extraction when overrides suffice
NeuralNav's LLM hallucinates invalid use_case values (e.g., "text_summarization" instead of "summarization_short"), causing 422 errors from the /api/v1/extract endpoint. Two fixes: - DeploymentIntent now uses Literal types for use_case (9 values), experience_class (5 values), and priority fields (low/medium/high), matching NeuralNav's schema for defense-in-depth validation. - recommend() skips the extract_intent() call when both use_case_override and user_count_override are provided, since those are the only fields consumed from the extracted intent. This avoids the 422 entirely when the caller specifies the use case directly. Signed-off-by: Amit Oren <amoren@redhat.com>
1 parent e113f15 commit 9cabb45

4 files changed

Lines changed: 241 additions & 24 deletions

File tree

src/rhoai_mcp/composites/neuralnav/client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,22 +190,31 @@ def recommend(
190190
) -> RecommendationResult:
191191
"""Run the full recommendation flow.
192192
193-
1. Extract intent from text
193+
1. Extract intent from text (skipped when overrides cover all needed fields)
194194
2. Apply overrides
195195
3. Fetch SLO defaults + workload profile + expected RPS
196196
4. Apply SLO overrides on top of fetched defaults
197197
5. Get ranked recommendations with all constraints
198198
6. Extract top recommendation from each ranking list
199199
"""
200-
# Step 1: Extract intent
201-
intent = self.extract_intent(text)
202-
203-
# Step 2: Apply overrides
204-
use_case = use_case_override if use_case_override is not None else intent.use_case
205-
user_count = user_count_override if user_count_override is not None else intent.user_count
206-
gpu_types = (
207-
gpu_types_override if gpu_types_override is not None else intent.preferred_gpu_types
208-
)
200+
# Step 1: Extract intent (skip when all overrides are provided)
201+
if (
202+
use_case_override is not None
203+
and user_count_override is not None
204+
and gpu_types_override is not None
205+
):
206+
use_case = use_case_override
207+
user_count = user_count_override
208+
gpu_types = gpu_types_override
209+
else:
210+
intent = self.extract_intent(text)
211+
use_case = use_case_override if use_case_override is not None else intent.use_case
212+
user_count = (
213+
user_count_override if user_count_override is not None else intent.user_count
214+
)
215+
gpu_types = (
216+
gpu_types_override if gpu_types_override is not None else intent.preferred_gpu_types
217+
)
209218

210219
# Step 3: Fetch defaults
211220
slo_data = self.get_slo_defaults(use_case)

src/rhoai_mcp/composites/neuralnav/models.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,48 @@
22

33
from __future__ import annotations
44

5-
from typing import Any
5+
from typing import Any, Literal
66

77
from pydantic import BaseModel, Field
88

9+
UseCaseType = Literal[
10+
"chatbot_conversational",
11+
"code_completion",
12+
"code_generation_detailed",
13+
"translation",
14+
"content_generation",
15+
"summarization_short",
16+
"document_analysis_rag",
17+
"long_document_summarization",
18+
"research_legal_analysis",
19+
]
20+
21+
ExperienceClassType = Literal[
22+
"instant",
23+
"conversational",
24+
"interactive",
25+
"deferred",
26+
"batch",
27+
]
28+
29+
PriorityType = Literal["low", "medium", "high"]
30+
931

1032
class DeploymentIntent(BaseModel):
1133
"""Extracted deployment intent from natural language."""
1234

13-
use_case: str = Field(..., description="Primary use case type")
35+
use_case: UseCaseType = Field(..., description="Primary use case type")
1436
user_count: int = Field(..., description="Number of users or scale")
15-
experience_class: str = Field(default="conversational", description="User experience class")
37+
experience_class: ExperienceClassType = Field(
38+
default="conversational", description="User experience class"
39+
)
1640
preferred_gpu_types: list[str] = Field(
1741
default_factory=list, description="Preferred GPU types (empty = any)"
1842
)
19-
accuracy_priority: str = Field(default="medium", description="Accuracy importance")
20-
cost_priority: str = Field(default="medium", description="Cost sensitivity")
21-
latency_priority: str = Field(default="medium", description="Latency importance")
22-
complexity_priority: str = Field(default="medium", description="Simplicity preference")
43+
accuracy_priority: PriorityType = Field(default="medium", description="Accuracy importance")
44+
cost_priority: PriorityType = Field(default="medium", description="Cost sensitivity")
45+
latency_priority: PriorityType = Field(default="medium", description="Latency importance")
46+
complexity_priority: PriorityType = Field(default="medium", description="Simplicity preference")
2347
domain_specialization: list[str] = Field(
2448
default_factory=lambda: ["general"], description="Domain requirements"
2549
)

tests/composites/neuralnav/test_client.py

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,9 @@ def test_recommend_full_flow(self, mock_httpx: MagicMock) -> None:
327327

328328
@patch("rhoai_mcp.composites.neuralnav.client.httpx")
329329
def test_recommend_with_overrides(self, mock_httpx: MagicMock) -> None:
330-
"""Overrides replace extracted intent values."""
330+
"""When both use_case and user_count overrides are provided, extraction is skipped."""
331331
mock_client = MagicMock()
332332

333-
extract_resp = MagicMock()
334-
extract_resp.status_code = 200
335-
extract_resp.json.return_value = SAMPLE_INTENT
336-
extract_resp.raise_for_status = MagicMock()
337-
338333
slo_resp = MagicMock()
339334
slo_resp.status_code = 200
340335
slo_resp.json.return_value = SAMPLE_SLO_DEFAULTS
@@ -355,7 +350,8 @@ def test_recommend_with_overrides(self, mock_httpx: MagicMock) -> None:
355350
ranked_resp.json.return_value = SAMPLE_RANKED_RESPONSE
356351
ranked_resp.raise_for_status = MagicMock()
357352

358-
mock_client.post.side_effect = [extract_resp, ranked_resp]
353+
# Only one POST (ranked-recommend), extraction is skipped
354+
mock_client.post.side_effect = [ranked_resp]
359355
mock_client.get.side_effect = [slo_resp, workload_resp, rps_resp]
360356

361357
mock_httpx.Client.return_value.__enter__ = MagicMock(return_value=mock_client)
@@ -372,6 +368,8 @@ def test_recommend_with_overrides(self, mock_httpx: MagicMock) -> None:
372368
# Verify the overridden use_case was used for SLO defaults fetch
373369
get_calls = mock_client.get.call_args_list
374370
assert "code_completion" in get_calls[0].args[0]
371+
# Extraction was skipped — only one POST call (ranked-recommend)
372+
assert mock_client.post.call_count == 1
375373

376374
@patch("rhoai_mcp.composites.neuralnav.client.httpx")
377375
def test_recommend_api_error(self, mock_httpx: MagicMock) -> None:
@@ -625,6 +623,150 @@ def test_recommend_forwards_constraints(self, mock_httpx: MagicMock) -> None:
625623
assert payload["percentile"] == "p99"
626624

627625

626+
class TestNeuralNavClientRecommendExtractionBypass:
627+
"""Tests for skipping extraction when overrides are sufficient."""
628+
629+
@patch("rhoai_mcp.composites.neuralnav.client.httpx")
630+
def test_recommend_skips_extraction_when_all_overrides_provided(
631+
self, mock_httpx: MagicMock
632+
) -> None:
633+
"""When all overrides are provided, extraction is skipped."""
634+
mock_client = MagicMock()
635+
636+
slo_resp = MagicMock()
637+
slo_resp.status_code = 200
638+
slo_resp.json.return_value = SAMPLE_SLO_DEFAULTS
639+
slo_resp.raise_for_status = MagicMock()
640+
641+
workload_resp = MagicMock()
642+
workload_resp.status_code = 200
643+
workload_resp.json.return_value = SAMPLE_WORKLOAD_PROFILE
644+
workload_resp.raise_for_status = MagicMock()
645+
646+
rps_resp = MagicMock()
647+
rps_resp.status_code = 200
648+
rps_resp.json.return_value = SAMPLE_EXPECTED_RPS
649+
rps_resp.raise_for_status = MagicMock()
650+
651+
ranked_resp = MagicMock()
652+
ranked_resp.status_code = 200
653+
ranked_resp.json.return_value = SAMPLE_RANKED_RESPONSE
654+
ranked_resp.raise_for_status = MagicMock()
655+
656+
# Only one POST call: ranked-recommend (no extract call)
657+
mock_client.post.side_effect = [ranked_resp]
658+
mock_client.get.side_effect = [slo_resp, workload_resp, rps_resp]
659+
660+
mock_httpx.Client.return_value.__enter__ = MagicMock(return_value=mock_client)
661+
mock_httpx.Client.return_value.__exit__ = MagicMock(return_value=False)
662+
663+
client = NeuralNavClient("http://localhost:8000")
664+
result = client.recommend(
665+
"I need a chatbot for 1000 users",
666+
use_case_override="chatbot_conversational",
667+
user_count_override=1000,
668+
gpu_types_override=["A100"],
669+
)
670+
671+
# Only one POST call was made (ranked-recommend, not extract)
672+
assert mock_client.post.call_count == 1
673+
assert result.specification["use_case"] == "chatbot_conversational"
674+
assert result.specification["user_count"] == 1000
675+
676+
@patch("rhoai_mcp.composites.neuralnav.client.httpx")
677+
def test_recommend_still_extracts_when_only_use_case_override(
678+
self, mock_httpx: MagicMock
679+
) -> None:
680+
"""When only use_case override is provided, extraction still runs for user_count."""
681+
mock_client = MagicMock()
682+
683+
extract_resp = MagicMock()
684+
extract_resp.status_code = 200
685+
extract_resp.json.return_value = SAMPLE_INTENT
686+
extract_resp.raise_for_status = MagicMock()
687+
688+
slo_resp = MagicMock()
689+
slo_resp.status_code = 200
690+
slo_resp.json.return_value = SAMPLE_SLO_DEFAULTS
691+
slo_resp.raise_for_status = MagicMock()
692+
693+
workload_resp = MagicMock()
694+
workload_resp.status_code = 200
695+
workload_resp.json.return_value = SAMPLE_WORKLOAD_PROFILE
696+
workload_resp.raise_for_status = MagicMock()
697+
698+
rps_resp = MagicMock()
699+
rps_resp.status_code = 200
700+
rps_resp.json.return_value = SAMPLE_EXPECTED_RPS
701+
rps_resp.raise_for_status = MagicMock()
702+
703+
ranked_resp = MagicMock()
704+
ranked_resp.status_code = 200
705+
ranked_resp.json.return_value = SAMPLE_RANKED_RESPONSE
706+
ranked_resp.raise_for_status = MagicMock()
707+
708+
mock_client.post.side_effect = [extract_resp, ranked_resp]
709+
mock_client.get.side_effect = [slo_resp, workload_resp, rps_resp]
710+
711+
mock_httpx.Client.return_value.__enter__ = MagicMock(return_value=mock_client)
712+
mock_httpx.Client.return_value.__exit__ = MagicMock(return_value=False)
713+
714+
client = NeuralNavClient("http://localhost:8000")
715+
result = client.recommend(
716+
"I need a chatbot",
717+
use_case_override="code_completion",
718+
)
719+
720+
# Two POST calls: extract + ranked-recommend
721+
assert mock_client.post.call_count == 2
722+
# Use case override is applied
723+
assert result.specification["use_case"] == "code_completion"
724+
725+
@patch("rhoai_mcp.composites.neuralnav.client.httpx")
726+
def test_recommend_skips_extraction_uses_gpu_override(self, mock_httpx: MagicMock) -> None:
727+
"""When extraction is skipped, gpu_types_override is used."""
728+
mock_client = MagicMock()
729+
730+
slo_resp = MagicMock()
731+
slo_resp.status_code = 200
732+
slo_resp.json.return_value = SAMPLE_SLO_DEFAULTS
733+
slo_resp.raise_for_status = MagicMock()
734+
735+
workload_resp = MagicMock()
736+
workload_resp.status_code = 200
737+
workload_resp.json.return_value = SAMPLE_WORKLOAD_PROFILE
738+
workload_resp.raise_for_status = MagicMock()
739+
740+
rps_resp = MagicMock()
741+
rps_resp.status_code = 200
742+
rps_resp.json.return_value = SAMPLE_EXPECTED_RPS
743+
rps_resp.raise_for_status = MagicMock()
744+
745+
ranked_resp = MagicMock()
746+
ranked_resp.status_code = 200
747+
ranked_resp.json.return_value = SAMPLE_RANKED_RESPONSE
748+
ranked_resp.raise_for_status = MagicMock()
749+
750+
mock_client.post.side_effect = [ranked_resp]
751+
mock_client.get.side_effect = [slo_resp, workload_resp, rps_resp]
752+
753+
mock_httpx.Client.return_value.__enter__ = MagicMock(return_value=mock_client)
754+
mock_httpx.Client.return_value.__exit__ = MagicMock(return_value=False)
755+
756+
client = NeuralNavClient("http://localhost:8000")
757+
client.recommend(
758+
"I need a chatbot",
759+
use_case_override="chatbot_conversational",
760+
user_count_override=1000,
761+
gpu_types_override=["H100"],
762+
)
763+
764+
# Verify the GPU override was forwarded
765+
ranked_call = mock_client.post.call_args
766+
payload = ranked_call.kwargs.get("json") or ranked_call[1].get("json")
767+
assert payload["preferred_gpu_types"] == ["H100"]
768+
769+
628770
class TestNeuralNavClientRequestErrors:
629771
"""Tests for _request error handling edge cases."""
630772

tests/composites/neuralnav/test_models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
"""Tests for NeuralNav composite models."""
22

3+
from typing import get_args
4+
5+
import pytest
6+
from pydantic import ValidationError
7+
38
from rhoai_mcp.composites.neuralnav.models import (
49
DeploymentConfigResult,
510
DeploymentIntent,
@@ -9,6 +14,7 @@
914
RecommendationScores,
1015
SLOTargets,
1116
TrafficProfile,
17+
UseCaseType,
1218
)
1319

1420

@@ -39,6 +45,42 @@ def test_full_intent(self) -> None:
3945
assert intent.preferred_gpu_types == ["H100", "A100-80"]
4046
assert intent.accuracy_priority == "high"
4147

48+
def test_invalid_use_case_rejected(self) -> None:
49+
"""Invalid use_case values are rejected by Pydantic validation."""
50+
with pytest.raises(ValidationError, match="use_case"):
51+
DeploymentIntent(use_case="summarization", user_count=1000)
52+
53+
def test_invalid_use_case_text_summarization_rejected(self) -> None:
54+
"""LLM-hallucinated 'text_summarization' is rejected."""
55+
with pytest.raises(ValidationError, match="use_case"):
56+
DeploymentIntent(use_case="text_summarization", user_count=1000)
57+
58+
def test_invalid_experience_class_rejected(self) -> None:
59+
"""Invalid experience_class values are rejected."""
60+
with pytest.raises(ValidationError, match="experience_class"):
61+
DeploymentIntent(
62+
use_case="chatbot_conversational",
63+
user_count=1000,
64+
experience_class="realtime",
65+
)
66+
67+
def test_invalid_priority_rejected(self) -> None:
68+
"""Invalid priority values are rejected."""
69+
with pytest.raises(ValidationError, match="accuracy_priority"):
70+
DeploymentIntent(
71+
use_case="chatbot_conversational",
72+
user_count=1000,
73+
accuracy_priority="critical",
74+
)
75+
76+
def test_all_valid_use_cases_accepted(self) -> None:
77+
"""All valid use_case values are accepted."""
78+
valid_use_cases = list(get_args(UseCaseType))
79+
assert len(valid_use_cases) > 0
80+
for uc in valid_use_cases:
81+
intent = DeploymentIntent(use_case=uc, user_count=100)
82+
assert intent.use_case == uc
83+
4284

4385
class TestModelRecommendation:
4486
"""Tests for ModelRecommendation model."""

0 commit comments

Comments
 (0)