Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion packages/lambda-handler/src/lambda_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .lambda_handler import create_aws_auth as create_aws_auth
from .lambda_handler import create_opensearch_client as create_opensearch_client
from .lambda_handler import create_s3_client as create_s3_client
from .lambda_handler import get_eventbridge_data_from_s3_event as get_eventbridge_data_from_s3_event
from .lambda_handler import get_file_content_from_s3 as get_file_content_from_s3
from .lambda_handler import get_persistence_id as get_persistence_id
from .lambda_handler import get_s3_credentials as get_s3_credentials
Expand Down
13 changes: 0 additions & 13 deletions packages/lambda-handler/src/lambda_handler/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing

import boto3
from aws_lambda_typing import events as lambda_events
from botocore.client import BaseClient
from botocore.credentials import Credentials
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -100,18 +99,6 @@ def get_file_content_from_s3(bucket_name: str, object_key: str) -> str:
return response["Body"].read().decode("utf-8")


def get_eventbridge_data_from_s3_event(event: lambda_events.EventBridgeEvent) -> dict:
"""Extracts the file metadata from an S3 event triggered by a Lambda function.

:param event: The S3 event containing the bucket and object key information.
:return: A dictionary containing the bucket name and object key.
"""
bucket_name = event["detail"]["bucket"]["name"]
object_key = event["detail"]["object"]["key"]

return {"bucket_name": bucket_name, "object_key": object_key}


def put_file(file_obj: typing.BinaryIO, bucket_name: str, object_key: str) -> None:
"""Uploads a file object to a S3 bucket.

Expand Down
16 changes: 16 additions & 0 deletions packages/lambda-handler/src/lambda_handler/models/opensearch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from aws_lambda_powertools.utilities.data_classes import SQSRecord
from pydantic import BaseModel
from pydantic import Field

Expand All @@ -8,6 +9,21 @@ class S3Location(BaseModel):
bucket: str = Field(description="The S3 bucket where the file is located.")
key: str = Field(description="The S3 key (path) where the file is located.")

@classmethod
def from_sqs_record(cls, record: SQSRecord) -> "S3Location":
"""Create an S3Location model from an SQSRecord."""
return cls.model_validate(
{
"bucket": record.json_body["detail"]["bucket"]["name"],
"key": record.json_body["detail"]["object"]["key"],
}
)

@property
def address(self) -> str:
"""Return the address string in the form: `'s3://{self.bucket}/{self.key}'`."""
return f"s3://{self.bucket}/{self.key}"


class OpenSearchHitSource(BaseModel):
"""Represents a single search result _source returned from OpenSearch."""
Expand Down
15 changes: 0 additions & 15 deletions packages/lambda-handler/tests/test_lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ def test_create_s3_client(self, moto_setup):
assert s3_client._get_credentials().access_key == "test_access_key_id"


class TestGetEventBridgeDataFromS3Event:
def test_get_eventbridge_data_from_s3_event(self, moto_setup):
"""Test get file content from S3 event."""
moto_setup.put_object(
Bucket=moto_setup.bucket_name, Key="test.txt", Body=b"This eICR has errors"
)

event = {
"detail": {"bucket": {"name": moto_setup.bucket_name}, "object": {"key": "test.txt"}}
}

content = lambda_handler.get_eventbridge_data_from_s3_event(event)
assert content == {"bucket_name": moto_setup.bucket_name, "object_key": "test.txt"}


class TestGetFileContentFromS3:
def test_get_file_content_from_s3(self, moto_setup):
"""Test get file content from S3."""
Expand Down
29 changes: 29 additions & 0 deletions packages/lambda-handler/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from aws_lambda_powertools.utilities.data_classes import SQSRecord

from lambda_handler.models.opensearch import S3Location


class TestS3Location:
def test_from_sqs_record(self):
expected_bucket = "expected_bucket"
expected_key = "expected_key"
record = SQSRecord(
{
"body": f'{{"detail": {{"bucket": {{"name": "{expected_bucket}"}},"object": {{"key": "{expected_key}"}}}}}}'
}
)

actual = S3Location.from_sqs_record(record)

assert actual == S3Location(bucket=expected_bucket, key=expected_key)

def test_address(self):
test_bucket = "test_bucket"
test_key = "test_key"
expected_address = f"s3://{test_bucket}/{test_key}"

location = S3Location(bucket=test_bucket, key=test_key)

actual = location.address

assert actual == expected_address
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from aws_lambda_powertools.utilities.data_classes import event_source
from aws_lambda_powertools.utilities.typing import LambdaContext
from botocore.client import BaseClient
from opensearchpy import OpenSearch
from opensearchpy.client import OpenSearch

import lambda_handler
from lambda_handler.models.opensearch import S3Location
from text_to_code.models import query as query_models
from text_to_code.services import eicr_processor
from text_to_code.services import embedder
Expand Down Expand Up @@ -52,20 +53,7 @@ def handler(event: SQSEvent, context: LambdaContext) -> dict:
:param context: The Lambda context object.
:return: A dictionary containing the status code, message, and any relevant data about the processing results.
"""
global _cached_auth, _cached_opensearch_client, _cached_s3_client # noqa: PLW0603

