Skip to content

Commit 1a8c3b7

Browse files
authored
Merge pull request #17 from alphagov/ACW-43/opensearch-integration
OpenSearch Integration
2 parents 0cdc628 + b00c261 commit 1a8c3b7

9 files changed

Lines changed: 587 additions & 116 deletions

File tree

app.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
import asyncio
22
import logging
33
import os
4-
import time
54

65
from asgiref.wsgi import WsgiToAsgi
76
from dotenv import load_dotenv
87
from flask import Flask, jsonify, render_template, request
98
from werkzeug.exceptions import BadRequest
109

1110
from src.utils import (
12-
background_run_extraction,
13-
get_active_job_status,
14-
get_job_id_for_path,
1511
read_job_status,
1612
resume_interrupted_jobs,
17-
update_job_status,
13+
start_extraction_job,
1814
)
19-
from src.visualiser_graph_generator import generate_output_path
2015
from src.visualiser_graph_loader import (
2116
extract_path_parts,
2217
load_json_file,
@@ -75,54 +70,21 @@ async def extract_quotes():
7570
"""
7671
Endpoint that runs the Cytoscape graph generation logic based on graph.json.
7772
"""
78-
try:
79-
source_path = request.args.get("source_path")
80-
if not source_path:
81-
return jsonify({"error": "Missing 'source_path' query parameter"}), 400
82-
83-
input_path, output_path = generate_output_path(source_path)
84-
job_id = get_job_id_for_path(source_path)
85-
86-
active_status = get_active_job_status(job_id)
87-
if active_status:
88-
logger.info(
89-
f"Duplicate request for {source_path}. Job {job_id} is already in progress."
90-
)
91-
return jsonify(
92-
{
93-
"job_id": job_id,
94-
"status": "already_running",
95-
"message": (
96-
f"A graph generation job is already in progress for {source_path}"
97-
),
98-
"output_path": output_path,
99-
}
100-
), 202
101-
102-
initial_status = {
103-
"job_id": job_id,
104-
"status": "pending",
105-
"source_path": source_path,
106-
"created_at": time.time(),
107-
}
108-
update_job_status(job_id, initial_status)
109-
110-
asyncio.create_task(
111-
background_run_extraction(job_id, input_path, output_path, initial_status)
112-
)
113-
114-
return jsonify(
115-
{
116-
"job_id": job_id,
117-
"status": "accepted",
118-
"message": f"Graph generation started in background for {source_path}",
119-
"output_path": output_path,
120-
}
121-
), 202
73+
source_path = request.args.get("source_path") or ""
74+
data, status_code = await start_extraction_job(source_path, extractor_type="s3")
75+
return jsonify(data), status_code
12276

123-
except Exception as e:
124-
app.logger.error(f"Error starting background task: {str(e)}")
125-
return jsonify({"error": str(e)}), 500
77+
@app.route("/extract-os", methods=["GET"])
78+
async def extract_quotes_os():
79+
"""
80+
Endpoint that runs the extraction using OpenSearch.
81+
"""
82+
source_path = request.args.get("source_path") or ""
83+
perform_indexing = request.args.get("index", "false").lower() == "true"
84+
data, status_code = await start_extraction_job(
85+
source_path, extractor_type="opensearch", perform_indexing=perform_indexing
86+
)
87+
return jsonify(data), status_code
12688

12789
@app.route("/status/<job_id>", methods=["GET"])
12890
def get_status(job_id):

src/content_extractor/base.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class BaseExtractorConfig:
5757
)
5858
chunk_max_chars: int = 6000
5959
secret_id: Optional[str] = None
60+
aws_profile: Optional[str] = None
61+
aws_access_key_id: Optional[str] = None
62+
aws_secret_access_key: Optional[str] = None
63+
aws_session_token: Optional[str] = None
6064

6165

6266
# --- Base Extractor Class ---
@@ -65,11 +69,23 @@ class BaseExtractorConfig:
6569
class BaseQuoteExtractor:
6670
def __init__(self, config: BaseExtractorConfig):
6771
self.config = config
68-
self.s3_client = boto3.client("s3", region_name=self.config.region)
72+
73+
# Initialize AWS Session
74+
self.session = boto3.Session(
75+
profile_name=self.config.aws_profile,
76+
aws_access_key_id=self.config.aws_access_key_id,
77+
aws_secret_access_key=self.config.aws_secret_access_key,
78+
aws_session_token=self.config.aws_session_token,
79+
region_name=self.config.region,
80+
)
81+
82+
self.s3_client = self.session.client("s3")
83+
self.url_map: Dict[str, str] = {}
6984

7085
# Initialize Bedrock Agent
86+
bedrock_client = self.session.client("bedrock-runtime")
7187
model = BedrockConverseModel(
72-
self.config.model_id, provider=BedrockProvider(region_name=self.config.region)
88+
self.config.model_id, provider=BedrockProvider(bedrock_client=bedrock_client)
7389
)
7490
self.agent = Agent(
7591
model,
@@ -95,9 +111,43 @@ def __init__(self, config: BaseExtractorConfig):
95111
),
96112
)
97113

114+
def _fetch_url_map(self, s3_uris: List[str]):
115+
"""
116+
Attempts to fetch sources.json files from the directories of the input files.
117+
Deduplicates potential sources.json locations and merges their mappings.
118+
"""
119+
import logging
120+
121+
logger = logging.getLogger(__name__)
122+
123+
if not s3_uris:
124+
return
125+
126+
sources_locations = set()
127+
for uri in s3_uris:
128+
if uri in self.url_map:
129+
continue
130+
131+
if "/input/" in uri:
132+
sources_uri = uri.split("/input/")[0] + "/input/sources.json"
133+
else:
134+
sources_uri = "/".join(uri.split("/")[:-1]) + "/sources.json"
135+
sources_locations.add(sources_uri)
136+
137+
for sources_uri in sources_locations:
138+
logger.info(f"Attempting to fetch sources map from {sources_uri}...")
139+
content: Optional[str] = self.fetch_s3_content(sources_uri)
140+
if content:
141+
try:
142+
new_map = json.loads(content)
143+
self.url_map.update(new_map)
144+
logger.info(f"Successfully loaded {len(new_map)} mappings from {sources_uri}.")
145+
except Exception as e:
146+
logger.error(f"Failed to parse {sources_uri}: {e}")
147+
98148
def get_aws_secret(self, secret_id: str) -> dict:
99149
"""Fetches and parses a JSON secret from AWS Secrets Manager."""
100-
client = boto3.client("secretsmanager", region_name=self.config.region)
150+
client = self.session.client("secretsmanager")
101151
try:
102152
response = client.get_secret_value(SecretId=secret_id)
103153
if "SecretString" in response:

0 commit comments

Comments
 (0)