Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Environment variables
S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code")
AUGMENTED_EICR_PREFIX = os.getenv("AUGMENTED_EICR_PREFIX", "AugmentationEICRV2/")
AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadata/")
AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadataV2/")

# Cache S3 client to reuse across Lambda invocations
_cached_s3_client: BaseClient | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
# Environment variables
S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code")
EICR_INPUT_PREFIX = os.getenv("EICR_INPUT_PREFIX", "eCRMessageV2/")
SCHEMATRON_ERROR_PREFIX = os.getenv("SCHEMATRON_ERROR_PREFIX", "schematronErrors/")
TTC_INPUT_PREFIX = os.getenv("TTC_INPUT_PREFIX", "TextToCodeValidateSubmissionV2/")
TTC_OUTPUT_PREFIX = os.getenv("TTC_OUTPUT_PREFIX", "TTCOutput/")
TTC_METADATA_PREFIX = os.getenv("TTC_METADATA_PREFIX", "TTCMetadata/")
SCHEMATRON_ERROR_PREFIX = os.getenv("SCHEMATRON_ERROR_PREFIX", "ValidationResponseV2/")
TTC_INPUT_PREFIX = os.getenv("TTC_INPUT_PREFIX", "TextToCodeSubmissionV2/")
TTC_OUTPUT_PREFIX = os.getenv("TTC_OUTPUT_PREFIX", "TTCAugmentationMetadataV2/")
TTC_METADATA_PREFIX = os.getenv("TTC_METADATA_PREFIX", "TTCMetadataV2/")
AWS_REGION = os.getenv("AWS_REGION")
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
OPENSEARCH_ENDPOINT_URL = os.getenv("OPENSEARCH_ENDPOINT_URL")
Expand Down Expand Up @@ -120,7 +120,8 @@ def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client:
# Parse the EventBridge S3 event from the SQS message body
eventbridge_data = lambda_handler.get_eventbridge_data_from_s3_event(s3_event)
object_key = eventbridge_data["object_key"]
logger.info(f"Processing S3 Object: s3://{S3_BUCKET}/{object_key}")
bucket_name = eventbridge_data.get("bucket_name") or S3_BUCKET
logger.info(f"Processing S3 Object: s3://{bucket_name}/{object_key}")

# Extract persistence_id from the RR object key
persistence_id = lambda_handler.get_persistence_id(object_key, TTC_INPUT_PREFIX)
Expand All @@ -129,7 +130,7 @@ def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client:
with logger.append_context_keys(
persistence_id=persistence_id,
):
_process_record_pipeline(persistence_id, s3_client, opensearch_client)
_process_record_pipeline(persistence_id, s3_client, opensearch_client, bucket_name)


def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]:
Expand All @@ -152,17 +153,20 @@ def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]:
return ttc_output, ttc_metadata_output


def _load_schematron_data_fields(persistence_id: str, s3_client: BaseClient) -> list:
def _load_schematron_data_fields(
persistence_id: str, s3_client: BaseClient, bucket_name: str
) -> list:
"""Load Schematron errors from S3 and extract relevant fields.

:param persistence_id: The persistence ID extracted from the S3 object key
:param s3_client: The S3 client to use for fetching files.
:param bucket_name: The S3 bucket name to read from.
:return: The relevant Schematron data fields for TTC processing.
"""
object_key = f"{SCHEMATRON_ERROR_PREFIX}{persistence_id}"
logger.info("Loading Schematron errors", s3_key=f"s3://{S3_BUCKET}/{object_key}")
logger.info("Loading Schematron errors", s3_key=f"s3://{bucket_name}/{object_key}")
schematron_errors = lambda_handler.get_file_content_from_s3(
bucket_name=S3_BUCKET,
bucket_name=bucket_name,
object_key=object_key,
s3_client=s3_client,
)
Expand All @@ -172,17 +176,18 @@ def _load_schematron_data_fields(persistence_id: str, s3_client: BaseClient) ->
return schematron_processor.get_data_fields_from_schematron_error(schematron_errors)


