Skip to content

Commit 8301b5d

Browse files
authored
Merge pull request #51 from alphagov/ACW-88/excluding-aliases-with-zero-occurences
Acw 88/excluding aliases with zero occurences
2 parents 6e15ba7 + 09e411c commit 8301b5d

3 files changed

Lines changed: 60 additions & 4 deletions

File tree

src/models/graph_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def model_post_init(self, __context: Any) -> None:
124124
"""
125125
import statistics
126126

127+
self.alias_imbalance = [stat for stat in self.alias_imbalance if stat.occurrence_count > 0]
128+
self.aliases = [alias for alias in self.aliases if alias.occurrence_count > 0]
129+
127130
counts = [stat.occurrence_count for stat in self.alias_imbalance]
128131
if len(counts) > 1:
129132
mean = statistics.mean(counts)
@@ -144,3 +147,29 @@ class GraphOutput(BaseModel):
144147
edges: List[Edge]
145148
relationships: List[Relationship] = Field(default_factory=list)
146149
outliers: List[EntityOutlier] = Field(default_factory=list)
150+
151+
def model_post_init(self, __context: Any) -> None:
152+
if not self.outliers:
153+
return
154+
155+
valid_entity_ids = set()
156+
valid_node_ids = set()
157+
158+
for outlier in self.outliers:
159+
if len(outlier.alias_imbalance) > 0:
160+
valid_entity_ids.add(outlier.entity_id)
161+
valid_node_ids.add(outlier.entity_id)
162+
for alias in outlier.aliases:
163+
valid_node_ids.add(alias.id)
164+
165+
self.outliers = [
166+
outlier for outlier in self.outliers if outlier.entity_id in valid_entity_ids
167+
]
168+
169+
self.nodes = [node for node in self.nodes if node.data.id in valid_node_ids]
170+
171+
self.edges = [
172+
edge
173+
for edge in self.edges
174+
if edge.data.source in valid_node_ids and edge.data.target in valid_node_ids
175+
]

tests/test_graph_loader.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,18 @@ def test_local_fixture_loads_independently_of_s3_path():
5555
def test_graph_output_relationships_contain_required_fields():
5656
data = load_json_file(FIXTURE_PATH)
5757
graph = GraphInput.model_validate(data)
58-
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))
5958

60-
output = build_node_structure(graph.entities, empty_results, graph.relationships)
59+
from src.models.graph_models import Occurrence
60+
61+
mock_occurrence = Occurrence(link="mock", context="mock")
62+
63+
class MockAliasDict(dict):
64+
def get(self, key, default=None):
65+
return [mock_occurrence]
66+
67+
mock_results: Dict[str, Any] = defaultdict(MockAliasDict)
68+
69+
output = build_node_structure(graph.entities, mock_results, graph.relationships)
6170
dumped = output.model_dump(exclude_none=True)
6271

6372
assert len(dumped["relationships"]) > 0
@@ -70,9 +79,18 @@ def test_graph_output_relationships_contain_required_fields():
7079
def test_relationship_edges_reference_existing_nodes():
7180
data = load_json_file(FIXTURE_PATH)
7281
graph = GraphInput.model_validate(data)
73-
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))
7482

75-
output = build_node_structure(graph.entities, empty_results, graph.relationships)
83+
from src.models.graph_models import Occurrence
84+
85+
mock_occurrence = Occurrence(link="mock", context="mock")
86+
87+
class MockAliasDict(dict):
88+
def get(self, key, default=None):
89+
return [mock_occurrence]
90+
91+
mock_results: Dict[str, Any] = defaultdict(MockAliasDict)
92+
93+
output = build_node_structure(graph.entities, mock_results, graph.relationships)
7694
dumped = output.model_dump(exclude_none=True)
7795

7896
node_ids = {n["data"]["id"] for n in dumped["nodes"]}

tests/test_graph_validation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def test_graph_output_validation():
4444
data = {
4545
"nodes": [{"data": {"id": "e1", "label": "Entity 1", "type": "entity"}}],
4646
"edges": [{"data": {"source": "e1", "target": "a1", "label": "Alias"}}],
47+
"outliers": [
48+
{
49+
"entity_id": "e1",
50+
"entity_label": "Entity 1",
51+
"alias_imbalance": [
52+
{"alias_id": "a1", "alias_label": "Alias", "occurrence_count": 1}
53+
],
54+
}
55+
],
4756
}
4857
validated = GraphOutput.model_validate(data)
4958
assert len(validated.nodes) == 1

0 commit comments

Comments
 (0)