Skip to content

Commit 8d6bfc9

Browse files
authored
Merge pull request #2 from alphagov/feature/dockerize-content-extractor
ACW-19 Content extractor
2 parents 23e0fd6 + ecf17c3 commit 8d6bfc9

7 files changed

Lines changed: 230 additions & 64 deletions

File tree

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ RUN uv sync --no-dev --no-install-project
1717

1818
COPY src/ ./src/
1919
COPY app.py ./app.py
20+
# COPY graph.json ./
2021

2122

2223

src/content_extractor/s3_sequential.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
3-
from typing import List
3+
import json
4+
from typing import List, Dict, Any
45
from collections import defaultdict
56
from .base import BaseQuoteExtractor, Finding, FinalQuoteExtraction, BaseExtractorConfig
67
from src.url.generator import generate_url_fragement, s3_to_govuk_url
@@ -11,6 +12,47 @@
1112
class S3QuoteExtractor(BaseQuoteExtractor):
1213
"""Processes documents sequentially by fetching from S3 and chunking."""
1314

15+
def __init__(self, config: BaseExtractorConfig):
16+
super().__init__(config)
17+
self.url_map: Dict[str, str] = {}
18+
19+
def _fetch_url_map(self, s3_uris: List[str]):
20+
"""
21+
Attempts to fetch sources.json files from the directories of the input files.
22+
Deduplicates potential sources.json locations and merges their mappings.
23+
"""
24+
if not s3_uris:
25+
return
26+
27+
sources_locations = set()
28+
for uri in s3_uris:
29+
if uri in self.url_map:
30+
continue
31+
32+
if "/input/" in uri:
33+
sources_uri = uri.split("/input/")[0] + "/input/sources.json"
34+
else:
35+
sources_uri = "/".join(uri.split("/")[:-1]) + "/sources.json"
36+
sources_locations.add(sources_uri)
37+
38+
for sources_uri in sources_locations:
39+
logger.info(f"Attempting to fetch sources map from {sources_uri}...")
40+
content = self.fetch_s3_content(sources_uri)
41+
if content:
42+
try:
43+
new_map = json.loads(content)
44+
self.url_map.update(new_map)
45+
logger.info(f"Successfully loaded {len(new_map)} mappings from {sources_uri}.")
46+
except Exception as e:
47+
logger.error(f"Failed to parse {sources_uri}: {e}")
48+
else:
49+
logger.warning(f"No sources.json found at {sources_uri}.")
50+
51+
if self.url_map:
52+
logger.info(f"Total URL mappings loaded: {len(self.url_map)}")
53+
else:
54+
logger.warning("No URL mappings loaded. Falling back to derived URLs.")
55+
1456
async def process_document(self, s3_uri: str, keywords: List[str], results_list: list):
1557
"""Processes a single document for a specific set of keywords."""
1658
content = self.fetch_s3_content(s3_uri)
@@ -20,6 +62,8 @@ async def process_document(self, s3_uri: str, keywords: List[str], results_list:
2062
if len(chunks) > 1:
2163
logger.info(f" Split {s3_uri} into {len(chunks)} chunks.")
2264

65+
base_govuk_url = s3_to_govuk_url(s3_uri, self.url_map)
66+
2367
for i, chunk in enumerate(chunks, 1):
2468
prompt = (
2569
f"Keywords: {', '.join(keywords)}\n\n"
@@ -32,24 +76,26 @@ async def process_document(self, s3_uri: str, keywords: List[str], results_list:
3276
"content": q.content,
3377
"keyword_matched": q.keyword_matched,
3478
"source": s3_uri,
35-
"link": generate_url_fragement(s3_to_govuk_url(s3_uri), q.content)
79+
"link": generate_url_fragement(base_govuk_url, q.content)
3680
})
3781
except Exception as e:
3882
logger.error(f" Error in {s3_uri} chunk {i}: {e}")
3983

40-
async def run_mapping(self, doc_to_keywords: dict):
84+
async def run_mapping(self, doc_to_keywords: Dict[str, List[str]]):
4185
"""Processes documents based on a mapping of {s3_uri: [keywords]}."""
4286
raw_findings = []
4387

88+
self._fetch_url_map(list(doc_to_keywords.keys()))
89+
4490
tasks = [
4591
self.process_document(uri, keywords, raw_findings)
4692
for uri, keywords in doc_to_keywords.items()
4793
]
4894
await asyncio.gather(*tasks)
4995
return raw_findings
5096

51-
async def run(self, output_file: str = "outputs/extracted_quotes.json"):
52-
"""Main entry point to run extraction and save results."""
97+
async def run(self):
98+
"""Main entry point to run extraction"""
5399
doc_to_keywords = {uri: self.config.keywords for uri in self.config.s3_documents}
54100
raw_findings = await self.run_mapping(doc_to_keywords)
55101

src/generate_graph.py

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from src.content_extractor.s3_sequential import S3QuoteExtractor
99
from src.content_extractor.base import BaseExtractorConfig
1010
from src.content_extractor.highlighter import highlight_occurrence
11+
from src.models.graph_models import (
12+
GraphInput, GraphOutput, Node, NodeData, Edge, EdgeData, Occurrence, Entity
13+
)
1114

1215
logger = logging.getLogger(__name__)
1316

@@ -17,25 +20,19 @@ def slugify(text: str) -> str:
1720
text = re.sub(r'[^a-z0-9]+', '_', text)
1821
return text.strip('_')
1922

20-
def build_registries(entities: List[Dict[str, Any]]) -> Dict[str, Any]:
21-
"""Parses entities to map s3_uris to keywords and metadata."""
23+
def build_registries(entities: List[Entity]) -> Dict[str, Any]:
24+
"""Parses entities to map s3_uris to keywords and metadata based on structured aliases."""
2225
registry = defaultdict(lambda: {"keywords": set(), "entities": []})
2326

2427
for ent in entities:
25-
props = ent.get("properties", {})
26-
source_urls_raw = props.get("sourceUrls", [])
27-
28-
if isinstance(source_urls_raw, str):
29-
s3_uris = [u.strip() for u in source_urls_raw.split(',')]
30-
else:
31-
s3_uris = source_urls_raw
32-
33-
aliases = ent.get("aliases", [])
34-
35-
for uri in s3_uris:
36-
if not uri: continue
37-
registry[uri]["keywords"].update(aliases)
38-
registry[uri]["entities"].append(ent)
28+
for alias in ent.aliases:
29+
for uri in alias.source_files:
30+
if not uri or not uri.startswith("s3://"):
31+
continue
32+
registry[uri]["keywords"].add(alias.name)
33+
# Ensure each entity is only added once per unique URI
34+
if ent not in registry[uri]["entities"]:
35+
registry[uri]["entities"].append(ent)
3936

4037
return registry
4138

@@ -65,63 +62,64 @@ def map_findings_to_entities(raw_findings: List[Dict[str, Any]], registry: Dict[
6562
uri = finding["source"]
6663
keyword = finding["keyword_matched"]
6764
content = finding["content"]
68-
link = finding["link"] # Use the pre-calculated link from extractor
65+
link = finding["link"]
6966

7067
for ent in registry[uri]["entities"]:
71-
if keyword in ent.get("aliases", []):
72-
occurrence = {
73-
"link": link,
74-
"context": highlight_occurrence(content, keyword)
75-
}
76-
results[ent["canonical_key"]][keyword].append(occurrence)
68+
if any(a.name == keyword for a in ent.aliases):
69+
occurrence = Occurrence(
70+
link=link,
71+
context=highlight_occurrence(content, keyword)
72+
)
73+
results[ent.canonical_key][keyword].append(occurrence)
7774

7875
return results
7976

80-
def build_node_structure(entities: List[Dict[str, Any]], entity_results: Dict[str, Any]) -> Dict[str, Any]:
77+
def build_node_structure(entities: List[Entity], entity_results: Dict[str, Any]) -> GraphOutput:
8178
"""Constructs the final list of nodes and edges."""
8279
nodes, edges = [], []
8380

8481
for ent in entities:
85-
ent_id = ent["canonical_key"]
86-
human_label = ent.get("label") or ent_id.replace("_", " ").title()
87-
nodes.append({"data": {"id": ent_id, "label": human_label, "type": "entity"}})
82+
ent_id = ent.canonical_key
83+
human_label = ent.label or ent_id.replace("_", " ").title()
84+
nodes.append(Node(data=NodeData(id=ent_id, label=human_label, type="entity")))
8885

8986
# Use a dict to accumulate alias nodes by their slugified ID to avoid duplicates
9087
alias_map = {}
9188

92-
for alias in ent.get("aliases", []):
89+
for alias_obj in ent.aliases:
90+
alias = alias_obj.name
9391
occurrences = entity_results[ent_id].get(alias, [])
9492
alias_id = f"{ent_id}__{slugify(alias)}"
9593

9694
if alias_id not in alias_map:
97-
alias_map[alias_id] = {
98-
"id": alias_id,
99-
"label": alias,
100-
"type": "alias",
101-
"occurrences": []
102-
}
95+
alias_map[alias_id] = NodeData(
96+
id=alias_id,
97+
label=alias,
98+
type="alias",
99+
occurrences=[]
100+
)
103101

104102
if occurrences:
105-
alias_map[alias_id]["occurrences"].extend(occurrences)
103+
alias_map[alias_id].occurrences.extend(occurrences)
106104

107105
# Add the deduplicated alias nodes and their edges
108-
for alias_id, alias_data in alias_map.items():
109-
# If no occurrences, remove the empty list from the data
110-
if not alias_data["occurrences"]:
111-
del alias_data["occurrences"]
106+
for alias_id, node_data in alias_map.items():
107+
# If no occurrences, clear the list (Pydantic will handle Optional)
108+
if not node_data.occurrences:
109+
node_data.occurrences = None
112110

113-
nodes.append({"data": alias_data})
111+
nodes.append(Node(data=node_data))
114112

115-
count = len(alias_data.get("occurrences", []))
116-
edges.append({
117-
"data": {
118-
"source": ent_id,
119-
"target": alias_id,
120-
"label": f"Alias ({count})" if count > 0 else "Alias"
121-
}
122-
})
113+
count = len(node_data.occurrences) if node_data.occurrences else 0
114+
edges.append(Edge(
115+
data=EdgeData(
116+
source=ent_id,
117+
target=alias_id,
118+
label=f"Alias ({count})" if count > 0 else "Alias"
119+
)
120+
))
123121

124-
return {"nodes": nodes, "edges": edges}
122+
return GraphOutput(nodes=nodes, edges=edges)
125123

126124
async def generate_graph(input_data: Union[str, Dict[str, Any]], output_path: Optional[str] = None):
127125
"""Main orchestration function. Can take a file path (str) or a dictionary."""
@@ -134,13 +132,21 @@ async def generate_graph(input_data: Union[str, Dict[str, Any]], output_path: Op
134132
else:
135133
graph_data = input_data
136134

137-
entities = graph_data.get("entities", [])
135+
# Validate input
136+
try:
137+
validated_input = GraphInput.model_validate(graph_data)
138+
entities = validated_input.entities
139+
except Exception as e:
140+
logger.error(f"Input validation failed: {e}")
141+
raise
142+
138143
registry = build_registries(entities)
139144

140145
raw_findings = await fetch_extraction_findings(registry)
141146
entity_results = map_findings_to_entities(raw_findings, registry)
142147

143-
cy_json = build_node_structure(entities, entity_results)
148+
cy_graph = build_node_structure(entities, entity_results)
149+
cy_json = cy_graph.model_dump(exclude_none=True)
144150

145151
if output_path:
146152
os.makedirs(os.path.dirname(output_path), exist_ok=True)

src/models/__init__.py

Whitespace-only changes.

src/models/graph_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pydantic import BaseModel, Field, ConfigDict
2+
from typing import List, Optional, Dict, Union, Any, Literal
3+
4+
class Alias(BaseModel):
5+
name: str
6+
source_files: List[str] = Field(default_factory=list)
7+
8+
class Entity(BaseModel):
9+
id: str
10+
canonical_key: str
11+
label: Optional[str] = None
12+
aliases: List[Alias] = Field(default_factory=list)
13+
properties: Dict[str, Any] = Field(default_factory=dict)
14+
type: Optional[str] = None
15+
description: Optional[str] = None
16+
17+
model_config = ConfigDict(extra="allow")
18+
19+
class GraphInput(BaseModel):
20+
entities: List[Entity]
21+
22+
model_config = ConfigDict(extra="allow")
23+
24+
class Occurrence(BaseModel):
25+
link: str
26+
context: str
27+
28+
class NodeData(BaseModel):
29+
id: str
30+
label: str
31+
type: Literal["entity", "alias"]
32+
occurrences: Optional[List[Occurrence]] = None
33+
34+
class Node(BaseModel):
35+
data: NodeData
36+
37+
class EdgeData(BaseModel):
38+
source: str
39+
target: str
40+
label: str
41+
42+
class Edge(BaseModel):
43+
data: EdgeData
44+
45+
class GraphOutput(BaseModel):
46+
nodes: List[Node]
47+
edges: List[Edge]

src/url/generator.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import urllib.parse
2+
from typing import Optional
23

3-
def convert_string_to_url_query_format(text: str):
4-
# For GOV.UK text fragments (#:~:text=), characters like '-' are reserved
5-
# syntax characters and must be percent-encoded.
6-
# Python's urllib.parse.quote never quotes '-', '.', '_', or '~'.
7-
# So we manually encode '-' to ensure it works with text fragments.
4+
def convert_string_to_url_query_format(text: str)-> str:
85
quoted = urllib.parse.quote(text, safe='')
9-
return quoted.replace('-', '%2D')
6+
quoted= (quoted
7+
.replace('-', '%2D')
8+
.replace('.', '%2E')
9+
.replace('~', '%7E')
10+
.replace('_', '%5F'))
11+
return quoted
1012

1113
def generate_url_fragement(base_url: str, content: str):
1214
encoded_content = convert_string_to_url_query_format(content)
1315
url = f"{base_url}#:~:text={encoded_content}"
1416
return url
1517

16-
def s3_to_govuk_url(s3_uri: str) -> str:
17-
"""Derives a GOV.UK URL directly from an S3 URI by stripping the prefix and extension."""
18+
def s3_to_govuk_url(s3_uri: str, url_map: Optional[dict] = None) -> str:
19+
"""Derives a GOV.UK URL from an S3 URI, using url_map if provided, otherwise using fallback logic."""
20+
if url_map and s3_uri in url_map:
21+
return url_map[s3_uri]
22+
1823
if "/input/" in s3_uri:
1924
path = s3_uri.split("/input/")[-1]
2025
else:

0 commit comments

Comments
 (0)