Skip to content

Commit 09e411c

Browse files
Ademola AdefioyeAdemola Adefioye
authored andcommitted
Fix testing
Updating test modules
1 parent 641a382 commit 09e411c

3 files changed

Lines changed: 34 additions & 10 deletions

File tree

src/models/graph_models.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ class GraphOutput(BaseModel):
149149
outliers: List[EntityOutlier] = Field(default_factory=list)
150150

151151
def model_post_init(self, __context: Any) -> None:
152+
if not self.outliers:
153+
return
154+
152155
valid_entity_ids = set()
153156
valid_node_ids = set()
154157

@@ -170,9 +173,3 @@ def model_post_init(self, __context: Any) -> None:
170173
for edge in self.edges
171174
if edge.data.source in valid_node_ids and edge.data.target in valid_node_ids
172175
]
173-
174-
self.relationships = [
175-
rel
176-
for rel in self.relationships
177-
if rel.from_ in valid_entity_ids and rel.to in valid_entity_ids
178-
]

tests/test_graph_loader.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,18 @@ def test_local_fixture_loads_independently_of_s3_path():
5555
def test_graph_output_relationships_contain_required_fields():
5656
data = load_json_file(FIXTURE_PATH)
5757
graph = GraphInput.model_validate(data)
58-
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))
5958

60-
output = build_node_structure(graph.entities, empty_results, graph.relationships)
59+
from src.models.graph_models import Occurrence
60+
61+
mock_occurrence = Occurrence(link="mock", context="mock")
62+
63+
class MockAliasDict(dict):
64+
def get(self, key, default=None):
65+
return [mock_occurrence]
66+
67+
mock_results: Dict[str, Any] = defaultdict(MockAliasDict)
68+
69+
output = build_node_structure(graph.entities, mock_results, graph.relationships)
6170
dumped = output.model_dump(exclude_none=True)
6271

6372
assert len(dumped["relationships"]) > 0
@@ -70,9 +79,18 @@ def test_graph_output_relationships_contain_required_fields():
7079
def test_relationship_edges_reference_existing_nodes():
7180
data = load_json_file(FIXTURE_PATH)
7281
graph = GraphInput.model_validate(data)
73-
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))
7482

75-
output = build_node_structure(graph.entities, empty_results, graph.relationships)
83+
from src.models.graph_models import Occurrence
84+
85+
mock_occurrence = Occurrence(link="mock", context="mock")
86+
87+
class MockAliasDict(dict):
88+
def get(self, key, default=None):
89+
return [mock_occurrence]
90+
91+
mock_results: Dict[str, Any] = defaultdict(MockAliasDict)
92+
93+
output = build_node_structure(graph.entities, mock_results, graph.relationships)
7694
dumped = output.model_dump(exclude_none=True)
7795

7896
node_ids = {n["data"]["id"] for n in dumped["nodes"]}

tests/test_graph_validation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def test_graph_output_validation():
4444
data = {
4545
"nodes": [{"data": {"id": "e1", "label": "Entity 1", "type": "entity"}}],
4646
"edges": [{"data": {"source": "e1", "target": "a1", "label": "Alias"}}],
47+
"outliers": [
48+
{
49+
"entity_id": "e1",
50+
"entity_label": "Entity 1",
51+
"alias_imbalance": [
52+
{"alias_id": "a1", "alias_label": "Alias", "occurrence_count": 1}
53+
],
54+
}
55+
],
4756
}
4857
validated = GraphOutput.model_validate(data)
4958
assert len(validated.nodes) == 1

0 commit comments

Comments
 (0)