Skip to content

Commit 95d4753

Browse files
feat: add normalization contract and dispatcher gating (#1047)
* feat: add normalization result contract and dispatcher gate Introduce a structured NormalizationResult model and enforce dispatcher gating so non-ok normalization status returns no assessment before backend queries. Attach normalization payload to QueryInput for backend consumption and add model/dispatcher unit tests for OK and conflict paths. * refactor: simplify normalization payload contract Reduce NormalizationResult to a minimal backend-facing structure (original text, name, acronym, ISSN/eISSN, aliases, input identifiers) and keep normalization assessment/gating internal to dispatcher. This removes confidence-like normalization evidence from the shared payload while preserving strict no-assessment behavior on normalization failures. --------- Co-authored-by: florath-ai-assistant[bot] <Andreas.Florath@telekom.de>
1 parent 0e50ea6 commit 95d4753

File tree

4 files changed

+252
-3
lines changed

4 files changed

+252
-3
lines changed

src/aletheia_probe/dispatcher.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919
from .enums import AssessmentType, EvidenceType
2020
from .fallback_chain import QueryFallbackChain
2121
from .logging_config import get_detail_logger, get_status_logger
22-
from .models import AssessmentResult, BackendResult, BackendStatus, QueryInput
22+
from .lookup import VenueLookupService
23+
from .models import (
24+
AssessmentResult,
25+
BackendResult,
26+
BackendStatus,
27+
NormalizationResult,
28+
QueryInput,
29+
VenueType,
30+
)
2331
from .normalizer import InputNormalizer, input_normalizer
2432
from .openalex import OpenAlexClient
2533
from .quality_assessment import QualityAssessmentProcessor
@@ -84,6 +92,7 @@ def __init__(self) -> None:
8492
self.cross_validation_registry = get_cross_validation_registry()
8593
self.quality_processor = QualityAssessmentProcessor()
8694
self.journal_cache = JournalCache()
95+
self.lookup_service = VenueLookupService(journal_cache=self.journal_cache)
8796

8897
async def assess_journal(self, query_input: QueryInput) -> AssessmentResult:
8998
"""Assess a journal using all enabled backends.
@@ -95,6 +104,18 @@ async def assess_journal(self, query_input: QueryInput) -> AssessmentResult:
95104
AssessmentResult with aggregated assessment from all backends
96105
"""
97106
start_time = time.time()
107+
(
108+
normalization_result,
109+
normalization_failure,
110+
) = await self._normalize_for_dispatch(query_input)
111+
query_input = self._attach_normalization_to_query(
112+
query_input, normalization_result
113+
)
114+
if normalization_failure:
115+
return self._build_normalization_blocked_result(
116+
query_input, normalization_failure, start_time
117+
)
118+
98119
query_input = await self._enrich_query_identifiers(query_input)
99120

100121
# Get enabled backends from registry
@@ -138,6 +159,110 @@ async def assess_journal(self, query_input: QueryInput) -> AssessmentResult:
138159
assessment_result, query_input, enabled_backends, start_time
139160
)
140161