# Initialize cached clients if they don't exist
if _cached_auth is None:
_cached_auth = lambda_handler.create_aws_auth()
auth = _cached_auth

if _cached_opensearch_client is None:
_cached_opensearch_client = lambda_handler.create_opensearch_client(auth)
opensearch_client = _cached_opensearch_client

if _cached_s3_client is None:
_cached_s3_client = lambda_handler.create_s3_client()
s3_client = _cached_s3_client
opensearch_client, s3_client = _initilize_clients()

logger.info(f"Received event with {len(event['Records'])} record(s)")

Expand Down Expand Up @@ -98,6 +86,27 @@ def handler(event: SQSEvent, context: LambdaContext) -> dict:
)


def _initilize_clients() -> tuple[OpenSearch, BaseClient]:
"""Initlize auth, OpenSearch, and S3 clients, and return the OpenSearch and S3 clients.

TODO: I Do not love that this function initilizes all 3, but only returns 2.
"""
global _cached_auth, _cached_opensearch_client, _cached_s3_client # noqa: PLW0603
if _cached_auth is None:
_cached_auth = lambda_handler.create_aws_auth()
auth = _cached_auth

if _cached_opensearch_client is None:
_cached_opensearch_client = lambda_handler.create_opensearch_client(auth)
opensearch_client = _cached_opensearch_client

if _cached_s3_client is None:
_cached_s3_client = lambda_handler.create_s3_client()
s3_client = _cached_s3_client

return opensearch_client, s3_client


def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client: OpenSearch) -> None:
"""Process each SQS record.

Expand All @@ -107,22 +116,16 @@ def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client:
logger.warning("Empty SQS body", message_id=record.message_id)
return

s3_event = json.loads(record.body)

# Parse the EventBridge S3 event from the SQS message body
eventbridge_data = lambda_handler.get_eventbridge_data_from_s3_event(s3_event)
bucket = eventbridge_data["bucket_name"]
object_key = eventbridge_data["object_key"]
logger.info(f"Processing S3 Object: s3://{bucket}/{object_key}")
s3_location = S3Location.from_sqs_record(record)
logger.info("Processing S3 Object: %s", s3_location.address)

# Extract persistence_id from the RR object key
persistence_id = lambda_handler.get_persistence_id(object_key, TTC_INPUT_PREFIX)
logger.info(f"Extracted persistence_id: {persistence_id}")
persistence_id = lambda_handler.get_persistence_id(s3_location.key, TTC_INPUT_PREFIX)
logger.info("Extracted persistence_id: %s", persistence_id)

with logger.append_context_keys(
persistence_id=persistence_id,
):
_process_record_pipeline(bucket, persistence_id, s3_client, opensearch_client)
_process_record_pipeline(s3_location.bucket, persistence_id, s3_client, opensearch_client)


def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]:
Expand All @@ -132,12 +135,8 @@ def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]:
:return: The TTC output and TTC metadata output dictionaries.
"""
# TODO: Update the ttc_output to ensure it matches and uses the expected model once ticket #263 is completed
ttc_output: dict = {"persistence_id": "", "eicr_metadata": {}, "schematron_errors": {}}
ttc_metadata_output: dict = {
"persistence_id": "",
"eicr_metadata": {},
"schematron_errors": {},
}
ttc_output = {"persistence_id": "", "eicr_metadata": {}, "schematron_errors": {}}
ttc_metadata_output = {"persistence_id": "", "eicr_metadata": {}, "schematron_errors": {}}
ttc_output["persistence_id"] = persistence_id
ttc_metadata_output["persistence_id"] = persistence_id
return ttc_output, ttc_metadata_output
Expand Down Expand Up @@ -184,9 +183,7 @@ def _load_original_eicr(bucket: str, persistence_id: str) -> str:


def _populate_eicr_metadata(
original_eicr_content: str,
ttc_output: dict,
ttc_metadata_output: dict,
original_eicr_content: str, ttc_output: dict, ttc_metadata_output: dict
) -> None:
"""Populate eICR metadata on TTC outputs.

Expand Down Expand Up @@ -304,10 +301,7 @@ def _save_ttc_outputs(persistence_id: str, ttc_output: dict, ttc_metadata_output


def _process_record_pipeline(
bucket: str,
persistence_id: str,
s3_client: BaseClient,
opensearch_client: OpenSearch,
bucket: str, persistence_id: str, s3_client: BaseClient, opensearch_client: OpenSearch
) -> dict:
"""The main pipeline for processing each record.

Expand Down
Loading