def _load_original_eicr(persistence_id: str, s3_client: BaseClient) -> str:
def _load_original_eicr(persistence_id: str, s3_client: BaseClient, bucket_name: str) -> str:
"""Load the original eICR from S3.

:param persistence_id: The persistence ID extracted from the S3 object key
:param s3_client: The S3 client to use for fetching files.
:param bucket_name: The S3 bucket name to read from.
:return: The original eICR content.
"""
object_key = f"{EICR_INPUT_PREFIX}{persistence_id}"
logger.info(f"Retrieving eICR from s3://{S3_BUCKET}/{object_key}")
logger.info(f"Retrieving eICR from s3://{bucket_name}/{object_key}")
original_eicr_content = lambda_handler.get_file_content_from_s3(
bucket_name=S3_BUCKET, object_key=object_key, s3_client=s3_client
bucket_name=bucket_name, object_key=object_key, s3_client=s3_client
)
logger.info(f"Retrieved eICR content for persistence_id {persistence_id}")
return original_eicr_content
Expand Down Expand Up @@ -283,20 +288,25 @@ def _process_schematron_errors(


def _save_ttc_outputs(
persistence_id: str, ttc_output: dict, ttc_metadata_output: dict, s3_client: BaseClient
persistence_id: str,
ttc_output: dict,
ttc_metadata_output: dict,
s3_client: BaseClient,
bucket_name: str,
) -> None:
"""Save TTC output and metadata output to S3.

:param persistence_id: The persistence ID extracted from the S3 object key
:param ttc_output: The TTC output dictionary.
:param ttc_metadata_output: The TTC metadata output dictionary.
:param s3_client: The S3 client to use for uploading files.
:param bucket_name: The S3 bucket name to write to.
"""
# Save the TTC output to S3 for the Augmentation Lambda to consume
logger.info(f"Saving TTC output to S3 for persistence_id {persistence_id}")
lambda_handler.put_file(
file_obj=io.BytesIO(json.dumps(ttc_output, default=str).encode("utf-8")),
bucket_name=S3_BUCKET,
bucket_name=bucket_name,
object_key=f"{TTC_OUTPUT_PREFIX}{persistence_id}",
s3_client=s3_client,
)
Expand All @@ -305,7 +315,7 @@ def _save_ttc_outputs(
logger.info(f"Saving TTC metadata output to S3 for persistence_id {persistence_id}")
lambda_handler.put_file(
file_obj=io.BytesIO(json.dumps(ttc_metadata_output, default=str).encode("utf-8")),
bucket_name=S3_BUCKET,
bucket_name=bucket_name,
object_key=f"{TTC_METADATA_PREFIX}{persistence_id}",
s3_client=s3_client,
)
Expand All @@ -315,6 +325,7 @@ def _process_record_pipeline(
persistence_id: str,
s3_client: BaseClient,
opensearch_client: OpenSearch,
bucket_name: str,
) -> dict:
"""The main pipeline for processing each record.

Expand All @@ -333,11 +344,12 @@ def _process_record_pipeline(
:param persistence_id: The persistence ID extracted from the S3 object key
:param s3_client: The S3 client to use for S3 operations.
:param opensearch_client: The OpenSearch client.
:param bucket_name: The S3 bucket name extracted from the event, or the default.
"""
ttc_output, ttc_metadata_output = _initialize_ttc_outputs(persistence_id)

logger.info("Starting TTC processing")
schematron_data_fields = _load_schematron_data_fields(persistence_id, s3_client)
schematron_data_fields = _load_schematron_data_fields(persistence_id, s3_client, bucket_name)

if not schematron_data_fields:
logger.warning(
Expand All @@ -348,13 +360,13 @@ def _process_record_pipeline(
logger.info(f"Saving TTC metadata output to S3 for persistence_id {persistence_id}")
lambda_handler.put_file(
file_obj=io.BytesIO(json.dumps(ttc_metadata_output, default=str).encode("utf-8")),
bucket_name=S3_BUCKET,
bucket_name=bucket_name,
object_key=f"{TTC_METADATA_PREFIX}{persistence_id}",
s3_client=s3_client,
)
return ttc_output

original_eicr_content = _load_original_eicr(persistence_id, s3_client)
original_eicr_content = _load_original_eicr(persistence_id, s3_client, bucket_name)
_populate_eicr_metadata(original_eicr_content, ttc_output, ttc_metadata_output)
_process_schematron_errors(
original_eicr_content,
Expand All @@ -363,6 +375,6 @@ def _process_record_pipeline(
ttc_output,
ttc_metadata_output,
)
_save_ttc_outputs(persistence_id, ttc_output, ttc_metadata_output, s3_client)
_save_ttc_outputs(persistence_id, ttc_output, ttc_metadata_output, s3_client, bucket_name)

return {"statusCode": 200, "message": "TTC processed successfully!"}
8 changes: 4 additions & 4 deletions packages/text-to-code-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

S3_BUCKET = "dibbs-text-to-code"
EICR_INPUT_PREFIX = "eCRMessageV2/"
SCHEMATRON_ERROR_PREFIX = "schematronErrors/"
TTC_INPUT_PREFIX = "TextToCodeValidateSubmissionV2/"
TTC_OUTPUT_PREFIX = "TTCOutput/"
TTC_METADATA_PREFIX = "TTCMetadata/"
SCHEMATRON_ERROR_PREFIX = "ValidationResponseV2/"
TTC_INPUT_PREFIX = "TextToCodeSubmissionV2/"
TTC_OUTPUT_PREFIX = "TTCAugmentationMetadataV2/"
TTC_METADATA_PREFIX = "TTCMetadataV2/"
AWS_REGION = "us-east-1"
AWS_ACCESS_KEY_ID = "test_access_key_id"
AWS_SECRET_ACCESS_KEY = "test_secret_access_key" # noqa: S105
Expand Down
16 changes: 8 additions & 8 deletions terraform/_variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,26 @@ variable "eicr_input_prefix" {

variable "schematron_error_prefix" {
type = string
default = "schematronErrors/"
description = "S3 prefix for schematron error files"
default = "ValidationResponseV2/"
description = "S3 prefix for schematron validation response files"
}

variable "ttc_input_prefix" {
type = string
default = "TextToCodeValidateSubmissionV2/"
default = "TextToCodeSubmissionV2/"
description = "S3 prefix for TTC input submission files"
}

variable "ttc_output_prefix" {
type = string
default = "TTCOutput/"
description = "S3 prefix for TTC output files"
default = "TTCAugmentationMetadataV2/"
description = "S3 prefix for TTC augmentation metadata output files"
}

variable "ttc_metadata_prefix" {
type = string
default = "TTCMetadata/"
description = "S3 prefix for TTC metadata files"
default = "TTCMetadataV2/"
description = "S3 prefix for TTC analysis metadata files"
}

variable "augmented_eicr_prefix" {
Expand All @@ -143,7 +143,7 @@ variable "augmented_eicr_prefix" {

variable "augmentation_metadata_prefix" {
type = string
default = "AugmentationMetadata/"
default = "AugmentationMetadataV2/"
description = "S3 prefix for augmentation metadata files"
}

Expand Down
Loading