162+
async def _normalize_for_dispatch(
163+
self, query_input: QueryInput
164+
) -> tuple[NormalizationResult, str | None]:
165+
"""Build minimal normalization payload and evaluate gating failures."""
166+
requested_venue_type = (
167+
query_input.venue_type
168+
if query_input.venue_type != VenueType.UNKNOWN
169+
else VenueType.JOURNAL
170+
)
171+
172+
lookup_result = self.lookup_service.lookup(
173+
query_input.raw_input,
174+
venue_type=requested_venue_type,
175+
confidence_min=DEFAULT_ACRONYM_CONFIDENCE_MIN,
176+
)
177+
primary_name = (
178+
lookup_result.normalized_name or query_input.normalized_name or ""
179+
).strip() or None
180+
selected_issn = query_input.identifiers.get("issn") or (
181+
lookup_result.issns[0] if lookup_result.issns else None
182+
)
183+
selected_eissn = query_input.identifiers.get("eissn") or (
184+
lookup_result.eissns[0] if lookup_result.eissns else None
185+
)
186+
187+
consistency_errors = list(lookup_result.consistency_errors)
188+
failure_reason: str | None = None
189+
if not primary_name and not (selected_issn or selected_eissn):
190+
failure_reason = "Normalization did not resolve a name or identifier"
191+
192+
input_ids = {value for value in query_input.identifiers.values() if value}
193+
if primary_name and input_ids:
194+
resolved_ids: set[str] = set()
195+
for candidate in lookup_result.candidates:
196+
if candidate.normalized_name != primary_name:
197+
continue
198+
if candidate.issn:
199+
resolved_ids.add(candidate.issn)
200+
if candidate.eissn:
201+
resolved_ids.add(candidate.eissn)
202+
203+
if resolved_ids and input_ids.isdisjoint(resolved_ids):
204+
consistency_errors.append(
205+
"Input mismatch: provided identifier(s) "
206+
f"{sorted(input_ids)} do not match '{primary_name}' "
207+
f"(resolved identifiers: {sorted(resolved_ids)})"
208+
)
209+
210+
if consistency_errors:
211+
failure_reason = "; ".join(sorted(set(consistency_errors)))
212+
213+
normalization_result = NormalizationResult(
214+
original_text=lookup_result.raw_input,
215+
venue_type=requested_venue_type,
216+
name=primary_name,
217+
acronym=query_input.acronym_expanded_from,
218+
issn=selected_issn,
219+
eissn=selected_eissn,
220+
aliases=lookup_result.aliases,
221+
input_identifiers=dict(query_input.identifiers),
222+
)
223+
return normalization_result, failure_reason
224+
225+
def _attach_normalization_to_query(
226+
self, query_input: QueryInput, normalization_result: NormalizationResult
227+
) -> QueryInput:
228+
"""Attach normalization payload and selected fields to query input."""
229+
normalized_name = normalization_result.name or query_input.normalized_name
230+
merged_identifiers = dict(query_input.identifiers)
231+
if normalization_result.issn:
232+
merged_identifiers.setdefault("issn", normalization_result.issn)
233+
if normalization_result.eissn:
234+
merged_identifiers.setdefault("eissn", normalization_result.eissn)
235+
return query_input.model_copy(
236+
update={
237+
"normalized_name": normalized_name,
238+
"identifiers": merged_identifiers,
239+
"normalization_result": normalization_result,
240+
}
241+
)
242+
243+
def _build_normalization_blocked_result(
244+
self,
245+
query_input: QueryInput,
246+
failure_reason: str,
247+
start_time: float,
248+
) -> AssessmentResult:
249+
"""Build a no-assessment result when normalization gate is not OK."""
250+
reason = failure_reason or "Normalization failed; no assessment possible"
251+
self.status_logger.warning(f"Normalization blocked assessment: {reason}")
252+
return AssessmentResult(
253+
input_query=query_input.raw_input,
254+
assessment=AssessmentType.INSUFFICIENT_DATA,
255+
confidence=0.0,
256+
overall_score=0.0,
257+
backend_results=[],
258+
metadata=None,
259+
reasoning=[reason],
260+
processing_time=time.time() - start_time,
261+
acronym_expanded_from=query_input.acronym_expanded_from,
262+
acronym_expansion_used=bool(query_input.acronym_expanded_from),
263+
venue_type=query_input.venue_type,
264+
)
265+
141266
async def _enrich_query_identifiers(self, query_input: QueryInput) -> QueryInput:
142267
"""Enrich query identifiers with reliable ISSN/eISSN from cache/API."""
143268
if query_input.identifiers.get("issn") or query_input.identifiers.get("eissn"):

src/aletheia_probe/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,23 @@ class BackendStatus(str, Enum):
4949
TIMEOUT = "timeout"
5050

5151

52+
class NormalizationResult(BaseModel):
53+
"""Minimal normalization payload passed to backends."""
54+
55+
original_text: str = Field(..., description="Original query input string")
56+
name: str | None = Field(None, description="Normalized venue name")
57+
acronym: str | None = Field(
58+
None, description="Detected or expanded acronym, if available"
59+
)
60+
issn: str | None = Field(None, description="Resolved print ISSN")
61+
eissn: str | None = Field(None, description="Resolved electronic ISSN")
62+
venue_type: VenueType = Field(..., description="Requested/detected venue type")
63+
aliases: list[str] = Field(default_factory=list, description="Known aliases")
64+
input_identifiers: dict[str, str] = Field(
65+
default_factory=dict, description="Identifiers extracted directly from input"
66+
)
67+
68+
5269
class QueryInput(BaseModel):
5370
"""Input query data for journal assessment."""
5471

@@ -68,6 +85,10 @@ class QueryInput(BaseModel):
6885
default_factory=dict,
6986
description="Acronym to full name mappings extracted during normalization",
7087
)
88+
normalization_result: NormalizationResult | None = Field(
89+
None,
90+
description="Structured normalization payload passed to backends",
91+
)
7192

7293

7394
class BackendResult(BaseModel):

tests/unit/test_dispatcher.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
AssessmentResult,
1515
BackendResult,
1616
BackendStatus,
17+
NormalizationResult,
1718
QueryInput,
19+
VenueType,
1820
)
1921

2022

