Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/models/graph_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -19,6 +20,34 @@ class Entity(BaseModel):

model_config = ConfigDict(extra="allow")

def model_post_init(self, __context: Any) -> None:
if self.label:
new_alias_name = " ".join(
word.lower() for word in re.split(r"(?=[A-Z])", self.label) if word
Comment thread
rjsteixeira marked this conversation as resolved.
)

existing_alias = next((a for a in self.aliases if a.name == new_alias_name), None)

if not existing_alias:
source_urls_val = self.properties.get("sourceUrls", "")
if isinstance(source_urls_val, list):
new_urls = [
url.strip()
for url in source_urls_val
if isinstance(url, str) and url.strip()
]
elif isinstance(source_urls_val, str):
new_urls = (
[url.strip() for url in source_urls_val.split(",") if url.strip()]
if source_urls_val
else []
)
else:
new_urls = []

new_alias = Alias(name=new_alias_name, source_files=new_urls)
self.aliases.append(new_alias)


class Relationship(BaseModel):
type: str
Expand Down Expand Up @@ -95,6 +124,9 @@ def model_post_init(self, __context: Any) -> None:
"""
import statistics

self.alias_imbalance = [stat for stat in self.alias_imbalance if stat.occurrence_count > 0]
self.aliases = [alias for alias in self.aliases if alias.occurrence_count > 0]

counts = [stat.occurrence_count for stat in self.alias_imbalance]
if len(counts) > 1:
mean = statistics.mean(counts)
Expand All @@ -115,3 +147,29 @@ class GraphOutput(BaseModel):
edges: List[Edge]
relationships: List[Relationship] = Field(default_factory=list)
outliers: List[EntityOutlier] = Field(default_factory=list)

def model_post_init(self, __context: Any) -> None:
if not self.outliers:
return

valid_entity_ids = set()
valid_node_ids = set()

for outlier in self.outliers:
if len(outlier.alias_imbalance) > 0:
valid_entity_ids.add(outlier.entity_id)
valid_node_ids.add(outlier.entity_id)
for alias in outlier.aliases:
valid_node_ids.add(alias.id)

self.outliers = [
outlier for outlier in self.outliers if outlier.entity_id in valid_entity_ids
]

self.nodes = [node for node in self.nodes if node.data.id in valid_node_ids]

self.edges = [
edge
for edge in self.edges
if edge.data.source in valid_node_ids and edge.data.target in valid_node_ids
]
26 changes: 22 additions & 4 deletions tests/test_graph_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ def test_local_fixture_loads_independently_of_s3_path():
def test_graph_output_relationships_contain_required_fields():
data = load_json_file(FIXTURE_PATH)
graph = GraphInput.model_validate(data)
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))

output = build_node_structure(graph.entities, empty_results, graph.relationships)
from src.models.graph_models import Occurrence

mock_occurrence = Occurrence(link="mock", context="mock")

class MockAliasDict(dict):
def get(self, key, default=None):
return [mock_occurrence]

mock_results: Dict[str, Any] = defaultdict(MockAliasDict)

output = build_node_structure(graph.entities, mock_results, graph.relationships)
dumped = output.model_dump(exclude_none=True)

assert len(dumped["relationships"]) > 0
Expand All @@ -70,9 +79,18 @@ def test_graph_output_relationships_contain_required_fields():
def test_relationship_edges_reference_existing_nodes():
data = load_json_file(FIXTURE_PATH)
graph = GraphInput.model_validate(data)
empty_results: Dict[str, Any] = defaultdict(lambda: defaultdict(list))

output = build_node_structure(graph.entities, empty_results, graph.relationships)
from src.models.graph_models import Occurrence

mock_occurrence = Occurrence(link="mock", context="mock")

class MockAliasDict(dict):
def get(self, key, default=None):
return [mock_occurrence]

mock_results: Dict[str, Any] = defaultdict(MockAliasDict)

output = build_node_structure(graph.entities, mock_results, graph.relationships)
dumped = output.model_dump(exclude_none=True)

node_ids = {n["data"]["id"] for n in dumped["nodes"]}
Expand Down
9 changes: 9 additions & 0 deletions tests/test_graph_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_graph_output_validation():
data = {
"nodes": [{"data": {"id": "e1", "label": "Entity 1", "type": "entity"}}],
"edges": [{"data": {"source": "e1", "target": "a1", "label": "Alias"}}],
"outliers": [
{
"entity_id": "e1",
"entity_label": "Entity 1",
"alias_imbalance": [
{"alias_id": "a1", "alias_label": "Alias", "occurrence_count": 1}
],
}
],
}
validated = GraphOutput.model_validate(data)
assert len(validated.nodes) == 1
Expand Down
Loading