@@ -96,8 +98,26 @@ async def test_assess_journal_basic_flow(
9698
self, dispatcher, sample_query_input, mock_backend
9799
):
98100
"""Test basic journal assessment flow."""
99-
with patch.object(
100-
dispatcher, "_get_enabled_backends", return_value=[mock_backend]
101+
with (
102+
patch.object(
103+
dispatcher, "_get_enabled_backends", return_value=[mock_backend]
104+
),
105+
patch.object(
106+
dispatcher,
107+
"_normalize_for_dispatch",
108+
AsyncMock(
109+
return_value=(
110+
NormalizationResult(
111+
original_text=sample_query_input.raw_input,
112+
venue_type=VenueType.JOURNAL,
113+
name="journal of advanced computer science",
114+
issn="1234-5679",
115+
input_identifiers={"issn": "1234-5679"},
116+
),
117+
None,
118+
)
119+
),
120+
),
101121
):
102122
result = await dispatcher.assess_journal(sample_query_input)
103123

@@ -108,6 +128,36 @@ async def test_assess_journal_basic_flow(
108128
assert result.processing_time > 0
109129
assert len(result.backend_results) == 1
110130

131+
@pytest.mark.asyncio
132+
async def test_assess_journal_blocks_on_normalization_conflict(
133+
self, dispatcher, sample_query_input, mock_backend
134+
):
135+
"""Do not query backends when normalization status is conflict."""
136+
conflict_result = NormalizationResult(
137+
original_text=sample_query_input.raw_input,
138+
venue_type=VenueType.JOURNAL,
139+
name="journal of advanced computer science",
140+
issn="1234-5679",
141+
input_identifiers={"issn": "1234-5679"},
142+
)
143+
144+
with (
145+
patch.object(
146+
dispatcher, "_get_enabled_backends", return_value=[mock_backend]
147+
),
148+
patch.object(
149+
dispatcher,
150+
"_normalize_for_dispatch",
151+
AsyncMock(return_value=(conflict_result, "identifier mismatch")),
152+
),
153+
):
154+
result = await dispatcher.assess_journal(sample_query_input)
155+
156+
assert result.assessment == AssessmentType.INSUFFICIENT_DATA
157+
assert result.backend_results == []
158+
assert any("identifier mismatch" in reason for reason in result.reasoning)
159+
mock_backend.query_with_timeout.assert_not_called()
160+
111161
@pytest.mark.asyncio
112162
async def test_assess_journal_no_backends(self, dispatcher, sample_query_input):
113163
"""Test assessment with no enabled backends."""

tests/unit/test_models.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BibtexEntry,
1717
ConfigBackend,
1818
JournalMetadata,
19+
NormalizationResult,
1920
QueryInput,
2021
VenueType,
2122
)
@@ -45,6 +46,58 @@ def test_create_full_query_input(self):
4546
assert query.identifiers["issn"] == "1234-5679"
4647
assert "Test Science Journal" in query.aliases
4748

49+
def test_query_input_with_normalization_result(self):
50+
"""Test attaching normalization payload to QueryInput."""
51+
normalization_result = NormalizationResult(
52+
original_text="Nature 0028-0836",
53+
venue_type=VenueType.JOURNAL,
54+
name="nature",
55+
issn="0028-0836",
56+
input_identifiers={"issn": "0028-0836"},
57+
)
58+
query = QueryInput(
59+
raw_input="Nature 0028-0836",
60+
normalized_name="nature",
61+
identifiers={"issn": "0028-0836"},
62+
normalization_result=normalization_result,
63+
)
64+
assert query.normalization_result is not None
65+
assert query.normalization_result.name == "nature"
66+
67+
68+
class TestNormalizationResult:
69+
"""Tests for normalization contract models."""
70+
71+
def test_create_normalization_result(self):
72+
"""Create minimal normalization payload for backend consumption."""
73+
result = NormalizationResult(
74+
original_text="Nature",
75+
venue_type=VenueType.JOURNAL,
76+
name="nature",
77+
aliases=[],
78+
acronym=None,
79+
input_identifiers={},
80+
issn="0028-0836",
81+
eissn="1476-4687",
82+
)
83+
assert result.name == "nature"
84+
assert result.issn == "0028-0836"
85+
assert result.eissn == "1476-4687"
86+
87+
def test_create_partial_normalization_result(self):
88+
"""Missing fields should be represented as None in minimal payload."""
89+
result = NormalizationResult(
90+
original_text="Unknown Venue",
91+
venue_type=VenueType.JOURNAL,
92+
name=None,
93+
acronym=None,
94+
issn=None,
95+
eissn=None,
96+
)
97+
assert result.name is None
98+
assert result.issn is None
99+
assert result.eissn is None
100+
48101

49102
class TestBackendResult:
50103
"""Tests for BackendResult model."""

0 commit comments

Comments
 (0)