diff --git a/.github/workflows/build-images.yml b/.github/workflows/build-images.yml index f725351dd9852..ea1e1e66e9f6d 100644 --- a/.github/workflows/build-images.yml +++ b/.github/workflows/build-images.yml @@ -29,7 +29,6 @@ jobs: - frontend # pushed both the frontend and backend images - upload_failures - upload_success - - dataset_submissions - processing - wmg_processing - cellguide_pipeline diff --git a/.github/workflows/rdev-tests.yml b/.github/workflows/rdev-tests.yml index bd9c006ef1df5..fca6ecf06f916 100644 --- a/.github/workflows/rdev-tests.yml +++ b/.github/workflows/rdev-tests.yml @@ -234,6 +234,6 @@ jobs: if: always() with: name: logged-in-test-results - path: frontend/playwright-report/ + path: /home/runner/work/single-cell-data-portal/single-cell-data-portal/frontend/playwright-report retention-days: 30 if-no-files-found: error diff --git a/.happy/config.json b/.happy/config.json index ff5b349190589..3402b1ff90859 100644 --- a/.happy/config.json +++ b/.happy/config.json @@ -11,7 +11,6 @@ "backend-wmg", "cellguide_pipeline", "processing", - "dataset_submissions", "upload_failures", "upload_success", "wmg_processing" @@ -38,9 +37,6 @@ "wmg_processing": { "profile": "wmg_processing" }, - "dataset_submissions": { - "profile": "dataset_submissions" - }, "upload_failures": { "profile": "upload_failures" }, diff --git a/.happy/terraform/modules/batch/main.tf b/.happy/terraform/modules/batch/main.tf index d85b62b7c9a3b..270b4e72dfe39 100644 --- a/.happy/terraform/modules/batch/main.tf +++ b/.happy/terraform/modules/batch/main.tf @@ -11,7 +11,7 @@ resource aws_batch_job_definition batch_job_def { container_properties = jsonencode({ "jobRoleArn": "${var.batch_role_arn}", "image": "${var.image}", - "memory": var.batch_container_memory_limit, + "memory": 8000, "environment": [ { "name": "ARTIFACT_BUCKET", @@ -42,11 +42,55 @@ resource aws_batch_job_definition batch_job_def { "value": "${var.frontend_url}" } ], - "vcpus": 8, - "linuxParameters": { - "maxSwap": 800000, - "swappiness": 60 - }, + "vcpus": 1, + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "${aws_cloudwatch_log_group.cloud_watch_logs_group.id}", + "awslogs-region": "${data.aws_region.current.name}" + } + } +}) +} + +resource aws_batch_job_definition cxg_job_def { + type = "container" + name = "dp-${var.deployment_stage}-${var.custom_stack_name}-convert" + container_properties = jsonencode({ + "jobRoleArn": "${var.batch_role_arn}", + "image": "${var.image}", + "memory": 16000, + "environment": [ + { + "name": "ARTIFACT_BUCKET", + "value": "${var.artifact_bucket}" + }, + { + "name": "CELLXGENE_BUCKET", + "value": "${var.cellxgene_bucket}" + }, + { + "name": "DATASETS_BUCKET", + "value": "${var.datasets_bucket}" + }, + { + "name": "DEPLOYMENT_STAGE", + "value": "${var.deployment_stage}" + }, + { + "name": "AWS_DEFAULT_REGION", + "value": "${data.aws_region.current.name}" + }, + { + "name": "REMOTE_DEV_PREFIX", + "value": "${var.remote_dev_prefix}" + }, + { + "name": "FRONTEND_URL", + "value": "${var.frontend_url}" + } + ], + "vcpus": 2, "logConfiguration": { "logDriver": "awslogs", "options": { diff --git a/.happy/terraform/modules/batch/outputs.tf b/.happy/terraform/modules/batch/outputs.tf index 9ac37a1728e39..5d5a7ae2beb85 100644 --- a/.happy/terraform/modules/batch/outputs.tf +++ b/.happy/terraform/modules/batch/outputs.tf @@ -8,6 +8,11 @@ output batch_job_definition_no_revision { description = "ARN for the batch job definition" } +output cxg_job_definition_no_revision { + value = "arn:aws:batch:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:job-definition/${aws_batch_job_definition.cxg_job_def.name}" + description = "ARN for the cxg batch job definition" +} + output batch_job_log_group { value = aws_cloudwatch_log_group.cloud_watch_logs_group.id description = "Name of the CloudWatch log group for the batch job" diff --git a/.happy/terraform/modules/ecs-stack/main.tf b/.happy/terraform/modules/ecs-stack/main.tf index 6fe174f35949f..4067c1714a3ec 100644 --- a/.happy/terraform/modules/ecs-stack/main.tf +++ b/.happy/terraform/modules/ecs-stack/main.tf @@ -386,6 +386,7 @@ module upload_error_lambda { module upload_sfn { source = "../sfn" job_definition_arn = module.upload_batch.batch_job_definition_no_revision + cxg_definition_arn = module.upload_batch.cxg_job_definition_no_revision job_queue_arn = local.job_queue_arn role_arn = local.sfn_role_arn custom_stack_name = local.custom_stack_name @@ -404,22 +405,6 @@ module upload_sfn { batch_job_log_group = module.upload_batch.batch_job_log_group } -module dataset_submissions_lambda { - source = "../lambda" - image = "${local.lambda_dataset_submissions_repo}:${local.image_tag}" - name = "dataset-submissions" - custom_stack_name = local.custom_stack_name - remote_dev_prefix = local.remote_dev_prefix - deployment_stage = local.deployment_stage - artifact_bucket = local.artifact_bucket - cellxgene_bucket = local.cellxgene_bucket - datasets_bucket = local.datasets_bucket - lambda_execution_role = aws_iam_role.dataset_submissions_lambda_service_role.arn - step_function_arn = module.upload_sfn.step_function_arn - subnets = local.subnets - security_groups = local.security_groups -} - module schema_migration { source = "../schema_migration" image = "${local.upload_image_repo}:${local.image_tag}" @@ -527,22 +512,3 @@ resource "aws_iam_role_policy_attachment" "lambda_step_function_execution_policy role = aws_iam_role.dataset_submissions_lambda_service_role.name policy_arn = aws_iam_policy.lambda_step_function_execution_policy.arn } - -resource "aws_lambda_permission" "allow_dataset_submissions_lambda_execution" { - statement_id = "AllowExecutionFromS3Bucket" - action = "lambda:InvokeFunction" - function_name = module.dataset_submissions_lambda.arn - principal = "s3.amazonaws.com" - source_arn = try(local.secret["s3_buckets"]["dataset_submissions"]["arn"], "") -} - -resource "aws_s3_bucket_notification" "on_dataset_submissions_object_created" { - bucket = local.dataset_submissions_bucket - - lambda_function { - lambda_function_arn = module.dataset_submissions_lambda.arn - events = ["s3:ObjectCreated:*"] - } - - depends_on = [aws_lambda_permission.allow_dataset_submissions_lambda_execution] -} diff --git a/.happy/terraform/modules/schema_migration/main.tf b/.happy/terraform/modules/schema_migration/main.tf index c10f3904556ea..a01d128bb1924 100644 --- a/.happy/terraform/modules/schema_migration/main.tf +++ b/.happy/terraform/modules/schema_migration/main.tf @@ -5,7 +5,6 @@ data aws_caller_identity current {} locals { name = "schema-migration" job_definition_arn = "arn:aws:batch:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:job-definition/dp-${var.deployment_stage}-${var.custom_stack_name}-schema-migration" - swap_job_definition_arn = "${local.job_definition_arn}-swap" } resource aws_cloudwatch_log_group batch_cloud_watch_logs_group { @@ -13,58 +12,6 @@ resource aws_cloudwatch_log_group batch_cloud_watch_logs_group { name = "/dp/${var.deployment_stage}/${var.custom_stack_name}/${local.name}-batch" } -resource aws_batch_job_definition schema_migrations_swap { - type = "container" - name = "dp-${var.deployment_stage}-${var.custom_stack_name}-${local.name}-swap" - container_properties = jsonencode({ - jobRoleArn= var.batch_role_arn, - image= var.image, - environment= [ - { - name= "ARTIFACT_BUCKET", - value= var.artifact_bucket - }, - { - name= "DEPLOYMENT_STAGE", - value= var.deployment_stage - }, - { - name= "AWS_DEFAULT_REGION", - value= data.aws_region.current.name - }, - { - name= "REMOTE_DEV_PREFIX", - value= var.remote_dev_prefix - }, - { - name= "DATASETS_BUCKET", - value= var.datasets_bucket - }, - ], - resourceRequirements = [ - { - type= "VCPU", - Value="32" - }, - { - Type="MEMORY", - Value = "256000" - } - ] - linuxParameters= { - maxSwap= 0, - swappiness= 60 - }, - logConfiguration= { - logDriver= "awslogs", - options= { - awslogs-group= aws_cloudwatch_log_group.batch_cloud_watch_logs_group.id, - awslogs-region= data.aws_region.current.name - } - } - }) -} - resource aws_batch_job_definition schema_migrations { type = "container" name = "dp-${var.deployment_stage}-${var.custom_stack_name}-${local.name}" @@ -94,14 +41,14 @@ resource aws_batch_job_definition schema_migrations { }, ], resourceRequirements = [ - { - type= "VCPU", - Value="2" - }, - { - Type="MEMORY", - Value = "2048" - } + { + type= "VCPU", + Value="1" + }, + { + Type="MEMORY", + Value = "8000" + } ] logConfiguration= { logDriver= "awslogs", @@ -385,7 +332,7 @@ resource aws_sfn_state_machine sfn_schema_migration { "Type": "Task", "Resource": "arn:aws:states:::batch:submitJob.sync", "Parameters": { - "JobDefinition": "${resource.aws_batch_job_definition.schema_migrations_swap.arn}", + "JobDefinition": "${resource.aws_batch_job_definition.schema_migrations.arn}", "JobName": "dataset_migration", "JobQueue": "${var.job_queue_arn}", "Timeout": { @@ -458,7 +405,7 @@ resource aws_sfn_state_machine sfn_schema_migration { "Name.$": "$.result.sfn_name", "Input": { "AWS_STEP_FUNCTIONS_STARTED_BY_EXECUTION_ID.$": "$$.Execution.Id", - "url.$": "$.result.uri", + "manifest.$": "$.result.manifest", "dataset_version_id.$": "$.result.dataset_version_id", "collection_version_id.$": "$.result.collection_version_id", "job_queue": "${var.job_queue_arn}" @@ -518,7 +465,7 @@ resource aws_sfn_state_machine sfn_schema_migration { "Key.$": "$.key_name" } }, - "MaxConcurrency": 10, + "MaxConcurrency": 30, "Next": "report", "Catch": [ { diff --git a/.happy/terraform/modules/sfn/main.tf b/.happy/terraform/modules/sfn/main.tf index 21120f185e117..f3a0fc6cc8e2e 100644 --- a/.happy/terraform/modules/sfn/main.tf +++ b/.happy/terraform/modules/sfn/main.tf @@ -1,7 +1,8 @@ -# Same file as https://github.com/chanzuckerberg/single-cell-infra/blob/main/.happy/terraform/modules/sfn/main.tf # This is used for environment (dev, staging, prod) deployments locals { - timeout = 86400 # 24 hours + h5ad_timeout = 86400 # 24 hours + atac_timeout = 86400 # 24 hours + cxg_timeout = 172800 # 48 hours } data aws_region current {} @@ -12,430 +13,284 @@ resource "aws_sfn_state_machine" "state_machine" { definition = </ @@ -198,34 +215,43 @@ def download(self, local_file_name: str): class RegisteredSources: """Manages all of the download sources.""" - _registered: typing.Set[typing.Type[URI]] = set() + _sources: typing.List[typing.Type[URI]] = [] @classmethod def add(cls, parser: typing.Type[URI]): if issubclass(parser, URI): - cls._registered.add(parser) + cls._sources.append(parser) else: raise TypeError(f"subclass type {URI.__name__} expected") @classmethod def remove(cls, parser: typing.Type[URI]): - cls._registered.remove(parser) + cls._sources.remove(parser) @classmethod def get(cls) -> typing.Iterable: - return cls._registered + return cls._sources + + @classmethod + def is_empty(cls) -> bool: + return not cls._sources + + @classmethod + def empty(cls): + cls._sources = [] def from_uri(uri: str) -> typing.Optional[URI]: """Given a URI return a object that can be used by the processing container to download data.""" + if RegisteredSources.is_empty(): + # RegisteredSources are processed in the order registered and returns the first match. + RegisteredSources.add(DropBoxURL) + RegisteredSources.add(CXGPublicURL) + RegisteredSources.add(S3URL) + RegisteredSources.add(S3URI) + for source in RegisteredSources.get(): uri_obj = source.validate(uri) if uri_obj: return uri_obj return None - - -# RegisteredSources are processed in the order registered and returns the first match. -RegisteredSources.add(DropBoxURL) -RegisteredSources.add(S3URL) -RegisteredSources.add(S3URI) diff --git a/backend/curation/api/curation-api.yml b/backend/curation/api/curation-api.yml index ca16b07b5dbf5..e78d26f344179 100644 --- a/backend/curation/api/curation-api.yml +++ b/backend/curation/api/curation-api.yml @@ -510,6 +510,67 @@ paths: "410": $ref: "#/components/responses/410" + /v1/collections/{collection_id}/datasets/{dataset_id}/manifest: + get: + summary: Get manifest for a dataset + description: | + Retrieve the manifest that represents the files in the dataset. This can be used to reprocess a dataset. + The public URLs will be used to represent the original raw assets. + operationId: backend.curation.api.v1.curation.collections.collection_id.datasets.dataset_id.manifest.actions.get + tags: + - Collection + parameters: + - $ref: "#/components/parameters/path_collection_id" + - $ref: "#/components/parameters/path_dataset_id" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ingestion_manifest" + "400": + $ref: "#/components/responses/400" + "404": + $ref: "#/components/responses/404" + "410": + $ref: "#/components/responses/410" + put: + summary: Upload dataset via a manifest + description: | + Submit a manifest containing a list of files to be validated as a single dataset submission. + The request body must conform to the manifest schema and all files must be reachable. + operationId: backend.curation.api.v1.curation.collections.collection_id.datasets.dataset_id.manifest.actions.put + tags: + - Collection + security: + - curatorAccess: [] + parameters: + - $ref: "#/components/parameters/path_collection_id" + - $ref: "#/components/parameters/path_dataset_id" + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/ingestion_manifest" + required: + - anndata + responses: + "202": + $ref: "#/components/responses/202" + "400": + $ref: "#/components/responses/400" + "401": + $ref: "#/components/responses/401" + "403": + $ref: "#/components/responses/403" + "405": + $ref: "#/components/responses/405" + "404": + $ref: "#/components/responses/404" + "413": + $ref: "#/components/responses/413" + /v1/collections/{collection_id}/s3-upload-credentials: get: summary: Get credentials for uploading local files @@ -757,6 +818,8 @@ components: enum: - H5AD - RDS + - ATAC_FRAGMENT + - ATAC_INDEX type: string description: the file type of the asset. batch_condition: @@ -1639,6 +1702,33 @@ components: - PUBLIC - PRIVATE type: string + ingestion_manifest: # Schema lightly modified from yaml.dump(IngestionManifest.model_json_schema()) + description: | + Manifest of files defining a dataset. + properties: + anndata: + anyOf: + - format: uri + maxLength: 2083 + minLength: 1 + type: string + - format: uri + minLength: 1 + type: string + title: Anndata + atac_fragment: + type: string + anyOf: + - format: uri + maxLength: 2083 + minLength: 1 + type: string + - format: uri + minLength: 1 + type: string + title: ATAC Seq Fragment + title: IngestionManifest + type: object parameters: dataset_visibility: description: | diff --git a/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/actions.py b/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/actions.py index 7af6951d489af..53d00c2181e4e 100644 --- a/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/actions.py +++ b/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/actions.py @@ -1,20 +1,16 @@ -from typing import Tuple - from flask import Response, jsonify, make_response from backend.common.utils.exceptions import MaxFileSizeExceededException from backend.common.utils.http_exceptions import ( ForbiddenHTTPException, - GoneHTTPException, InvalidParametersHTTPException, MethodNotAllowedException, NotFoundHTTPException, TooLargeHTTPException, ) from backend.curation.api.v1.curation.collections.common import ( - get_inferred_collection_version, + _get_collection_and_dataset, reshape_dataset_for_curation_api, - validate_uuid_else_forbidden, ) from backend.layers.auth.user_info import UserInfo from backend.layers.business.exceptions import ( @@ -23,16 +19,13 @@ CollectionUpdateException, DatasetInWrongStatusException, DatasetIsPrivateException, - DatasetIsTombstonedException, DatasetNotFoundException, + InvalidIngestionManifestException, InvalidMetadataException, InvalidURIException, ) from backend.layers.common.entities import ( - CollectionVersionWithDatasets, DatasetArtifactMetadataUpdate, - DatasetId, - DatasetVersion, ) from backend.portal.api.providers import get_business_logic @@ -49,35 +42,6 @@ def get(collection_id: str, dataset_id: str = None): return make_response(jsonify(response_body), 200) -def _get_collection_and_dataset( - collection_id: str, dataset_id: str -) -> Tuple[CollectionVersionWithDatasets, DatasetVersion]: - """ - Get collection and dataset by their ids. Will look up collection by version and canonical id, and dataset by - canonical only - """ - validate_uuid_else_forbidden(collection_id) - validate_uuid_else_forbidden(dataset_id) - collection_version = get_inferred_collection_version(collection_id) - - # Extract the dataset from the dataset list. - dataset_version = None - for dataset in collection_version.datasets: - if dataset.dataset_id.id == dataset_id: - dataset_version = dataset - break - if dataset.version_id.id == dataset_id: - raise ForbiddenHTTPException from None - if dataset_version is None: - try: - get_business_logic().get_dataset_version_from_canonical(DatasetId(dataset_id), get_tombstoned=True) - except DatasetIsTombstonedException: - raise GoneHTTPException() from None - raise NotFoundHTTPException() from None - - return collection_version, dataset_version - - def delete(token_info: dict, collection_id: str, dataset_id: str, delete_published: bool = False): business_logic = get_business_logic() user_info = UserInfo(token_info) @@ -136,6 +100,8 @@ def put(collection_id: str, dataset_id: str, body: dict, token_info: dict): raise InvalidParametersHTTPException(detail="The dropbox shared link is invalid.") from None except MaxFileSizeExceededException: raise TooLargeHTTPException() from None + except InvalidIngestionManifestException as e: + raise InvalidParametersHTTPException(detail=e.message) from None except DatasetInWrongStatusException: raise MethodNotAllowedException( detail="Submission failed. A dataset cannot be updated while a previous update for the same dataset " diff --git a/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/manifest/actions.py b/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/manifest/actions.py new file mode 100644 index 0000000000000..5fb6ba69eab3b --- /dev/null +++ b/backend/curation/api/v1/curation/collections/collection_id/datasets/dataset_id/manifest/actions.py @@ -0,0 +1,101 @@ +from flask import Response, jsonify, make_response + +from backend.common.utils.exceptions import MaxFileSizeExceededException +from backend.common.utils.http_exceptions import ( + ForbiddenHTTPException, + InvalidParametersHTTPException, + MethodNotAllowedException, + NotFoundHTTPException, + TooLargeHTTPException, +) +from backend.curation.api.v1.curation.collections.common import ( + _get_collection_and_dataset, +) +from backend.layers.auth.user_info import UserInfo +from backend.layers.business.business import BusinessLogic +from backend.layers.business.exceptions import ( + CollectionIsPublishedException, + CollectionNotFoundException, + DatasetInWrongStatusException, + DatasetNotFoundException, + InvalidIngestionManifestException, + InvalidURIException, +) +from backend.layers.common.entities import ( + DatasetArtifactType, + DatasetVersion, +) +from backend.portal.api.providers import get_business_logic + + +def get_single_artifact_permanent_url( + dataset_version: DatasetVersion, artifact_type: DatasetArtifactType, required: bool = False +) -> str | None: + """Find exactly one artifact of the given type, then return it's canonical URI. + + If `required` is True and no artifact is found, raises ValueError. + If more than one is found, always raises ValueError. + """ + artifacts = dataset_version.artifacts + matches = [a for a in artifacts if a.type == artifact_type] + + if len(matches) > 1: + raise ValueError(f"Multiple '{artifact_type}' artifacts found.") + + if not matches and required: + raise ValueError(f"No '{artifact_type}' artifact found.") + + if matches: + return BusinessLogic.generate_permanent_url(dataset_version, matches[0].id, artifact_type) + + +def get(collection_id: str, dataset_id: str = None): + _, dataset_version = _get_collection_and_dataset(collection_id, dataset_id) + + response_body = {} + for key, artifact_type in [ + ("anndata", DatasetArtifactType.H5AD), + ("atac_fragment", DatasetArtifactType.ATAC_FRAGMENT), + ]: + if uri := get_single_artifact_permanent_url(dataset_version, artifact_type): + response_body[key] = uri + + return make_response(jsonify(response_body), 200) + + +def put(collection_id: str, dataset_id: str, body: dict, token_info: dict): + # TODO: deduplicate from ApiCommon. We need to settle the class/module level debate before can do that + business_logic = get_business_logic() + + collection_version, dataset_version = _get_collection_and_dataset(collection_id, dataset_id) + + if not UserInfo(token_info).is_user_owner_or_allowed(collection_version.owner): + raise ForbiddenHTTPException() + + try: + business_logic.ingest_dataset( + collection_version.version_id, + body, + None, + None if dataset_id is None else dataset_version.version_id, + ) + return Response(status=202) + except CollectionNotFoundException: + raise ForbiddenHTTPException() from None + except CollectionIsPublishedException: + raise ForbiddenHTTPException() from None + except DatasetNotFoundException: + raise NotFoundHTTPException() from None + except InvalidURIException: + raise InvalidParametersHTTPException(detail="The dropbox shared link is invalid.") from None + except InvalidIngestionManifestException as e: + raise InvalidParametersHTTPException(detail=e.message) from None + except MaxFileSizeExceededException: + raise TooLargeHTTPException() from None + except DatasetInWrongStatusException: + raise MethodNotAllowedException( + detail="Submission failed. A dataset cannot be updated while a previous update for the same dataset " + "is in progress. Please cancel the current submission by deleting the dataset, or wait until " + "the submission has finished processing." + ) from None + # End of duplicate block diff --git a/backend/curation/api/v1/curation/collections/common.py b/backend/curation/api/v1/curation/collections/common.py index 30f54e20e43b0..55c08ec2af798 100644 --- a/backend/curation/api/v1/curation/collections/common.py +++ b/backend/curation/api/v1/curation/collections/common.py @@ -12,6 +12,7 @@ ) from backend.layers.auth.user_info import UserInfo from backend.layers.business.business import BusinessLogic +from backend.layers.business.exceptions import DatasetIsTombstonedException from backend.layers.common.entities import ( CollectionId, CollectionVersion, @@ -33,13 +34,47 @@ from backend.portal.api.explorer_url import generate as generate_explorer_url from backend.portal.api.providers import get_business_logic -allowed_dataset_asset_types = (DatasetArtifactType.H5AD, DatasetArtifactType.RDS) +allowed_dataset_asset_types = ( + DatasetArtifactType.H5AD, + DatasetArtifactType.RDS, + DatasetArtifactType.ATAC_FRAGMENT, + DatasetArtifactType.ATAC_INDEX, +) def get_collections_base_url(): return CorporaConfig().collections_base_url +def _get_collection_and_dataset( + collection_id: str, dataset_id: str +) -> Tuple[CollectionVersionWithDatasets, DatasetVersion]: + """ + Get collection and dataset by their ids. Will look up collection by version and canonical id, and dataset by + canonical only + """ + validate_uuid_else_forbidden(collection_id) + validate_uuid_else_forbidden(dataset_id) + collection_version = get_inferred_collection_version(collection_id) + + # Extract the dataset from the dataset list. + dataset_version = None + for dataset in collection_version.datasets: + if dataset.dataset_id.id == dataset_id: + dataset_version = dataset + break + if dataset.version_id.id == dataset_id: + raise ForbiddenHTTPException from None + if dataset_version is None: + try: + get_business_logic().get_dataset_version_from_canonical(DatasetId(dataset_id), get_tombstoned=True) + except DatasetIsTombstonedException: + raise GoneHTTPException() from None + raise NotFoundHTTPException() from None + + return collection_version, dataset_version + + def extract_dataset_assets(dataset_version: DatasetVersion): asset_list = list() for asset in dataset_version.artifacts: @@ -48,7 +83,7 @@ def extract_dataset_assets(dataset_version: DatasetVersion): filesize = get_business_logic().s3_provider.get_file_size(asset.uri) if filesize is None: filesize = -1 - url = BusinessLogic.generate_permanent_url(dataset_version.version_id, asset.type) + url = BusinessLogic.generate_permanent_url(dataset_version, asset.id, asset.type) result = { "filesize": filesize, "filetype": asset.type.upper(), diff --git a/backend/database/versions/08_92c817dddc7d_new_atac_artifact_enums.py b/backend/database/versions/08_92c817dddc7d_new_atac_artifact_enums.py new file mode 100644 index 0000000000000..452cbb762c077 --- /dev/null +++ b/backend/database/versions/08_92c817dddc7d_new_atac_artifact_enums.py @@ -0,0 +1,38 @@ +"""new-atac-artifact-enums + +Changing DatasetArtifactTable.type from an enum to a string to allow for new ATAC artifact types. + +Revision ID: 08_92c817dddc7d +Revises: 07_e31a29561f38 +Create Date: 2025-03-06 12:31:12.249307 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "08_92c817dddc7d" +down_revision = "07_e31a29561f38" +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column( + "DatasetArtifact", + "type", + type_=sa.String(), + existing_type=sa.Enum("CXG", "RDS", "H5AD", "RAW_H5AD", name="datasetartifacttype"), + schema="persistence_schema", + ) + + +def downgrade(): + op.alter_column( + "DatasetArtifact", + "type", + type_=sa.Enum("CXG", "RDS", "H5AD", "RAW_H5AD", name="datasetartifacttype"), + existing_type=sa.String(), + schema="persistence_schema", + ) diff --git a/backend/layers/business/business.py b/backend/layers/business/business.py index 69e564cc02fdd..13ead113aa6a1 100644 --- a/backend/layers/business/business.py +++ b/backend/layers/business/business.py @@ -6,6 +6,8 @@ from functools import reduce from typing import Dict, Iterable, List, Optional, Set, Tuple +from pydantic import ValidationError + from backend.common.constants import DATA_SUBMISSION_POLICY_VERSION from backend.common.corpora_config import CorporaConfig from backend.common.doi import doi_curie_from_link @@ -38,6 +40,7 @@ DatasetNotFoundException, DatasetUpdateException, DatasetVersionNotFoundException, + InvalidIngestionManifestException, InvalidURIException, MaxFileSizeExceededException, NoPreviousCollectionVersionException, @@ -46,6 +49,7 @@ from backend.layers.common import validation from backend.layers.common.cleanup import sanitize, sanitize_dataset_artifact_metadata_update from backend.layers.common.entities import ( + ARTIFACT_TO_EXTENSION, CanonicalCollection, CollectionId, CollectionLinkType, @@ -77,6 +81,7 @@ from backend.layers.common.helpers import ( get_published_at_and_collection_version_id_else_not_found, ) +from backend.layers.common.ingestion_manifest import IngestionManifest from backend.layers.common.regex import S3_URI_REGEX from backend.layers.persistence.persistence_interface import DatabaseProviderInterface from backend.layers.thirdparty.batch_job_provider import BatchJobProviderInterface @@ -114,12 +119,26 @@ def __init__( super().__init__() @staticmethod - def generate_permanent_url(dataset_version_id: DatasetVersionId, asset_type: DatasetArtifactType): + def generate_permanent_url( + dataset_version: DatasetVersion, artifact_id: DatasetArtifactId, asset_type: DatasetArtifactType + ) -> str: """ Return the permanent URL for the given asset. """ + if asset_type in [DatasetArtifactType.ATAC_INDEX, DatasetArtifactType.ATAC_FRAGMENT]: + fmt_str = "{}/{}-fragment.{}" + else: + fmt_str = "{}/{}.{}" + + if asset_type == DatasetArtifactType.ATAC_INDEX: + entity_id = [a for a in dataset_version.artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0].id + elif asset_type == DatasetArtifactType.ATAC_FRAGMENT: + entity_id = artifact_id + else: + entity_id = dataset_version.version_id + base_url = CorporaConfig().dataset_assets_base_url - return f"{base_url}/{dataset_version_id.id}.{asset_type}" + return fmt_str.format(base_url, entity_id.id, ARTIFACT_TO_EXTENSION[asset_type]) @staticmethod def generate_dataset_citation( @@ -495,7 +514,8 @@ def _assert_dataset_version_processing_status( dataset = self.database_provider.get_dataset_version(dataset_version_id) if dataset.status.processing_status != expected_status: raise DatasetInWrongStatusException( - f"Dataset {dataset_version_id.id} processing status must be {expected_status.name} but is {dataset.status.processing_status}." + f"Dataset {dataset_version_id.id} processing status must be {expected_status.name} but is " + f"{dataset.status.processing_status}." ) return dataset @@ -537,11 +557,30 @@ def create_empty_dataset_version_for_current_dataset( return new_dataset_version + def is_already_ingested(self, uri): + return str(uri).startswith(CorporaConfig().dataset_assets_base_url) + + def get_ingestion_manifest(self, dataset_version_id: DatasetVersionId) -> IngestionManifest: + dataset_version = self.database_provider.get_dataset_version(dataset_version_id) + if dataset_version is None: + raise DatasetNotFoundException(f"Dataset {dataset_version_id.id} not found") + raw_h5ad_uri = self.get_artifact_type_from_dataset(dataset_version, DatasetArtifactType.RAW_H5AD) + atac_fragment_uri = self.get_artifact_type_from_dataset(dataset_version, DatasetArtifactType.ATAC_FRAGMENT) + return IngestionManifest(anndata=raw_h5ad_uri, atac_fragment=atac_fragment_uri) + + def get_artifact_type_from_dataset( + self, dataset_version: DatasetVersion, artifact_type: DatasetArtifactType + ) -> Optional[str]: + uris = [artifact.uri for artifact in dataset_version.artifacts if artifact.type == artifact_type] + if not uris: + return None + return uris[0] + # TODO: Alternatives: 1) return DatasetVersion 2) Return a new class def ingest_dataset( self, collection_version_id: CollectionVersionId, - url: str, + url: dict | str, # TODO: change to manifest file_size: Optional[int], current_dataset_version_id: Optional[DatasetVersionId], start_step_function: bool = True, @@ -550,25 +589,70 @@ def ingest_dataset( Creates a canonical dataset and starts its ingestion by invoking the step function If `size` is not provided, it will be inferred automatically """ + # Convert old style input to new style + try: + manifest = IngestionManifest(anndata=url) if isinstance(url, str) else IngestionManifest(**url) + except ValidationError as e: + raise InvalidIngestionManifestException("Ingestion manifest is invalid.", errors=e.errors()) from e + logger.info( { "message": "ingesting dataset", "collection_version_id": collection_version_id, - "url": url, + "manifest": manifest, "current_dataset_version_id": current_dataset_version_id, } ) - if not self.uri_provider.validate(url): - raise InvalidURIException(f"Trying to upload invalid URI: {url}") + # Validate the URIs + # TODO: This should be done in the IngestionManifest class + for key, _url in manifest.model_dump(exclude_none=True).items(): + _url = str(_url) + if not self.uri_provider.validate(_url): + raise InvalidURIException(f"Trying to upload invalid URI: {_url}") + if not self.is_already_ingested(_url): + continue + if not current_dataset_version_id: + raise InvalidIngestionManifestException( + message="Cannot ingest public datasets without a current dataset version" + ) + if key == "anndata": + dataset_version_id, extension = _url.split("/")[-1].split(".", maxsplit=1) + if extension != ARTIFACT_TO_EXTENSION[DatasetArtifactType.H5AD]: + raise InvalidIngestionManifestException(message=f"{_url} is not an h5ad file") + previous_dv = self.database_provider.get_dataset_version(DatasetVersionId(dataset_version_id)) + if previous_dv is None: + raise InvalidIngestionManifestException( + message=f"{_url} refers to existing dataset, but that dataset could not be found." + ) + all_dvs = self.database_provider.get_all_versions_for_dataset(previous_dv.dataset_id) + if current_dataset_version_id not in [dv.version_id for dv in all_dvs]: + raise InvalidIngestionManifestException(message=f"{_url} is not a part of the canonical dataset") + manifest.anndata = [a for a in previous_dv.artifacts if a.type == DatasetArtifactType.RAW_H5AD][0].uri + + if key == "atac_fragment": + file_name, extension = _url.split("/")[-1].split(".", 1) + artifact_id = file_name.rsplit("-", 1)[0] + if extension != ARTIFACT_TO_EXTENSION[DatasetArtifactType.ATAC_FRAGMENT]: + raise InvalidIngestionManifestException(message=f"{_url} is not an atac_fragments file") + artifact = self.database_provider.get_dataset_artifacts([DatasetArtifactId(artifact_id)]) + if not len(artifact): + raise InvalidIngestionManifestException(message=f"{_url} atac_fragments not found") + dataset_id = self.get_dataset_version(current_dataset_version_id).dataset_id + if not self.database_provider.check_artifact_is_part_of_dataset(dataset_id, artifact[0].id): + raise InvalidIngestionManifestException( + message=f"{_url} atac_fragments is not a part of the canonical dataset" + ) if file_size is None: - file_info = self.uri_provider.get_file_info(url) + file_info = self.uri_provider.get_file_info(str(manifest.anndata)) file_size = file_info.size max_file_size_gb = CorporaConfig().upload_max_file_size_gb * 2**30 if file_size is not None and file_size > max_file_size_gb: - raise MaxFileSizeExceededException(f"{url} exceeds the maximum allowed file size of {max_file_size_gb} Gb") + raise MaxFileSizeExceededException( + f"{manifest.anndata} exceeds the maximum allowed file size of {max_file_size_gb} Gb" + ) # Ensure that the collection exists and is not published collection = self._assert_collection_version_unpublished(collection_version_id) @@ -625,7 +709,9 @@ def ingest_dataset( # Starts the step function process if start_step_function: - self.step_function_provider.start_step_function(collection_version_id, new_dataset_version.version_id, url) + self.step_function_provider.start_step_function( + collection_version_id, new_dataset_version.version_id, manifest.model_dump_json() + ) return (new_dataset_version.version_id, new_dataset_version.dataset_id) @@ -712,7 +798,8 @@ def get_private_collection_versions_with_datasets( self, owner: str = None ) -> List[CollectionVersionWithPrivateDatasets]: """ - Returns collection versions with their datasets for private collections. Only private collections with datasets, or + Returns collection versions with their datasets for private collections. Only private collections with + datasets, or unpublished revisions with new or updated datasets are returned; unpublished revisions with no new datasets, and no changed datasets are not returned. @@ -789,14 +876,15 @@ def get_dataset_artifact_download_data( """ Returns data required for download: file size and permanent URL. """ - artifacts = self.get_dataset_artifacts(dataset_version_id) + dataset_version = self.database_provider.get_dataset_version(dataset_version_id) + artifacts = dataset_version.artifacts artifact = next((a for a in artifacts if a.id == artifact_id), None) if not artifact: raise ArtifactNotFoundException(f"Artifact {artifact_id} not found in dataset {dataset_version_id}") file_size = self.s3_provider.get_file_size(artifact.uri) - url = self.generate_permanent_url(dataset_version_id, artifact.type) + url = self.generate_permanent_url(dataset_version, artifact.id, artifact.type) return DatasetArtifactDownloadData(file_size, url) @@ -814,7 +902,6 @@ def update_dataset_version_status( validation_message: Optional[str] = None, ) -> None: """ - TODO: split into two method, one for updating validation_message, and the other statuses. Updates the status of a dataset version. status_key can be one of: [upload, validation, cxg, rds, h5ad, processing] """ @@ -836,6 +923,10 @@ def update_dataset_version_status( self.database_provider.update_dataset_conversion_status( dataset_version_id, "h5ad_status", new_dataset_status ) + elif status_key == DatasetStatusKey.ATAC and isinstance(new_dataset_status, DatasetConversionStatus): + self.database_provider.update_dataset_conversion_status( + dataset_version_id, "atac_status", new_dataset_status + ) else: raise DatasetUpdateException( f"Invalid status update for dataset {dataset_version_id}: cannot set {status_key} to " @@ -846,18 +937,21 @@ def update_dataset_version_status( self.database_provider.update_dataset_validation_message(dataset_version_id, validation_message) def add_dataset_artifact( - self, dataset_version_id: DatasetVersionId, artifact_type: str, artifact_uri: str + self, + dataset_version_id: DatasetVersionId, + artifact_type: DatasetArtifactType, + artifact_uri: str, + artifact_id: Optional[DatasetArtifactId] = None, ) -> DatasetArtifactId: """ Registers an artifact to a dataset version. """ - - # TODO: we should probably validate that artifact_uri is a valid S3 URI - - if artifact_type not in [artifact.value for artifact in DatasetArtifactType]: + if not isinstance(artifact_type, DatasetArtifactType): raise DatasetIngestException(f"Wrong artifact type for {dataset_version_id}: {artifact_type}") - return self.database_provider.add_dataset_artifact(dataset_version_id, artifact_type, artifact_uri) + return self.database_provider.create_dataset_artifact( + dataset_version_id, artifact_type, artifact_uri, artifact_id + ) def update_dataset_artifact(self, artifact_id: DatasetArtifactId, artifact_uri: str) -> None: """ @@ -920,15 +1014,29 @@ def delete_collection_version(self, collection_version: CollectionVersionWithDat # Collection was never published; delete CollectionTable row self.database_provider.delete_unpublished_collection(collection_version.collection_id) - def delete_dataset_versions_from_public_bucket(self, dataset_version_ids: List[str]) -> List[str]: + def get_atac_fragment_uris_from_dataset_version(self, dataset_version: DatasetVersion) -> List[str]: + """ + get all atac fragment files associated with a dataset version from the public bucket + """ + object_keys = set() + object_keys.update( + [a.uri.rsplit("/", 1)[-1] for a in dataset_version.artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT] + ) + object_keys.update( + [a.uri.rsplit("/", 1)[-1] for a in dataset_version.artifacts if a.type == DatasetArtifactType.ATAC_INDEX] + ) + return list(object_keys) + + def delete_dataset_versions_from_public_bucket(self, dataset_versions: List[DatasetVersion]) -> List[str]: rdev_prefix = os.environ.get("REMOTE_DEV_PREFIX", "").strip("/") object_keys = set() - for d_v_id in dataset_version_ids: + for d_v in dataset_versions: for file_type in ("h5ad", "rds"): - dataset_version_s3_object_key = f"{d_v_id}.{file_type}" + dataset_version_s3_object_key = f"{d_v.version_id}.{file_type}" if rdev_prefix: dataset_version_s3_object_key = f"{rdev_prefix}/{dataset_version_s3_object_key}" object_keys.add(dataset_version_s3_object_key) + object_keys.update(self.get_atac_fragment_uris_from_dataset_version(d_v)) try: self.s3_provider.delete_files(os.getenv("DATASETS_BUCKET"), list(object_keys)) except S3DeleteException as e: @@ -940,7 +1048,7 @@ def delete_all_dataset_versions_from_public_bucket_for_collection(self, collecti Delete all associated publicly-accessible Datasets in s3 """ dataset_versions = self.database_provider.get_all_dataset_versions_for_collection(collection_id) - return self.delete_dataset_versions_from_public_bucket([dv.version_id.id for dv in dataset_versions]) + return self.delete_dataset_versions_from_public_bucket(dataset_versions) def get_unpublished_dataset_versions(self, dataset_id: DatasetId) -> List[DatasetVersion]: """ @@ -971,7 +1079,7 @@ def delete_dataset_versions(self, dataset_versions: List[DatasetVersion]) -> Non self.database_provider.delete_dataset_versions(dataset_versions) def delete_dataset_version_assets(self, dataset_versions: List[DatasetVersion]) -> None: - self.delete_dataset_versions_from_public_bucket([dv.version_id.id for dv in dataset_versions]) + self.delete_dataset_versions_from_public_bucket(dataset_versions) self.delete_artifacts(reduce(lambda artifacts, dv: artifacts + dv.artifacts, dataset_versions, [])) def tombstone_collection(self, collection_id: CollectionId) -> None: @@ -1016,7 +1124,7 @@ def resurrect_collection(self, collection_id: CollectionId) -> None: # Restore s3 public assets for dv_id in dataset_versions_to_restore: - for ext in (DatasetArtifactType.H5AD, DatasetArtifactType.RDS): + for ext in [ARTIFACT_TO_EXTENSION[x] for x in (DatasetArtifactType.H5AD, DatasetArtifactType.RDS)]: object_key = f"{dv_id}.{ext}" self.s3_provider.restore_object(os.getenv("DATASETS_BUCKET"), object_key) @@ -1057,29 +1165,32 @@ def publish_collection_version( if canonical_datasets != version_datasets: has_dataset_revisions = True - # Check Crossref for updates in publisher metadata since last publish of revision, or since private collection was created. - # Raise exception if DOI has moved from pre-print to published, forcing curators to re-publish the collection once corresponding + # Check Crossref for updates in publisher metadata since last publish of revision, or since private + # collection was created. + # Raise exception if DOI has moved from pre-print to published, forcing curators to re-publish the collection + # once corresponding # artifacts update is complete. last_action_at = date_of_last_publish if is_revision else version.created_at doi_update = self._update_crossref_metadata(version, last_action_at) if doi_update: raise CollectionPublishException( [ - f"DOI was updated from {doi_update[0]} to {doi_update[1]} requiring updates to corresponding artifacts. " + f"DOI was updated from {doi_update[0]} to {doi_update[1]} requiring updates to corresponding " + f"artifacts. " "Retry publish once artifact updates are complete." ] ) # Finalize Collection publication and delete any tombstoned assets is_auto_version = version.is_auto_version - dataset_version_ids_to_delete_from_s3 = self.database_provider.finalize_collection_version( + dataset_versions_to_delete_from_s3 = self.database_provider.finalize_collection_version( version.collection_id, version_id, schema_version, data_submission_policy_version, update_revised_at=has_dataset_revisions, ) - self.delete_dataset_versions_from_public_bucket(dataset_version_ids_to_delete_from_s3) + self.delete_dataset_versions_from_public_bucket(dataset_versions_to_delete_from_s3) # Handle cleanup of unpublished versions versions_to_keep = {dv.version_id.id for dv in version.datasets} @@ -1195,11 +1306,14 @@ def _update_crossref_metadata( """ Call Crossref for the latest publisher metadata and: - if a DOI has moved from pre-print to published, trigger update to collection (and artifacts), otherwise, - - if Crossref has been updated since last publish of revision or since private collection was created, update collection + - if Crossref has been updated since last publish of revision or since private collection was created, + update collection version publisher metadata. - :param collection_version_id: The collection version (either a revision or a private collection) to check publisher updates for. - :param last_action_at: The originally published at or revised at date of revision, or the created at date of a private collection. + :param collection_version_id: The collection version (either a revision or a private collection) to check + publisher updates for. + :param last_action_at: The originally published at or revised at date of revision, or the created at date of + a private collection. :return: Tuple of current DOI and DOI returned from Crossref if DOI has changed, otherwise None. """ # Get the DOI from the collection version metadata; exit if no DOI. @@ -1216,7 +1330,6 @@ def _update_crossref_metadata( # Handle change in publisher metadata from pre-print to published. crossref_doi = f"https://doi.org/{crossref_doi_curie}" if crossref_doi != link_doi.uri: - # Set the DOI in the collection version metadata links to be the returned DOI and update collection # version (subsequently triggering update of artifacts). updated_links = [Link(link.name, link.type, crossref_doi) if link.type == "DOI" else link for link in links] diff --git a/backend/layers/business/business_interface.py b/backend/layers/business/business_interface.py index 4031793b9b1a7..2266f8c203d38 100644 --- a/backend/layers/business/business_interface.py +++ b/backend/layers/business/business_interface.py @@ -24,6 +24,7 @@ DatasetVersionId, PublishedDatasetVersion, ) +from backend.layers.common.ingestion_manifest import IngestionManifest class BusinessLogicInterface: @@ -70,10 +71,16 @@ def create_collection( ) -> CollectionVersion: pass + def get_atac_fragment_uris_from_dataset_version(self, dataset_version: DatasetVersion) -> List[str]: + """ + get all atac fragment files associated with a dataset version from the public bucket + """ + pass + def delete_artifacts(self, artifacts: List[DatasetArtifact]) -> None: pass - def delete_dataset_versions_from_public_bucket(self, dataset_version_ids: List[str]) -> List[str]: + def delete_dataset_versions_from_public_bucket(self, dataset_versions: List[DatasetVersion]) -> List[str]: pass def delete_all_dataset_versions_from_public_bucket_for_collection(self, collection_id: CollectionId) -> List[str]: @@ -107,6 +114,9 @@ def delete_collection_version(self, collection_version: CollectionVersionWithDat def publish_collection_version(self, version_id: CollectionVersionId, data_submission_policy_version: str) -> None: pass + def get_ingestion_manifest(self, dataset_version_id: DatasetVersionId) -> IngestionManifest: + pass + def ingest_dataset( self, collection_version_id: CollectionVersionId, @@ -157,10 +167,17 @@ def update_dataset_version_status( pass def add_dataset_artifact( - self, dataset_version_id: DatasetVersionId, artifact_type: str, artifact_uri: str + self, + dataset_version_id: DatasetVersionId, + artifact_type: str, + artifact_uri: str, + artifact_id: Optional[DatasetArtifactId] = None, ) -> DatasetArtifactId: pass + def update_dataset_artifact(self, artifact_id: DatasetArtifactId, artifact_uri: str) -> None: + pass + def get_dataset_status(self, dataset_version_id: DatasetVersionId) -> DatasetStatus: pass diff --git a/backend/layers/business/exceptions.py b/backend/layers/business/exceptions.py index f93348b6ba403..1adc91b2abc16 100644 --- a/backend/layers/business/exceptions.py +++ b/backend/layers/business/exceptions.py @@ -85,6 +85,17 @@ class InvalidURIException(DatasetIngestException): """ +class InvalidIngestionManifestException(DatasetIngestException): + """ + Raised when trying to ingest a dataset with an invalid ingestion manifest + """ + + def __init__(self, message: str, errors: Optional[List[str]] = None) -> None: + self.errors: List[dict] = errors if errors else [] + self.message = message + super().__init__() + + class MaxFileSizeExceededException(DatasetIngestException): """ Raised when trying to ingest a dataset that is too big diff --git a/backend/layers/common/entities.py b/backend/layers/common/entities.py index 61227dae938b0..55a0dfdfe8a00 100644 --- a/backend/layers/common/entities.py +++ b/backend/layers/common/entities.py @@ -16,6 +16,7 @@ class DatasetStatusKey(str, Enum): CXG = "cxg" RDS = "rds" H5AD = "h5ad" + ATAC = "atac" PROCESSING = "processing" @@ -58,6 +59,7 @@ class DatasetValidationStatus(DatasetStatusGeneric, Enum): class DatasetConversionStatus(DatasetStatusGeneric, Enum): NA = "NA" + COPIED = "COPIED" # when the artifact is copied from another dataset version CONVERTING = "CONVERTING" CONVERTED = "CONVERTED" UPLOADING = "UPLOADING" @@ -80,6 +82,18 @@ class DatasetArtifactType(str, Enum): H5AD = "h5ad" RDS = "rds" CXG = "cxg" + ATAC_FRAGMENT = "atac_fragment" + ATAC_INDEX = "atac_index" + + +ARTIFACT_TO_EXTENSION = { + DatasetArtifactType.RAW_H5AD: "h5ad", + DatasetArtifactType.H5AD: "h5ad", + DatasetArtifactType.RDS: "rds", + DatasetArtifactType.CXG: "cxg", + DatasetArtifactType.ATAC_FRAGMENT: "tsv.bgz", + DatasetArtifactType.ATAC_INDEX: "tsv.bgz.tbi", +} class Visibility(Enum): @@ -104,12 +118,13 @@ class DatasetStatus: cxg_status: Optional[DatasetConversionStatus] rds_status: Optional[DatasetConversionStatus] h5ad_status: Optional[DatasetConversionStatus] - processing_status: Optional[DatasetProcessingStatus] + atac_status: Optional[DatasetConversionStatus] = None + processing_status: Optional[DatasetProcessingStatus] = None validation_message: Optional[str] = None @staticmethod def empty(): - return DatasetStatus(None, None, None, None, None, None) + return DatasetStatus(*[None] * 7) @dataclass @@ -152,6 +167,10 @@ class DatasetArtifact: def get_file_name(self): return urlparse(self.uri).path.split("/")[-1] + @property + def extension(self): + return ARTIFACT_TO_EXTENSION[self.type] + @dataclass class OntologyTermId: diff --git a/backend/layers/common/ingestion_manifest.py b/backend/layers/common/ingestion_manifest.py new file mode 100644 index 0000000000000..69332b61d3d66 --- /dev/null +++ b/backend/layers/common/ingestion_manifest.py @@ -0,0 +1,38 @@ +import re +from typing import Optional, Union + +from pydantic import AnyUrl, BaseModel, HttpUrl + + +class S3Url(AnyUrl): + """Pydantic Model for S3 URLs + + Copied from https://gist.github.com/rajivnarayan/c38f01b89de852b3e7d459cfde067f3f + # TODO consolidate with backend/common/utils/dl_sources/uri.py + """ + + allowed_schemes = {"s3"} + pattern = re.compile( + r"^s3://" + r"(?=[a-z0-9])" # Bucket name must start with a letter or digit + r"(?!(^xn--|sthree-|sthree-configurator|.+-s3alias$))" # Bucket name must not start with xn--, sthree-, sthree-configurator or end with -s3alias + r"(?!.*\.\.)" # Bucket name must not contain two adjacent periods + r"[a-z0-9][a-z0-9.-]{1,61}[a-z0-9]" # Bucket naming constraints + r"(? List[str]: + ) -> List[DatasetVersion]: """ Finalizes a collection version. Returns a list of ids for all Dataset Versions for any/all tombstoned Datasets. """ @@ -708,18 +708,19 @@ def finalize_collection_version( dataset_ids_to_tombstone.append(previous_d_id) # get all dataset versions for the datasets that are being tombstoned - dataset_version_ids_to_delete_from_s3 = [] + dataset_versions_to_delete_from_s3 = [] if dataset_ids_to_tombstone: tombstone_dataset_statement = ( update(DatasetTable).where(DatasetTable.id.in_(dataset_ids_to_tombstone)).values(tombstone=True) ) session.execute(tombstone_dataset_statement) - dataset_all_version_ids = ( - session.query(DatasetVersionTable.id) + dataset_all_versions = [ + self._hydrate_dataset_version(dv) + for dv in session.query(DatasetVersionTable) .filter(DatasetVersionTable.dataset_id.in_(dataset_ids_to_tombstone)) .all() - ) - dataset_version_ids_to_delete_from_s3.extend(str(dv_id) for dv_id in dataset_all_version_ids) + ] + dataset_versions_to_delete_from_s3.extend(dataset_all_versions) # update dataset versions for datasets that are not being tombstoned dataset_version_ids = session.query(CollectionVersionTable.datasets).filter_by(id=version_id.id).one()[0] @@ -733,7 +734,7 @@ def finalize_collection_version( if dataset.published_at is None: dataset.published_at = published_at - return dataset_version_ids_to_delete_from_s3 + return dataset_versions_to_delete_from_s3 def get_dataset_version(self, dataset_version_id: DatasetVersionId, get_tombstoned: bool = False) -> DatasetVersion: """ @@ -793,12 +794,21 @@ def get_most_recent_active_dataset_version(self, dataset_id: DatasetId) -> Optio def get_all_versions_for_dataset(self, dataset_id: DatasetId) -> List[DatasetVersion]: """ - Returns all dataset versions for a canonical dataset_id. ***AT PRESENT THIS FUNCTION IS NOT USED*** + Returns all dataset versions for a canonical dataset_id. """ with self._manage_session() as session: dataset_versions = session.query(DatasetVersionTable).filter_by(dataset_id=dataset_id.id).all() return [self._hydrate_dataset_version(dv) for dv in dataset_versions] + def check_artifact_is_part_of_dataset(self, datset_id: DatasetId, artifact_id: DatasetArtifactId) -> bool: + """ + Check if the artifact is part of any of the dataset versions associated with the dataset_id + """ + with self._manage_session() as session: + dataset_versions = session.query(DatasetVersionTable).filter_by(dataset_id=datset_id.id).all() + artifact_ids = [str(artifact_id) for dv in dataset_versions for artifact_id in dv.artifacts] + return artifact_id.id in artifact_ids + def get_all_mapped_datasets_and_collections(self) -> Tuple[List[DatasetVersion], List[CollectionVersion]]: """ Returns all mapped datasets and mapped collection versions. @@ -831,6 +841,14 @@ def get_dataset_artifacts_by_version_id(self, dataset_version_id: DatasetVersion ) return self.get_dataset_artifacts(artifact_ids[0]) + def get_artifact_by_uri_suffix(self, uri_suffix: str) -> Optional[DatasetArtifact]: + """ + Returns the artifact with the given uri suffix + """ + with self._manage_session() as session: + artifact = session.query(DatasetArtifactTable).filter(DatasetArtifactTable.uri.endswith(uri_suffix)).one() + return self._row_to_dataset_artifact(artifact) if artifact else artifact + def create_canonical_dataset(self, collection_version_id: CollectionVersionId) -> DatasetVersion: """ Initializes a canonical dataset, generating a dataset_id and a dataset_version_id. @@ -859,17 +877,21 @@ def create_canonical_dataset(self, collection_version_id: CollectionVersionId) - return self._row_to_dataset_version(dataset_version, CanonicalDataset(dataset_id, None, False, None), []) @retry(wait=wait_fixed(1), stop=stop_after_attempt(5)) - def add_dataset_artifact( - self, version_id: DatasetVersionId, artifact_type: DatasetArtifactType, artifact_uri: str + def create_dataset_artifact( + self, + dataset_version_id: DatasetVersionId, + artifact_type: DatasetArtifactType, + artifact_uri: str, + artifact_id: Optional[DatasetArtifactId] = None, ) -> DatasetArtifactId: """ Adds a dataset artifact to an existing dataset version. """ - artifact_id = DatasetArtifactId() - artifact = DatasetArtifactTable(id=artifact_id.id, type=artifact_type, uri=artifact_uri) + artifact_id = artifact_id if artifact_id else DatasetArtifactId() + artifact = DatasetArtifactTable(id=artifact_id.id, type=artifact_type.name, uri=artifact_uri) with self._get_serializable_session() as session: session.add(artifact) - dataset_version = session.query(DatasetVersionTable).filter_by(id=version_id.id).one() + dataset_version = session.query(DatasetVersionTable).filter_by(id=dataset_version_id.id).one() artifacts = list(dataset_version.artifacts) artifacts.append(uuid.UUID(artifact_id.id)) dataset_version.artifacts = artifacts @@ -883,6 +905,18 @@ def update_dataset_artifact(self, artifact_id: DatasetArtifactId, artifact_uri: artifact = session.query(DatasetArtifactTable).filter_by(id=artifact_id.id).one() artifact.uri = artifact_uri + def add_artifact_to_dataset_version( + self, dataset_version_id: DatasetVersionId, artifact_id: DatasetArtifactId + ) -> None: + """ + Adds an artifact to an existing dataset version + """ + with self._manage_session() as session: + dataset_version = session.query(DatasetVersionTable).filter_by(id=dataset_version_id.id).one() + artifacts = list(dataset_version.artifacts) + artifacts.append(uuid.UUID(artifact_id.id)) + dataset_version.artifacts = artifacts + @retry(wait=wait_fixed(1), stop=stop_after_attempt(5)) def update_dataset_processing_status(self, version_id: DatasetVersionId, status: DatasetProcessingStatus) -> None: """ @@ -934,7 +968,9 @@ def update_dataset_validation_message(self, version_id: DatasetVersionId, valida with self._get_serializable_session() as session: dataset_version = session.query(DatasetVersionTable).filter_by(id=version_id.id).one() dataset_version_status = deepcopy(dataset_version.status) - dataset_version_status["validation_message"] = validation_message + message = dataset_version_status.get("validation_message") + message = validation_message if message is None else "\n".join([message, validation_message]) + dataset_version_status["validation_message"] = message dataset_version.status = dataset_version_status def get_dataset_version_status(self, version_id: DatasetVersionId) -> DatasetStatus: @@ -1048,14 +1084,16 @@ def set_collection_version_datasets_order( # Confirm collection version datasets length matches given dataset version IDs length. if len(collection_version.datasets) != len(dataset_version_ids): raise ValueError( - f"Dataset Version IDs length does not match Collection Version {collection_version_id} Datasets length" + f"Dataset Version IDs length does not match Collection Version {collection_version_id} Datasets " + f"length" ) # Confirm all given dataset version IDs belong to collection version. if {dv_id.id for dv_id in dataset_version_ids} != {str(d) for d in collection_version.datasets}: raise ValueError("Dataset Version IDs do not match saved Collection Version Dataset IDs") - # Replace collection version datasets with given, ordered dataset version IDs and update custom ordered flag. + # Replace collection version datasets with given, ordered dataset version IDs and update custom ordered + # flag. updated_datasets = [uuid.UUID(dv_id.id) for dv_id in dataset_version_ids] collection_version.datasets = updated_datasets collection_version.has_custom_dataset_order = True diff --git a/backend/layers/persistence/persistence_interface.py b/backend/layers/persistence/persistence_interface.py index 2945642ffcbee..b4e0957c7a7c9 100644 --- a/backend/layers/persistence/persistence_interface.py +++ b/backend/layers/persistence/persistence_interface.py @@ -154,7 +154,7 @@ def finalize_collection_version( data_submission_policy_version: str, published_at: Optional[datetime] = None, update_revised_at: bool = False, - ) -> List[str]: + ) -> List[DatasetVersion]: """ Finalizes a collection version. This is equivalent to calling: 1. update_collection_version_mapping @@ -193,6 +193,11 @@ def get_all_versions_for_dataset(self, dataset_id: DatasetId) -> List[DatasetVer Returns all dataset versions for a canonical dataset_id """ + def get_artifact_by_uri_suffix(self, uri_suffix: str) -> Optional[DatasetArtifact]: + """ + Returns a dataset artifact by its uri_suffix + """ + def get_all_mapped_datasets_and_collections( self, ) -> Tuple[List[DatasetVersion], List[CollectionVersion]]: # TODO: add filters if needed @@ -206,18 +211,41 @@ def get_dataset_artifacts_by_version_id(self, dataset_version_id: DatasetVersion Returns all the artifacts for a specific dataset version """ + def get_dataset_artifacts(self, artifact: List[DatasetArtifactId]) -> List[DatasetArtifact]: + """ + Returns a list of dataset artifacts by id + """ + def create_canonical_dataset(self, collection_version_id: CollectionVersionId) -> DatasetVersion: """ Initializes a canonical dataset, generating a dataset_id and a dataset_version_id. Returns the newly created DatasetVersion. """ - def add_dataset_artifact( - self, version_id: DatasetVersionId, artifact_type: str, artifact_uri: str + def create_dataset_artifact( + self, + dataset_version_id: DatasetVersionId, + artifact_type: str, + artifact_uri: str, + artifact_id: Optional[DatasetArtifactId] = None, ) -> DatasetArtifactId: """ - Adds a dataset artifact to an existing dataset version. + Create a dataset artifact to add to a dataset version. + """ + + def update_dataset_artifact(self, artifact_id: DatasetArtifactId, artifact_uri: str) -> None: + """ + Updates a dataset artifact uri + """ + pass + + def add_artifact_to_dataset_version( + self, dataset_version_id: DatasetVersionId, artifact_id: DatasetArtifactId + ) -> None: + """ + Adds an artifact to a dataset version """ + pass def update_dataset_processing_status(self, version_id: DatasetVersionId, status: DatasetProcessingStatus) -> None: """ diff --git a/backend/layers/persistence/persistence_mock.py b/backend/layers/persistence/persistence_mock.py index 5dc252e46f902..33b10c0d2b964 100644 --- a/backend/layers/persistence/persistence_mock.py +++ b/backend/layers/persistence/persistence_mock.py @@ -310,7 +310,7 @@ def finalize_collection_version( data_submission_policy_version: str, published_at: Optional[datetime] = None, update_revised_at: bool = False, - ) -> List[str]: + ) -> List[DatasetVersion]: published_at = published_at if published_at else datetime.utcnow() dataset_ids_for_new_collection_version = [] @@ -325,7 +325,7 @@ def finalize_collection_version( dataset_ids_for_new_collection_version.append(dataset_version.dataset_id.id) previous_collection = self.collections.get(collection_id.id) - dataset_version_ids_to_delete_from_s3 = [] + dataset_versions_to_delete_from_s3 = [] if previous_collection is None: self.collections[collection_id.id] = CanonicalCollection( id=collection_id, @@ -347,7 +347,7 @@ def finalize_collection_version( self.datasets[previous_dataset_id].tombstoned = True for dataset_version in self.datasets_versions.values(): if dataset_version.dataset_id == previous_dataset_id: - dataset_version_ids_to_delete_from_s3.append(dataset_version.version_id.id) + dataset_versions_to_delete_from_s3.append(dataset_version) new_collection = copy.deepcopy(previous_collection) new_collection.version_id = version_id @@ -359,7 +359,7 @@ def finalize_collection_version( self.collections_versions[version_id.id].data_submission_policy_version = data_submission_policy_version self.collections_versions[version_id.id].is_auto_version = False - return dataset_version_ids_to_delete_from_s3 + return dataset_versions_to_delete_from_s3 # OR # def update_collection_version_mapping(self, collection_id: CollectionId, version_id: CollectionVersionId) -> None: @@ -444,7 +444,7 @@ def get_most_recent_active_dataset_version(self, dataset_id: DatasetId) -> Optio def get_all_versions_for_dataset(self, dataset_id: DatasetId) -> List[DatasetVersion]: """ - Returns all dataset versions for a canonical dataset_id. ***AT PRESENT THIS FUNCTION IS NOT USED*** + Returns all dataset versions for a canonical dataset_id. """ versions = [] for dataset_version in self.datasets_versions.values(): @@ -452,6 +452,16 @@ def get_all_versions_for_dataset(self, dataset_id: DatasetId) -> List[DatasetVer versions.append(self._update_dataset_version_with_canonical(dataset_version)) return versions + def get_artifact_by_uri_suffix(self, uri_suffix: str) -> Optional[DatasetArtifact]: + for artifact in self.dataset_artifacts.values(): + if artifact.uri.endswith(uri_suffix): + return artifact + + def check_artifact_is_part_of_dataset(self, dataset_id: DatasetId, artifact_id: DatasetArtifactId): + versions = [v for v in self.datasets_versions.values() if v.dataset_id == dataset_id] + artifacts = [a for v in versions for a in v.artifacts] + return any(a.id == artifact_id for a in artifacts) + def _get_all_datasets(self) -> Iterable[DatasetVersion]: """ Returns all the mapped datasets. Currently unused @@ -496,11 +506,15 @@ def add_dataset_to_collection_version_mapping( ) -> None: self.collections_versions[collection_version_id.id].datasets.append(dataset_version_id) - def add_dataset_artifact( - self, version_id: DatasetVersionId, artifact_type: str, artifact_uri: str + def create_dataset_artifact( + self, + dataset_version_id: DatasetVersionId, + artifact_type: str, + artifact_uri: str, + artifact_id: Optional[DatasetArtifactId] = None, ) -> DatasetArtifactId: - version = self.datasets_versions[version_id.id] - artifact_id = DatasetArtifactId() + version = self.datasets_versions[dataset_version_id.id] + artifact_id = artifact_id if artifact_id else DatasetArtifactId() dataset_artifact = DatasetArtifact(artifact_id, artifact_type, artifact_uri) version.artifacts.append(dataset_artifact) self.dataset_artifacts[artifact_id.id] = dataset_artifact @@ -518,6 +532,9 @@ def update_dataset_artifact(self, artifact_id: DatasetArtifactId, artifact_uri: found_artifact = True break + def add_artifact_to_dataset_version(self, version_id: DatasetVersionId, artifact_id: DatasetArtifactId) -> None: + self.datasets_versions[version_id.id].artifacts.append(self.dataset_artifacts[artifact_id.id]) + def set_dataset_metadata(self, version_id: DatasetVersionId, metadata: DatasetMetadata) -> None: version = self.datasets_versions[version_id.id] version.metadata = copy.deepcopy(metadata) @@ -543,7 +560,12 @@ def update_dataset_conversion_status( def update_dataset_validation_message(self, version_id: DatasetVersionId, validation_message: str) -> None: dataset_version = self.datasets_versions[version_id.id] - dataset_version.status.validation_message = validation_message + if dataset_version.status.validation_message is not None: + dataset_version.status.validation_message = ( + dataset_version.status.validation_message + "\n" + validation_message + ) + else: + dataset_version.status.validation_message = validation_message def add_dataset_to_collection_version(self, version_id: CollectionVersionId, dataset_id: DatasetId) -> None: # Not needed for now - create_dataset does this @@ -581,7 +603,8 @@ def replace_dataset_in_collection_version( new_dataset_version = self.get_dataset_version(new_dataset_version_id) if collection_version.collection_id != new_dataset_version.collection_id: raise ValueError( - f"Dataset version {new_dataset_version_id} does not belong to collection {collection_version.collection_id}" + f"Dataset version {new_dataset_version_id} does not belong to collection " + f"{collection_version.collection_id}" ) idx = next(i for i, e in enumerate(collection_version.datasets) if e == old_dataset_version_id) diff --git a/backend/layers/processing/dataset_metadata_update.py b/backend/layers/processing/dataset_metadata_update.py index 586aec2ee5a46..83ee6c74798b6 100644 --- a/backend/layers/processing/dataset_metadata_update.py +++ b/backend/layers/processing/dataset_metadata_update.py @@ -9,8 +9,6 @@ import scanpy import tiledb -from rpy2.robjects import StrVector -from rpy2.robjects.packages import importr from backend.common.utils.corpora_constants import CorporaConstants from backend.layers.business.business import BusinessLogic @@ -29,13 +27,10 @@ from backend.layers.processing.exceptions import ProcessingFailed from backend.layers.processing.h5ad_data_file import H5ADDataFile from backend.layers.processing.logger import configure_logging -from backend.layers.processing.process_download import ProcessDownload +from backend.layers.processing.process_validate_h5ad import ProcessValidateH5AD from backend.layers.thirdparty.s3_provider import S3Provider from backend.layers.thirdparty.uri_provider import UriProvider -base = importr("base") -seurat = importr("SeuratObject") - configure_logging(level=logging.INFO) # maps artifact name for metadata field to DB field name, if different @@ -43,7 +38,7 @@ FIELDS_IN_RAW_H5AD = ["title"] -class DatasetMetadataUpdaterWorker(ProcessDownload): +class DatasetMetadataUpdaterWorker(ProcessValidateH5AD): def __init__(self, artifact_bucket: str, datasets_bucket: str, spatial_deep_zoom_dir: str = None) -> None: # init each worker with business logic backed by non-shared DB connection self.business_logic = BusinessLogic( @@ -54,7 +49,7 @@ def __init__(self, artifact_bucket: str, datasets_bucket: str, spatial_deep_zoom S3Provider(), UriProvider(), ) - super().__init__(self.business_logic, self.business_logic.uri_provider, self.business_logic.s3_provider) + super().__init__(self.business_logic, self.business_logic.uri_provider, self.business_logic.s3_provider, None) self.artifact_bucket = artifact_bucket self.datasets_bucket = datasets_bucket self.spatial_deep_zoom_dir = spatial_deep_zoom_dir @@ -67,7 +62,7 @@ def update_raw_h5ad( metadata_update: DatasetArtifactMetadataUpdate, ): raw_h5ad_filename = self.download_from_source_uri( - source_uri=raw_h5ad_uri, + source_uri=str(raw_h5ad_uri), local_path=CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME, ) try: @@ -86,8 +81,8 @@ def update_raw_h5ad( DatasetArtifactType.RAW_H5AD, new_key_prefix, new_dataset_version_id, - self.artifact_bucket, DatasetStatusKey.H5AD, + self.artifact_bucket, ) self.update_processing_status(new_dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.UPLOADED) finally: @@ -102,7 +97,7 @@ def update_h5ad( metadata_update: DatasetArtifactMetadataUpdate, ): h5ad_filename = self.download_from_source_uri( - source_uri=h5ad_uri, + source_uri=str(h5ad_uri), local_path=CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME, ) try: @@ -123,8 +118,8 @@ def update_h5ad( DatasetArtifactType.H5AD, new_key_prefix, new_dataset_version_id, - self.artifact_bucket, DatasetStatusKey.H5AD, + self.artifact_bucket, datasets_bucket=self.datasets_bucket, ) self.update_processing_status( @@ -136,48 +131,6 @@ def update_h5ad( finally: os.remove(h5ad_filename) - def update_rds( - self, - rds_uri: str, - new_key_prefix: str, - new_dataset_version_id: DatasetVersionId, - metadata_update: DatasetArtifactMetadataUpdate, - ): - seurat_filename = self.download_from_source_uri( - source_uri=rds_uri, - local_path=CorporaConstants.LABELED_RDS_ARTIFACT_FILENAME, - ) - - try: - self.update_processing_status( - new_dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.CONVERTING - ) - - rds_object = base.readRDS(seurat_filename) - - for key, val in metadata_update.as_dict_without_none_values().items(): - seurat_metadata = seurat.Misc(object=rds_object) - if seurat_metadata.rx2[key]: - val = val if isinstance(val, list) else [val] - seurat_metadata[seurat_metadata.names.index(key)] = StrVector(val) - - base.saveRDS(rds_object, file=seurat_filename) - - self.create_artifact( - seurat_filename, - DatasetArtifactType.RDS, - new_key_prefix, - new_dataset_version_id, - self.artifact_bucket, - DatasetStatusKey.RDS, - datasets_bucket=self.datasets_bucket, - ) - self.update_processing_status( - new_dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.CONVERTED - ) - finally: - os.remove(seurat_filename) - def update_cxg( self, cxg_uri: str, @@ -208,7 +161,7 @@ def update_cxg( self.update_processing_status(new_dataset_version_id, DatasetStatusKey.CXG, DatasetConversionStatus.CONVERTED) -class DatasetMetadataUpdater(ProcessDownload): +class DatasetMetadataUpdater(ProcessValidateH5AD): def __init__( self, business_logic: BusinessLogic, @@ -217,7 +170,7 @@ def __init__( datasets_bucket: str, spatial_deep_zoom_dir: str, ) -> None: - super().__init__(business_logic, business_logic.uri_provider, business_logic.s3_provider) + super().__init__(business_logic, business_logic.uri_provider, business_logic.s3_provider, None) self.artifact_bucket = artifact_bucket self.cellxgene_bucket = cellxgene_bucket self.datasets_bucket = datasets_bucket @@ -257,19 +210,6 @@ def update_h5ad( metadata_update, ) - @staticmethod - def update_rds( - artifact_bucket: str, - datasets_bucket: str, - rds_uri: str, - new_key_prefix: str, - new_dataset_version_id: DatasetVersionId, - metadata_update: DatasetArtifactMetadataUpdate, - ): - DatasetMetadataUpdaterWorker(artifact_bucket, datasets_bucket).update_rds( - rds_uri, new_key_prefix, new_dataset_version_id, metadata_update - ) - @staticmethod def update_cxg( artifact_bucket: str, @@ -331,7 +271,8 @@ def update_metadata( ) else: self.logger.info("Main: No raw h5ad update required") - self.upload_raw_h5ad(new_dataset_version_id, raw_h5ad_uri, self.artifact_bucket) + key_prefix = self.get_key_prefix(new_dataset_version_id.id) + self.upload_raw_h5ad(new_dataset_version_id, raw_h5ad_uri, self.artifact_bucket, key_prefix) if DatasetArtifactType.H5AD in artifact_uris: self.logger.info("Main: Starting thread for h5ad update") @@ -353,28 +294,8 @@ def update_metadata( self.logger.error(f"Cannot find labeled H5AD artifact uri for {current_dataset_version_id}.") self.update_processing_status(new_dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.FAILED) - if DatasetArtifactType.RDS in artifact_uris: - self.logger.info("Main: Starting thread for rds update") - rds_job = Process( - target=DatasetMetadataUpdater.update_rds, - args=( - self.artifact_bucket, - self.datasets_bucket, - artifact_uris[DatasetArtifactType.RDS], - new_artifact_key_prefix, - new_dataset_version_id, - metadata_update, - ), - ) - artifact_jobs.append(rds_job) - rds_job.start() - elif current_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED: - self.update_processing_status(new_dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) - else: - self.logger.error( - f"Cannot find RDS artifact uri for {current_dataset_version_id}, and Conversion Status is not SKIPPED." - ) - self.update_processing_status(new_dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.FAILED) + # Mark all RDS conversions as skipped + self.update_processing_status(new_dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) if DatasetArtifactType.CXG in artifact_uris: self.logger.info("Main: Starting thread for cxg update") @@ -420,6 +341,10 @@ def has_valid_artifact_statuses(self, dataset_version_id: DatasetVersionId) -> b dataset_version.status.rds_status == DatasetConversionStatus.CONVERTED or dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED ) + and ( + dataset_version.status.atac_status == DatasetConversionStatus.UPLOADED + or dataset_version.status.atac_status == DatasetConversionStatus.SKIPPED + ) ) diff --git a/backend/layers/processing/exceptions.py b/backend/layers/processing/exceptions.py index 2526de0e68db5..7e19d761eae39 100644 --- a/backend/layers/processing/exceptions.py +++ b/backend/layers/processing/exceptions.py @@ -13,8 +13,18 @@ class ProcessingCanceled(ProcessingException): @dataclass -class ValidationFailed(ProcessingException): - errors: List[str] +class ValidationAnndataFailed(ProcessingException): + def __init__(self, errors: List[str]): + self.errors = errors + + +class ValidationAtacFailed(ProcessingException): + def __init__(self, errors: List[str]): + self.errors = errors + + +class AddLabelsFailed(ProcessingException): + failed_status: DatasetStatusKey = DatasetStatusKey.H5AD class ProcessingFailed(ProcessingException): diff --git a/backend/layers/processing/h5ad_data_file.py b/backend/layers/processing/h5ad_data_file.py index 5720b31a4a558..2f66742ccc68d 100644 --- a/backend/layers/processing/h5ad_data_file.py +++ b/backend/layers/processing/h5ad_data_file.py @@ -3,9 +3,10 @@ from os import path from typing import Dict, Optional -import anndata +import dask import numpy as np import tiledb +from cellxgene_schema.utils import read_h5ad from backend.common.utils.corpora_constants import CorporaConstants from backend.common.utils.cxg_constants import CxgConstants @@ -32,6 +33,8 @@ class H5ADDataFile: tile_db_ctx_config = { "sm.consolidation.buffer_size": consolidation_buffer_size(0.1), + "sm.consolidation.step_min_frags": 2, + "sm.consolidation.step_max_frags": 20, # see https://docs.tiledb.com/main/how-to/performance/performance-tips/tuning-consolidation "py.deduplicate": True, # May reduce memory requirements at cost of performance } @@ -95,19 +98,23 @@ def to_cxg( def write_anndata_x_matrices_to_cxg(self, output_cxg_directory, ctx, sparse_threshold): matrix_container = f"{output_cxg_directory}/X" - x_matrix_data = self.anndata.X - is_sparse = is_matrix_sparse(x_matrix_data, sparse_threshold) # big memory usage - logging.info(f"is_sparse: {is_sparse}") - - convert_matrices_to_cxg_arrays(matrix_container, x_matrix_data, is_sparse, ctx) # big memory usage + with dask.config.set( + { + "num_workers": 2, # match the number of workers to the number of vCPUs + "threads_per_worker": 1, + "distributed.worker.memory.limit": "6GB", + "scheduler": "threads", + } + ): + is_sparse = is_matrix_sparse(x_matrix_data, sparse_threshold) + logging.info(f"is_sparse: {is_sparse}") + convert_matrices_to_cxg_arrays(matrix_container, x_matrix_data, is_sparse, self.tile_db_ctx_config) - suffixes = ["r", "c"] if is_sparse else [""] logging.info("start consolidating") - for suffix in suffixes: - tiledb.consolidate(matrix_container + suffix, ctx=ctx) - if hasattr(tiledb, "vacuum"): - tiledb.vacuum(matrix_container + suffix) + tiledb.consolidate(matrix_container, ctx=ctx) + if hasattr(tiledb, "vacuum"): + tiledb.vacuum(matrix_container) def write_anndata_embeddings_to_cxg(self, output_cxg_directory, ctx): def is_valid_embedding(adata, embedding_name, embedding_array): @@ -183,7 +190,7 @@ def validate_anndata(self): def extract_anndata_elements_from_file(self): logging.info(f"Reading in AnnData dataset: {path.basename(self.input_filename)}") - self.anndata = anndata.read_h5ad(self.input_filename) + self.anndata = read_h5ad(self.input_filename, chunk_size=7500) logging.info("Completed reading in AnnData dataset!") self.obs = self.transform_dataframe_index_into_column(self.anndata.obs, "obs", self.obs_index_column_name) diff --git a/backend/layers/processing/make_seurat.R b/backend/layers/processing/make_seurat.R deleted file mode 100644 index 13792d09e93f1..0000000000000 --- a/backend/layers/processing/make_seurat.R +++ /dev/null @@ -1,21 +0,0 @@ -library(sceasy) - -require(devtools) - -h5adPath <- commandArgs(trailingOnly = TRUE)[1] - -target_uns_keys <- c("schema_version", - "title", - "batch_condition", - "default_embedding", - "X_approximate_distribution", - "citation", - "schema_reference" - ) - -sceasy::convertFormat(h5adPath, - from="anndata", - to="seurat", - outFile = gsub(".h5ad", ".rds", h5adPath), - main_layer = "data", - target_uns_keys = target_uns_keys) diff --git a/backend/layers/processing/process.py b/backend/layers/processing/process.py index 6a2507ce53d13..b1d011ea4f5e5 100644 --- a/backend/layers/processing/process.py +++ b/backend/layers/processing/process.py @@ -13,20 +13,23 @@ DatasetValidationStatus, DatasetVersionId, ) +from backend.layers.common.ingestion_manifest import IngestionManifest from backend.layers.persistence.persistence import DatabaseProvider from backend.layers.processing.exceptions import ( + AddLabelsFailed, ConversionFailed, ProcessingCanceled, ProcessingFailed, UploadFailed, - ValidationFailed, + ValidationAnndataFailed, + ValidationAtacFailed, ) from backend.layers.processing.logger import configure_logging +from backend.layers.processing.process_add_labels import ProcessAddLabels from backend.layers.processing.process_cxg import ProcessCxg -from backend.layers.processing.process_download import ProcessDownload from backend.layers.processing.process_logic import ProcessingLogic -from backend.layers.processing.process_seurat import ProcessSeurat -from backend.layers.processing.process_validate import ProcessValidate +from backend.layers.processing.process_validate_atac import ProcessValidateATAC +from backend.layers.processing.process_validate_h5ad import ProcessValidateH5AD from backend.layers.processing.schema_migration import SchemaMigrate from backend.layers.thirdparty.s3_provider import S3Provider, S3ProviderInterface from backend.layers.thirdparty.schema_validator_provider import ( @@ -43,8 +46,7 @@ class ProcessMain(ProcessingLogic): Main class for the dataset pipeline processing """ - process_validate: ProcessValidate - process_seurat: ProcessSeurat + process_validate_h5ad: ProcessValidateH5AD process_cxg: ProcessCxg def __init__( @@ -59,11 +61,15 @@ def __init__( self.uri_provider = uri_provider self.s3_provider = s3_provider self.schema_validator = schema_validator - self.process_download = ProcessDownload(self.business_logic, self.uri_provider, self.s3_provider) - self.process_validate = ProcessValidate( + self.process_validate_h5ad = ProcessValidateH5AD( + self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator + ) + self.process_validate_atac_seq = ProcessValidateATAC( + self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator + ) + self.process_add_labels = ProcessAddLabels( self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator ) - self.process_seurat = ProcessSeurat(self.business_logic, self.uri_provider, self.s3_provider) self.process_cxg = ProcessCxg(self.business_logic, self.uri_provider, self.s3_provider) self.schema_migrate = SchemaMigrate(self.business_logic, self.schema_validator) @@ -81,7 +87,6 @@ def log_batch_environment(self): "MAX_ATTEMPTS", "MIGRATE", "REMOTE_DEV_PREFIX", - "TASK_TOKEN", ] env_vars = dict() for var in batch_environment_variables: @@ -93,7 +98,7 @@ def process( collection_version_id: Optional[CollectionVersionId], dataset_version_id: DatasetVersionId, step_name: str, - dropbox_uri: Optional[str], + manifest: Optional[IngestionManifest], artifact_bucket: Optional[str], datasets_bucket: Optional[str], cxg_bucket: Optional[str], @@ -103,29 +108,43 @@ def process( """ self.logger.info(f"Processing dataset version {dataset_version_id}", extra={"step_name": step_name}) try: - if step_name == "download": - self.process_download.process( - dataset_version_id, dropbox_uri, artifact_bucket, os.environ.get("TASK_TOKEN", "") + if step_name == "validate_anndata": + self.process_validate_h5ad.process(dataset_version_id, manifest, artifact_bucket) + elif step_name == "validate_atac": + self.process_validate_atac_seq.process( + collection_version_id, + dataset_version_id, + manifest, + datasets_bucket, ) - elif step_name == "validate": - self.process_validate.process( + elif step_name == "add_labels": + self.process_add_labels.process( collection_version_id, dataset_version_id, artifact_bucket, datasets_bucket ) elif step_name == "cxg": self.process_cxg.process(dataset_version_id, artifact_bucket, cxg_bucket) elif step_name == "cxg_remaster": self.process_cxg.process(dataset_version_id, artifact_bucket, cxg_bucket, is_reprocess=True) - elif step_name == "seurat": - self.process_seurat.process(dataset_version_id, artifact_bucket, datasets_bucket) else: self.logger.error(f"Step function configuration error: Unexpected STEP_NAME '{step_name}'") # TODO: this could be better - maybe collapse all these exceptions and pass in the status key and value except ProcessingCanceled: pass # TODO: what's the effect of canceling a dataset now? - except ValidationFailed as e: + except ValidationAnndataFailed as e: self.update_processing_status( - dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.INVALID, e.errors + dataset_version_id, + DatasetStatusKey.VALIDATION, + DatasetValidationStatus.INVALID, + validation_errors=e.errors, + ) + return False + except ValidationAtacFailed as e: + self.update_processing_status( + dataset_version_id, + DatasetStatusKey.VALIDATION, + DatasetValidationStatus.INVALID, + validation_errors=e.errors, ) return False except ProcessingFailed: @@ -133,6 +152,9 @@ def process( dataset_version_id, DatasetStatusKey.PROCESSING, DatasetProcessingStatus.FAILURE ) return False + except AddLabelsFailed as e: + self.update_processing_status(dataset_version_id, e.failed_status, DatasetConversionStatus.FAILED) + return False except UploadFailed: self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.FAILED) return False @@ -141,12 +163,12 @@ def process( return False except Exception as e: self.logger.exception(f"An unexpected error occurred while processing the data set: {e}") - if step_name in ["validate", "download"]: + if step_name in ["validate_anndata"]: self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.FAILED) - elif step_name == "seurat": - self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.FAILED) - elif step_name == "cxg" or step_name == "cxg_remaster": + elif step_name in ["cxg", "cxg_remaster"]: self.update_processing_status(dataset_version_id, DatasetStatusKey.CXG, DatasetConversionStatus.FAILED) + elif step_name == "add_labels": + self.update_processing_status(dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.FAILED) return False return True @@ -160,7 +182,8 @@ def main(self): else: dataset_version_id = os.environ["DATASET_VERSION_ID"] collection_version_id = os.environ.get("COLLECTION_VERSION_ID") - dropbox_uri = os.environ.get("DROPBOX_URL") + if manifest := os.environ.get("MANIFEST"): + manifest = IngestionManifest.model_validate_json(manifest) artifact_bucket = os.environ.get("ARTIFACT_BUCKET") datasets_bucket = os.environ.get("DATASETS_BUCKET") cxg_bucket = os.environ.get("CELLXGENE_BUCKET") @@ -170,7 +193,7 @@ def main(self): ), dataset_version_id=DatasetVersionId(dataset_version_id), step_name=step_name, - dropbox_uri=dropbox_uri, + manifest=manifest, artifact_bucket=artifact_bucket, datasets_bucket=datasets_bucket, cxg_bucket=cxg_bucket, diff --git a/backend/layers/processing/process_validate.py b/backend/layers/processing/process_add_labels.py similarity index 70% rename from backend/layers/processing/process_validate.py rename to backend/layers/processing/process_add_labels.py index dd3987470e156..db589383cdfbc 100644 --- a/backend/layers/processing/process_validate.py +++ b/backend/layers/processing/process_add_labels.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional +import h5py import numpy -import scanpy +from cellxgene_schema.utils import read_h5ad from backend.common.utils.corpora_constants import CorporaConstants from backend.layers.business.business_interface import BusinessLogicInterface @@ -17,7 +18,7 @@ SpatialMetadata, TissueOntologyTermId, ) -from backend.layers.processing.exceptions import ValidationFailed +from backend.layers.processing.exceptions import AddLabelsFailed from backend.layers.processing.logger import logit from backend.layers.processing.process_logic import ProcessingLogic from backend.layers.thirdparty.s3_provider import S3ProviderInterface @@ -25,15 +26,17 @@ from backend.layers.thirdparty.uri_provider import UriProviderInterface -class ProcessValidate(ProcessingLogic): +class ProcessAddLabels(ProcessingLogic): """ - Base class for handling the `Validate` step of the step function. + Base class for handling the `add label` step of the step function. This will: - 1. Download the original artifact from the provided URI - 2. Run the cellxgene-schema validator - 3. Save and upload a labeled copy of the original artifact (local.h5ad) - 5. Persist the dataset metadata on the database - 6. Determine if a Seurat conversion is possible (it is not if the X matrix has more than 2**32-1 nonzero values) + 1. Download the h5ad artifact + 2. Add labels to h5ad using cellxgene-schema + 3. Persist the dataset metadata on the database + 4. upload the labeled file to S3 + 5. set DatasetStatusKey.H5AD status to DatasetUploadStatus.UPLOADED + + If this step completes successfully, ProcessCxg and ProcessSeurat will start in parallel. If this step fails, the handle_failures lambda will be invoked. """ @@ -54,42 +57,21 @@ def __init__( self.schema_validator = schema_validator @logit - def validate_h5ad_file_and_add_labels( + def add_labels( self, collection_version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, local_filename: str - ) -> Tuple[str, bool]: + ) -> str: """ - Validates and labels the specified dataset file and updates the processing status in the database + labels the specified dataset file and updates the processing status in the database :param dataset_version_id: version ID of the dataset to update :param collection_version_id: version ID of the collection dataset is being uploaded to :param local_filename: file name of the dataset to validate and label :return: file name of labeled dataset, boolean indicating if seurat conversion is possible """ - # TODO: use a provider here - - self.update_processing_status( - dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALIDATING - ) - - output_filename = CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME - try: - is_valid, errors, can_convert_to_seurat = self.schema_validator.validate_and_save_labels( - local_filename, output_filename - ) - except Exception as e: - self.logger.exception("validation failed") - raise ValidationFailed([str(e)]) from None - - if not is_valid: - raise ValidationFailed(errors) - else: - self.populate_dataset_citation(collection_version_id, dataset_version_id, output_filename) - - # TODO: optionally, these could be batched into one - self.update_processing_status(dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.CONVERTED) - self.update_processing_status( - dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALID - ) - return output_filename, can_convert_to_seurat + output_filename = self.get_file_path(CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME) + self.schema_validator.add_labels(local_filename, output_filename) + self.populate_dataset_citation(collection_version_id, dataset_version_id, output_filename) + self.update_processing_status(dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.CONVERTED) + return output_filename def populate_dataset_citation( self, collection_version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, adata_path: str @@ -104,9 +86,8 @@ def populate_dataset_citation( collection = self.business_logic.get_collection_version(collection_version_id, get_tombstoned=False) doi = next((link.uri for link in collection.metadata.links if link.type == "DOI"), None) citation = self.business_logic.generate_dataset_citation(collection.collection_id, dataset_version_id, doi) - adata = scanpy.read_h5ad(adata_path) - adata.uns["citation"] = citation - adata.write(adata_path, compression="gzip") + with h5py.File(adata_path, "r+") as f: + f["uns"].create_dataset("citation", data=citation) def get_spatial_metadata(self, spatial_dict: Dict[str, Any]) -> Optional[SpatialMetadata]: """ @@ -128,28 +109,17 @@ def get_spatial_metadata(self, spatial_dict: Dict[str, Any]) -> Optional[Spatial def extract_metadata(self, filename) -> DatasetMetadata: """Pull metadata out of the AnnData file to insert into the dataset table.""" - adata = scanpy.read_h5ad(filename, backed="r") + adata = read_h5ad(filename) - # TODO: Concern with respect to previous use of raising error when there is no raw layer. - # This new way defaults to adata.X. layer_for_mean_genes_per_cell = adata.raw.X if adata.raw is not None and adata.raw.X is not None else adata.X # For mean_genes_per_cell, we only want the columns (genes) that have a feature_biotype of `gene`, # as opposed to `spike-in` filter_gene_vars = numpy.where(adata.var.feature_biotype == "gene")[0] - - # Calling np.count_nonzero on and h5py.Dataset appears to read the entire thing - # into memory, so we need to chunk it to be safe. - stride = 50000 - numerator, denominator = 0, 0 - for bounds in zip( - range(0, layer_for_mean_genes_per_cell.shape[0], stride), - range(stride, layer_for_mean_genes_per_cell.shape[0] + stride, stride), - strict=False, - ): - chunk = layer_for_mean_genes_per_cell[bounds[0] : bounds[1], :][:, filter_gene_vars] - numerator += chunk.nnz if hasattr(chunk, "nnz") else numpy.count_nonzero(chunk) - denominator += chunk.shape[0] + filtered_matrix = layer_for_mean_genes_per_cell[:, filter_gene_vars] + nnz_gene_exp = self.schema_validator.count_matrix_nonzero(filtered_matrix) + total_cells = layer_for_mean_genes_per_cell.shape[0] + mean_genes_per_cell = nnz_gene_exp / total_cells def _get_term_pairs(base_term) -> List[OntologyTermId]: base_term_id = base_term + "_ontology_term_id" @@ -196,7 +166,7 @@ def _get_batch_condition() -> Optional[str]: development_stage=_get_term_pairs("development_stage"), cell_count=adata.shape[0], primary_cell_count=int(adata.obs["is_primary_data"].astype("int").sum()), - mean_genes_per_cell=numerator / denominator, + mean_genes_per_cell=mean_genes_per_cell, is_primary_data=_get_is_primary_data(), cell_type=_get_term_pairs("cell_type"), x_approximate_distribution=_get_x_approximate_distribution(), @@ -232,31 +202,31 @@ def process( :param datasets_bucket: :return: """ + self.update_processing_status(dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALID) # Download the original dataset from S3 key_prefix = self.get_key_prefix(dataset_version_id.id) - original_h5ad_artifact_file_name = CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME - object_key = f"{key_prefix}/{original_h5ad_artifact_file_name}" + original_h5ad_artifact_file_name = self.get_file_path(CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME) + object_key = f"{key_prefix}/{CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME}" self.download_from_s3(artifact_bucket, object_key, original_h5ad_artifact_file_name) - # Validate and label the dataset - file_with_labels, can_convert_to_seurat = self.validate_h5ad_file_and_add_labels( - collection_version_id, dataset_version_id, original_h5ad_artifact_file_name - ) + # label the dataset + try: + file_with_labels = self.add_labels( + collection_version_id, dataset_version_id, original_h5ad_artifact_file_name + ) + except Exception as e: + self.logger.exception(f"An unexpected error occurred while adding labels to the data set: {e}") + raise AddLabelsFailed() from e # Process metadata metadata = self.extract_metadata(file_with_labels) self.business_logic.set_dataset_metadata(dataset_version_id, metadata) - - if not can_convert_to_seurat: - self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) - self.logger.info(f"Skipping Seurat conversion for dataset {dataset_version_id}") - # Upload the labeled dataset to the artifact bucket self.create_artifact( file_with_labels, DatasetArtifactType.H5AD, key_prefix, dataset_version_id, - artifact_bucket, DatasetStatusKey.H5AD, + artifact_bucket, datasets_bucket=datasets_bucket, ) diff --git a/backend/layers/processing/process_cxg.py b/backend/layers/processing/process_cxg.py index 32e069d335b81..2600463bc06f4 100644 --- a/backend/layers/processing/process_cxg.py +++ b/backend/layers/processing/process_cxg.py @@ -19,7 +19,7 @@ class ProcessCxg(ProcessingLogic): 1. Download the labeled h5ad artifact from S3 (uploaded by DownloadAndValidate) 2. Convert to cxg 3. Upload the cxg artifact (a directory) to S3 - If this step completes successfully, and ProcessSeurat is completed, the handle_success lambda will be invoked + If this step completes successfully, the handle_success lambda will be invoked If this step fails, the handle_failures lambda will be invoked """ @@ -48,7 +48,7 @@ def process( :return: """ - labeled_h5ad_filename = "local.h5ad" + labeled_h5ad_filename = self.get_file_path("local.h5ad") # Download the labeled dataset from the artifact bucket object_key = None @@ -91,9 +91,7 @@ def copy_cxg_files_to_cxg_bucket(self, cxg_dir, s3_uri): self.s3_provider.upload_directory(cxg_dir, s3_uri) def process_cxg(self, local_filename, dataset_version_id, cellxgene_bucket, current_artifacts=None): - cxg_dir = self.convert_file( - self.make_cxg, local_filename, "Issue creating cxg.", dataset_version_id, DatasetStatusKey.CXG - ) + cxg_dir = self.convert_file(self.make_cxg, local_filename, dataset_version_id, DatasetStatusKey.CXG) s3_uri = None if current_artifacts: existing_cxg = [artifact for artifact in current_artifacts if artifact.type == DatasetArtifactType.CXG][0] diff --git a/backend/layers/processing/process_download.py b/backend/layers/processing/process_download.py deleted file mode 100644 index 4c99acf848a13..0000000000000 --- a/backend/layers/processing/process_download.py +++ /dev/null @@ -1,203 +0,0 @@ -import json -import os -from math import ceil -from typing import Any, Dict, Optional - -import scanpy - -from backend.common.corpora_config import CorporaConfig -from backend.common.utils.corpora_constants import CorporaConstants -from backend.common.utils.dl_sources.uri import DownloadFailed -from backend.common.utils.math_utils import MB -from backend.layers.business.business_interface import BusinessLogicInterface -from backend.layers.common.entities import ( - DatasetArtifactType, - DatasetProcessingStatus, - DatasetStatusKey, - DatasetUploadStatus, - DatasetVersionId, -) -from backend.layers.processing.exceptions import UploadFailed -from backend.layers.processing.logger import logit -from backend.layers.processing.process_logic import ProcessingLogic -from backend.layers.thirdparty.s3_provider_interface import S3ProviderInterface -from backend.layers.thirdparty.step_function_provider import StepFunctionProvider -from backend.layers.thirdparty.uri_provider import UriProviderInterface - - -class ProcessDownload(ProcessingLogic): - """ - Base class for handling the `Download` step of the step function. - This will: - 1. Download the original artifact from the provided URI - 2. estimate memory requirements - 4. Upload a copy of the original artifact (raw.h5ad) - - """ - - def __init__( - self, - business_logic: BusinessLogicInterface, - uri_provider: UriProviderInterface, - s3_provider: S3ProviderInterface, - config: Optional[CorporaConfig] = None, - ) -> None: - super().__init__() - self.business_logic = business_logic - self.uri_provider = uri_provider - self.s3_provider = s3_provider - self.config = config or CorporaConfig() - - @logit - def download_from_source_uri(self, source_uri: str, local_path: str) -> str: - """Given a source URI, download it to local_path. - Handles fixing the url so it downloads directly. - """ - file_url = self.uri_provider.parse(source_uri) - if not file_url: - raise ValueError(f"Malformed source URI: {source_uri}") - try: - file_url.download(local_path) - except DownloadFailed as e: - raise UploadFailed(f"Failed to download file from source URI: {source_uri}") from e - return local_path - - # TODO: after upgrading to Python 3.9, replace this with removeprefix() - @staticmethod - def remove_prefix(string: str, prefix: str) -> str: - if string.startswith(prefix): - return string[len(prefix) :] - else: - return string - - @staticmethod - def get_job_definion_name(dataset_version_id: str) -> str: - if os.getenv("REMOTE_DEV_PREFIX"): - stack_name = os.environ["REMOTE_DEV_PREFIX"].replace("/", "") - prefix = f"{os.environ['DEPLOYMENT_STAGE']}-{stack_name}" - else: - prefix = f"{os.environ['DEPLOYMENT_STAGE']}" - job_definition_name = f"dp-{prefix}-ingest-process-{dataset_version_id}" - return job_definition_name - - def estimate_resource_requirements( - self, - adata: scanpy.AnnData, - memory_modifier: Optional[float] = None, - min_vcpu: Optional[int] = None, - max_vcpu: Optional[int] = None, - max_swap_memory_MB: Optional[int] = None, - swap_modifier: Optional[int] = None, - memory_per_vcpu: int = 8000, - ) -> Dict[str, int]: - """ - Estimate the resource requirements for a given dataset - - :param adata: The datasets AnnData object - :param memory_modifier: A multiplier to increase/decrease the memory requirements by - :param min_vcpu: The minimum number of vCPUs to allocate. - :param max_vcpu: The maximum number of vCPUs to allocate. - :param memory_per_vcpu: The amount of memory to allocate per vCPU. 8000 MB is what AWS uses as the ratio - :param swap_modifier: The multiplier to increase/decrease the swap memory requirements by - :param max_swap_memory_MB: The maximum amount of swap memory to allocate. - :return: A dictionary containing the resource requirements - """ - memory_modifier = memory_modifier or self.config.ingest_memory_modifier - min_vcpu = min_vcpu or self.config.ingest_min_vcpu - max_vcpu = max_vcpu or self.config.ingest_max_vcpu - max_swap_memory_MB = max_swap_memory_MB or self.config.ingest_max_swap_memory_mb - swap_modifier = swap_modifier or self.config.ingest_swap_modifier - - # Note: this is a rough estimate of the uncompressed size of the dataset. This method avoid loading the entire - # dataset into memory. - min_memory_MB = min_vcpu * memory_per_vcpu - max_memory_MB = max_vcpu * memory_per_vcpu - uncompressed_size_MB = adata.n_obs * adata.n_vars / MB - estimated_memory_MB = max([int(ceil(uncompressed_size_MB * memory_modifier)), min_memory_MB]) - vcpus = max_vcpu if estimated_memory_MB > max_memory_MB else int(ceil(estimated_memory_MB / memory_per_vcpu)) - memory = memory_per_vcpu * vcpus # round up to nearest memory_per_vcpu - max_swap = min([max_swap_memory_MB, memory * swap_modifier]) - self.logger.info( - { - "message": "Estimated resource requirements", - "memory_modifier": memory_modifier, - "swap_modifier": swap_modifier, - "min_vcpu": min_vcpu, - "max_vcpu": max_vcpu, - "max_swap_memory_MB": max_swap_memory_MB, - "memory_per_vcpu": memory_per_vcpu, - "uncompressed_size_MB": uncompressed_size_MB, - "max_swap": max_swap, - "memory": memory, - "vcpus": vcpus, - } - ) - - return {"Vcpus": vcpus, "Memory": memory, "MaxSwap": max_swap} - - def create_batch_job_definition_parameters(self, local_filename: str, dataset_version_id: str) -> Dict[str, Any]: - adata = scanpy.read_h5ad(local_filename, backed="r") - batch_resources = self.estimate_resource_requirements(adata) - job_definition_name = self.get_job_definion_name(dataset_version_id) - - return { # Using PascalCase to match the Batch API - "JobDefinitionName": job_definition_name, - "Vcpus": batch_resources["Vcpus"], - "Memory": batch_resources["Memory"], - "LinuxParameters": { - "Swappiness": 60, - "MaxSwap": batch_resources["MaxSwap"], - }, - } - - def upload_raw_h5ad(self, dataset_version_id: DatasetVersionId, dataset_uri: str, artifact_bucket: str) -> str: - """ - Upload raw h5ad from dataset_uri to artifact bucket - - :param dataset_version_id: - :param dataset_uri: - :param artifact_bucket: - :return: local_filename: Local filepath to raw h5ad - """ - self.update_processing_status(dataset_version_id, DatasetStatusKey.PROCESSING, DatasetProcessingStatus.PENDING) - - # Download the original dataset from Dropbox - local_filename = self.download_from_source_uri( - source_uri=dataset_uri, - local_path=CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME, - ) - - key_prefix = self.get_key_prefix(dataset_version_id.id) - # Upload the original dataset to the artifact bucket - self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.UPLOADING) - self.create_artifact( - local_filename, - DatasetArtifactType.RAW_H5AD, - key_prefix, - dataset_version_id, - artifact_bucket, - DatasetStatusKey.H5AD, - ) - self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.UPLOADED) - - return local_filename - - def process( - self, dataset_version_id: DatasetVersionId, dataset_uri: str, artifact_bucket: str, sfn_task_token: str - ): - """ - Process the download step of the step function--download raw H5AD locally, upload to artifact bucket, and report - processing job memory estimates to Step Function - - :param dataset_version_id: - :param dataset_uri: - :param artifact_bucket: - :param sfn_task_token: use to report back the memory requirements, if called in a step function - """ - local_filename = self.upload_raw_h5ad(dataset_version_id, dataset_uri, artifact_bucket) - - response = self.create_batch_job_definition_parameters(local_filename, dataset_version_id.id) - self.logger.info(response) - - sfn_client = StepFunctionProvider().client - sfn_client.send_task_success(taskToken=sfn_task_token, output=json.dumps(response)) diff --git a/backend/layers/processing/process_logic.py b/backend/layers/processing/process_logic.py index 338948bbf6681..ff4b635b1d2eb 100644 --- a/backend/layers/processing/process_logic.py +++ b/backend/layers/processing/process_logic.py @@ -1,17 +1,22 @@ import logging +import os +import tempfile from datetime import datetime -from os.path import basename, join +from os.path import basename from typing import Callable, List, Optional +from backend.common.utils.dl_sources.uri import DownloadFailed from backend.layers.business.business_interface import BusinessLogicInterface from backend.layers.common.entities import ( + ARTIFACT_TO_EXTENSION, + DatasetArtifactType, DatasetConversionStatus, DatasetStatusGeneric, DatasetStatusKey, DatasetVersion, DatasetVersionId, ) -from backend.layers.processing.exceptions import ConversionFailed +from backend.layers.processing.exceptions import ConversionFailed, UploadFailed from backend.layers.processing.logger import logit from backend.layers.thirdparty.s3_provider import S3ProviderInterface from backend.layers.thirdparty.uri_provider import UriProviderInterface @@ -30,6 +35,11 @@ class ProcessingLogic: # TODO: ProcessingLogicBase def __init__(self) -> None: self.logger = logging.getLogger("processing") + # Store all artifacts in the temp directory so they are cleanup up automatically + self.path = tempfile.TemporaryDirectory() + + def get_file_path(self, *args) -> str: + return os.path.join(self.path.name, *args) def update_processing_status( self, @@ -52,61 +62,63 @@ def update_processing_status( ), ) + @logit + def download_from_source_uri(self, source_uri: str, local_path: str) -> str: + """Given a source URI, download it to local_path. + Handles fixing the url so it downloads directly. + """ + file_url = self.uri_provider.parse(source_uri) + if not file_url: + raise ValueError(f"Malformed source URI: {source_uri}") + try: + file_url.download(local_path) + except DownloadFailed as e: + raise UploadFailed(f"Failed to download file from source URI: {source_uri}") from e + return local_path + def download_from_s3(self, bucket_name: str, object_key: str, local_filename: str): self.s3_provider.download_file(bucket_name, object_key, local_filename) - @staticmethod - def make_s3_uri(artifact_bucket, key_prefix, file_name): - return join("s3://", artifact_bucket, key_prefix, file_name) - - def upload_artifact( - self, - file_name: str, - key_prefix: str, - artifact_bucket: str, - ) -> str: - file_base = basename(file_name) + def upload_artifact(self, file_name: str, key: str, bucket_name: str) -> str: self.s3_provider.upload_file( file_name, - artifact_bucket, - join(key_prefix, file_base), + bucket_name, + key, extra_args={"ACL": "bucket-owner-full-control"}, ) - return self.make_s3_uri(artifact_bucket, key_prefix, file_base) + return "/".join(["s3:/", bucket_name, key]) @logit def create_artifact( self, file_name: str, - artifact_type: str, + artifact_type: DatasetArtifactType, key_prefix: str, dataset_version_id: DatasetVersionId, - artifact_bucket: str, processing_status_key: DatasetStatusKey, + artifact_bucket: str, # If provided, dataset will be uploaded to this bucket for future migrations datasets_bucket: Optional[str] = None, # If provided, dataset will be uploaded to this bucket for public access ): self.update_processing_status(dataset_version_id, processing_status_key, DatasetConversionStatus.UPLOADING) try: - s3_uri = self.upload_artifact(file_name, key_prefix, artifact_bucket) + key = "/".join([key_prefix, basename(file_name)]) + s3_uri = self.upload_artifact(file_name, key, artifact_bucket) self.logger.info(f"Uploaded [{dataset_version_id}/{file_name}] to {s3_uri}") self.business_logic.add_dataset_artifact(dataset_version_id, artifact_type, s3_uri) self.logger.info(f"Updated database with {artifact_type}.") if datasets_bucket: - key = ".".join((key_prefix, artifact_type)) - self.s3_provider.upload_file( - file_name, datasets_bucket, key, extra_args={"ACL": "bucket-owner-full-control"} - ) - datasets_s3_uri = self.make_s3_uri(datasets_bucket, key_prefix, key) - self.logger.info(f"Uploaded {dataset_version_id}.{artifact_type} to {datasets_s3_uri}") + key = ".".join([key_prefix, ARTIFACT_TO_EXTENSION[artifact_type]]) + s3_uri = self.upload_artifact(file_name, key, datasets_bucket) + self.logger.info(f"Uploaded [{dataset_version_id}/{file_name}] to {s3_uri}") self.update_processing_status(dataset_version_id, processing_status_key, DatasetConversionStatus.UPLOADED) - except Exception: + except Exception as e: + self.logger.error(e) raise ConversionFailed(processing_status_key) from None def convert_file( self, converter: Callable, local_filename: str, - error_message: str, dataset_version_id: DatasetVersionId, processing_status_key: DatasetStatusKey, ) -> str: @@ -122,11 +134,10 @@ def convert_file( return file_dir def get_key_prefix(self, identifier: str) -> str: - import os remote_dev_prefix = os.environ.get("REMOTE_DEV_PREFIX", "") if remote_dev_prefix: - return join(remote_dev_prefix, identifier).strip("/") + return "/".join([remote_dev_prefix, identifier]).strip("/") else: return identifier diff --git a/backend/layers/processing/process_seurat.py b/backend/layers/processing/process_seurat.py deleted file mode 100644 index d0a4fc7dddbf9..0000000000000 --- a/backend/layers/processing/process_seurat.py +++ /dev/null @@ -1,132 +0,0 @@ -import logging -import os -import subprocess - -import anndata - -from backend.layers.business.business_interface import BusinessLogicInterface -from backend.layers.common.entities import ( - DatasetArtifactType, - DatasetConversionStatus, - DatasetStatusKey, - DatasetVersionId, -) -from backend.layers.processing.logger import logit -from backend.layers.processing.process_logic import ProcessingLogic -from backend.layers.processing.utils.matrix_utils import enforce_canonical_format -from backend.layers.processing.utils.rds_citation_from_h5ad import rds_citation_from_h5ad -from backend.layers.thirdparty.s3_provider import S3ProviderInterface -from backend.layers.thirdparty.uri_provider import UriProviderInterface - -logger: logging.Logger = logging.getLogger("processing") - - -class ProcessSeurat(ProcessingLogic): - """ - Base class for handling the `Process Seurat` step of the step function. - This will: - 1. Determine if a Seurat conversion is possible (this is set by the previous step - DownloadAndValidate) - 2. Download the labeled h5ad artifact from S3 (uploaded by DownloadAndValidate) - 3. Convert to RDS - 4. Upload the RDS artifact to S3 - If this step completes successfully, and ProcessCxg is completed, the handle_success lambda will be invoked - If this step fails, the handle_failures lambda will be invoked - """ - - def __init__( - self, - business_logic: BusinessLogicInterface, - uri_provider: UriProviderInterface, - s3_provider: S3ProviderInterface, - ) -> None: - super().__init__() - self.business_logic = business_logic - self.uri_provider = uri_provider - self.s3_provider = s3_provider - - def process(self, dataset_version_id: DatasetVersionId, artifact_bucket: str, datasets_bucket: str): - """ - 1. Download the labeled dataset from the artifact bucket - 2. Convert it to Seurat format - 3. Upload the Seurat file to the artifact bucket - :param dataset_version_id: - :param artifact_bucket: - :param datasets_bucket: - :return: - """ - - # If the validator previously marked the dataset as rds_status.SKIPPED, do not start the Seurat processing - dataset = self.business_logic.get_dataset_version(dataset_version_id) - - if dataset is None: - raise Exception("Dataset not found") # TODO: maybe improve - - if dataset.status.rds_status == DatasetConversionStatus.SKIPPED: - self.logger.info("Skipping Seurat conversion") - return - - # Download h5ad locally - labeled_h5ad_filename = "local.h5ad" - key_prefix = self.get_key_prefix(dataset_version_id.id) - object_key = f"{key_prefix}/{labeled_h5ad_filename}" - self.download_from_s3(artifact_bucket, object_key, labeled_h5ad_filename) - - # Convert the citation from h5ad to RDS - adata = anndata.read_h5ad(labeled_h5ad_filename) - if "citation" in adata.uns: - adata.uns["citation"] = rds_citation_from_h5ad(adata.uns["citation"]) - - # enforce for canonical - logger.info("enforce canonical format in X") - enforce_canonical_format(adata) - if adata.raw: - logger.info("enforce canonical format in raw.X") - enforce_canonical_format(adata.raw) - - adata.write_h5ad(labeled_h5ad_filename) - - # Use Seurat to convert to RDS - seurat_filename = self.convert_file( - self.make_seurat, - labeled_h5ad_filename, - "Failed to convert dataset to Seurat format.", - dataset_version_id, - DatasetStatusKey.RDS, - ) - - self.create_artifact( - seurat_filename, - DatasetArtifactType.RDS, - key_prefix, - dataset_version_id, - artifact_bucket, - DatasetStatusKey.RDS, - datasets_bucket=datasets_bucket, - ) - - @logit - def make_seurat(self, local_filename, *args, **kwargs): - """ - Create a Seurat rds file from the AnnData file. - """ - try: - completed_process = subprocess.run( - ["Rscript", "-e", "\"installed.packages()[, c('Package', 'Version')]\""], capture_output=True - ) - logger.debug({"stdout": completed_process.stdout, "args": completed_process.args}) - - subprocess.run( - [ - "Rscript", - os.path.join(os.path.abspath(os.path.dirname(__file__)), "make_seurat.R"), - local_filename, - ], - capture_output=True, - check=True, - ) - except subprocess.CalledProcessError as ex: - msg = f"Seurat conversion failed: {ex.output} {ex.stderr}" - self.logger.exception(msg) - raise RuntimeError(msg) from ex - - return local_filename.replace(".h5ad", ".rds") diff --git a/backend/layers/processing/process_validate_atac.py b/backend/layers/processing/process_validate_atac.py new file mode 100644 index 0000000000000..f880e364f68df --- /dev/null +++ b/backend/layers/processing/process_validate_atac.py @@ -0,0 +1,218 @@ +import hashlib +import os + +from backend.common.utils.corpora_constants import CorporaConstants +from backend.layers.business.business_interface import BusinessLogicInterface +from backend.layers.common.entities import ( + ARTIFACT_TO_EXTENSION, + CollectionVersionId, + DatasetArtifactId, + DatasetArtifactType, + DatasetConversionStatus, + DatasetStatusKey, + DatasetValidationStatus, + DatasetVersionId, +) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.exceptions import ConversionFailed, ValidationAtacFailed +from backend.layers.processing.process_logic import ProcessingLogic +from backend.layers.thirdparty.s3_provider_interface import S3ProviderInterface +from backend.layers.thirdparty.schema_validator_provider import SchemaValidatorProviderInterface +from backend.layers.thirdparty.uri_provider import UriProviderInterface + + +class ProcessValidateATAC(ProcessingLogic): + def __init__( + self, + business_logic: BusinessLogicInterface, + uri_provider: UriProviderInterface, + s3_provider: S3ProviderInterface, + schema_validator: SchemaValidatorProviderInterface, + ) -> None: + super().__init__() + self.business_logic = business_logic + self.uri_provider = uri_provider + self.s3_provider = s3_provider + self.schema_validator = schema_validator + + def create_atac_artifact( + self, + file_name: str, + artifact_type: DatasetArtifactType, + dataset_version_id: DatasetVersionId, + datasets_bucket: str, + fragment_artifact_id: DatasetArtifactId = None, + ) -> DatasetArtifactId: + """ + Uploads the file to S3 and updates the database with the artifact using the artifact_id as the prefix for the + key. + :param file_name: the local file to upload + :param artifact_type: the type of artifact to upload + :param dataset_version_id: the dataset version id + :param datasets_bucket: the bucket to upload the dataset to + :param fragment_artifact_id: the artifact id of the fragment file, to be used in the fragment index filepath for storage + :return: + """ + self.update_processing_status(dataset_version_id, DatasetStatusKey.ATAC, DatasetConversionStatus.UPLOADING) + try: + artifact_id = DatasetArtifactId() + if fragment_artifact_id: + key_prefix = self.get_key_prefix(fragment_artifact_id.id) + else: + key_prefix = self.get_key_prefix(artifact_id.id) + key = f"{key_prefix}-fragment.{ARTIFACT_TO_EXTENSION[artifact_type]}" + datasets_s3_uri = self.upload_artifact(file_name, key, datasets_bucket) + self.logger.info(f"Uploaded [{dataset_version_id}/{artifact_type}] to {datasets_s3_uri}") + self.business_logic.add_dataset_artifact(dataset_version_id, artifact_type, datasets_s3_uri, artifact_id) + self.logger.info(f"Updated database with {artifact_type}.") + return artifact_id + except Exception as e: + self.logger.error(e) + raise ConversionFailed( + DatasetStatusKey.ATAC, + ) from None + + def skip_atac_validation( + self, local_anndata_filename: str, manifest: IngestionManifest, dataset_version_id + ) -> bool: + """ + Check if atac validation should be skipped + :param local_anndata_filename: the local anndata file + :param manifest: the manifest + :param dataset_version_id: the dataset version id + :return: True if the validation should be skipped, False otherwise + """ + # check if the anndata should have a fragment file + try: + result = self.schema_validator.check_anndata_requires_fragment(local_anndata_filename) + except ValueError as e: # fragment file forbidden + self.logger.warning(f"Anndata does not support atac fragment files for the following reason: {e}") + if manifest.atac_fragment: + self.update_processing_status( + dataset_version_id, + DatasetStatusKey.VALIDATION, + DatasetValidationStatus.INVALID, + ) + raise ValidationAtacFailed(errors=[str(e), "Fragment file not allowed for non atac anndata."]) from None + self.logger.warning("Fragment validation not applicable for dataset assay type.") + self.update_processing_status( + dataset_version_id, + DatasetStatusKey.ATAC, + DatasetConversionStatus.NA, + validation_errors=[str(e)], + ) + return True + + if manifest.atac_fragment is None: + if result: # fragment file required + self.update_processing_status( + dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.INVALID + ) + raise ValidationAtacFailed(errors=["Anndata requires fragment file"]) + else: # fragment file optional + self.logger.info("Fragment is optional and not present. Skipping fragment validation.") + self.update_processing_status( + dataset_version_id, + DatasetStatusKey.ATAC, + DatasetConversionStatus.SKIPPED, + validation_errors=["Fragment is optional and not present."], + ) + return True + return False + + def hash_file(self, file_name: str) -> str: + """ + Hash the file + :param file_name: the file to hash + :return: the hash + """ + hashobj = hashlib.md5() + buffer = bytearray(2**18) + view = memoryview(buffer) + with open(file_name, "rb") as f: + while chunk := f.readinto(buffer): + hashobj.update(view[:chunk]) + return hashobj.hexdigest() + + def process( + self, + collection_version_id: CollectionVersionId, + dataset_version_id: DatasetVersionId, + manifest: IngestionManifest, + datasets_bucket: str, + ): + """ + + :param collection_version_id: + :param dataset_version_id: + :param manifest: + :param datasets_bucket: + :return: + """ + # Download the original dataset files from URI + local_anndata_filename = self.download_from_source_uri( + source_uri=str(manifest.anndata), + local_path=self.get_file_path(CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME), + ) + + if self.skip_atac_validation(local_anndata_filename, manifest, dataset_version_id): + return + + # Download the original fragment file from URI + local_fragment_filename = self.download_from_source_uri( + source_uri=str(manifest.atac_fragment), + local_path=self.get_file_path(CorporaConstants.ORIGINAL_ATAC_FRAGMENT_FILENAME), + ) + + # Validate the fragment with anndata file + try: + errors, fragment_index_file, fragment_file = self.schema_validator.validate_atac( + local_fragment_filename, local_anndata_filename, CorporaConstants.NEW_ATAC_FRAGMENT_FILENAME + ) + except Exception as e: + # for unexpected errors, log the exception and raise a ValidationAtacFailed exception + self.logger.exception("validation failed") + self.update_processing_status(dataset_version_id, DatasetStatusKey.ATAC, DatasetConversionStatus.FAILED) + raise ValidationAtacFailed(errors=[str(e)]) from None + + if errors: + # if the validation fails, update the processing status and raise a ValidationAtacFailed exception + self.update_processing_status(dataset_version_id, DatasetStatusKey.ATAC, DatasetConversionStatus.FAILED) + raise ValidationAtacFailed(errors=errors) + + # Changes to processing only happen during a migration. Only hash the files if the migration is set to true + in_migration = os.environ.get("MIGRATION", "").lower() == "true" + if in_migration: + # check if the new fragment is the same as the old fragment + fragment_unchanged = self.hash_file(local_fragment_filename) == self.hash_file(fragment_file) + else: + fragment_unchanged = False + + # fragment file to avoid uploading the same file multiple times + # if the fragment file is unchanged from a migration or the fragment file is already ingested, use the old fragment. + if fragment_unchanged or (self.business_logic.is_already_ingested(manifest.atac_fragment) and not in_migration): + # get the artifact id of the old fragment, and add it to the new dataset + artifact_name = str(manifest.atac_fragment).split("/")[-1] + artifact = self.business_logic.database_provider.get_artifact_by_uri_suffix(artifact_name) + self.business_logic.database_provider.add_artifact_to_dataset_version(dataset_version_id, artifact.id) + # get the artifact id of the old fragment index, and add it to the new dataset + artifact = self.business_logic.database_provider.get_artifact_by_uri_suffix(artifact_name + ".tbi") + self.business_logic.database_provider.add_artifact_to_dataset_version(dataset_version_id, artifact.id) + self.update_processing_status(dataset_version_id, DatasetStatusKey.ATAC, DatasetConversionStatus.COPIED) + else: + fragment_artifact_id = self.create_atac_artifact( + fragment_file, + DatasetArtifactType.ATAC_FRAGMENT, + dataset_version_id, + datasets_bucket, + ) + self.create_atac_artifact( + fragment_index_file, + DatasetArtifactType.ATAC_INDEX, + dataset_version_id, + datasets_bucket, + fragment_artifact_id, + ) + self.update_processing_status(dataset_version_id, DatasetStatusKey.ATAC, DatasetConversionStatus.UPLOADED) + self.logger.info("Processing completed successfully") + return diff --git a/backend/layers/processing/process_validate_h5ad.py b/backend/layers/processing/process_validate_h5ad.py new file mode 100644 index 0000000000000..4bc16bd129048 --- /dev/null +++ b/backend/layers/processing/process_validate_h5ad.py @@ -0,0 +1,133 @@ +from backend.common.utils.corpora_constants import CorporaConstants +from backend.layers.business.business_interface import BusinessLogicInterface +from backend.layers.common.entities import ( + DatasetArtifactType, + DatasetConversionStatus, + DatasetProcessingStatus, + DatasetStatusKey, + DatasetUploadStatus, + DatasetValidationStatus, + DatasetVersionId, +) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.exceptions import ValidationAnndataFailed +from backend.layers.processing.logger import logit +from backend.layers.processing.process_logic import ProcessingLogic +from backend.layers.thirdparty.s3_provider import S3ProviderInterface +from backend.layers.thirdparty.schema_validator_provider import SchemaValidatorProviderInterface +from backend.layers.thirdparty.uri_provider import UriProviderInterface + + +class ProcessValidateH5AD(ProcessingLogic): + """ + Base class for handling the `Validate` step of the step function. + This will: + 1. Download the h5ad artifact + 2. upload the original file to S3 + 3. Set DatasetStatusKey.H5AD DatasetValidationStatus.VALIDATING + 4. Validate the h5ad + 5. Set DatasetStatusKey.H5AD DatasetValidationStatus.VALID + 6. Set the DatasetStatusKey.RDS DatasetConversionStatus.SKIPPED + """ + + schema_validator: SchemaValidatorProviderInterface + + def __init__( + self, + business_logic: BusinessLogicInterface, + uri_provider: UriProviderInterface, + s3_provider: S3ProviderInterface, + schema_validator: SchemaValidatorProviderInterface, + ) -> None: + super().__init__() + self.business_logic = business_logic + self.uri_provider = uri_provider + self.s3_provider = s3_provider + self.schema_validator = schema_validator + + def upload_raw_h5ad( + self, dataset_version_id: DatasetVersionId, anndata_uri: str, artifact_bucket: str, key_prefix: str + ) -> str: + """ + Upload raw h5ad from dataset_uri to artifact bucket + + :param dataset_version_id: + :param anndata_uri: + :param artifact_bucket: + :param key_prefix: + :return: local_filename: Local filepath to raw h5ad + """ + self.update_processing_status(dataset_version_id, DatasetStatusKey.PROCESSING, DatasetProcessingStatus.PENDING) + + # Download the original dataset from Dropbox + local_filename = self.download_from_source_uri( + source_uri=anndata_uri, local_path=self.get_file_path(CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME) + ) + + # Upload the original dataset to the artifact bucket + self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.UPLOADING) + self.create_artifact( + local_filename, + DatasetArtifactType.RAW_H5AD, + key_prefix, + dataset_version_id, + DatasetStatusKey.H5AD, + artifact_bucket, + ) + self.update_processing_status(dataset_version_id, DatasetStatusKey.UPLOAD, DatasetUploadStatus.UPLOADED) + + return local_filename + + @logit + def validate_h5ad_file(self, dataset_version_id: DatasetVersionId, local_filename: str) -> None: + """ + Validates the specified dataset file and updates the processing status in the database + :param dataset_version_id: version ID of the dataset to update + :param local_filename: file name of the dataset to validate and label + :return: boolean indicating if seurat conversion is possible + """ + self.update_processing_status( + dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALIDATING + ) + + try: + is_valid, errors, can_convert_to_seurat = self.schema_validator.validate_anndata(local_filename) + except Exception as e: + self.logger.exception("validation failed") + self.update_processing_status( + dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.INVALID + ) + raise ValidationAnndataFailed([str(e)]) from None + + if not is_valid: + self.update_processing_status( + dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.INVALID + ) + raise ValidationAnndataFailed(errors) + else: + # Skip seurat conversion + self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) + self.update_processing_status(dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.CONVERTING) + + def process( + self, + dataset_version_id: DatasetVersionId, + manifest: IngestionManifest, + artifact_bucket: str, + ): + """ + 1. Download the original dataset from URI + 2. Validate + + :param manifest: + :param dataset_version_id: + :param artifact_bucket: + :return: + """ + anndata_uri = str(manifest.anndata) + # validate and upload raw h5ad file to s3 + key_prefix = self.get_key_prefix(dataset_version_id.id) + local_filename = self.upload_raw_h5ad(dataset_version_id, anndata_uri, artifact_bucket, key_prefix) + + # Validate and label the dataset + self.validate_h5ad_file(dataset_version_id, local_filename) diff --git a/backend/layers/processing/schema_migration.py b/backend/layers/processing/schema_migration.py index 8742b5382efa3..ce3130fabb658 100644 --- a/backend/layers/processing/schema_migration.py +++ b/backend/layers/processing/schema_migration.py @@ -14,14 +14,10 @@ CollectionId, CollectionVersion, CollectionVersionId, - DatasetArtifactType, - DatasetConversionStatus, DatasetProcessingStatus, - DatasetStatusKey, - DatasetUploadStatus, - DatasetValidationStatus, DatasetVersionId, ) +from backend.layers.common.ingestion_manifest import S3Url from backend.layers.processing import logger from backend.layers.processing.process_logic import ProcessingLogic from backend.layers.thirdparty.schema_validator_provider import SchemaValidatorProvider @@ -32,12 +28,12 @@ class SchemaMigrate(ProcessingLogic): def __init__(self, business_logic: BusinessLogic, schema_validator: SchemaValidatorProvider): + super().__init__() self.schema_validator = schema_validator self.business_logic = business_logic self.s3_provider = business_logic.s3_provider # For compatiblity with ProcessingLogic self.artifact_bucket = os.environ.get("ARTIFACT_BUCKET", "artifact-bucket") self.execution_id = os.environ.get("EXECUTION_ID", "test-execution-arn") - self.logger = logging.getLogger("processing") self.local_path: str = "." # Used for testing self.limit_migration = os.environ.get("LIMIT_MIGRATION", 0) # Run a small migration for testing self._schema_version = None @@ -95,17 +91,12 @@ def gather_collections(self) -> Tuple[Dict[str, str], Dict[str, str]]: def dataset_migrate( self, collection_version_id: str, collection_id: str, dataset_id: str, dataset_version_id: str ) -> Dict[str, str]: - raw_h5ad_uri = [ - artifact.uri - for artifact in self.business_logic.get_dataset_artifacts(DatasetVersionId(dataset_version_id)) - if artifact.type == DatasetArtifactType.RAW_H5AD - ][0] - source_bucket_name, source_object_key = self.s3_provider.parse_s3_uri(raw_h5ad_uri) - self.s3_provider.download_file(source_bucket_name, source_object_key, "previous_schema.h5ad") - migrated_file = "migrated.h5ad" - reported_changes = self.schema_validator.migrate( - "previous_schema.h5ad", migrated_file, collection_id, dataset_id - ) + manifest = self.business_logic.get_ingestion_manifest(DatasetVersionId(dataset_version_id)) + source_bucket_name, source_object_key = self.s3_provider.parse_s3_uri(str(manifest.anndata)) + previous_file = self.get_file_path("previous_schema.h5ad") + self.s3_provider.download_file(source_bucket_name, source_object_key, previous_file) + migrated_file = self.get_file_path("migrated.h5ad") + reported_changes = self.schema_validator.migrate(previous_file, migrated_file, collection_id, dataset_id) if reported_changes: self._store_sfn_response( "report/migrate_changes", @@ -113,10 +104,13 @@ def dataset_migrate( {f"{collection_id}_{dataset_id}": reported_changes}, ) key_prefix = self.get_key_prefix(dataset_version_id) - uri = self.upload_artifact(migrated_file, key_prefix, self.artifact_bucket) + key = "/".join([key_prefix, "migrated.h5ad"]) + uri = self.upload_artifact(migrated_file, key, self.artifact_bucket) + manifest.anndata = S3Url(uri) + manifest_dict = manifest.model_dump() new_dataset_version_id, _ = self.business_logic.ingest_dataset( CollectionVersionId(collection_version_id), - uri, + manifest_dict, file_size=0, # TODO: this shouldn't be needed but it gets around a 404 for HeadObject current_dataset_version_id=DatasetVersionId(dataset_version_id), start_step_function=False, # The schema_migration sfn will start the ingest sfn @@ -125,7 +119,7 @@ def dataset_migrate( return { "collection_version_id": collection_version_id, "dataset_version_id": new_dataset_version_id.id, - "uri": uri, + "manifest": manifest.model_dump_json(), "sfn_name": sfn_name, "execution_id": self.execution_id, } @@ -150,6 +144,7 @@ def collection_migrate( # Generate canonical collection url collection_url = self.business_logic.get_collection_url(version.collection_id.id) + private_collection_version_id = collection_version_id if not datasets: # Handles the case were the collection has no datasets or all datasets are already migrated. @@ -171,8 +166,6 @@ def collection_migrate( CollectionId(collection_id), is_auto_version=True, ).version_id.id - else: - private_collection_version_id = collection_version_id response_for_dataset_migrate = [ { "collection_id": collection_id, @@ -196,11 +189,13 @@ def collection_migrate( response_for_sfn["execution_id"] = self.execution_id self._store_sfn_response( - "log_errors_and_cleanup", version.collection_id.id, response_for_log_errors_and_cleanup + "log_errors_and_cleanup", private_collection_version_id, response_for_log_errors_and_cleanup ) if response_for_dataset_migrate: - key_name = self._store_sfn_response("span_datasets", version.collection_id.id, response_for_dataset_migrate) + key_name = self._store_sfn_response( + "span_datasets", private_collection_version_id, response_for_dataset_migrate + ) response_for_sfn["key_name"] = key_name return (response_for_sfn, response_for_log_errors_and_cleanup, response_for_dataset_migrate) @@ -211,7 +206,7 @@ def log_errors_and_cleanup(self, collection_version_id: str) -> list: object_keys_to_delete = [] # Get the datasets that were processed - extra_info = self._retrieve_sfn_response("log_errors_and_cleanup", collection_version.collection_id.id) + extra_info = self._retrieve_sfn_response("log_errors_and_cleanup", collection_version_id) processed_datasets = {d["dataset_id"]: d["dataset_version_id"] for d in extra_info["datasets"]} # Process datasets errors @@ -232,25 +227,6 @@ def log_errors_and_cleanup(self, collection_version_id: str) -> list: key_prefix = self.get_key_prefix(previous_dataset_version_id) object_keys_to_delete.append(f"{key_prefix}/migrated.h5ad") if dataset.status.processing_status != DatasetProcessingStatus.SUCCESS: - # If only rds failure, set rds status to skipped + processing status to successful and do not rollback - if ( - dataset.status.rds_status == DatasetConversionStatus.FAILED - and dataset.status.upload_status == DatasetUploadStatus.UPLOADED - and dataset.status.validation_status == DatasetValidationStatus.VALID - and dataset.status.cxg_status == DatasetConversionStatus.UPLOADED - and dataset.status.h5ad_status == DatasetConversionStatus.UPLOADED - ): - self.business_logic.update_dataset_version_status( - dataset.version_id, - DatasetStatusKey.RDS, - DatasetConversionStatus.SKIPPED, - ) - self.business_logic.update_dataset_version_status( - dataset.version_id, - DatasetStatusKey.PROCESSING, - DatasetProcessingStatus.SUCCESS, - ) - continue error = { "message": dataset.status.validation_message, "dataset_status": dataset.status.to_dict(), diff --git a/backend/layers/processing/submissions/__init__.py b/backend/layers/processing/submissions/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/backend/layers/processing/submissions/app.py b/backend/layers/processing/submissions/app.py deleted file mode 100644 index 205c5545caf84..0000000000000 --- a/backend/layers/processing/submissions/app.py +++ /dev/null @@ -1,132 +0,0 @@ -import logging -import os -import re -import sys -from typing import Optional, Tuple -from urllib.parse import unquote_plus - -from pythonjsonlogger import jsonlogger - -from backend.common.logging_config import DATETIME_FORMAT, LOG_FORMAT -from backend.common.providers.crossref_provider import CrossrefProvider -from backend.common.utils.exceptions import ( - CorporaException, - NonExistentCollectionException, - NonExistentDatasetException, -) -from backend.common.utils.regex import COLLECTION_ID_REGEX, DATASET_ID_REGEX, USERNAME_REGEX -from backend.layers.business.business import BusinessLogic -from backend.layers.business.exceptions import CollectionNotFoundException, DatasetNotFoundException -from backend.layers.persistence.persistence import DatabaseProvider -from backend.layers.thirdparty.batch_job_provider import BatchJobProvider -from backend.layers.thirdparty.s3_provider import S3Provider -from backend.layers.thirdparty.step_function_provider import StepFunctionProvider -from backend.layers.thirdparty.uri_provider import UriProvider - -log_handler = logging.StreamHandler(stream=sys.stdout) -formatter = jsonlogger.JsonFormatter(LOG_FORMAT, DATETIME_FORMAT) -log_handler.setFormatter(formatter) -logger = logging.getLogger() -logger.setLevel(logging.INFO) -logger.handlers = [log_handler] - -REGEX = f"^{USERNAME_REGEX}/{COLLECTION_ID_REGEX}/{DATASET_ID_REGEX}$" - -_business_logic = None - - -def get_business_logic(): - global _business_logic - if not _business_logic: - database_provider = DatabaseProvider() - uri_provider = UriProvider() - step_function_provider = StepFunctionProvider() - s3_provider = S3Provider() - crossref_provider = CrossrefProvider() - batch_job_provider = BatchJobProvider() - _business_logic = BusinessLogic( - database_provider, batch_job_provider, crossref_provider, step_function_provider, s3_provider, uri_provider - ) - return _business_logic - - -def dataset_submissions_handler(s3_event: dict, unused_context) -> None: - """ - Lambda function invoked when a dataset is uploaded to the dataset submissions S3 bucket - :param s3_event: Lambda's event object - :param unused_context: Lambda's context object - :return: - """ - logger.info(dict(message="s3_event", **s3_event)) - logger.debug(dict(REMOTE_DEV_PREFIX=os.environ.get("REMOTE_DEV_PREFIX", ""))) - - for record in s3_event["Records"]: - bucket, key, size = parse_s3_event_record(record) - logger.debug(f"{bucket=}, {key=}, {size=}") - - parsed = parse_key(key) - if not parsed: - raise CorporaException(f"Missing Collection ID and/or Dataset ID for {key=}") - logger.debug(parsed) - - collection_id = parsed["collection_id"] - dataset_id = parsed["dataset_id"] - - business_logic = get_business_logic() - try: - collection_version, dataset_version = business_logic._get_collection_and_dataset(collection_id, dataset_id) - except CollectionNotFoundException: - raise NonExistentCollectionException(f"Collection {parsed['collection_id']} does not exist") from None - except DatasetNotFoundException: - raise NonExistentDatasetException(f"No Dataset with {dataset_id=} in Collection {collection_id}") from None - - collection_owner = collection_version.owner - - logger.info(dict(collection_owner=collection_owner, dataset_id=dataset_id)) - if parsed["username"] == "super": - pass - elif parsed["username"] != collection_owner: - raise CorporaException( - f"user:{parsed['username']} does not have permission to modify datasets in collection " - f"{parsed['collection_id']}." - ) - - s3_uri = f"s3://{bucket}/{key}" - - get_business_logic().ingest_dataset( - collection_version_id=collection_version.version_id, - url=s3_uri, - file_size=size, - current_dataset_version_id=dataset_version.version_id, - ) - - -def parse_s3_event_record(s3_event_record: dict) -> Tuple[str, str, int]: - """ - Parses the S3 event record and returns the bucket name, object key and object size - :param s3_event_record: - :return: - """ - bucket = s3_event_record["s3"]["bucket"]["name"] - key = unquote_plus(s3_event_record["s3"]["object"]["key"], encoding="utf-8") - size = s3_event_record["s3"]["object"]["size"] - return bucket, key, size - - -def parse_key(key: str) -> Optional[dict]: - """ - Parses the S3 object key to extract the Collection ID and Dataset ID, ignoring the REMOTE_DEV_PREFIX - - Example of key with dataset id: - s3:///// - - :param key: - :return: - """ - rdev_prefix = os.environ.get("REMOTE_DEV_PREFIX", "").strip("/") - if rdev_prefix: - key = key.replace(f"{rdev_prefix}/", "") - - matched = re.match(REGEX, key) - if matched: - return matched.groupdict() diff --git a/backend/layers/processing/upload_failures/app.py b/backend/layers/processing/upload_failures/app.py index 3a8f2621dac28..7807fe25dad3e 100644 --- a/backend/layers/processing/upload_failures/app.py +++ b/backend/layers/processing/upload_failures/app.py @@ -1,13 +1,14 @@ import json import logging import os -from typing import Optional +from typing import List, Optional from backend.common.corpora_config import CorporaConfig from backend.common.utils.aws import delete_many_from_s3 from backend.common.utils.result_notification import aws_batch_job_url_fmt_str, aws_sfn_url_fmt_str, notify_slack from backend.layers.common.entities import ( CollectionVersionId, + DatasetConversionStatus, DatasetProcessingStatus, DatasetStatusKey, DatasetVersionId, @@ -138,7 +139,9 @@ def get_failure_slack_notification_message( "type": "section", "text": { "type": "mrkdwn", - "text": f"Dataset processing job failed! Please follow the triage steps: https://docs.google.com/document/d/1n5cngEIz-Lqk9737zz3makXGTMrEKT5kN4lsofXPRso/edit#bookmark=id.3ofm47y0709y\n" + "text": f"Dataset processing job failed! Please follow the triage steps: " + f"https://docs.google.com/document/d/1n5cngEIz-Lqk9737zz3makXGTMrEKT5kN4lsofXPRso/edit" + f"#bookmark=id.3ofm47y0709y\n" f"*Owner*: {collection_owner}\n" f"*Collection URL*: {collection_url}\n" f"*Collection Version URL*: {collection_version_url}\n" @@ -181,18 +184,46 @@ def trigger_slack_notification( FAILED_ARTIFACT_CLEANUP_MESSAGE = "Failed to clean up artifacts." FAILED_DATASET_CLEANUP_MESSAGE = "Failed to clean up datasets." FAILED_CXG_CLEANUP_MESSAGE = "Failed to clean up cxgs." +FAILED_ATAC_DATASET_MESSAGE = "Failed to clean up ATAC datasets fragment files. artifact_id: {}" + + +def delete_atac_fragment_files(dataset_version_id: str) -> None: + """Delete all atac fragment files and index files from S3 assocaited with the dataset version""" + dataset = get_business_logic().get_dataset_version(DatasetVersionId(dataset_version_id)) + if not dataset: + # If dataset not in db dont worry about deleting its files + return + + if dataset.status.atac_status in [ + DatasetConversionStatus.COPIED, + DatasetConversionStatus.SKIPPED, + DatasetConversionStatus.NA, + ]: + # If the dataset is copied , we don't need to delete the files since they are part of another dataset. + # If the dataset is skipped or NA, we don't need to delete the files since they are not created. + return + + object_keys: List[str] = get_business_logic().get_atac_fragment_uris_from_dataset_version(dataset) + for ok in object_keys: + object_key = os.path.join(os.environ.get("REMOTE_DEV_PREFIX", ""), ok) + delete_and_catch_error("DATASETS_BUCKET", object_key, FAILED_ATAC_DATASET_MESSAGE.format(ok)) + + +def delete_and_catch_error(bucket_name: str, object_key: str, error_message: str) -> None: + with logger.LogSuppressed(Exception, message=error_message): + bucket_name = os.environ[bucket_name] + delete_many_from_s3(bucket_name, object_key) def cleanup_artifacts(dataset_version_id: str, error_step_name: Optional[str] = None) -> None: """Clean up artifacts""" + object_key = os.path.join(os.environ.get("REMOTE_DEV_PREFIX", ""), dataset_version_id).strip("/") - if not error_step_name or error_step_name in ["validate", "download"]: - with logger.LogSuppressed(Exception, message=FAILED_ARTIFACT_CLEANUP_MESSAGE): - artifact_bucket = os.environ["ARTIFACT_BUCKET"] - delete_many_from_s3(artifact_bucket, object_key + "/") - with logger.LogSuppressed(Exception, message=FAILED_DATASET_CLEANUP_MESSAGE): - datasets_bucket = os.environ["DATASETS_BUCKET"] - delete_many_from_s3(datasets_bucket, object_key + ".") - with logger.LogSuppressed(Exception, message=FAILED_CXG_CLEANUP_MESSAGE): - cellxgene_bucket = os.environ["CELLXGENE_BUCKET"] - delete_many_from_s3(cellxgene_bucket, object_key + ".cxg/") + + if error_step_name in ["validate_anndata", None]: + delete_and_catch_error("ARTIFACT_BUCKET", object_key + "/", FAILED_ARTIFACT_CLEANUP_MESSAGE) + if error_step_name in ["validate_atac", None]: + delete_atac_fragment_files(dataset_version_id) + + delete_and_catch_error("DATASETS_BUCKET", object_key + ".", FAILED_DATASET_CLEANUP_MESSAGE) + delete_and_catch_error("CELLXGENE_BUCKET", object_key + ".cxg/", FAILED_CXG_CLEANUP_MESSAGE) diff --git a/backend/layers/processing/upload_success/app.py b/backend/layers/processing/upload_success/app.py index 6896fd5cfa4a1..c19299090f2c3 100644 --- a/backend/layers/processing/upload_success/app.py +++ b/backend/layers/processing/upload_success/app.py @@ -3,12 +3,12 @@ from backend.layers.business.business import BusinessLogic from backend.layers.common.entities import DatasetProcessingStatus, DatasetStatusKey, DatasetVersionId from backend.layers.persistence.persistence import DatabaseProvider +from backend.layers.processing import logger from backend.layers.processing.upload_failures.app import handle_failure database_provider = DatabaseProvider() business_logic = BusinessLogic(database_provider, None, None, None, None, None) - -logger = logging.getLogger("processing") +logger.configure_logging(level=logging.INFO) def success_handler(events: dict, context) -> None: @@ -19,13 +19,11 @@ def success_handler(events: dict, context) -> None: :param context: Lambda's context object :return: """ - cxg_job, seurat_job = events["cxg_job"], events["seurat_job"] - cxg_job["execution_id"], seurat_job["execution_id"] = events["execution_id"], events["execution_id"] + cxg_job = events["cxg_job"] + cxg_job["execution_id"] = events["execution_id"] if cxg_job.get("error"): handle_failure(cxg_job, context) - elif seurat_job.get("error"): - handle_failure(seurat_job, context, delete_artifacts=False) else: business_logic.update_dataset_version_status( DatasetVersionId(cxg_job["dataset_version_id"]), diff --git a/backend/layers/processing/utils/cxg_generation_utils.py b/backend/layers/processing/utils/cxg_generation_utils.py index 366cfd024bbd8..c11829abbc437 100644 --- a/backend/layers/processing/utils/cxg_generation_utils.py +++ b/backend/layers/processing/utils/cxg_generation_utils.py @@ -1,11 +1,14 @@ import json import logging +import dask.array as da import numpy as np import pandas as pd import tiledb +from cellxgene_schema.utils import get_matrix_format from backend.common.constants import IS_SINGLE, UNS_SPATIAL_KEY +from backend.layers.processing.utils.dask_utils import TileDBSparseArrayWriteWrapper from backend.layers.processing.utils.spatial import SpatialDataProcessor from backend.layers.processing.utils.type_conversion_utils import get_dtype_and_schema_of_array @@ -154,178 +157,68 @@ def create_ndarray_array(ndarray_name, ndarray): tiledb.consolidate(ndarray_name, ctx=ctx) -def _sort_by_primary_var_and_secondary_obs(data_dict): - ix = np.argsort(data_dict["var"]) - x = data_dict["obs"][ix] - y = data_dict["var"][ix] - d = data_dict[""][ix] - - df = pd.DataFrame() - df["x"] = x - df["y"] = y - df["d"] = d - - gb = df.groupby("y") - - xs = [] - ds = [] - for k in gb.groups: - ix = np.argsort(x[gb.groups[k]]) - xs.extend(x[gb.groups[k]][ix]) - ds.extend(d[gb.groups[k]][ix]) - xs = np.array(xs) - ds = np.array(ds) - return xs, y, ds - - -def _sort_by_primary_obs_and_secondary_var(data_dict): - ix = np.argsort(data_dict["obs"]) - x = data_dict["obs"][ix] - y = data_dict["var"][ix] - d = data_dict[""][ix] - - df = pd.DataFrame() - df["x"] = x - df["y"] = y - df["d"] = d - - gb = df.groupby("x") - - ys = [] - ds = [] - for k in gb.groups: - ix = np.argsort(y[gb.groups[k]]) - ys.extend(y[gb.groups[k]][ix]) - ds.extend(d[gb.groups[k]][ix]) - ys = np.array(ys) - ds = np.array(ds) - return x, ys, ds - - -def convert_matrices_to_cxg_arrays(matrix_name, matrix, encode_as_sparse_array, ctx): +def convert_matrices_to_cxg_arrays(matrix_name: str, matrix: da.Array, encode_as_sparse_array: bool, ctx: tiledb.Ctx): """ Converts a numpy array matrix into a TileDB SparseArray of DenseArray based on whether `encode_as_sparse_array` is true or not. Note that when the matrix is encoded as a SparseArray, it only writes the values that are nonzero. This means that if you count the number of elements in the SparseArray, it will not equal the total number of elements in the matrix, only the number of nonzero elements. """ - - def create_matrix_array( - matrix_name, number_of_rows, number_of_columns, encode_as_sparse_array, compression=22, row=True - ): - logging.info(f"create {matrix_name}") - dim_filters = tiledb.FilterList([tiledb.ByteShuffleFilter(), tiledb.ZstdFilter(level=compression)]) - if not encode_as_sparse_array: - attrs = [tiledb.Attr(dtype=np.float32, filters=tiledb.FilterList([tiledb.ZstdFilter(level=compression)]))] - domain = tiledb.Domain( - tiledb.Dim( - name="obs", - domain=(0, number_of_rows - 1), - tile=min(number_of_rows, 256), - dtype=np.uint32, - filters=dim_filters, - ), - tiledb.Dim( - name="var", - domain=(0, number_of_columns - 1), - tile=min(number_of_columns, 2048), - dtype=np.uint32, - filters=dim_filters, - ), - ) - schema = tiledb.ArraySchema( - domain=domain, - sparse=False, - allows_duplicates=False, - attrs=attrs, - cell_order="row-major", - tile_order="col-major", - capacity=0, - ) - tiledb.Array.create(matrix_name, schema) - else: - attrs = [tiledb.Attr(dtype=np.float32, filters=tiledb.FilterList([tiledb.ZstdFilter(level=compression)]))] - if row: - domain = tiledb.Domain( - tiledb.Dim( - name="obs", - domain=(0, number_of_rows - 1), - tile=min(number_of_rows, 256), - dtype=np.uint32, - filters=dim_filters, - ) - ) - attrs.append(tiledb.Attr(name="var", dtype=np.uint32, filters=dim_filters)) - else: - domain = tiledb.Domain( - tiledb.Dim( - name="var", - domain=(0, number_of_columns - 1), - tile=min(number_of_columns, 256), - dtype=np.uint32, - filters=dim_filters, - ), - ) - attrs.append(tiledb.Attr(name="obs", dtype=np.uint32, filters=dim_filters)) - - schema = tiledb.ArraySchema( - domain=domain, - sparse=True, - allows_duplicates=True, - attrs=attrs, - cell_order="row-major", - tile_order="row-major", - capacity=1024000, - ) - tiledb.Array.create(matrix_name, schema) - number_of_rows = matrix.shape[0] number_of_columns = matrix.shape[1] - stride_rows = min(int(np.power(10, np.around(np.log10(1e9 / number_of_columns)))), 10_000) - stride_columns = min(int(np.power(10, np.around(np.log10(1e9 / number_of_rows)))), 2_000) - - if not encode_as_sparse_array: - create_matrix_array(matrix_name, number_of_rows, number_of_columns, False) - with tiledb.open(matrix_name, mode="w", ctx=ctx) as array: - for start_row_index in range(0, number_of_rows, stride_rows): - end_row_index = min(start_row_index + stride_rows, number_of_rows) - matrix_subset = matrix[start_row_index:end_row_index, :] - if not isinstance(matrix_subset, np.ndarray): - matrix_subset = matrix_subset.toarray() - array[start_row_index:end_row_index, :] = matrix_subset + compression = 22 + + logging.info(f"create {matrix_name}") + dim_filters = tiledb.FilterList([tiledb.ByteShuffleFilter(), tiledb.ZstdFilter(level=compression)]) + attrs = [tiledb.Attr(dtype=np.float32, filters=tiledb.FilterList([tiledb.ZstdFilter(level=compression)]))] + + tiledb_obs_dim = tiledb.Dim( + name="obs", + domain=(0, number_of_rows - 1), + tile=min(number_of_rows, 256), + dtype=np.uint32, + filters=dim_filters, + ) + tiledb_var_dim = tiledb.Dim( + name="var", + domain=(0, number_of_columns - 1), + tile=min(number_of_columns, 2048), + dtype=np.uint32, + filters=dim_filters, + ) + domain = tiledb.Domain(tiledb_obs_dim, tiledb_var_dim) + + if encode_as_sparse_array: + array_schema_params = dict( + sparse=True, + allows_duplicates=True, + capacity=1024000, + ) + else: + array_schema_params = dict( + sparse=False, + allows_duplicates=False, + capacity=0, + ) + schema = tiledb.ArraySchema( + domain=domain, + attrs=attrs, + cell_order="row-major", + tile_order="col-major", + **array_schema_params, + ) + tiledb.Array.create(matrix_name, schema) + + if encode_as_sparse_array: + matrix_write = TileDBSparseArrayWriteWrapper(matrix_name, ctx=ctx) + matrix.store(matrix_write, lock=False, compute=True) else: - create_matrix_array(matrix_name + "r", number_of_rows, number_of_columns, True, row=True) - create_matrix_array(matrix_name + "c", number_of_rows, number_of_columns, True, row=False) - array_r = tiledb.open(matrix_name + "r", mode="w", ctx=ctx) - array_c = tiledb.open(matrix_name + "c", mode="w", ctx=ctx) - - logging.info(f"Store rows: {number_of_rows}") - for start_row_index in range(0, number_of_rows, stride_rows): - end_row_index = min(start_row_index + stride_rows, number_of_rows) - matrix_subset = matrix[start_row_index:end_row_index, :] - if not isinstance(matrix_subset, np.ndarray): - matrix_subset = matrix_subset.toarray() - - indices = np.nonzero(matrix_subset) - trow = indices[0] + start_row_index - t_data = matrix_subset[indices[0], indices[1]] - data_dict = {"obs": trow, "var": indices[1], "": t_data} - obs, var, data = _sort_by_primary_obs_and_secondary_var(data_dict) - array_r[obs] = {"var": var, "": data} - - logging.info(f"Store columns: {number_of_columns}") - for start_col_index in range(0, number_of_columns, stride_columns): - end_col_index = min(start_col_index + stride_columns, number_of_columns) - matrix_subset = matrix[:, start_col_index:end_col_index] - if not isinstance(matrix_subset, np.ndarray): - matrix_subset = matrix_subset.toarray() - - indices = np.nonzero(matrix_subset) - tcol = indices[1] + start_col_index - t_data = matrix_subset[indices[0], indices[1]] - data_dict = {"obs": indices[0], "var": tcol, "": t_data} - obs, var, data = _sort_by_primary_var_and_secondary_obs(data_dict) - array_c[var] = {"obs": obs, "": data} - - array_r.close() - array_c.close() + # if matrix is a scipy sparse matrix but encode_as_sparse_array is False, convert to dense array + if get_matrix_format(matrix) != "dense": + matrix = matrix.map_blocks( + lambda x: x.toarray().astype(np.float32), dtype=np.float32, meta=np.array([], dtype=np.float32) + ) + elif matrix.dtype != np.float32: + matrix = matrix.map_blocks(lambda x: x.astype(np.float32), dtype=np.float32) + with tiledb.open(matrix_name, "w") as A: + matrix.to_tiledb(A, storage_options={"ctx": ctx}) diff --git a/backend/layers/processing/utils/dask_utils.py b/backend/layers/processing/utils/dask_utils.py new file mode 100644 index 0000000000000..5563650c8c0f2 --- /dev/null +++ b/backend/layers/processing/utils/dask_utils.py @@ -0,0 +1,22 @@ +import logging + +import tiledb +from scipy import sparse + +logger = logging.getLogger(__name__) + + +class TileDBSparseArrayWriteWrapper: + def __init__(self, uri, *, ctx=None): + self.uri = uri + self.ctx = {} + self.ctx.update(**ctx or {}) + + def __setitem__(self, k: tuple[slice, ...], v: sparse.spmatrix): + with tiledb.scope_ctx(self.ctx): + row_slice, col_slice = k + row_offset = row_slice.start if row_slice.start is not None else 0 + col_offset = col_slice.start if col_slice.start is not None else 0 + v_coo = v.tocoo() + tiledb_array = tiledb.open(self.uri, mode="w") + tiledb_array[v_coo.row + row_offset, v_coo.col + col_offset] = v.data diff --git a/backend/layers/processing/utils/matrix_utils.py b/backend/layers/processing/utils/matrix_utils.py index 0572c63230500..fc85639e6b4f6 100644 --- a/backend/layers/processing/utils/matrix_utils.py +++ b/backend/layers/processing/utils/matrix_utils.py @@ -1,63 +1,25 @@ import logging -import numpy as np +import dask.array as da +from cellxgene_schema.validate import Validator logger: logging.Logger = logging.getLogger("matrix_utils") -def is_matrix_sparse(matrix: np.ndarray, sparse_threshold): +def is_matrix_sparse(matrix: da.Array, sparse_threshold: float) -> bool: """ Returns whether `matrix` is sparse or not (i.e. dense). This is determined by figuring out whether the matrix has a sparsity percentage below the sparse_threshold, returning the number of non-zeros encountered and number of - elements evaluated. This function may return before evaluating the whole matrix if it can be determined that matrix - is not sparse enough. + elements evaluated. """ - if sparse_threshold == 100.0: return True if sparse_threshold == 0.0: return False - total_number_of_rows = matrix.shape[0] - total_number_of_columns = matrix.shape[1] - total_number_of_matrix_elements = total_number_of_rows * total_number_of_columns - - # For efficiency, we count the number of non-zero elements in chunks of the matrix at a time until we hit the - # maximum number of non zero values allowed before the matrix is deemed "dense." This allows the function the - # quit early for large dense matrices. - row_stride = min(int(np.power(10, np.around(np.log10(1e9 / total_number_of_columns)))), 10_000) - - maximum_number_of_non_zero_elements_in_matrix = int( - total_number_of_rows * total_number_of_columns * sparse_threshold / 100 - ) - number_of_non_zero_elements = 0 - - for start_row_index in range(0, total_number_of_rows, row_stride): - end_row_index = min(start_row_index + row_stride, total_number_of_rows) - - matrix_subset = matrix[start_row_index:end_row_index, :] - if not isinstance(matrix_subset, np.ndarray): - matrix_subset = matrix_subset.toarray() - - number_of_non_zero_elements += np.count_nonzero(matrix_subset) - if number_of_non_zero_elements > maximum_number_of_non_zero_elements_in_matrix: - if end_row_index != total_number_of_rows: - percentage_of_non_zero_elements = ( - 100 * number_of_non_zero_elements / (end_row_index * total_number_of_columns) - ) - logging.info( - f"Matrix is not sparse. Percentage of non-zero elements (estimate): " - f"{percentage_of_non_zero_elements:6.2f}" - ) - else: - percentage_of_non_zero_elements = 100 * number_of_non_zero_elements / total_number_of_matrix_elements - logging.info( - f"Matrix is not sparse. Percentage of non-zero elements (exact): " - f"{percentage_of_non_zero_elements:6.2f}" - ) - return False - - is_sparse = (100.0 * number_of_non_zero_elements / total_number_of_matrix_elements) < sparse_threshold + total_number_of_matrix_elements = matrix.shape[0] * matrix.shape[1] + number_of_non_zero_elements = Validator.count_matrix_nonzero(matrix) + is_sparse = (100.0 * (number_of_non_zero_elements / total_number_of_matrix_elements)) < sparse_threshold return is_sparse diff --git a/backend/layers/thirdparty/schema_validator_provider.py b/backend/layers/thirdparty/schema_validator_provider.py index 4797c7c135b93..b7dbe6c0c78a9 100644 --- a/backend/layers/thirdparty/schema_validator_provider.py +++ b/backend/layers/thirdparty/schema_validator_provider.py @@ -1,12 +1,14 @@ -from typing import List, Protocol, Tuple +from typing import List, Optional, Protocol, Tuple from cellxgene_schema import validate from cellxgene_schema.migrate import migrate from cellxgene_schema.schema import get_current_schema_version +from backend.layers.processing.exceptions import AddLabelsFailed + class SchemaValidatorProviderInterface(Protocol): - def validate_and_save_labels(self, input_file: str, output_file: str) -> Tuple[bool, list, bool]: + def validate_anndata(self, input_file: str) -> Tuple[bool, list, bool]: pass def migrate(self, input_file: str, output_file: str, collection_id: str, dataset_id: str) -> List[str]: @@ -21,19 +23,41 @@ def get_current_schema_version(self) -> str: """ pass + def add_labels(self, input_file: str, output_file: str) -> None: + """ + Adds labels to the provided `input_file` and writes the result to `output_file`. + """ + pass + + def validate_atac(self, fragment_file, anndata_file, output_file) -> Tuple[Optional[List[str]], str, str]: + """ + Validates an ATAC fragment file against an anndata file. + + Returns a tuple that contains, in order: + 1. A List[str] with the validation errors. This is only defined if the first boolean is false + 2. The path to the index file + 3. The path to the fragment file + """ + pass + + def check_anndata_requires_fragment(self, anndata_file) -> bool: + """ + Check if an anndata file requires a fragment file + """ + pass + class SchemaValidatorProvider(SchemaValidatorProviderInterface): - def validate_and_save_labels(self, input_file: str, output_file: str) -> Tuple[bool, list, bool]: + def validate_anndata(self, input_file: str) -> Tuple[bool, list, bool]: """ - Runs `cellxgene-schema validate` on the provided `input_file`. This also saves a labeled copy - of the artifact to `output_file`. + Runs `cellxgene-schema validate` on the provided `input_file`. Returns a tuple that contains, in order: 1. A boolean that indicates whether the artifact is valid 2. A List[str] with the validation errors. This is only defined if the first boolean is false 3. A boolean that indicates whether the artifact is Seurat convertible """ - return validate.validate(input_file, output_file) + return validate.validate(input_file) def migrate(self, input_file, output_file, collection_id, dataset_id) -> List[str]: """ @@ -43,3 +67,41 @@ def migrate(self, input_file, output_file, collection_id, dataset_id) -> List[st def get_current_schema_version(self) -> str: return get_current_schema_version() + + def add_labels(self, input_file: str, output_file: str) -> None: + """ + Adds labels to the provided `input_file` and writes the result to `output_file`. + """ + from cellxgene_schema.utils import read_h5ad + from cellxgene_schema.write_labels import AnnDataLabelAppender + + adata = read_h5ad(input_file) + anndata_label_adder = AnnDataLabelAppender(adata) + if not anndata_label_adder.write_labels(output_file): + raise AddLabelsFailed(anndata_label_adder.errors) + + def count_matrix_nonzero(self, matrix): + return validate.Validator.count_matrix_nonzero(matrix) + + def validate_atac(self, fragment_file, anndata_file, output_file) -> Tuple[Optional[List[str]], str, str]: + """ + Validates an ATAC fragment file against an anndata file. + + Returns a tuple that contains, in order: + 1. A List[str] with the validation errors. This is only defined if the first boolean is false + 2. The path to the index file + 3. The path to the fragment file + """ + import cellxgene_schema.atac_seq as atac_seq + + index_file = output_file + ".tbi" + return ( + atac_seq.process_fragment(fragment_file, anndata_file, True, output_file=output_file), + index_file, + output_file, + ) + + def check_anndata_requires_fragment(self, anndata_file) -> bool: + import cellxgene_schema.atac_seq as atac_seq + + return atac_seq.check_anndata_requires_fragment(anndata_file) diff --git a/backend/layers/thirdparty/step_function_provider.py b/backend/layers/thirdparty/step_function_provider.py index cddf91428194b..73215291e437c 100644 --- a/backend/layers/thirdparty/step_function_provider.py +++ b/backend/layers/thirdparty/step_function_provider.py @@ -16,7 +16,7 @@ def sfn_name_generator(dataset_version_id: DatasetVersionId, prefix=None) -> str class StepFunctionProviderInterface: def start_step_function( - self, version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, url: str + self, version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, manifest: dict ) -> None: pass @@ -26,15 +26,15 @@ def __init__(self) -> None: self.client = boto3.client("stepfunctions") def start_step_function( - self, version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, url: str + self, version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, manifest: dict ) -> None: """ - Starts a step function that will ingest the dataset `dataset_version_id` using the artifact - located at `url` + Starts a step function that will ingest the dataset `dataset_version_id` and all artifact specified in the + manifest. """ input_parameters = { "collection_version_id": version_id.id, - "url": url, + "manifest": manifest, "dataset_version_id": dataset_version_id.id, } sfn_name = sfn_name_generator(dataset_version_id) diff --git a/backend/portal/api/portal-api.yml b/backend/portal/api/portal-api.yml index 0ec146eabf858..ff9d409e1075f 100644 --- a/backend/portal/api/portal-api.yml +++ b/backend/portal/api/portal-api.yml @@ -700,6 +700,18 @@ paths: enum: [VALIDATING, VALID, INVALID, NA] validation_message: type: string + atac_status: + type: string + enum: + [ + CONVERTING, + CONVERTED, + UPLOADING, + UPLOADED, + SKIPPED, + FAILED, + NA, + ] h5ad_status: type: string enum: @@ -1279,7 +1291,7 @@ components: type: string filetype: type: string - enum: [RAW_H5AD, H5AD, RDS, CXG] + enum: [RAW_H5AD, H5AD, RDS, CXG, ATAC_FRAGMENT, ATAC_INDEX] filename: type: string s3_uri: diff --git a/backend/portal/api/portal_api.py b/backend/portal/api/portal_api.py index 17442d0383802..65d18075a5e83 100644 --- a/backend/portal/api/portal_api.py +++ b/backend/portal/api/portal_api.py @@ -93,6 +93,7 @@ def _dataset_processing_status_to_response(status: DatasetStatus, dataset_id: st Converts a DatasetStatus object to an object compliant to the API specifications """ return { + "atac_status": status.atac_status or "NA", "created_at": 0, # NA "cxg_status": status.cxg_status or "NA", "dataset_id": dataset_id, @@ -800,6 +801,7 @@ def get_status(dataset_id: str, token_info: dict): version, _ = _assert_dataset_has_right_owner(DatasetVersionId(dataset_id), UserInfo(token_info)) response = { + "atac_status": version.status.atac_status or "NA", "cxg_status": version.status.cxg_status or "NA", "rds_status": version.status.rds_status or "NA", "h5ad_status": version.status.h5ad_status or "NA", diff --git a/backend/scripts/mirror_rds_data.sh b/backend/scripts/mirror_rds_data.sh index 325ffc93120c1..327e72a78565c 100755 --- a/backend/scripts/mirror_rds_data.sh +++ b/backend/scripts/mirror_rds_data.sh @@ -71,3 +71,4 @@ EOF ) fi +make db/connect ARGS="${DB_UPDATE_CMDS}" \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 486ed046a4ae8..4cdd214f8de9b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -146,45 +146,8 @@ services: corporanet: aliases: - uploadsuccess.corporanet.local - dataset_submissions: - image: "${DOCKER_REPO}dataset-submissions" - platform: linux/amd64 - profiles: - - dataset_submissions - build: - context: . - cache_from: - - "${DOCKER_REPO}dataset-submissions:branch-main" - dockerfile: Dockerfile.dataset_submissions - args: - - BUILDKIT_INLINE_CACHE=1 - - HAPPY_COMMIT=$HAPPY_COMMIT - - HAPPY_BRANCH=$HAPPY_BRANCH - - HAPPY_TAG - restart: "no" - ports: - - "9002:8080" - volumes: - - ./backend/portal/pipelines/submissions:/var/task - - ./backend/common:/var/task/backend/common - environment: - - PYTHONUNBUFFERED=1 - - CORPORA_LOCAL_DEV=true - - AWS_REGION=us-west-2 - - AWS_DEFAULT_REGION=us-west-2 - - AWS_ACCESS_KEY_ID=test - - AWS_SECRET_ACCESS_KEY=test - - BOTO_ENDPOINT_URL=http://localstack:4566 - - DEPLOYMENT_STAGE=test - - ARTIFACT_BUCKET=artifact-bucket - - CELLXGENE_BUCKET=cellxgene-bucket - networks: - corporanet: - aliases: - - dataset_submissions.corporanet.local processing: image: "${DOCKER_REPO}corpora-upload" - platform: linux/amd64 profiles: - processing build: diff --git a/python_dependencies/backend/common/server/requirements.txt b/python_dependencies/backend/common/server/requirements.txt index b5465d1f437b7..7145644301955 100644 --- a/python_dependencies/backend/common/server/requirements.txt +++ b/python_dependencies/backend/common/server/requirements.txt @@ -5,4 +5,4 @@ Flask-Cors~=3.0.6 flask-server-timing~=0.1.2 ddtrace~=2.8.5 python-json-logger -boto3~=1.34.114 \ No newline at end of file +boto3~=1.34.114 diff --git a/python_dependencies/processing/requirements.txt b/python_dependencies/processing/requirements.txt index 20dd6753ce6f2..a773cfed894d6 100644 --- a/python_dependencies/processing/requirements.txt +++ b/python_dependencies/processing/requirements.txt @@ -1,8 +1,8 @@ -anndata==0.10.8 +anndata==0.11.2 awscli boto3>=1.11.17 -botocore>=1.14.17 -cellxgene-schema +git+https://github.com/chanzuckerberg/single-cell-curation/@main#subdirectory=cellxgene_schema_cli +dask==2024.12.0 dataclasses-json ddtrace==2.1.4 numba==0.59.1 diff --git a/python_dependencies/upload_handler/requirements.txt b/python_dependencies/upload_handler/requirements.txt index c7284f10949c9..9ee88c9c7bdb3 100644 --- a/python_dependencies/upload_handler/requirements.txt +++ b/python_dependencies/upload_handler/requirements.txt @@ -1,5 +1,6 @@ boto3 dataclasses-json +pydantic<3 psycopg2-binary==2.* pyrsistent python-json-logger==2.0.7 diff --git a/scripts/cxg_admin.py b/scripts/cxg_admin.py index 073a0a98d993b..37ba3b70222e8 100755 --- a/scripts/cxg_admin.py +++ b/scripts/cxg_admin.py @@ -238,19 +238,7 @@ def refresh_preprint_doi(ctx): updates.refresh_preprint_doi(ctx) -# Commands to reprocess dataset artifacts (seurat or cxg) - - -@cli.command() -@click.argument("dataset_id") -@click.pass_context -def reprocess_seurat(ctx: click.Context, dataset_id: str) -> None: - """ - Reconverts the specified dataset to Seurat format in place. - :param ctx: command context - :param dataset_id: ID of dataset to reconvert to Seurat format - """ - reprocess_datafile.reprocess_seurat(ctx, dataset_id) +# Commands to reprocess cxg dataset artifacts @cli.command() diff --git a/scripts/cxg_admin_scripts/reprocess_datafile.py b/scripts/cxg_admin_scripts/reprocess_datafile.py index 8c0aed6039ef7..efbecf7f36269 100644 --- a/scripts/cxg_admin_scripts/reprocess_datafile.py +++ b/scripts/cxg_admin_scripts/reprocess_datafile.py @@ -1,12 +1,8 @@ -import json import logging import os import sys -from time import time import boto3 -import click -from click import Context pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "...")) # noqa sys.path.insert(0, pkg_root) # noqa @@ -33,38 +29,3 @@ def get_happy_stack_name(deployment) -> str: def cxg_remaster(ctx): """Cxg remaster v2""" pass - - -def reprocess_seurat(ctx: Context, dataset_id: str) -> None: - """ - Reconverts the specified dataset to Seurat format in place. - :param ctx: command context - :param dataset_id: ID of dataset to reconvert to Seurat format - """ - - deployment = ctx.obj["deployment"] - - click.confirm( - f"Are you sure you want to run this script? " - f"It will reconvert and replace the dataset {dataset_id} to Seurat in the {deployment} environment.", - abort=True, - ) - - aws_account_id = get_aws_account_id() - deployment = ctx.obj["deployment"] - happy_stack_name = get_happy_stack_name(deployment) - - payload = {"dataset_id": dataset_id} - - client = boto3.client("stepfunctions") - response = client.start_execution( - stateMachineArn=f"arn:aws:states:us-west-2:{aws_account_id}:stateMachine:dp-{happy_stack_name}-seurat-sfn", - name=f"{dataset_id}-{int(time())}", - input=json.dumps(payload), - ) - - click.echo( - f"Step function executing: " - f"https://us-west-2.console.aws.amazon.com/states/home?region=us-west-2#/executions/details/" - f"{response['executionArn']}" - ) diff --git a/scripts/smoke_tests/setup.py b/scripts/smoke_tests/setup.py index d39d83430fe32..f2420cd0d7105 100644 --- a/scripts/smoke_tests/setup.py +++ b/scripts/smoke_tests/setup.py @@ -1,24 +1,26 @@ #!/usr/bin/env python import json import os -import sys import threading from backend.common.constants import DATA_SUBMISSION_POLICY_VERSION from backend.common.corpora_config import CorporaAuthConfig -from tests.functional.backend.constants import API_URL, DATASET_URI +from tests.functional.backend.constants import API_URL, ATAC_SEQ_MANIFEST, DATASET_MANIFEST, VISIUM_DATASET_MANIFEST from tests.functional.backend.utils import ( get_auth_token, + get_curation_api_access_token, make_cookie, make_proxy_auth_token, make_session, - upload_and_wait, + upload_manifest_and_wait, ) # Amount to reduce chance of collision where multiple test instances select the same collection to test against NUM_TEST_DATASETS = 3 NUM_TEST_COLLECTIONS = 10 TEST_ACCT_CONTACT_NAME = "Smoke Test User" +VISIUM_ACCT_CONTACT_NAME = "Visium Test User" +ATAC_SEQ_ACCT_CONTACT_NAME = "ATAC Seq Test User" class SmokeTestsInitializer: @@ -31,42 +33,63 @@ def __init__(self): username, password = self.config.test_account_username, self.config.test_account_password auth_token = get_auth_token(username, password, self.session, self.config, self.deployment_stage) self.curator_cookie = make_cookie(auth_token) + self.curation_api_access_token = get_curation_api_access_token(self.session, self.api, self.config) + self.manifests = [DATASET_MANIFEST, VISIUM_DATASET_MANIFEST, ATAC_SEQ_MANIFEST] self.headers = {"Cookie": f"cxguser={self.curator_cookie}", "Content-Type": "application/json"} + self.cached_get_collections_response = None - def get_collection_count(self): - res = self.session.get(f"{self.api}/curation/v1/collections?visiblity=PUBLIC", headers=self.headers) - res.raise_for_status() - data = json.loads(res.content) + def get_collection_count(self, contact_name: str, expected_count: int) -> int: + if self.cached_get_collections_response is None: + res = self.session.get(f"{self.api}/curation/v1/collections?visiblity=PUBLIC", headers=self.headers) + res.raise_for_status() + self.cached_get_collections_response = json.loads(res.content) num_collections = 0 - for collection in data: - if collection["contact_name"] == TEST_ACCT_CONTACT_NAME: + for collection in self.cached_get_collections_response: + if collection["contact_name"] == contact_name: num_collections += 1 - if num_collections == NUM_TEST_COLLECTIONS: + if num_collections == expected_count: return num_collections return num_collections - def create_and_publish_collection(self, dropbox_url): - collection_id = self.create_collection() + def start_upload_thread(self, collection_id, manifest) -> threading.Thread: + _thread = threading.Thread( + target=upload_manifest_and_wait, + args=( + self.session, + self.api, + self.curation_api_access_token, + self.curator_cookie, + collection_id, + manifest, + ), + ) + _thread.start() + return _thread + + def create_and_publish_collection( + self, + contact_name=TEST_ACCT_CONTACT_NAME, + collection_name="test collection", + manifest=DATASET_MANIFEST, + num_datasets=NUM_TEST_DATASETS, + ): + collection_id = self.create_collection(contact_name, collection_name) _threads = [] - for _ in range(NUM_TEST_DATASETS): - _thread = threading.Thread( - target=upload_and_wait, args=(self.session, self.api, self.curator_cookie, collection_id, dropbox_url) - ) - _threads.append(_thread) - _thread.start() + for _ in range(num_datasets): + _threads.append(self.start_upload_thread(collection_id, manifest)) for _thread in _threads: _thread.join() self.publish_collection(collection_id) print(f"created and published collection {collection_id}") - def create_collection(self): + def create_collection(self, contact_name, collection_name): data = { "contact_email": "example@gmail.com", - "contact_name": TEST_ACCT_CONTACT_NAME, + "contact_name": contact_name, "curator_name": "John Smith", "description": "Well here are some words", "links": [{"link_name": "a link to somewhere", "link_type": "PROTOCOL", "link_url": "http://protocol.com"}], - "name": "test collection", + "name": collection_name, } res = self.session.post(f"{self.api}/dp/v1/collections", data=json.dumps(data), headers=self.headers) @@ -85,14 +108,34 @@ def publish_collection(self, collection_id): if __name__ == "__main__": smoke_test_init = SmokeTestsInitializer() # check whether we need to create collections - collection_count = smoke_test_init.get_collection_count() - if collection_count >= NUM_TEST_COLLECTIONS: - sys.exit(0) - num_to_create = NUM_TEST_COLLECTIONS - collection_count + test_collection_count = smoke_test_init.get_collection_count(TEST_ACCT_CONTACT_NAME, NUM_TEST_COLLECTIONS) + visium_collection_count = smoke_test_init.get_collection_count(VISIUM_ACCT_CONTACT_NAME, 1) + atac_seq_collection_count = smoke_test_init.get_collection_count(ATAC_SEQ_ACCT_CONTACT_NAME, 1) + threads = [] - for _ in range(num_to_create): - thread = threading.Thread(target=smoke_test_init.create_and_publish_collection, args=(DATASET_URI,)) + + if test_collection_count < NUM_TEST_COLLECTIONS: + num_to_create = NUM_TEST_COLLECTIONS - test_collection_count + + for _ in range(num_to_create): + thread = threading.Thread(target=smoke_test_init.create_and_publish_collection) + threads.append(thread) + + if visium_collection_count < 1: + thread = threading.Thread( + target=smoke_test_init.create_and_publish_collection, + args=(VISIUM_ACCT_CONTACT_NAME, "Visium Test Collection", VISIUM_DATASET_MANIFEST, 1), + ) threads.append(thread) + + if atac_seq_collection_count < 1: + thread = threading.Thread( + target=smoke_test_init.create_and_publish_collection, + args=(ATAC_SEQ_ACCT_CONTACT_NAME, "ATAC Seq Test Collection", ATAC_SEQ_MANIFEST, 1), + ) + threads.append(thread) + + for thread in threads: thread.start() for thread in threads: thread.join() diff --git a/tests/functional/backend/conftest.py b/tests/functional/backend/conftest.py index 347ebe59e07b3..be9ff9630e498 100644 --- a/tests/functional/backend/conftest.py +++ b/tests/functional/backend/conftest.py @@ -7,10 +7,12 @@ from tests.functional.backend.distributed import distributed_singleton from tests.functional.backend.utils import ( get_auth_token, + get_curation_api_access_token, make_cookie, make_proxy_auth_token, make_session, - upload_and_wait, + upload_manifest_and_wait, + upload_url_and_wait, ) @@ -67,22 +69,15 @@ def api_url(deployment_stage): @pytest.fixture(scope="session") def curation_api_access_token(session, api_url, config, tmp_path_factory, worker_id): def _curation_api_access_token() -> str: - response = session.post( - f"{api_url}/curation/v1/auth/token", - headers={"x-api-key": config.super_curator_api_key}, - ) - response.raise_for_status() - return response.json()["access_token"] + return get_curation_api_access_token(session, api_url, config) return distributed_singleton(tmp_path_factory, worker_id, _curation_api_access_token) @pytest.fixture(scope="session") def upload_dataset(session, api_url, curator_cookie, request): - def _upload_dataset(collection_id, dropbox_url, existing_dataset_id=None, skip_rds_status=False): - result = upload_and_wait( - session, api_url, curator_cookie, collection_id, dropbox_url, existing_dataset_id, skip_rds_status - ) + def _upload_dataset(collection_id, dropbox_url, existing_dataset_id=None): + result = upload_url_and_wait(session, api_url, curator_cookie, collection_id, dropbox_url, existing_dataset_id) dataset_id = result["dataset_id"] headers = {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} request.addfinalizer(lambda: session.delete(f"{api_url}/dp/v1/datasets/{dataset_id}", headers=headers)) @@ -93,6 +88,22 @@ def _upload_dataset(collection_id, dropbox_url, existing_dataset_id=None, skip_r return _upload_dataset +@pytest.fixture(scope="session") +def upload_manifest(session, api_url, curation_api_access_token, curator_cookie, request): + def _upload_manifest(collection_id: str, manifest: dict, existing_dataset_id=None): + result = upload_manifest_and_wait( + session, api_url, curation_api_access_token, curator_cookie, collection_id, manifest, existing_dataset_id + ) + dataset_id = result["dataset_id"] + headers = {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} + request.addfinalizer(lambda: session.delete(f"{api_url}/dp/v1/datasets/{dataset_id}", headers=headers)) + if result["errors"]: + raise pytest.fail(str(result["errors"])) + return dataset_id + + return _upload_manifest + + @pytest.fixture() def collection_data(request): return { diff --git a/tests/functional/backend/constants.py b/tests/functional/backend/constants.py index fd4ea6eb84fbb..2ca11dd1d3262 100644 --- a/tests/functional/backend/constants.py +++ b/tests/functional/backend/constants.py @@ -16,8 +16,18 @@ } DATASET_URI = ( - "https://www.dropbox.com/scl/fi/y50umqlcrbz21a6jgu99z/5_0_0_example_valid.h5ad?rlkey" - "=s7p6ybyx082hswix26hbl11pm&dl=0" + "https://www.dropbox.com/scl/fi/8yizrdfcfl02dtk3ke4sg/example_5_3_valid.h5ad?rlkey" + "=i1qc5qai9w2o9l1fithyatxdf&st=uxgudiwz&dl=0" ) -VISIUM_DATASET_URI = "https://www.dropbox.com/scl/fi/lmhue0va6ihk50ivp26da/visium_small.h5ad?rlkey=n0fo4dyi1ah7ckg9kgzwlhm8s&st=59m7g97u&dl=0" +VISIUM_DATASET_URI = ( + "https://www.dropbox.com/scl/fi/3y22olsc70of8rbb1es77/visium_small.h5ad?rlkey" + "=cgwd59ouk340zlqh6fcnthizz&st=u2nyo3xp&dl=0" +) + +DATASET_MANIFEST = {"anndata": DATASET_URI} +VISIUM_DATASET_MANIFEST = {"anndata": VISIUM_DATASET_URI} +ATAC_SEQ_MANIFEST = { + "anndata": "https://www.dropbox.com/scl/fi/rth5ol8dyn3qypmnr3w79/atac.h5ad?rlkey=lpor3wj4he2n4dkp6pq3v4c6t&st=dni608bw&dl=0", + "atac_fragment": "https://www.dropbox.com/scl/fi/p4kmriyki1xyvcc9bvwxc/fragments_sorted.tsv.gz?rlkey=hydxliidfy4yneaan2rrw2arp&dl=0", +} diff --git a/tests/functional/backend/corpora/test_api.py b/tests/functional/backend/corpora/test_api.py index 7513cd7cbc553..439b76c3ffaae 100644 --- a/tests/functional/backend/corpora/test_api.py +++ b/tests/functional/backend/corpora/test_api.py @@ -5,7 +5,7 @@ from requests import HTTPError from backend.common.constants import DATA_SUBMISSION_POLICY_VERSION -from tests.functional.backend.constants import DATASET_URI, VISIUM_DATASET_URI +from tests.functional.backend.constants import ATAC_SEQ_MANIFEST, DATASET_URI, VISIUM_DATASET_URI from tests.functional.backend.skip_reason import skip_creation_on_prod from tests.functional.backend.utils import assertStatusCode, create_test_collection @@ -164,15 +164,33 @@ def test_dataset_upload_flow_with_visium_dataset( ): headers = {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} collection_id = create_test_collection(headers, request, session, api_url, collection_data) + _verify_upload_and_delete_succeeded(collection_id, headers, VISIUM_DATASET_URI, session, api_url, upload_dataset) + + +@skip_creation_on_prod +def test_dataset_upload_flow_with_atac_seq_dataset( + session, curator_cookie, api_url, upload_manifest, request, collection_data, curation_api_access_token +): + headers = {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} + collection_id = create_test_collection( + headers, + request, + session, + api_url, + collection_data, + ) _verify_upload_and_delete_succeeded( - collection_id, headers, VISIUM_DATASET_URI, session, api_url, upload_dataset, skip_rds_status=True + collection_id, + headers, + ATAC_SEQ_MANIFEST, + session, + api_url, + upload_manifest, ) -def _verify_upload_and_delete_succeeded( - collection_id, headers, dataset_uri, session, api_url, upload_and_wait, skip_rds_status=False -): - dataset_id = upload_and_wait(collection_id, dataset_uri, skip_rds_status=skip_rds_status) +def _verify_upload_and_delete_succeeded(collection_id, headers, req_body, session, api_url, upload_and_wait): + dataset_id = upload_and_wait(collection_id, req_body) # test non owner cant retrieve status no_auth_headers = {"Content-Type": "application/json"} res = session.get(f"{api_url}/dp/v1/datasets/{dataset_id}/status", headers=no_auth_headers) diff --git a/tests/functional/backend/utils.py b/tests/functional/backend/utils.py index e1c08f660de3a..8dbf80a75c4a4 100644 --- a/tests/functional/backend/utils.py +++ b/tests/functional/backend/utils.py @@ -56,11 +56,10 @@ def assertStatusCode(actual: int, expected_response: requests.Response): def create_test_collection(headers, request, session, api_url, body): res = session.post(f"{api_url}/dp/v1/collections", data=json.dumps(body), headers=headers) - res.raise_for_status() + assertStatusCode(requests.codes.created, res) data = json.loads(res.content) collection_id = data["collection_id"] request.addfinalizer(lambda: session.delete(f"{api_url}/dp/v1/collections/{collection_id}", headers=headers)) - assertStatusCode(requests.codes.created, res) return collection_id @@ -68,12 +67,10 @@ def create_explorer_url(dataset_id: str, deployment_stage: str) -> str: return f"https://cellxgene.{deployment_stage}.single-cell.czi.technology/e/{dataset_id}.cxg/" -def upload_and_wait( - session, api_url, curator_cookie, collection_id, dropbox_url, existing_dataset_id=None, skip_rds_status=False -): +def upload_url_and_wait(session, api_url, curator_cookie, collection_id, dropbox_url, existing_dataset_id=None): headers = {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} body = {"url": dropbox_url} - errors = [] + if existing_dataset_id is None: res = session.post( f"{api_url}/dp/v1/collections/{collection_id}/upload-links", data=json.dumps(body), headers=headers @@ -83,11 +80,44 @@ def upload_and_wait( res = session.put( f"{api_url}/dp/v1/collections/{collection_id}/upload-links", data=json.dumps(body), headers=headers ) - - res.raise_for_status() + assertStatusCode(requests.codes.accepted, res) dataset_id = json.loads(res.content)["dataset_id"] - assert res.status_code == requests.codes.accepted + return _wait_for_dataset_status(session, api_url, dataset_id, headers) + + +def upload_manifest_and_wait( + session, api_url, curation_api_access_token, curator_cookie, collection_id, manifest, existing_dataset_id=None +): + headers = {"Authorization": f"Bearer {curation_api_access_token}", "Content-Type": "application/json"} + + if not existing_dataset_id: + # Create dataset id + res = session.post(f"{api_url}/curation/v1/collections/{collection_id}/datasets", headers=headers) + assertStatusCode(201, res) + dataset_id = json.loads(res.content)["dataset_id"] + res = session.get(f"{api_url}/curation/v1/collections/{collection_id}", headers=headers) + assertStatusCode(200, res) + version_id = json.loads(res.content)["datasets"][0]["dataset_version_id"] + else: + dataset_id = existing_dataset_id + + # Upload manifest + res = session.put( + f"{api_url}/curation/v1/collections/{collection_id}/datasets/{dataset_id}/manifest", + data=json.dumps(manifest), + headers=headers, + ) + assertStatusCode(202, res) + + # Wait for dataset status + return _wait_for_dataset_status( + session, api_url, version_id, {"Cookie": f"cxguser={curator_cookie}", "Content-Type": "application/json"} + ) + + +def _wait_for_dataset_status(session, api_url, dataset_id, headers): + errors = [] keep_trying = True expected_upload_statuses = ["WAITING", "UPLOADING", "UPLOADED"] expected_conversion_statuses = ["CONVERTING", "CONVERTED", "FAILED", "UPLOADING", "UPLOADED", "NA", None] @@ -104,17 +134,21 @@ def upload_and_wait( cxg_status = data.get("cxg_status") rds_status = data.get("rds_status") h5ad_status = data.get("h5ad_status") + atac_status = data.get("atac_status") assert data.get("cxg_status") in expected_conversion_statuses if cxg_status == "FAILED": errors.append(f"CXG CONVERSION FAILED. Status: {data}, Check logs for dataset: {dataset_id}") if h5ad_status == "FAILED": errors.append(f"Anndata CONVERSION FAILED. Status: {data}, Check logs for dataset: {dataset_id}") + if atac_status == "FAILED": + errors.append(f"Atac CONVERSION FAILED. Status: {data}, Check logs for dataset: {dataset_id}") if rds_status == "FAILED": errors.append(f"RDS CONVERSION FAILED. Status: {data}, Check logs for dataset: {dataset_id}") if any( [ - cxg_status == rds_status == h5ad_status == "UPLOADED", - skip_rds_status and cxg_status == h5ad_status == "UPLOADED" and rds_status == "SKIPPED", + cxg_status == h5ad_status == "UPLOADED" + and rds_status == "SKIPPED" + and atac_status in ["SKIPPED", "UPLOADED", "NA", "COPIED"], errors, ] ): @@ -161,3 +195,12 @@ def make_proxy_auth_token(config, deployment_stage) -> dict: access_token = res.json()["access_token"] return {"Authorization": f"Bearer {access_token}"} return {} + + +def get_curation_api_access_token(session, api_url, config) -> str: + response = session.post( + f"{api_url}/curation/v1/auth/token", + headers={"x-api-key": config.super_curator_api_key}, + ) + response.raise_for_status() + return response.json()["access_token"] diff --git a/tests/memory/processing/test_process_cxg.py b/tests/memory/processing/test_process_cxg.py index 7f8d2d141b2cc..c8b93308a0852 100644 --- a/tests/memory/processing/test_process_cxg.py +++ b/tests/memory/processing/test_process_cxg.py @@ -10,7 +10,7 @@ from tests.unit.backend.fixtures.environment_setup import fixture_file_path if __name__ == "__main__": - file_name = "labeled_visium.h5ad" + file_name = "labeled_slide_seq.h5ad" dataset_version_id = DatasetVersionId("test_dataset_id") with tempfile.TemporaryDirectory() as tmpdirname: temp_file = "/".join([tmpdirname, file_name]) diff --git a/tests/memory/processing/test_process_seurat.py b/tests/memory/processing/test_process_seurat.py deleted file mode 100644 index 0535a38d06096..0000000000000 --- a/tests/memory/processing/test_process_seurat.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This script is used to test the ProcessCxg class. -""" - -import shutil -import tempfile - -from backend.layers.common.entities import DatasetVersionId -from backend.layers.processing.process_seurat import ProcessSeurat -from tests.unit.backend.fixtures.environment_setup import fixture_file_path - -if __name__ == "__main__": - file_name = "labeled_visium.h5ad" - dataset_version_id = DatasetVersionId("test_dataset_id") - with tempfile.TemporaryDirectory() as tmpdirname: - temp_file = "/".join([tmpdirname, file_name]) - shutil.copy(fixture_file_path(file_name), temp_file) - process = ProcessSeurat(None, None, None) - process.make_seurat(temp_file, dataset_version_id) diff --git a/tests/memory/readme.md b/tests/memory/readme.md index 2c3b4496b097b..838c195d6a54d 100644 --- a/tests/memory/readme.md +++ b/tests/memory/readme.md @@ -28,6 +28,8 @@ mprof plot docker compose --project-directory ../../ run --rm -w /single-cell-data-portal processing python -m memory_profiler ./tests/memory/processing/test_process_cxg.py ``` +See [line-by-line memory usage](https://github.com/pythonprofilers/memory_profiler?tab=readme-ov-file#line-by-line-memory-usage) in the memory-profiler documentation for measuring memory in nested functions. + ### Run memray and produce a flamegraph: Before running you will need to comment out @profile from some of the code. This is used by memory-profiler and is not supported by memray. diff --git a/tests/unit/backend/common/test_manifest.py b/tests/unit/backend/common/test_manifest.py new file mode 100644 index 0000000000000..5f358810e44ea --- /dev/null +++ b/tests/unit/backend/common/test_manifest.py @@ -0,0 +1,26 @@ +import pydantic +import pytest + +from backend.layers.common.ingestion_manifest import IngestionManifest + + +@pytest.mark.parametrize( + "manifest", + [ + """{"anndata": "https://example.com/dataset.h5ad"}""", + """{"anndata": "s3://bucket/dataset.h5ad"}""", + """{"anndata": "https://example.com/dataset.h5ad", "atac_fragments": "https://example.com/fragments.tsv.gz"}""", + """{"anndata": "https://example.com/dataset.h5ad", "atac_fragments": "s3://bucket/fragments.tsv.gz"}""", + ], +) +def test_anndata_validation_success(manifest): + IngestionManifest.model_validate_json(manifest) + + +@pytest.mark.parametrize( + "manifest", + ["""{"atac_fragments": "https://example.com/fragments.tsv.gz"}""", """{"anndata": 1234}"""], +) +def test_anndata_validation_failure(manifest): + with pytest.raises(pydantic.ValidationError): + IngestionManifest.model_validate_json(manifest) diff --git a/tests/unit/backend/layers/api/test_curation_api.py b/tests/unit/backend/layers/api/test_curation_api.py index 216be970ad137..9ae6809ec149e 100644 --- a/tests/unit/backend/layers/api/test_curation_api.py +++ b/tests/unit/backend/layers/api/test_curation_api.py @@ -16,6 +16,7 @@ CollectionLinkType, CollectionMetadata, CollectionVersion, + DatasetArtifactType, DatasetProcessingStatus, DatasetStatusKey, DatasetUploadStatus, @@ -665,10 +666,12 @@ def test__get_published_collection_verify_body_is_reshaped_correctly__OK(self): collection_version=collection_version, metadata=dataset_metadata, artifacts=[ - DatasetArtifactUpdate(type="h5ad", uri="http://test_filename/1234-5678-9/local.h5ad"), - DatasetArtifactUpdate(type="rds", uri="http://test_filename/1234-5678-9/local.rds"), - DatasetArtifactUpdate(type="cxg", uri="http://test_filename/1234-5678-9/local.cxg"), - DatasetArtifactUpdate(type="raw_h5ad", uri="http://test_filename/1234-5678-9/raw.h5ad"), + DatasetArtifactUpdate(type=DatasetArtifactType.H5AD, uri="http://test_filename/1234-5678-9/local.h5ad"), + DatasetArtifactUpdate(type=DatasetArtifactType.RDS, uri="http://test_filename/1234-5678-9/local.rds"), + DatasetArtifactUpdate(type=DatasetArtifactType.CXG, uri="http://test_filename/1234-5678-9/local.cxg"), + DatasetArtifactUpdate( + type=DatasetArtifactType.RAW_H5AD, uri="http://test_filename/1234-5678-9/raw.h5ad" + ), ], ) self.business_logic.publish_collection_version(collection_version.version_id) @@ -1508,7 +1511,10 @@ def _delete(self, auth, collection_id, dataset_id, query_param_str=None): """ Helper method to call the delete endpoint """ - test_url = f"/curation/v1/collections/{collection_id}/datasets/{dataset_id}{'?' + query_param_str if query_param_str else ''}" + test_url = ( + f"/curation/v1/collections/{collection_id}/datasets/{dataset_id}" + f"{'?' + query_param_str if query_param_str else ''}" + ) headers = auth() if callable(auth) else auth return self.app.delete(test_url, headers=headers) @@ -1806,6 +1812,35 @@ def test_get_dataset_no_assets(self): body = response.json self.assertEqual([], body["assets"]) + def test_get_dataset_atac_assets(self): + dataset = self.generate_dataset( + artifacts=[ + DatasetArtifactUpdate(DatasetArtifactType.H5AD, "http://mock.uri/asset.h5ad"), + DatasetArtifactUpdate(DatasetArtifactType.ATAC_FRAGMENT, "http://mock.uri/atac_frags-fragment.tsv.bgz"), + DatasetArtifactUpdate( + DatasetArtifactType.ATAC_INDEX, "http://mock.uri/atac_frags-fragment.tsv.bgz.tbi" + ), + ] + ) + artifacts = self.business_logic.get_dataset_artifacts(DatasetVersionId(dataset.dataset_version_id)) + atac_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + + test_url = f"/curation/v1/collections/{dataset.collection_id}/datasets/{dataset.dataset_id}" + response = self.app.get(test_url) + body = response.json + + expected_assets = [ + {"filesize": -1, "filetype": "H5AD", "url": f"http://domain/{dataset.dataset_version_id}.h5ad"}, + {"filesize": -1, "filetype": "ATAC_FRAGMENT", "url": f"http://domain/{atac_artifact.id}-fragment.tsv.bgz"}, + { + "filesize": -1, + "filetype": "ATAC_INDEX", + "url": f"http://domain/{atac_artifact.id}-fragment.tsv.bgz.tbi", + }, + ] + + assert expected_assets == body["assets"] + def test_get_all_datasets_200(self): crossref_return_value_1 = (generate_mock_publisher_metadata(), "12.3456/j.celrep", 17169328.664) self.crossref_provider.fetch_metadata = Mock(return_value=crossref_return_value_1) @@ -1854,7 +1889,8 @@ def test_get_all_datasets_200(self): self.assertEqual(3, len(response.json)) with self.subTest( - "Contains collection_id, collection_version_id, collection_name, collection_doi, and collection_doi_label" + "Contains collection_id, collection_version_id, collection_name, collection_doi, " + "and collection_doi_label" ): collection_ids = {published_collection_1.collection_id.id, published_collection_2.collection_id.id} collection__version_ids = { @@ -2160,6 +2196,44 @@ def _validate_datasets(response_datasets, expected_dataset_ids: list[str]): self.assertIsNone(response_dataset["published_at"]) self.assertIsNone(response_dataset["revised_at"]) + def test_get_datasets_atac_seq(self): + collection = self.generate_unpublished_collection() + dataset = self.generate_dataset( + collection_version=collection, + artifacts=[ + DatasetArtifactUpdate(DatasetArtifactType.H5AD, "http://mock.uri/asset.h5ad"), + DatasetArtifactUpdate(DatasetArtifactType.ATAC_FRAGMENT, "http://mock.uri/atac_frags-fragment.tsv.bgz"), + DatasetArtifactUpdate( + DatasetArtifactType.ATAC_INDEX, "http://mock.uri/atac_frags-fragment.tsv.bgz.tbi" + ), + ], + ) + self.business_logic.publish_collection_version(collection.version_id) + artifacts = self.business_logic.get_dataset_artifacts(DatasetVersionId(dataset.dataset_version_id)) + atac_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + + response = self.app.get("/curation/v1/datasets") + body = response.json + expected_assets = [ + { + "filesize": -1, + "filetype": "H5AD", + "url": f"http://domain/{dataset.dataset_version_id}.h5ad", + }, + { + "filesize": -1, + "filetype": "ATAC_FRAGMENT", + "url": f"http://domain/{atac_artifact.id}-fragment.tsv.bgz", + }, + { + "filesize": -1, + "filetype": "ATAC_INDEX", + "url": f"http://domain/{atac_artifact.id}-fragment.tsv.bgz.tbi", + }, + ] + assert len(body) == 1, body + assert expected_assets == body[0]["assets"] + def test_get_private_datasets_400(self): # 400 if PRIVATE and schema version. self._fetch_datasets( @@ -2640,6 +2714,95 @@ def test_get_dataset_id_version_4xx(self): self.assertEqual(410, response.status_code) +class TestGetDatasetManifest(BaseAPIPortalTest): + def test_get_manifest_cases_ok(self): + cases = [ + { + "artifacts": { + "atac_fragment": DatasetArtifactUpdate( + DatasetArtifactType.ATAC_FRAGMENT, "http://mock.uri/atac_frags-fragment.tsv.bgz" + ) + }, + "name": "fragments_only", + }, + { + "artifacts": {"anndata": DatasetArtifactUpdate(DatasetArtifactType.H5AD, "http://mock.uri/asset.h5ad")}, + "name": "anndata_only", + }, + { + "artifacts": { + "anndata": DatasetArtifactUpdate(DatasetArtifactType.H5AD, "http://mock.uri/asset.h5ad"), + "atac_fragment": DatasetArtifactUpdate( + DatasetArtifactType.ATAC_FRAGMENT, "http://mock.uri/atac_frags-fragment.tsv.bgz" + ), + }, + "name": "anndata_and_fragments", + }, + {"artifacts": {}, "name": "no_artifacts"}, + ] + for case in cases: + with self.subTest(f"Get manifest case: {case['name']}"): + + collection = self.generate_unpublished_collection() + + dataset = self.generate_dataset( + collection_version=collection, + artifacts=list(case["artifacts"].values()), + ) + artifacts = self.business_logic.get_dataset_artifacts(DatasetVersionId(dataset.dataset_version_id)) + + assert len(artifacts) == len(case["artifacts"]) + + expected = {} + for artifact in artifacts: + if artifact.type == DatasetArtifactType.ATAC_FRAGMENT: + expected["atac_fragment"] = f"http://domain/{artifact.id}-fragment.{artifact.extension}" + elif artifact.type == DatasetArtifactType.H5AD: + expected["anndata"] = f"http://domain/{dataset.dataset_version_id}.{artifact.extension}" + + test_url = f"/curation/v1/collections/{dataset.collection_id}/datasets/{dataset.dataset_id}/manifest" + response = self.app.get(test_url) + self.assertEqual(200, response.status_code) + + assert expected == response.json + + def test__get_manifest_tombstoned__410(self): + published_collection = self.generate_published_collection() + dataset = published_collection.datasets[0] + self.business_logic.tombstone_collection(published_collection.collection_id) + with self.subTest("Returns 410 when a tombstoned canonical id is requested"): + resp = self.app.get( + f"/curation/v1/collections/{published_collection.collection_id}/datasets/{dataset.dataset_id}/manifest" + ) + self.assertEqual(410, resp.status_code) + + def test__get_manifest_by_dataset_version_id_fails(self): + collection = self.generate_unpublished_collection(add_datasets=1) + dataset = collection.datasets[0] + + test_url = f"/curation/v1/collections/{dataset.collection_id}/datasets/{dataset.version_id}/manifest" + response = self.app.get(test_url) + # TODO: I think this should be a 404 but this is also the behaviour of GET /collections/{ + # colleciton_id}/datasets/{dataset_version_id} + self.assertEqual(403, response.status_code) + + def test__get_manifest_by_missing_dataset_id_fails(self): + import uuid + + collection = self.generate_unpublished_collection(add_datasets=0) + + test_url = f"/curation/v1/collections/{collection.collection_id}/datasets/{uuid.uuid4()}/manifest" + response = self.app.get(test_url) + self.assertEqual(404, response.status_code) + + def test__get_manifest_by_missing_collection_id_fails(self): + import uuid + + test_url = f"/curation/v1/collections/{uuid.uuid4()}/datasets/{uuid.uuid4()}/manifest" + response = self.app.get(test_url) + self.assertEqual(404, response.status_code) + + class TestPostDataset(BaseAPIPortalTest): """ Unit test for POST /datasets, which is used to add an empty dataset to a collection version @@ -2769,12 +2932,7 @@ def test__post_revision__Super_Curator(self): return_value={"size": 1, "name": "file.h5ad"}, ) @patch("backend.layers.thirdparty.step_function_provider.StepFunctionProvider") -class TestPutLink(BaseAPIPortalTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.good_link = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" - cls.dummy_link = "https://www.dropbox.com/s/12345678901234/test.h5ad?dl=0" +class BasePutTest: def test__from_link__no_auth(self, *mocks): """ @@ -2783,11 +2941,11 @@ def test__from_link__no_auth(self, *mocks): dataset = self.generate_dataset( statuses=[DatasetStatusUpdate(DatasetStatusKey.PROCESSING, DatasetProcessingStatus.INITIALIZED)] ) - body = {"link": self.good_link} + body = self.good_request_body headers = None - for id in [dataset.dataset_version_id, dataset.dataset_id]: + for dataset_id in [dataset.dataset_version_id, dataset.dataset_id]: response = self.app.put( - f"/curation/v1/collections/{dataset.collection_id}/datasets/{id}", + self.endpoint.format(collection_version_id=dataset.collection_id, dataset_version_id=dataset_id), json=body, headers=headers, ) @@ -2803,11 +2961,11 @@ def test__from_link__Not_Public(self, *mocks): statuses=[DatasetStatusUpdate(DatasetStatusKey.PROCESSING, DatasetProcessingStatus.INITIALIZED)], publish=True, ) - body = {"link": self.good_link} + body = self.good_request_body headers = self.make_owner_header() - for id in [dataset.dataset_version_id, dataset.dataset_id]: + for dataset_id in [dataset.dataset_version_id, dataset.dataset_id]: response = self.app.put( - f"/curation/v1/collections/{dataset.collection_id}/datasets/{id}", + self.endpoint.format(collection_version_id=dataset.collection_id, dataset_version_id=dataset_id), json=body, headers=headers, ) @@ -2823,11 +2981,11 @@ def test__from_link__Not_Owner(self, *mocks): dataset = self.generate_dataset( statuses=[DatasetStatusUpdate(DatasetStatusKey.PROCESSING, DatasetProcessingStatus.INITIALIZED)], ) - body = {"link": self.dummy_link} + body = self.dummy_request_body headers = self.make_not_owner_header() - for id in [dataset.dataset_version_id, dataset.dataset_id]: + for dataset_id in [dataset.dataset_version_id, dataset.dataset_id]: response = self.app.put( - f"/curation/v1/collections/{dataset.collection_id}/datasets/{id}", + self.endpoint.format(collection_version_id=dataset.collection_id, dataset_version_id=dataset_id), json=body, headers=headers, ) @@ -2841,9 +2999,9 @@ def test__new_from_link__OK(self, *mocks): """ def _test_create(collection_id, dataset_id, headers): - body = {"link": self.good_link} + body = self.good_request_body response = self.app.put( - f"/curation/v1/collections/{collection_id}/datasets/{dataset_id}", + self.endpoint.format(collection_version_id=collection_id, dataset_version_id=dataset_id), json=body, headers=headers, ) @@ -2873,9 +3031,9 @@ def test__existing_from_link__OK(self, *mocks): """ def _test_create(collection_id, dataset_id, headers): - body = {"link": self.good_link} + body = self.good_request_body response = self.app.put( - f"/curation/v1/collections/{collection_id}/datasets/{dataset_id}", + self.endpoint.format(collection_version_id=collection_id, dataset_version_id=dataset_id), json=body, headers=headers, ) @@ -2905,10 +3063,12 @@ def test_from_link__403(self, *mocks): """ def _test_create(collection_version_id, dataset_version_id): - body = {"link": self.good_link} + body = self.good_request_body headers = self.make_owner_header() response = self.app.put( - f"/curation/v1/collections/{collection_version_id}/datasets/{dataset_version_id}", + self.endpoint.format( + collection_version_id=collection_version_id, dataset_version_id=dataset_version_id + ), json=body, headers=headers, ) @@ -2928,6 +3088,53 @@ def _test_create(collection_version_id, dataset_version_id): ) _test_create(dataset.collection_id, dataset.dataset_version_id) + def test_with_bad_already_ingested_anndata__400(self, *mocks): + """ + Calling Put /datasets/:dataset_id with a bad published anndata link should fail with 400 + """ + header = self.make_super_curator_header() + dataset = self.generate_dataset( + statuses=[DatasetStatusUpdate(DatasetStatusKey.PROCESSING, DatasetProcessingStatus.SUCCESS)], + ) + body = self.ingested_dataset_request_body + response = self.app.put( + self.endpoint.format(collection_version_id=dataset.collection_id, dataset_version_id=dataset.dataset_id), + json=body, + headers=header, + ) + self.assertEqual(400, response.status_code) + self.assertIn("detail", response.json.keys()) + + +class TestPutLink(BasePutTest, BaseAPIPortalTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.good_request_body = {"link": "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0"} + cls.dummy_request_body = {"link": "https://www.dropbox.com/s/12345678901234/test.h5ad?dl=0"} + cls.ingested_dataset_request_body = {"link": "http://domain/1234.txt"} + cls.endpoint = "/curation/v1/collections/{collection_version_id}/datasets/{dataset_version_id}" + + +class TestPutManifest(BasePutTest, BaseAPIPortalTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.good_request_body = {"anndata": "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0"} + cls.dummy_request_body = {"anndata": "https://www.dropbox.com/s/12345678901234/test.h5ad?dl=0"} + cls.ingested_dataset_request_body = {"anndata": "http://domain/1234.txt"} + cls.endpoint = "/curation/v1/collections/{collection_version_id}/datasets/{dataset_version_id}/manifest" + + +class TestPutManifestATAC(TestPutManifest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.good_request_body = { + "anndata": "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0", + "atat_seq_fragment": "https://www.dropbox.com/scl/fo/kfk8ahs6e109i5puqbdhs/AIe45xJ361JqwH89fwanGwE?dl=0", + } + class TestAuthToken(BaseAPIPortalTest): @patch("backend.curation.api.v1.curation.auth.token.CorporaAuthConfig") diff --git a/tests/unit/backend/layers/api/test_portal_api.py b/tests/unit/backend/layers/api/test_portal_api.py index e3dd9969327d7..9e085f5f470a2 100644 --- a/tests/unit/backend/layers/api/test_portal_api.py +++ b/tests/unit/backend/layers/api/test_portal_api.py @@ -101,6 +101,7 @@ def test__get_collection_id__ok(self): "name": "test_dataset_name", "organism": [{"label": "test_organism_label", "ontology_term_id": "test_organism_term_id"}], "processing_status": { + "atac_status": "NA", "created_at": 0, "cxg_status": "NA", "dataset_id": mock.ANY, @@ -156,6 +157,7 @@ def test__get_collection_id__ok(self): "name": "test_dataset_name", "organism": [{"label": "test_organism_label", "ontology_term_id": "test_organism_term_id"}], "processing_status": { + "atac_status": "NA", "created_at": 0, "cxg_status": "NA", "dataset_id": mock.ANY, @@ -1804,6 +1806,7 @@ def test__get_status__ok(self): self.assertEqual(200, response.status_code) actual_body = json.loads(response.data) expected_body = { + "atac_status": "NA", "cxg_status": "NA", "rds_status": "NA", "h5ad_status": "NA", @@ -2001,6 +2004,9 @@ def test__get_dataset_assets(self): artifacts=[ DatasetArtifactUpdate(DatasetArtifactType.CXG, "s3://mock-bucket/mock-key.cxg"), DatasetArtifactUpdate(DatasetArtifactType.H5AD, "s3://mock-bucket/mock-key.h5ad"), + DatasetArtifactUpdate(DatasetArtifactType.RDS, "s3://mock-bucket/mock-key.rds"), + DatasetArtifactUpdate(DatasetArtifactType.ATAC_FRAGMENT, "s3://mock-bucket/mock-key.tsv.bgz"), + DatasetArtifactUpdate(DatasetArtifactType.ATAC_INDEX, "s3://mock-bucket/mock-key.tsv.bgz.tbi"), ] ) @@ -2011,9 +2017,12 @@ def test__get_dataset_assets(self): body = json.loads(response.data) self.assertIn("assets", body) assets = body["assets"] - self.assertEqual(len(assets), 2) + self.assertEqual(len(assets), 5) self.assertEqual(assets[0]["s3_uri"], "s3://mock-bucket/mock-key.cxg") self.assertEqual(assets[1]["s3_uri"], "s3://mock-bucket/mock-key.h5ad") + self.assertEqual(assets[2]["s3_uri"], "s3://mock-bucket/mock-key.rds") + self.assertEqual(assets[3]["s3_uri"], "s3://mock-bucket/mock-key.tsv.bgz") + self.assertEqual(assets[4]["s3_uri"], "s3://mock-bucket/mock-key.tsv.bgz.tbi") # ✅ def test__cancel_dataset_download__ok(self): diff --git a/tests/unit/backend/layers/business/test_business.py b/tests/unit/backend/layers/business/test_business.py index 9a2a3002a1520..588d7ca2d0792 100644 --- a/tests/unit/backend/layers/business/test_business.py +++ b/tests/unit/backend/layers/business/test_business.py @@ -31,12 +31,14 @@ DatasetInWrongStatusException, DatasetIsTombstonedException, DatasetNotFoundException, + InvalidIngestionManifestException, InvalidMetadataException, InvalidURIException, NoPreviousCollectionVersionException, NoPreviousDatasetVersionException, ) from backend.layers.common.entities import ( + ARTIFACT_TO_EXTENSION, CollectionId, CollectionMetadata, CollectionVersion, @@ -49,6 +51,7 @@ DatasetId, DatasetMetadata, DatasetProcessingStatus, + DatasetStatusKey, DatasetUploadStatus, DatasetValidationStatus, DatasetVersionId, @@ -57,6 +60,7 @@ SpatialMetadata, TissueOntologyTermId, ) +from backend.layers.common.ingestion_manifest import IngestionManifest from backend.layers.persistence.persistence import DatabaseProvider from backend.layers.persistence.persistence_mock import DatabaseProviderMock from backend.layers.thirdparty.batch_job_provider import BatchJobProviderInterface @@ -304,22 +308,30 @@ def complete_dataset_processing_with_success(self, dataset_version_id: DatasetVe Test method that "completes" a dataset processing. This is necessary since dataset ingestion is a complex process which happens asynchronously, and cannot be easily mocked. """ - for ext in ("h5ad", "rds"): - key = f"{dataset_version_id}.{ext}" - self.database_provider.add_dataset_artifact( - dataset_version_id, DatasetArtifactType.H5AD.value, f"s3://artifacts/{key}" - ) - self.s3_provider.upload_file(None, "artifacts", key, None) - # At present, not keeping public dataset assets as rows in DatasetArtifact table - self.s3_provider.upload_file(None, "datasets", key, None) - self.database_provider.add_dataset_artifact( - dataset_version_id, DatasetArtifactType.CXG.value, f"s3://cellxgene/{dataset_version_id}.cxg/" - ) - self.s3_provider.upload_file(None, "cellxgene", f"{dataset_version_id}.cxg/", None) + + def _add_artifact(bucket, key, key_type): + ext = ARTIFACT_TO_EXTENSION[key_type] + key_name = f"{key}.{ext}/" if key_type == DatasetArtifactType.CXG else f"{key}.{ext}" + self.database_provider.create_dataset_artifact(dataset_version_id, key_type, f"s3://{bucket}/{key_name}") + self.s3_provider.upload_file(None, bucket, key_name, None) + + _add_artifact("artifacts", f"{dataset_version_id}/raw", DatasetArtifactType.RAW_H5AD) + # At present, not keeping public dataset assets as rows in DatasetArtifact table + _add_artifact("datasets", f"{dataset_version_id}", DatasetArtifactType.H5AD) + _add_artifact("datasets", f"{dataset_version_id}", DatasetArtifactType.RDS) + _add_artifact("cellxgene", f"{dataset_version_id}", DatasetArtifactType.CXG) + + # special case for atac artifacts + artifact_id = DatasetArtifactId() + bucket = "datasets" + + key_name = f"{artifact_id}-fragment" + _add_artifact(bucket, key_name, DatasetArtifactType.ATAC_FRAGMENT) + _add_artifact(bucket, key_name, DatasetArtifactType.ATAC_INDEX) + self.database_provider.update_dataset_upload_status(dataset_version_id, DatasetUploadStatus.UPLOADED) self.database_provider.update_dataset_validation_status(dataset_version_id, DatasetValidationStatus.VALID) self.database_provider.update_dataset_processing_status(dataset_version_id, DatasetProcessingStatus.SUCCESS) - # TODO: if required, set the conversion status as well class TestCreateCollection(BaseBusinessLogicTestCase): @@ -856,6 +868,128 @@ def test_add_dataset_to_unpublished_collection_ok(self): self.step_function_provider.start_step_function.assert_called_once() + def test_reingest_published_anndata_dataset(self): + """A cellxgene public dataset url can be used to ingest a new dataset version.""" + + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version = revision.datasets[0] + url = f"https://dataset_assets_domain/{dataset_version.version_id}.h5ad" + + new_dataset_version_id, _ = self.business_logic.ingest_dataset( + revision.version_id, url, None, dataset_version.version_id + ) + new_dataset_version = self.database_provider.get_dataset_version(new_dataset_version_id) + self.assertIsNotNone(new_dataset_version) + self.assertIsNone(new_dataset_version.metadata) + self.assertEqual(new_dataset_version.collection_id, revision.collection_id) + self.assertEqual(new_dataset_version.status.upload_status, DatasetUploadStatus.WAITING) + self.assertEqual(new_dataset_version.status.processing_status, DatasetProcessingStatus.INITIALIZED) + self.step_function_provider.start_step_function.assert_called_once_with( + revision.version_id, + new_dataset_version_id, + f'{{"anndata":"s3://artifacts/{dataset_version.version_id}/raw.h5ad","atac_fragment":null}}', + ) + + def test_reingest_published_anndata_dataset__not_h5ad(self): + """A cellxgene public dataset url used for reingesting an h5ad must be an h5ad""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version = revision.datasets[0] + url = f"https://dataset_assets_domain/{dataset_version.version_id}.rds" + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, url, None, dataset_version.version_id) + + def test_reingest_published_anndata_dataset__not_in_found(self): + """A cellxgene public dataset url must already be uploaded""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version_id = DatasetVersionId() + url = f"https://dataset_assets_domain/{dataset_version_id.id}.h5ad" + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, url, None, dataset_version_id) + + def test_reginest_published_anndata_dataset__not_part_of_canonical_dataset(self): + """A cellxgene public dataset url must be part of a version of the canonical dataset""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=2) + dataset_version = revision.datasets[0] + other_dataset_version = revision.datasets[1] + url = f"https://dataset_assets_domain/{other_dataset_version.version_id}.h5ad" + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, url, None, dataset_version.version_id) + + def test_ingest_published_anndata_dataset_in_new_dataset__not_allowed(self): + """A cellxgene public dataset url cannot be used to create a new canonical dataset.""" + published_dataset = self.initialize_published_collection().datasets[0] + unpublished_collection = self.initialize_empty_unpublished_collection() + url = f"https://dataset_assets_domain/{published_dataset.version_id.id}.h5ad" + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(unpublished_collection.version_id, url, None, None) + + def test_reingest_published_atac_dataset(self): + """A cellxgene public dataset url can be used to ingest a new dataset version.""" + + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version = revision.datasets[0] + artifact_id = [a.id for a in dataset_version.artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + anndata_url = f"https://dataset_assets_domain/{dataset_version.version_id}.h5ad" + fragment_url = f"https://dataset_assets_domain/{artifact_id}-fragment.tsv.bgz" + manifest = {"anndata": anndata_url, "atac_fragment": fragment_url} + + new_dataset_version_id, _ = self.business_logic.ingest_dataset( + revision.version_id, manifest, None, dataset_version.version_id + ) + new_dataset_version = self.database_provider.get_dataset_version(new_dataset_version_id) + self.assertIsNotNone(new_dataset_version) + self.assertIsNone(new_dataset_version.metadata) + self.assertEqual(new_dataset_version.collection_id, revision.collection_id) + self.assertEqual(new_dataset_version.status.upload_status, DatasetUploadStatus.WAITING) + self.assertEqual(new_dataset_version.status.processing_status, DatasetProcessingStatus.INITIALIZED) + self.step_function_provider.start_step_function.assert_called_once_with( + revision.version_id, + new_dataset_version_id, + f'{{"anndata":"s3://artifacts/{dataset_version.version_id}/raw.h5ad","atac_fragment":"{fragment_url}"}}', + ) + + def test_reingest_published_atac_dataset__not_atac(self): + """A cellxgene public dataset url used for reingesting an h5ad must be an h5ad""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version = revision.datasets[0] + anndata_url = f"https://dataset_assets_domain/{dataset_version.version_id}.h5ad" + fragment_url = f"https://dataset_assets_domain/{dataset_version.version_id}.tsv" + manifest = {"anndata": anndata_url, "atac_fragment": fragment_url} + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, manifest, None, dataset_version.version_id) + + def test_reingest_published_atac_dataset__not_in_found(self): + """A cellxgene public dataset url must already be uploaded""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=1) + dataset_version = revision.datasets[0] + anndata_url = f"https://dataset_assets_domain/{dataset_version.version_id}.h5ad" + missing_dataset_version_id = DatasetVersionId() + fragment_url = f"https://dataset_assets_domain/{revision.datasets[0].version_id}-fragment.tsv.bgz" + manifest = {"anndata": anndata_url, "atac_fragment": fragment_url} + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, manifest, None, missing_dataset_version_id) + + def test_reginest_published_atac_dataset__not_part_of_canonical_dataset(self): + """A cellxgene public dataset url must be part of a version of the canonical dataset""" + collection, revision = self.initialize_collection_with_an_unpublished_revision(num_datasets=2) + dataset_version = revision.datasets[0] + anndata_url = f"https://dataset_assets_domain/{dataset_version.version_id}.h5ad" + other_dataset_version = revision.datasets[1] + fragment_url = f"https://dataset_assets_domain/{other_dataset_version.version_id}-fragment.tsv.bgz" + manifest = {"anndata": anndata_url, "atac_fragment": fragment_url} + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(revision.version_id, manifest, None, dataset_version.version_id) + + def test_ingest_published_atac_dataset_in_new_dataset__not_allowed(self): + """A cellxgene public dataset url cannot be used to create a new canonical dataset.""" + published_dataset = self.initialize_published_collection().datasets[0] + unpublished_dataset = self.initialize_unpublished_collection().datasets[0] + anndata_url = f"https://dataset_assets_domain/{unpublished_dataset.version_id}.h5ad" + fragment_url = f"https://dataset_assets_domain/{published_dataset.version_id}-fragment.tsv.bgz" + manifest = {"anndata": anndata_url, "atac_fragment": fragment_url} + with self.assertRaises(InvalidIngestionManifestException): + self.business_logic.ingest_dataset(unpublished_dataset.version_id, manifest, None, None) + def test_add_dataset_to_non_existing_collection_fail(self): """ Calling `ingest_dataset` on a collection that does not exist should fail @@ -893,7 +1027,7 @@ def test_add_dataset_with_invalid_link_fail(self): with self.assertRaises(DatasetIngestException) as ex: self.business_logic.ingest_dataset(version.version_id, url, None, None) - self.assertEqual(str(ex.exception), "Trying to upload invalid URI: http://bad.url") + self.assertEqual(str(ex.exception), "Trying to upload invalid URI: http://bad.url/") self.step_function_provider.start_step_function.assert_not_called() @@ -1562,13 +1696,7 @@ def test_get_dataset_artifacts_ok(self): dataset_version_id = published_version.datasets[0].version_id artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(3, len(artifacts)) - expected = [ - f"s3://artifacts/{dataset_version_id}.h5ad", - f"s3://artifacts/{dataset_version_id}.rds", - f"s3://cellxgene/{dataset_version_id}.cxg/", - ] - self.assertEqual(set(expected), {a.uri for a in artifacts}) + self.assertEqual(len(published_version.datasets[0].artifacts), len(artifacts)) def test_get_dataset_artifact_download_data_ok(self): """ @@ -1601,11 +1729,26 @@ def test_get_dataset_status_for_uploaded_dataset_ok(self): self.assertEqual(status.upload_status, DatasetUploadStatus.UPLOADED) self.assertEqual(status.validation_status, DatasetValidationStatus.VALID) + def test_get_ingest_manifest(self): + published_version = self.initialize_published_collection() + dataset = published_version.datasets[0] + expected_manifest = IngestionManifest( + anndata=[artifact.uri for artifact in dataset.artifacts if artifact.type == DatasetArtifactType.RAW_H5AD][ + 0 + ], + atac_fragment=[ + artifact.uri for artifact in dataset.artifacts if artifact.type == DatasetArtifactType.ATAC_FRAGMENT + ][0], + ) + manifest = self.business_logic.get_ingestion_manifest(dataset.version_id) + self.assertEqual(expected_manifest, manifest) + class TestGetAllDatasets(BaseBusinessLogicTestCase): def test_get_all_private_datasets_ok(self): """ - Private datasets the user is authorized to view can be retrieved with `get_all_private_collection_versions_with_datasets`. + Private datasets the user is authorized to view can be retrieved with + `get_all_private_collection_versions_with_datasets`. """ # test_user_1: # - private collection (2 datasets) @@ -1660,7 +1803,8 @@ def _validate(actual: List[CollectionVersionWithDatasets], expected: List[Collec [d.version_id for d in datasets], ) - # Create the expected shape of revision_1_updated: datasets should only include the replacement dataset as well as the new dataset. + # Create the expected shape of revision_1_updated: datasets should only include the replacement dataset as + # well as the new dataset. revision_1_updated_expected = deepcopy(revision_1_updated) revision_1_updated_expected.datasets = [ self.database_provider.get_dataset_version(updated_dataset_version_id), @@ -1708,6 +1852,20 @@ def test_update_dataset_status_validation_message_ok(self): self.assertEqual(version_from_db.status.validation_status, DatasetValidationStatus.INVALID) self.assertEqual(version_from_db.status.validation_message, "Validation error!") + def test_update_dataset_status_validate_message_with_appending(self): + """New messages are appended to the existing ones""" + unpublished_collection = self.initialize_unpublished_collection(complete_dataset_ingestion=False) + dataset = unpublished_collection.datasets[0] + error_message = "Validation error!" + + for _ in range(2): + self.business_logic.update_dataset_version_status( + dataset.version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.INVALID, error_message + ) + version_from_db = self.database_provider.get_dataset_version(dataset.version_id) + validation_message = version_from_db.status.validation_message.split("\n") + self.assertEqual([error_message] * 2, validation_message) + def test_add_dataset_artifact_ok(self): """ A dataset artifact can be added using `add_dataset_artifact` @@ -1716,7 +1874,9 @@ def test_add_dataset_artifact_ok(self): self.assertEqual(2, len(unpublished_collection.datasets)) for dataset in unpublished_collection.datasets: self.assertEqual(dataset.artifacts, []) - self.business_logic.add_dataset_artifact(dataset.version_id, "h5ad", "http://fake.uri/artifact.h5ad") + self.business_logic.add_dataset_artifact( + dataset.version_id, DatasetArtifactType.H5AD, "http://fake.uri/artifact.h5ad" + ) version_from_db = self.database_provider.get_dataset_version(dataset.version_id) self.assertEqual(1, len(version_from_db.artifacts)) @@ -2878,19 +3038,20 @@ def test__delete_datasets_from_public_access_bucket(self): self.complete_dataset_processing_with_success(replaced_dataset_version_id) - dataset_version_ids = [d_v.version_id.id for d_v in published_collection.datasets] + [ - replaced_dataset_version_id + dataset_versions = published_collection.datasets + [ + self.business_logic.get_dataset_version(replaced_dataset_version_id) ] expected_delete_keys = set() fake_public_bucket = "datasets" - for d_v_id in dataset_version_ids: + for d_v in dataset_versions: for file_type in ("h5ad", "rds"): - key = f"{d_v_id}.{file_type}" + key = f"{d_v.version_id}.{file_type}" self.s3_provider.upload_file(None, fake_public_bucket, key, None) # Populate s3 mock with assets self.assertTrue(self.s3_provider.uri_exists(f"s3://{fake_public_bucket}/{key}")) - expected_delete_keys.add(f"{d_v_id}.{file_type}") + expected_delete_keys.add(f"{d_v.version_id}.{file_type}") + expected_delete_keys.update(self.business_logic.get_atac_fragment_uris_from_dataset_version(d_v)) self.assertTrue(len(expected_delete_keys) > 0) - [self.assertTrue(self.s3_provider.file_exists(fake_public_bucket, key)) for key in expected_delete_keys] + self.assertTrue(all(self.s3_provider.file_exists(fake_public_bucket, key) for key in expected_delete_keys)) actual_delete_keys = set( self.business_logic.delete_all_dataset_versions_from_public_bucket_for_collection( published_collection.collection_id @@ -3073,9 +3234,9 @@ def test_concurrency(self): dataset = collection.datasets[0] def add_artifact(): - self.database_provider.add_dataset_artifact(dataset.version_id, DatasetArtifactType.H5AD, "fake_uri") + self.database_provider.create_dataset_artifact(dataset.version_id, DatasetArtifactType.H5AD, "fake_uri") - self.assertEqual(len(dataset.artifacts), 3) + self.assertEqual(len(dataset.artifacts), 6) from concurrent.futures import ThreadPoolExecutor @@ -3086,7 +3247,7 @@ def add_artifact(): dv = self.business_logic.get_dataset_version(dataset.version_id) self.assertIsNotNone(dv) if dv is not None: - self.assertEqual(len(dv.artifacts), 13) + self.assertEqual(len(dv.artifacts), 16) class TestDatasetArtifactMetadataUpdates(BaseBusinessLogicTestCase): @@ -3141,7 +3302,7 @@ def test_trigger_dataset_artifact_update__with_new_dataset_version_id(self): current_dataset_version_id = revision.datasets[0].version_id new_dataset_version_id, _ = self.business_logic.ingest_dataset( revision.version_id, - None, + "http://fake.url", None, current_dataset_version_id=current_dataset_version_id, start_step_function=False, diff --git a/tests/unit/backend/layers/common/base_test.py b/tests/unit/backend/layers/common/base_test.py index 026dd9467f21d..c6c36604f4977 100644 --- a/tests/unit/backend/layers/common/base_test.py +++ b/tests/unit/backend/layers/common/base_test.py @@ -10,15 +10,18 @@ from backend.common.providers.crossref_provider import CrossrefProviderInterface from backend.layers.business.business import BusinessLogic from backend.layers.common.entities import ( + ARTIFACT_TO_EXTENSION, CollectionId, CollectionMetadata, CollectionVersion, CollectionVersionWithDatasets, + DatasetArtifact, DatasetArtifactType, DatasetMetadata, DatasetStatusGeneric, DatasetStatusKey, DatasetValidationStatus, + DatasetVersion, DatasetVersionId, Link, OntologyTermId, @@ -43,6 +46,10 @@ class DatasetArtifactUpdate: type: str uri: str + @property + def extension(self): + return ARTIFACT_TO_EXTENSION[self.type] + @dataclass class DatasetData: @@ -327,6 +334,11 @@ def generate_collection_revision(self, owner="test_user_id") -> CollectionVersio published_collection = self.generate_published_collection(owner) return self.business_logic.create_collection_version(published_collection.collection_id) + def get_artifact_type_from_dataset( + self, dataset_version: DatasetVersion, artifact_type: DatasetArtifactType + ) -> DatasetArtifact: + return next(artifact for artifact in dataset_version.artifacts if artifact.type == artifact_type) + def link_class_to_api_link_dict(link: Link) -> dict: return {"link_name": link.name, "link_type": link.type, "link_url": link.uri} diff --git a/tests/unit/backend/layers/utils/test_aws.py b/tests/unit/backend/layers/utils/test_aws.py index 7d5db8d295b59..d2dda6d418c0b 100644 --- a/tests/unit/backend/layers/utils/test_aws.py +++ b/tests/unit/backend/layers/utils/test_aws.py @@ -15,12 +15,10 @@ def setUp(self) -> None: super().setUp() self.tmp_dir = tempfile.mkdtemp() self.h5ad_filename = pathlib.Path(self.tmp_dir, "test.h5ad") - self.seurat_filename = pathlib.Path(self.tmp_dir, "test.rds") self.cxg_filename = pathlib.Path(self.tmp_dir, "test.cxg") self.h5ad_filename.touch() self.cxg_filename.touch() - self.seurat_filename.touch() # Mock S3 service if we don't have a mock api already running if os.getenv("BOTO_ENDPOINT_URL"): diff --git a/tests/unit/backend/layers/utils/test_uri.py b/tests/unit/backend/layers/utils/test_uri.py index d97ce4098a02f..f60995597bbc4 100644 --- a/tests/unit/backend/layers/utils/test_uri.py +++ b/tests/unit/backend/layers/utils/test_uri.py @@ -1,12 +1,12 @@ import os import unittest from tempfile import TemporaryDirectory -from unittest import TestCase +from unittest import TestCase, mock import boto3 from moto import mock_aws -from backend.common.utils.dl_sources.uri import S3URI, S3URL, DropBoxURL, RegisteredSources, from_uri +from backend.common.utils.dl_sources.uri import S3URI, S3URL, CXGPublicURL, DropBoxURL, RegisteredSources, from_uri class TestRegisteredSources(unittest.TestCase): @@ -92,3 +92,68 @@ def test_download(self): data = f.read() # assert the contents are correct self.assertEqual(data, content) + + +class TestS3URL(TestCase): + + def create_s3_url(self, bucket_name, key, content): + s3 = boto3.client("s3") + s3.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": os.environ["AWS_DEFAULT_REGION"]} + ) + s3.put_object(Bucket=bucket_name, Key=key, Body=content) + + return s3.generate_presigned_url("get_object", Params={"Bucket": bucket_name, "Key": key}, ExpiresIn=3600) + + @mock_aws + def test__validate_with_valid_s3_url__ok(self): + url = self.create_s3_url("bucket", "key/file.txt", content="stuff") + s3_url = S3URL.validate(url) + + self.assertEqual("https", s3_url.scheme) + self.assertEqual("bucket.s3.amazonaws.com", s3_url.netloc) + self.assertEqual("/key/file.txt", s3_url.path) + + @mock_aws + def test__validate_with_invalid_s3_url__returns_none(self): + s3_url = S3URL.validate("http://somebucket.s3.amazonaws.com/key") + self.assertIsNone(s3_url) + + s3_url = S3URL.validate("https://somebucket/key") + self.assertIsNone(s3_url) + + @mock_aws + def test_get_file_info(self): + url = self.create_s3_url("bucket", "key/file.txt", content="stuff") + s3_url = S3URL.validate(url) + info = s3_url.file_info() + self.assertEqual("/key/file.txt", info["name"]) + self.assertEqual(5, info["size"]) + + +class TestCXGPubURL(TestCase): + + @mock.patch("backend.common.utils.dl_sources.uri.CorporaConfig") + def test__validate_with_valid_url__ok(self, config_mock): + config = mock.Mock() + config.dataset_assets_base_url = "https://datasets.test.technology" + config_mock.return_value = config + + url = "https://datasets.test.technology/key/file.txt" + url = CXGPublicURL.validate(url) + + self.assertEqual("https", url.scheme) + self.assertEqual("datasets.test.technology", url.netloc) + self.assertEqual("/key/file.txt", url.path) + + @mock.patch("backend.common.utils.dl_sources.uri.CorporaConfig") + def test__validate_with_invalid_url__returns_none(self, config_mock): + config = mock.Mock() + config.dataset_assets_base_url = "https://datasets.test.technology" + config_mock.return_value = config + + url = CXGPublicURL.validate("http://somebucket.s3.amazonaws.com/key") + self.assertIsNone(url) + + url = CXGPublicURL.validate("https://somebucket/key") + self.assertIsNone(url) diff --git a/tests/unit/processing/base_processing_test.py b/tests/unit/processing/base_processing_test.py index 54081d92513bd..adb4dd3d5a312 100644 --- a/tests/unit/processing/base_processing_test.py +++ b/tests/unit/processing/base_processing_test.py @@ -13,4 +13,5 @@ def setUp(self): self.uri_provider.get_file_info = Mock(return_value=FileInfo(1, "local.h5ad")) self.s3_provider = MockS3Provider() self.schema_validator = Mock(spec=SchemaValidatorProviderInterface) - self.schema_validator.validate_and_save_labels = Mock(return_value=(True, [], True)) + self.schema_validator.validate_anndata = Mock(return_value=(True, [], True)) + self.schema_validator.validate_atac = Mock(return_value=([], "fragment.tsv.bgz.tbi", "fragment.tsv.bgz")) diff --git a/tests/unit/processing/schema_migration/test_collection_migrate.py b/tests/unit/processing/schema_migration/test_collection_migrate.py index c0106121637d3..2898f1a352a98 100644 --- a/tests/unit/processing/schema_migration/test_collection_migrate.py +++ b/tests/unit/processing/schema_migration/test_collection_migrate.py @@ -25,10 +25,10 @@ def test_migrate_published_collection(self, schema_migrate_and_collections): published.version_id.id, ) schema_migrate._store_sfn_response.assert_any_call( - "log_errors_and_cleanup", published.collection_id.id, response_for_log_errors_and_cleanup + "log_errors_and_cleanup", collection_version_id.id, response_for_log_errors_and_cleanup ) schema_migrate._store_sfn_response.assert_any_call( - "span_datasets", published.collection_id.id, response_for_span_datasets + "span_datasets", collection_version_id.id, response_for_span_datasets ) assert response_for_log_errors_and_cleanup["collection_version_id"] == collection_version_id.id assert ( @@ -60,10 +60,10 @@ def test_migrate_private_collection(self, schema_migrate_and_collections): private.version_id.id, ) schema_migrate._store_sfn_response.assert_any_call( - "log_errors_and_cleanup", private.collection_id.id, response_for_log_errors_and_cleanup + "log_errors_and_cleanup", private.version_id.id, response_for_log_errors_and_cleanup ) schema_migrate._store_sfn_response.assert_any_call( - "span_datasets", private.collection_id.id, response_for_span_datasets + "span_datasets", private.version_id.id, response_for_span_datasets ) # verify response_for_log_errors_and_cleanup @@ -94,7 +94,7 @@ def test_filter_schema_version(self, schema_migrate_and_collections): published.version_id.id, ) schema_migrate._store_sfn_response.assert_called_once_with( - "log_errors_and_cleanup", published.collection_id.id, response_for_log_errors_and_cleanup + "log_errors_and_cleanup", published.version_id.id, response_for_log_errors_and_cleanup ) # verify response_for_log_errors_and_cleanup @@ -124,7 +124,7 @@ def test_no_datasets(self, schema_migrate_and_collections): published.version_id.id, ) schema_migrate._store_sfn_response.assert_called_once_with( - "log_errors_and_cleanup", published.collection_id.id, response_for_log_errors_and_cleanup + "log_errors_and_cleanup", published.version_id.id, response_for_log_errors_and_cleanup ) # verify response_for_log_errors_and_cleanup diff --git a/tests/unit/processing/schema_migration/test_dataset_migrate.py b/tests/unit/processing/schema_migration/test_dataset_migrate.py index 356d8e9a9dab7..f0dc08c0242f9 100644 --- a/tests/unit/processing/schema_migration/test_dataset_migrate.py +++ b/tests/unit/processing/schema_migration/test_dataset_migrate.py @@ -1,4 +1,5 @@ -from backend.layers.common.entities import DatasetArtifact, DatasetVersionId +from backend.layers.common.entities import DatasetVersionId +from backend.layers.common.ingestion_manifest import IngestionManifest class TestDatasetMigrate: @@ -6,9 +7,10 @@ def test_dataset_migrate_private(self, schema_migrate_and_collections): schema_migrate, collections = schema_migrate_and_collections private = collections["private"][0] schema_migrate.business_logic.s3_provider.parse_s3_uri.return_value = ("fake-bucket", "object_key.h5ad") - schema_migrate.business_logic.get_dataset_artifacts.return_value = [ - DatasetArtifact(id=None, type="raw_h5ad", uri="s3://fake-bucket/object_key.h5ad") - ] + schema_migrate.business_logic.get_ingestion_manifest.return_value = IngestionManifest( + anndata="s3://fake-bucket/object_key.h5ad" + ) + new_dataset_version_id = DatasetVersionId() schema_migrate.business_logic.ingest_dataset.return_value = ( new_dataset_version_id, @@ -21,20 +23,26 @@ def test_dataset_migrate_private(self, schema_migrate_and_collections): assert response["collection_version_id"] == private.version_id.id assert response["dataset_version_id"] == new_dataset_version_id.id assert dataset_version_id != new_dataset_version_id.id - assert response["uri"] == f"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad" + assert ( + response["manifest"] + == f'{{"anndata":"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad","atac_fragment":null}}' + ) assert response["sfn_name"].startswith("migrate_") assert new_dataset_version_id.id in response["sfn_name"] schema_migrate.schema_validator.migrate.assert_called_once_with( - "previous_schema.h5ad", "migrated.h5ad", private.collection_id.id, private.datasets[0].dataset_id.id + schema_migrate.get_file_path("previous_schema.h5ad"), + schema_migrate.get_file_path("migrated.h5ad"), + private.collection_id.id, + private.datasets[0].dataset_id.id, ) def test_dataset_migrate_published(self, schema_migrate_and_collections): schema_migrate, collections = schema_migrate_and_collections published = collections["published"][0] schema_migrate.business_logic.s3_provider.parse_s3_uri.return_value = ("fake-bucket", "object_key.h5ad") - schema_migrate.business_logic.get_dataset_artifacts.return_value = [ - DatasetArtifact(id=None, type="raw_h5ad", uri="s3://fake-bucket/object_key.h5ad") - ] + schema_migrate.business_logic.get_ingestion_manifest.return_value = IngestionManifest( + anndata="s3://fake-bucket/object_key.h5ad" + ) new_dataset_version_id = DatasetVersionId() schema_migrate.business_logic.ingest_dataset.return_value = ( new_dataset_version_id, @@ -47,20 +55,26 @@ def test_dataset_migrate_published(self, schema_migrate_and_collections): assert response["collection_version_id"] == published.version_id.id assert response["dataset_version_id"] == new_dataset_version_id.id assert dataset_version_id != new_dataset_version_id.id - assert response["uri"] == f"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad" + assert ( + response["manifest"] + == f'{{"anndata":"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad","atac_fragment":null}}' + ) assert response["sfn_name"].startswith("migrate_") assert new_dataset_version_id.id in response["sfn_name"] schema_migrate.schema_validator.migrate.assert_called_once_with( - "previous_schema.h5ad", "migrated.h5ad", published.collection_id.id, published.datasets[0].dataset_id.id + schema_migrate.get_file_path("previous_schema.h5ad"), + schema_migrate.get_file_path("migrated.h5ad"), + published.collection_id.id, + published.datasets[0].dataset_id.id, ) def test_dataset_migrate_revision(self, schema_migrate_and_collections): schema_migrate, collections = schema_migrate_and_collections revision = collections["revision"][0] schema_migrate.business_logic.s3_provider.parse_s3_uri.return_value = ("fake-bucket", "object_key.h5ad") - schema_migrate.business_logic.get_dataset_artifacts.return_value = [ - DatasetArtifact(id=None, type="raw_h5ad", uri="s3://fake-bucket/object_key.h5ad") - ] + schema_migrate.business_logic.get_ingestion_manifest.return_value = IngestionManifest( + anndata="s3://fake-bucket/object_key.h5ad" + ) new_dataset_version_id = DatasetVersionId() schema_migrate.business_logic.ingest_dataset.return_value = ( new_dataset_version_id, @@ -73,9 +87,15 @@ def test_dataset_migrate_revision(self, schema_migrate_and_collections): assert response["collection_version_id"] == revision.version_id.id assert response["dataset_version_id"] == new_dataset_version_id.id assert dataset_version_id != new_dataset_version_id.id - assert response["uri"] == f"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad" + assert ( + response["manifest"] + == f'{{"anndata":"s3://artifact-bucket/{dataset_version_id}/migrated.h5ad","atac_fragment":null}}' + ) assert response["sfn_name"].startswith("migrate_") assert new_dataset_version_id.id in response["sfn_name"] schema_migrate.schema_validator.migrate.assert_called_once_with( - "previous_schema.h5ad", "migrated.h5ad", revision.collection_id.id, revision.datasets[0].dataset_id.id + schema_migrate.get_file_path("previous_schema.h5ad"), + schema_migrate.get_file_path("migrated.h5ad"), + revision.collection_id.id, + revision.datasets[0].dataset_id.id, ) diff --git a/tests/unit/processing/schema_migration/test_log_errors_and_cleanup.py b/tests/unit/processing/schema_migration/test_log_errors_and_cleanup.py index 07d66d6911c9e..c56de9d12a701 100644 --- a/tests/unit/processing/schema_migration/test_log_errors_and_cleanup.py +++ b/tests/unit/processing/schema_migration/test_log_errors_and_cleanup.py @@ -57,7 +57,7 @@ def test_OK(self, mock_json, schema_migrate): errors = schema_migrate.log_errors_and_cleanup(collection_version.version_id.id) assert errors == [] schema_migrate.s3_provider.delete_files.assert_any_call( - "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_id.json"] + "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_version_id.json"] ) schema_migrate.s3_provider.delete_files.assert_any_call( "artifact-bucket", @@ -110,7 +110,7 @@ def test_with_errors(self, mock_json, schema_migrate): assert schema_migrate.business_logic.restore_previous_dataset_version.call_count == 1 assert schema_migrate.business_logic.delete_dataset_versions.call_count == 1 schema_migrate.s3_provider.delete_files.assert_any_call( - "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_id.json"] + "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_version_id.json"] ) schema_migrate.s3_provider.delete_files.assert_any_call( "artifact-bucket", @@ -133,6 +133,6 @@ def test_skip_unprocessed_datasets(self, mock_json, schema_migrate): assert errors == [] schema_migrate.check_dataset_is_latest_schema_version.assert_not_called() schema_migrate.s3_provider.delete_files.assert_any_call( - "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_id.json"] + "artifact-bucket", ["schema_migration/test-execution-arn/log_errors_and_cleanup/collection_version_id.json"] ) schema_migrate.s3_provider.delete_files.assert_any_call("artifact-bucket", []) diff --git a/tests/unit/processing/test_cxg_generation_utils.py b/tests/unit/processing/test_cxg_generation_utils.py index d3da650471a3f..029efeb9e91b4 100644 --- a/tests/unit/processing/test_cxg_generation_utils.py +++ b/tests/unit/processing/test_cxg_generation_utils.py @@ -6,7 +6,9 @@ import numpy as np import tiledb +from dask.array import from_array from pandas import Categorical, DataFrame, Series +from scipy import sparse from backend.layers.processing.utils.cxg_generation_utils import ( convert_dataframe_to_cxg_array, @@ -84,7 +86,7 @@ def test__convert_ndarray_to_cxg_dense_array__writes_successfully(self): ndarray = np.random.rand(3, 2) ndarray_name = f"{self.testing_cxg_temp_directory}/awesome_ndarray_{uuid4()}" - convert_ndarray_to_cxg_dense_array(ndarray_name, ndarray, tiledb.Ctx()) + convert_ndarray_to_cxg_dense_array(ndarray_name, ndarray, {}) actual_stored_array = tiledb.open(ndarray_name) @@ -93,43 +95,42 @@ def test__convert_ndarray_to_cxg_dense_array__writes_successfully(self): self.assertTrue((actual_stored_array[:, :] == ndarray).all()) def test__convert_matrices_to_cxg_arrays__dense_array_writes_successfully(self): - matrix = np.float32(np.random.rand(3, 2)) + matrix = from_array(np.float32(np.random.rand(3, 2))) matrix_name = f"{self.testing_cxg_temp_directory}/awesome_matrix_{uuid4()}" - convert_matrices_to_cxg_arrays(matrix_name, matrix, False, tiledb.Ctx()) + convert_matrices_to_cxg_arrays(matrix_name, matrix, False, {}) actual_stored_array = tiledb.open(matrix_name) self.assertTrue(path.isdir(matrix_name)) self.assertTrue(isinstance(actual_stored_array, tiledb.DenseArray)) self.assertTrue((actual_stored_array[:, :] == matrix).all()) def test__convert_matrices_to_cxg_arrays__sparse_array_only_store_nonzeros_empty_array(self): - matrix = np.zeros([3, 2]) + matrix = from_array(sparse.csr_matrix((np.zeros([3, 2])), dtype=np.float32)) matrix_name = f"{self.testing_cxg_temp_directory}/awesome_zero_matrix_{uuid4()}" - convert_matrices_to_cxg_arrays(matrix_name, matrix, True, tiledb.Ctx()) + convert_matrices_to_cxg_arrays(matrix_name, matrix, True, {}) - for suffix in ["r", "c"]: - actual_stored_array = tiledb.open(matrix_name + suffix) - self.assertTrue(path.isdir(matrix_name + suffix)) - self.assertTrue(isinstance(actual_stored_array, tiledb.SparseArray)) - self.assertTrue(actual_stored_array[:][""].size == 0) + actual_stored_array = tiledb.open(matrix_name) + self.assertTrue(path.isdir(matrix_name)) + self.assertTrue(isinstance(actual_stored_array, tiledb.SparseArray)) + self.assertTrue(actual_stored_array[:][""].size == 0) def test__convert_matrices_to_cxg_arrays__sparse_array_only_store_nonzeros(self): - matrix = np.zeros([3, 3]) + matrix = from_array(sparse.csr_matrix((np.zeros([3, 3])), dtype=np.float32)) matrix[0, 0] = 1 matrix[1, 1] = 1 matrix[2, 2] = 2 matrix_name = f"{self.testing_cxg_temp_directory}/awesome_sparse_matrix_{uuid4()}" - convert_matrices_to_cxg_arrays(matrix_name, matrix, True, tiledb.Ctx()) + convert_matrices_to_cxg_arrays(matrix_name, matrix, True, {}) def get_value_at_coord(array, coord, attr): x, y = coord return array[x][""][array[x][attr] == y][0] - for suffix, attr_dim in zip(["r", "c"], ["var", "obs"], strict=False): - actual_stored_array = tiledb.open(matrix_name + suffix) - self.assertTrue(path.isdir(matrix_name + suffix)) + for attr_dim in ["var", "obs"]: + actual_stored_array = tiledb.open(matrix_name) + self.assertTrue(path.isdir(matrix_name)) self.assertTrue(isinstance(actual_stored_array, tiledb.SparseArray)) self.assertTrue(get_value_at_coord(actual_stored_array, (0, 0), attr_dim) == 1) self.assertTrue(get_value_at_coord(actual_stored_array, (1, 1), attr_dim) == 1) diff --git a/tests/unit/processing/test_dataset_metadata_update.py b/tests/unit/processing/test_dataset_metadata_update.py index cf791205c5323..e40440850f097 100644 --- a/tests/unit/processing/test_dataset_metadata_update.py +++ b/tests/unit/processing/test_dataset_metadata_update.py @@ -1,7 +1,5 @@ import json -import os import tempfile -from shutil import copy2 from unittest.mock import Mock, patch import pytest @@ -27,12 +25,10 @@ from backend.layers.processing.exceptions import ProcessingFailed from backend.layers.processing.utils.cxg_generation_utils import convert_dictionary_to_cxg_group from backend.layers.thirdparty.s3_provider_mock import MockS3Provider -from tests.unit.backend.fixtures.environment_setup import fixture_file_path from tests.unit.backend.layers.common.base_test import DatasetArtifactUpdate, DatasetStatusUpdate from tests.unit.processing.base_processing_test import BaseProcessingTest base = importr("base") -seurat = importr("SeuratObject") def mock_process(target, args=()): @@ -72,7 +68,7 @@ def test_update_metadata(self, mock_worker_factory, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -86,7 +82,6 @@ def test_update_metadata(self, mock_worker_factory, *args): # skip raw_h5ad update since no updated fields are expected fields in raw H5AD mock_worker.update_raw_h5ad.assert_not_called() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_called_once() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata @@ -99,6 +94,9 @@ def test_update_metadata(self, mock_worker_factory, *args): assert new_dataset_version.status.upload_status == DatasetUploadStatus.UPLOADED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id}/raw.h5ad") @patch("backend.common.utils.dl_sources.uri.downloader") @@ -122,7 +120,7 @@ def test_update_metadata__rds_skipped(self, mock_worker_factory, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -135,7 +133,6 @@ def test_update_metadata__rds_skipped(self, mock_worker_factory, *args): mock_worker.update_raw_h5ad.assert_not_called() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_not_called() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata @@ -148,6 +145,9 @@ def test_update_metadata__rds_skipped(self, mock_worker_factory, *args): assert new_dataset_version.status.upload_status == DatasetUploadStatus.UPLOADED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id}/raw.h5ad") @patch("backend.common.utils.dl_sources.uri.downloader") @@ -166,7 +166,7 @@ def test_update_metadata__raw_h5ad_updated(self, mock_worker_factory, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -179,13 +179,15 @@ def test_update_metadata__raw_h5ad_updated(self, mock_worker_factory, *args): mock_worker.update_raw_h5ad.assert_called_once() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_called_once() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata collection_version = self.business_logic.get_collection_version(collection_version_id) new_dataset_version = collection_version.datasets[0] + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS def test_update_metadata__current_dataset_version_bad_processing_status(self, *args): @@ -238,7 +240,7 @@ def test_update_metadata__error_if_missing_raw_h5ad(self, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -270,7 +272,7 @@ def test_update_metadata__missing_labeled_h5ad(self, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -284,40 +286,8 @@ def test_update_metadata__missing_labeled_h5ad(self, *args): assert new_dataset_version.status.h5ad_status == DatasetConversionStatus.FAILED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE - @patch("backend.common.utils.dl_sources.uri.downloader") - @patch("scanpy.read_h5ad") - @patch("backend.layers.processing.dataset_metadata_update.S3Provider", Mock(side_effect=MockS3Provider)) - @patch("backend.layers.processing.dataset_metadata_update.DatabaseProvider", Mock(side_effect=DatabaseProviderMock)) - @patch("backend.layers.processing.dataset_metadata_update.DatasetMetadataUpdater") - def test_update_metadata__missing_rds(self, *args): - current_dataset_version = self.generate_dataset( - artifacts=[ - DatasetArtifactUpdate(DatasetArtifactType.RAW_H5AD, "s3://fake.bucket/raw.h5ad"), - DatasetArtifactUpdate(DatasetArtifactType.H5AD, "s3://fake.bucket/local.h5ad"), - DatasetArtifactUpdate(DatasetArtifactType.CXG, "s3://fake.bucket/local.cxg"), - ], - statuses=[ - DatasetStatusUpdate(status_key=DatasetStatusKey.PROCESSING, status=DatasetProcessingStatus.SUCCESS), - DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=DatasetConversionStatus.CONVERTED), - ], - ) - collection_version_id = CollectionVersionId(current_dataset_version.collection_version_id) - current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) - new_dataset_version_id, _ = self.business_logic.ingest_dataset( - collection_version_id=collection_version_id, - url=None, - file_size=0, - current_dataset_version_id=current_dataset_version_id, - start_step_function=False, - ) - - with pytest.raises(ProcessingFailed): - self.updater.update_metadata(current_dataset_version_id, new_dataset_version_id, None) - - new_dataset_version = self.business_logic.get_dataset_version(new_dataset_version_id) - - assert new_dataset_version.status.rds_status == DatasetConversionStatus.FAILED - assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED @patch("backend.common.utils.dl_sources.uri.downloader") @patch("scanpy.read_h5ad") @@ -340,7 +310,7 @@ def test_update_metadata__missing_cxg(self, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -354,6 +324,9 @@ def test_update_metadata__missing_cxg(self, *args): assert new_dataset_version.status.cxg_status == DatasetConversionStatus.FAILED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + @patch("backend.common.utils.dl_sources.uri.downloader") @patch("scanpy.read_h5ad") @patch("backend.layers.processing.dataset_metadata_update.DatasetMetadataUpdater") @@ -368,7 +341,7 @@ def test_update_metadata__invalid_artifact_status(self, *args): current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version_id, start_step_function=False, @@ -403,7 +376,7 @@ def test_update_raw_h5ad(self, mock_read_h5ad, *args): current_dataset_version = collection_version.datasets[0] new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version.version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version.version_id, start_step_function=False, @@ -448,7 +421,7 @@ def test_update_h5ad(self, mock_read_h5ad, *args): current_dataset_version = collection_version.datasets[0] new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version.version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version.version_id, start_step_function=False, @@ -507,7 +480,7 @@ def test_update_cxg(self): current_dataset_version = collection_version.datasets[0] new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version.version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version.version_id, start_step_function=False, @@ -572,7 +545,7 @@ def test_update_cxg__with_spatial_deepzoom_assets(self): current_dataset_version = collection_version.datasets[0] new_dataset_version_id, _ = self.business_logic.ingest_dataset( collection_version_id=collection_version.version_id, - url=None, + url="http://fake.url", file_size=0, current_dataset_version_id=current_dataset_version.version_id, start_step_function=False, @@ -599,44 +572,6 @@ def test_update_cxg__with_spatial_deepzoom_assets(self): f"s3://{self.updater.spatial_deep_zoom_dir}/{new_dataset_version_id.id}" ) - @patch("backend.layers.processing.dataset_metadata_update.os.remove") - def test_update_rds(self, *args): - with tempfile.TemporaryDirectory() as tempdir: - temp_path = os.path.join(tempdir, "test.rds") - copy2(fixture_file_path("test.rds"), temp_path) - self.updater.download_from_source_uri = Mock(return_value=temp_path) - - collection_version = self.generate_unpublished_collection(add_datasets=1) - current_dataset_version = collection_version.datasets[0] - new_dataset_version_id, _ = self.business_logic.ingest_dataset( - collection_version_id=collection_version.version_id, - url=None, - file_size=0, - current_dataset_version_id=current_dataset_version.version_id, - start_step_function=False, - ) - key_prefix = new_dataset_version_id.id - metadata_update_dict = DatasetArtifactMetadataUpdate(title="New Dataset Title") - - self.updater.update_rds(None, key_prefix, new_dataset_version_id, metadata_update_dict) - - # check Seurat object metadata is updated - seurat_object = base.readRDS(temp_path) - assert seurat.Misc(object=seurat_object, slot="title")[0] == "New Dataset Title" - # schema_version should stay the same as base fixture after update of other metadata - assert seurat.Misc(object=seurat_object, slot="schema_version")[0] == "3.1.0" - - # check new artifacts are uploaded in expected uris - assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id.id}/test.rds") - assert self.updater.s3_provider.uri_exists(f"s3://datasets_bucket/{new_dataset_version_id.id}.rds") - - # check artifacts + status updated in DB - new_dataset_version = self.business_logic.get_dataset_version(new_dataset_version_id) - artifacts = [(artifact.uri, artifact.type) for artifact in new_dataset_version.artifacts] - assert (f"s3://artifact_bucket/{new_dataset_version_id.id}/test.rds", DatasetArtifactType.RDS) in artifacts - - assert new_dataset_version.status.rds_status == DatasetConversionStatus.CONVERTED - class TestValidArtifactStatuses(BaseProcessingTest): def setUp(self): @@ -653,6 +588,7 @@ def test_has_valid_artifact_statuses(self, rds_status): DatasetStatusUpdate(status_key=DatasetStatusKey.H5AD, status=DatasetConversionStatus.CONVERTED), DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=rds_status), DatasetStatusUpdate(status_key=DatasetStatusKey.CXG, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.ATAC, status=DatasetConversionStatus.SKIPPED), ] ) @@ -667,6 +603,7 @@ def test_has_valid_artifact_statuses__invalid_rds_status(self, rds_status): DatasetStatusUpdate(status_key=DatasetStatusKey.H5AD, status=DatasetConversionStatus.CONVERTED), DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=rds_status), DatasetStatusUpdate(status_key=DatasetStatusKey.CXG, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.ATAC, status=DatasetConversionStatus.SKIPPED), ] ) @@ -681,6 +618,7 @@ def test_has_valid_artifact_statuses__invalid_h5ad_status(self, h5ad_status): DatasetStatusUpdate(status_key=DatasetStatusKey.H5AD, status=h5ad_status), DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=DatasetConversionStatus.CONVERTED), DatasetStatusUpdate(status_key=DatasetStatusKey.CXG, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.ATAC, status=DatasetConversionStatus.SKIPPED), ] ) @@ -695,6 +633,22 @@ def test_has_valid_artifact_statuses__invalid_cxg_status(self, cxg_status): DatasetStatusUpdate(status_key=DatasetStatusKey.H5AD, status=DatasetConversionStatus.CONVERTED), DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=DatasetConversionStatus.CONVERTED), DatasetStatusUpdate(status_key=DatasetStatusKey.CXG, status=cxg_status), + DatasetStatusUpdate(status_key=DatasetStatusKey.ATAC, status=DatasetConversionStatus.SKIPPED), + ] + ) + + dataset_version_id = DatasetVersionId(dataset_version.dataset_version_id) + assert self.updater.has_valid_artifact_statuses(dataset_version_id) is False + + @parameterized.expand([DatasetConversionStatus.CONVERTING, DatasetConversionStatus.FAILED]) + def test_has_valid_artifact_statuses__invalid_atac_status(self, atac_status): + dataset_version = self.generate_dataset( + statuses=[ + DatasetStatusUpdate(status_key=DatasetStatusKey.PROCESSING, status=DatasetProcessingStatus.PENDING), + DatasetStatusUpdate(status_key=DatasetStatusKey.H5AD, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.CXG, status=DatasetConversionStatus.CONVERTED), + DatasetStatusUpdate(status_key=DatasetStatusKey.ATAC, status=atac_status), ] ) diff --git a/tests/unit/processing/test_dataset_submissions.py b/tests/unit/processing/test_dataset_submissions.py deleted file mode 100644 index 9d1d714da4c59..0000000000000 --- a/tests/unit/processing/test_dataset_submissions.py +++ /dev/null @@ -1,104 +0,0 @@ -from unittest.mock import Mock, patch - -from backend.common.utils.exceptions import ( - CorporaException, - NonExistentCollectionException, - NonExistentDatasetException, -) -from backend.layers.common.entities import EntityId -from backend.layers.processing.submissions.app import dataset_submissions_handler -from tests.unit.backend.layers.common.base_test import BaseTest - - -class TestDatasetSubmissions(BaseTest): - def setUp(self) -> None: - super().setUp() - self.user_name = "test_user_id" - self.mock = patch( - "backend.layers.processing.submissions.app.get_business_logic", return_value=self.business_logic - ) - self.mock.start() - - def tearDown(self): - self.mock.stop() - - def test__missing_curator_file_name__raises_error(self): - mock_collection_id = EntityId() - s3_event = create_s3_event(key=f"{self.user_name}/{mock_collection_id}/") - with self.assertRaises(CorporaException): - dataset_submissions_handler(s3_event, None) - - def test__missing_collection_id__raises_error(self): - mock_collection_id = EntityId() - mock_dataset_id = EntityId() - - s3_event = create_s3_event(key=f"{self.user_name}/{mock_collection_id}/{mock_dataset_id}") - with self.assertRaises(NonExistentCollectionException): - dataset_submissions_handler(s3_event, None) - - def test__nonexistent_dataset__raises_error(self): - version = self.generate_unpublished_collection() - mock_dataset_id = EntityId() - - s3_event = create_s3_event(key=f"{self.user_name}/{version.version_id}/{mock_dataset_id}") - with self.assertRaises(NonExistentDatasetException): - dataset_submissions_handler(s3_event, None) - - def test__non_owner__raises_error(self): - version = self.generate_unpublished_collection(owner="someone_else") - mock_dataset_id = EntityId() - - s3_event = create_s3_event(key=f"{self.user_name}/{version.version_id}/{mock_dataset_id}") - with self.assertRaises(CorporaException): - dataset_submissions_handler(s3_event, None) - - def test__upload_update_by_dataset_id_owner__OK(self): - """ - Processing starts when an update of a dataset is uploaded by its ID by the collection owner. - """ - version = self.generate_unpublished_collection() - dataset_version_id = self.business_logic.create_empty_dataset(version.version_id).version_id - - mock_ingest = self.business_logic.ingest_dataset = Mock() - - s3_event = create_s3_event(key=f"{self.user_name}/{version.version_id}/{dataset_version_id}") - dataset_submissions_handler(s3_event, None) - mock_ingest.assert_called() - - def test__upload_update_by_dataset_id_super__OK(self): - """ - Processing starts when an update of a dataset is uploaded by its ID by a super curator - """ - version = self.generate_unpublished_collection() - dataset_version_id = self.business_logic.create_empty_dataset(version.version_id).version_id - - mock_ingest = self.business_logic.ingest_dataset = Mock() - - s3_event = create_s3_event(key=f"super/{version.version_id}/{dataset_version_id}") - dataset_submissions_handler(s3_event, None) - mock_ingest.assert_called() - - def test__upload_update_by_dataset_canonical_id__OK(self): - """ - Processing starts when an update of a dataset is uploaded using canonical ids - - """ - version = self.generate_unpublished_collection() - dataset_id = self.business_logic.create_empty_dataset(version.version_id).dataset_id - - mock_ingest = self.business_logic.ingest_dataset = Mock() - - s3_event = create_s3_event(key=f"{self.user_name}/{version.collection_id}/{dataset_id}") - dataset_submissions_handler(s3_event, None) - mock_ingest.assert_called() - - -def create_s3_event(bucket_name: str = "some_bucket", key: str = "", size: int = 0) -> dict: - """ - Returns an S3 event dictionary with only the keys that the dataset submissions handler cares about - :param bucket_name: - :param key: - :param size: - :return: - """ - return {"Records": [{"s3": {"bucket": {"name": bucket_name}, "object": {"key": key, "size": size}}}]} diff --git a/tests/unit/processing/test_extract_metadata.py b/tests/unit/processing/test_extract_metadata.py index 163cbea3b91e0..8ac9b1f59763b 100644 --- a/tests/unit/processing/test_extract_metadata.py +++ b/tests/unit/processing/test_extract_metadata.py @@ -1,21 +1,23 @@ +import tempfile from unittest.mock import patch import anndata import numpy as np import pandas +from dask.array import from_array from backend.layers.common.entities import OntologyTermId, SpatialMetadata, TissueOntologyTermId -from backend.layers.processing.process_validate import ProcessValidate +from backend.layers.processing.process_add_labels import ProcessAddLabels +from backend.layers.thirdparty.schema_validator_provider import SchemaValidatorProvider from tests.unit.processing.base_processing_test import BaseProcessingTest -class TestProcessingValidate(BaseProcessingTest): +class TestAddLabels(BaseProcessingTest): def setUp(self): super().setUp() - self.pdv = ProcessValidate(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + self.pal = ProcessAddLabels(self.business_logic, self.uri_provider, self.s3_provider, SchemaValidatorProvider()) - @patch("scanpy.read_h5ad") - def test_extract_metadata(self, mock_read_h5ad): + def test_extract_metadata(self): df = pandas.DataFrame( np.random.randint(10, size=(50001, 5)) * 50, columns=list("ABCDE"), index=(str(i) for i in range(50001)) ) @@ -77,7 +79,7 @@ def test_extract_metadata(self, mock_read_h5ad): uns = { "title": "my test dataset", "X_approximate_distribution": "normal", - "batch_condition": np.array({"batchA", "batchB"}), + "batch_condition": ["batchA", "batchB"], "schema_version": "3.0.0", "default_embedding": "X_umap", "citation": "Publication: https://doi.org/12.2345/science.abc1234 Dataset Version: " @@ -100,10 +102,11 @@ def test_extract_metadata(self, mock_read_h5ad): obsm = {"X_umap": np.zeros([50001, 2]), "X_pca": np.zeros([50001, 2])} - adata = anndata.AnnData(X=df, obs=obs, obsm=obsm, uns=uns, var=var) - mock_read_h5ad.return_value = adata + adata = anndata.AnnData(X=from_array(df.to_numpy()), obs=obs, obsm=obsm, uns=uns, var=var) - extracted_metadata = self.pdv.extract_metadata("dummy") + with tempfile.NamedTemporaryFile(suffix=".h5ad") as f: + adata.write_h5ad(f.name) + extracted_metadata = self.pal.extract_metadata(f.name) self.assertEqual(extracted_metadata.organism, [OntologyTermId("Homo sapiens", "NCBITaxon:8505")]) @@ -156,7 +159,7 @@ def test_extract_metadata(self, mock_read_h5ad): ) self.assertEqual(extracted_metadata.x_approximate_distribution, "NORMAL") - self.assertEqual(extracted_metadata.batch_condition, np.array({"batchA", "batchB"})) + self.assertCountEqual(extracted_metadata.batch_condition, ["batchA", "batchB"]) self.assertEqual(extracted_metadata.schema_version, "3.0.0") self.assertEqual(extracted_metadata.citation, uns["citation"]) @@ -177,7 +180,7 @@ def test_extract_metadata(self, mock_read_h5ad): self.assertEqual(extracted_metadata.raw_data_location, "X") self.assertEqual(extracted_metadata.spatial, None) - @patch("scanpy.read_h5ad") + @patch("cellxgene_schema.utils.read_h5ad") def test_extract_metadata_find_raw_layer(self, mock_read_h5ad): # Setup anndata to be read non_zeros_X_layer_df = pandas.DataFrame( @@ -241,7 +244,7 @@ def test_extract_metadata_find_raw_layer(self, mock_read_h5ad): uns = { "title": "my test dataset", "X_approximate_distribution": "normal", - "batch_condition": np.array({"batchA", "batchB"}), + "batch_condition": ["batchA", "batchB"], "schema_version": "3.0.0", "citation": "Publication: https://doi.org/12.2345/science.abc1234 Dataset Version: " "https://datasets.cellxgene.cziscience.com/dataset_id.h5ad curated and distributed by " @@ -262,7 +265,7 @@ def test_extract_metadata_find_raw_layer(self, mock_read_h5ad): obsm = {"X_umap": np.zeros([11, 2])} adata = anndata.AnnData( - X=non_zeros_X_layer_df, + X=from_array(non_zeros_X_layer_df.to_numpy()), obs=obs, obsm=obsm, uns=uns, @@ -272,10 +275,9 @@ def test_extract_metadata_find_raw_layer(self, mock_read_h5ad): adata_raw = anndata.AnnData(X=zeros_layer_df, obs=obs, uns=uns) adata.raw = adata_raw - mock_read_h5ad.return_value = adata - - # Run the extraction method - extracted_metadata = self.pdv.extract_metadata("dummy") + with tempfile.NamedTemporaryFile(suffix=".h5ad") as f: + adata.write_h5ad(f.name) + extracted_metadata = self.pal.extract_metadata(f.name) # Verify that the "my_awesome_wonky_layer" was read and not the default X layer. The layer contains only zeros # which should result in a mean_genes_per_cell value of 0 compared to 3 if the X layer was read. @@ -288,7 +290,7 @@ def test_get_spatial_metadata__is_single_and_fullres_true(self): "is_single": np.bool_(True), "dummy_library_id": {"images": {"fullres": "dummy_fullres"}}, } - self.assertEqual(self.pdv.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=True)) + self.assertEqual(self.pal.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=True)) def test_get_spatial_metadata__is_single_true_and_fullres_false(self): spatial_dict = { @@ -296,17 +298,17 @@ def test_get_spatial_metadata__is_single_true_and_fullres_false(self): "dummy_library_id": {"images": {}}, } self.assertEqual( - self.pdv.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=False) + self.pal.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=False) ) def test_get_spatial_metadata__is_single_true_and_no_library_id(self): spatial_dict = {"is_single": np.bool_(True)} self.assertEqual( - self.pdv.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=False) + self.pal.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=True, has_fullres=False) ) def test_get_spatial_metadata__is_single_false(self): spatial_dict = {"is_single": np.bool_(False)} self.assertEqual( - self.pdv.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=False, has_fullres=False) + self.pal.get_spatial_metadata(spatial_dict), SpatialMetadata(is_single=False, has_fullres=False) ) diff --git a/tests/unit/processing/test_h5ad_data_file.py b/tests/unit/processing/test_h5ad_data_file.py index 82f7dfae89e07..309fededd7a49 100644 --- a/tests/unit/processing/test_h5ad_data_file.py +++ b/tests/unit/processing/test_h5ad_data_file.py @@ -8,6 +8,7 @@ import numpy as np import tiledb from pandas import Categorical, DataFrame, Series +from scipy import sparse from backend.common.utils.corpora_constants import CorporaConstants from backend.layers.processing.h5ad_data_file import H5ADDataFile @@ -17,7 +18,10 @@ class TestH5ADDataFile(unittest.TestCase): def setUp(self): self.sample_anndata = self._create_sample_anndata_dataset() + self.sample_anndata_sparse = self.sample_anndata.copy() + self.sample_anndata_sparse.X = sparse.csr_matrix((np.random.rand(3, 4).astype(np.float32)), dtype=np.float32) self.sample_h5ad_filename = self._write_anndata_to_file(self.sample_anndata) + self.sample_sparse_h5ad_filename = self._write_anndata_to_file(self.sample_anndata_sparse) self.sample_output_directory = path.splitext(self.sample_h5ad_filename)[0] + ".cxg" self.dataset_version_id = "test_dataset_version_id" @@ -26,6 +30,9 @@ def tearDown(self): if self.sample_h5ad_filename: remove(self.sample_h5ad_filename) + if self.sample_sparse_h5ad_filename: + remove(self.sample_sparse_h5ad_filename) + if path.isdir(self.sample_output_directory): rmtree(self.sample_output_directory) @@ -40,7 +47,7 @@ def test__create_h5ad_data_file__non_h5ad_raises_exception(self): def test__create_h5ad_data_file__reads_anndata_successfully(self): h5ad_file = H5ADDataFile(self.sample_h5ad_filename) - self.assertTrue((h5ad_file.anndata.X == self.sample_anndata.X).all()) + self.assertTrue((h5ad_file.anndata.X.compute() == self.sample_anndata.X).all()) self.assertEqual( h5ad_file.anndata.obs.sort_index(inplace=True), self.sample_anndata.obs.sort_index(inplace=True) ) @@ -108,16 +115,16 @@ def test__create_h5ad_data_file__obs_and_var_index_names_specified_doesnt_exist_ self.assertIn("does not exist", str(exception_context.exception)) def test__to_cxg__simple_anndata_no_corpora_and_sparse(self): - h5ad_file = H5ADDataFile(self.sample_h5ad_filename) + h5ad_file = H5ADDataFile(self.sample_sparse_h5ad_filename) h5ad_file.to_cxg(self.sample_output_directory, 100, self.dataset_version_id) - self._validate_cxg_and_h5ad_content_match(self.sample_h5ad_filename, self.sample_output_directory, True) + self._validate_cxg_and_h5ad_content_match(self.sample_sparse_h5ad_filename, self.sample_output_directory, True) def test__to_cxg__simple_anndata_with_corpora_and_sparse(self): - h5ad_file = H5ADDataFile(self.sample_h5ad_filename) + h5ad_file = H5ADDataFile(self.sample_sparse_h5ad_filename) h5ad_file.to_cxg(self.sample_output_directory, 100, self.dataset_version_id) - self._validate_cxg_and_h5ad_content_match(self.sample_h5ad_filename, self.sample_output_directory, True) + self._validate_cxg_and_h5ad_content_match(self.sample_sparse_h5ad_filename, self.sample_output_directory, True) def test__to_cxg__simple_anndata_no_corpora_and_dense(self): h5ad_file = H5ADDataFile(self.sample_h5ad_filename) @@ -193,8 +200,6 @@ def _validate_cxg_and_h5ad_content_match(self, h5ad_filename, cxg_directory, is_ # Array locations metadata_array_location = f"{cxg_directory}/cxg_group_metadata" main_x_array_location = f"{cxg_directory}/X" - main_xr_array_location = f"{cxg_directory}/Xr" - main_xc_array_location = f"{cxg_directory}/Xc" embedding_array_location = f"{cxg_directory}/emb" specific_embedding_array_location = f"{self.sample_output_directory}/emb/awesome_embedding" obs_array_location = f"{cxg_directory}/obs" @@ -204,11 +209,7 @@ def _validate_cxg_and_h5ad_content_match(self, h5ad_filename, cxg_directory, is_ self.assertEqual(tiledb.object_type(cxg_directory), "group") self.assertEqual(tiledb.object_type(obs_array_location), "array") self.assertEqual(tiledb.object_type(var_array_location), "array") - if is_sparse: - self.assertEqual(tiledb.object_type(main_xr_array_location), "array") - self.assertEqual(tiledb.object_type(main_xc_array_location), "array") - else: - self.assertEqual(tiledb.object_type(main_x_array_location), "array") + self.assertEqual(tiledb.object_type(main_x_array_location), "array") self.assertEqual(tiledb.object_type(embedding_array_location), "group") self.assertEqual(tiledb.object_type(specific_embedding_array_location), "array") @@ -266,29 +267,17 @@ def _validate_cxg_and_h5ad_content_match(self, h5ad_filename, cxg_directory, is_ self.assertTrue(np.array_equal(expected_embedding_data, actual_embedding_data)) # Validate X matrix if not column shifted - if not has_column_encoding and not is_sparse: + if not has_column_encoding: expected_x_data = anndata_object.X with tiledb.open(main_x_array_location, mode="r") as x_array: if is_sparse: + expected_x_data = expected_x_data.toarray() actual_x_data = np.zeros_like(expected_x_data) - data = x_array[:] + data = x_array[:, :] actual_x_data[data["obs"], data["var"]] = data[""] else: actual_x_data = x_array[:, :] self.assertTrue(np.array_equal(expected_x_data, actual_x_data)) - elif not has_column_encoding: - expected_x_data = anndata_object.X - with tiledb.open(main_xr_array_location, mode="r") as x_array: - actual_x_data = np.zeros_like(expected_x_data) - data = x_array[:] - actual_x_data[data["obs"], data["var"]] = data[""] - self.assertTrue(np.array_equal(expected_x_data, actual_x_data)) - - with tiledb.open(main_xc_array_location, mode="r") as x_array: - actual_x_data = np.zeros_like(expected_x_data) - data = x_array[:] - actual_x_data[data["obs"], data["var"]] = data[""] - self.assertTrue(np.array_equal(expected_x_data, actual_x_data)) def _validate_cxg_var_index_column_match(self, cxg_directory, expected_index_name): var_array_location = f"{cxg_directory}/var" diff --git a/tests/unit/processing/test_handle_error.py b/tests/unit/processing/test_handle_error.py index b49e1d2aee595..c8af896d17be1 100644 --- a/tests/unit/processing/test_handle_error.py +++ b/tests/unit/processing/test_handle_error.py @@ -12,14 +12,20 @@ CollectionMetadata, CollectionVersionId, CollectionVersionWithDatasets, + DatasetArtifact, + DatasetArtifactId, + DatasetArtifactType, + DatasetConversionStatus, DatasetStatus, DatasetVersionId, ) from backend.layers.processing.upload_failures.app import ( FAILED_ARTIFACT_CLEANUP_MESSAGE, + FAILED_ATAC_DATASET_MESSAGE, FAILED_CXG_CLEANUP_MESSAGE, FAILED_DATASET_CLEANUP_MESSAGE, cleanup_artifacts, + delete_atac_fragment_files, get_failure_slack_notification_message, handle_failure, parse_event, @@ -47,6 +53,7 @@ def sample_slack_status_block(): "text": { "type": "mrkdwn", "text": "```{\n" + ' "atac_status": null,\n' ' "cxg_status": null,\n' ' "h5ad_status": null,\n' ' "processing_status": null,\n' @@ -104,6 +111,13 @@ def get_collection_version_mock(): ) +@pytest.fixture +def mock_business_logic(monkeypatch): + business_logic_mock = Mock() + monkeypatch.setattr(f"{module_path}.get_business_logic", Mock(return_value=business_logic_mock)) + return business_logic_mock + + def test_parse_event_with_empty_event(): ( dataset_version_id, @@ -206,11 +220,11 @@ def test_parse_event_with_invalid_error_cause(): def mock_get_dataset_version(collection_id): MockDatasetVersionId = Mock() MockDatasetVersionId.collection_id = collection_id - MockDatasetVersionId.status = DatasetStatus(None, None, None, None, None, None) + MockDatasetVersionId.status = DatasetStatus(*[None] * 7) return MockDatasetVersionId -def test_migration_event_does_not_trigger_slack(): +def test_migration_event_does_not_trigger_slack(mock_business_logic): mock_trigger_slack = Mock() mock_context = Mock() with patch("backend.layers.processing.upload_failures.app.trigger_slack_notification", mock_trigger_slack): @@ -220,11 +234,11 @@ def test_migration_event_does_not_trigger_slack(): "error": {}, "execution_id": "arn:aws:states:us-west-2:migrate_123456789012:execution:MyStateMachine", } - handle_failure(event, mock_context) + handle_failure(event, mock_context, delete_artifacts=False) mock_trigger_slack.assert_not_called() -def test_non_migration_event_triggers_slack(): +def test_non_migration_event_triggers_slack(mock_business_logic): mock_trigger_slack = Mock() mock_context = Mock() with patch("backend.layers.processing.upload_failures.app.trigger_slack_notification", mock_trigger_slack): @@ -234,7 +248,7 @@ def test_non_migration_event_triggers_slack(): "error": {}, "execution_id": "arn:aws:states:us-west-2:123456789012:execution:MyStateMachine", } - handle_failure(event, mock_context) + handle_failure(event, mock_context, delete_artifacts=False) mock_trigger_slack.assert_called_once() @@ -490,20 +504,58 @@ def dataset_version_id() -> str: return "example_dataset" +@pytest.fixture +def mock_delete_atac_fragment_files(monkeypatch) -> Mock: + mock = Mock() + monkeypatch.setattr(f"{module_path}.delete_atac_fragment_files", mock) + return mock + + +@pytest.mark.usefixtures("mock_delete_atac_fragment_files") class TestCleanupArtifacts: - @pytest.mark.parametrize("error_step", ["validate", "", None]) - def test_cleanup_artifacts__OK(self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, error_step): + def test_cleanup_artifacts__validate_anndata_OK( + self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, mock_delete_atac_fragment_files + ): + """Check that all artifacts are deleted for the given cases.""" + error_step = "validate_anndata" + cleanup_artifacts(dataset_version_id, error_step) + + # Assertions + mock_delete_atac_fragment_files.assert_not_called() + mock_delete_many_from_s3.assert_any_call(mock_env_vars["ARTIFACT_BUCKET"], dataset_version_id + "/") + mock_delete_many_from_s3.assert_any_call(mock_env_vars["DATASETS_BUCKET"], dataset_version_id + ".") + mock_delete_many_from_s3.assert_any_call(mock_env_vars["CELLXGENE_BUCKET"], dataset_version_id + ".cxg/") + assert mock_delete_many_from_s3.call_count == 3 + + def test_cleanup_artifacts__validate_atac_OK( + self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, mock_delete_atac_fragment_files + ): + """Check that all artifacts are deleted for the given cases.""" + error_step = "validate_atac" + cleanup_artifacts(dataset_version_id, error_step) + + # Assertions + mock_delete_atac_fragment_files.assert_called_once_with(dataset_version_id) + mock_delete_many_from_s3.assert_any_call(mock_env_vars["DATASETS_BUCKET"], dataset_version_id + ".") + mock_delete_many_from_s3.assert_any_call(mock_env_vars["CELLXGENE_BUCKET"], dataset_version_id + ".cxg/") + assert mock_delete_many_from_s3.call_count == 2 + + def test_cleanup_artifacts__None_OK( + self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, mock_delete_atac_fragment_files + ): """Check that all artifacts are deleted for the given cases.""" + error_step = None cleanup_artifacts(dataset_version_id, error_step) # Assertions + mock_delete_atac_fragment_files.assert_called_once_with(dataset_version_id) mock_delete_many_from_s3.assert_any_call(mock_env_vars["ARTIFACT_BUCKET"], dataset_version_id + "/") mock_delete_many_from_s3.assert_any_call(mock_env_vars["DATASETS_BUCKET"], dataset_version_id + ".") mock_delete_many_from_s3.assert_any_call(mock_env_vars["CELLXGENE_BUCKET"], dataset_version_id + ".cxg/") assert mock_delete_many_from_s3.call_count == 3 def test_cleanup_artifacts__not_download_validate( - self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id + self, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, *args ): """Check that file in the artifact bucket are not delete if error_step is not download-validate.""" cleanup_artifacts(dataset_version_id, "not_download_validate") @@ -514,7 +566,7 @@ def test_cleanup_artifacts__not_download_validate( assert mock_delete_many_from_s3.call_count == 2 @patch.dict(os.environ, clear=True) - def test_cleanup_artifacts__no_buckets(self, caplog, mock_delete_many_from_s3, dataset_version_id): + def test_cleanup_artifacts__no_buckets(self, caplog, mock_delete_many_from_s3, dataset_version_id, *args): """Check that no files are deleted if buckets are not specified.""" cleanup_artifacts(dataset_version_id) @@ -524,18 +576,77 @@ def test_cleanup_artifacts__no_buckets(self, caplog, mock_delete_many_from_s3, d assert FAILED_CXG_CLEANUP_MESSAGE in caplog.text assert FAILED_DATASET_CLEANUP_MESSAGE in caplog.text - def test_cleanup_artifacts__elete_many_from_s3_error( - self, caplog, mock_env_vars, mock_delete_many_from_s3, dataset_version_id + def test_cleanup_artifacts__delete_many_from_s3_error( + self, caplog, mock_env_vars, mock_delete_many_from_s3, dataset_version_id, *args ): """Check that delete_many_from_s3 errors are logged but do not raise exceptions.""" mock_delete_many_from_s3.side_effect = Exception("Boom!") cleanup_artifacts(dataset_version_id) # Assertions - mock_delete_many_from_s3.assert_any_call(mock_env_vars["ARTIFACT_BUCKET"], dataset_version_id + "/") - mock_delete_many_from_s3.assert_any_call(mock_env_vars["DATASETS_BUCKET"], dataset_version_id + ".") - mock_delete_many_from_s3.assert_any_call(mock_env_vars["CELLXGENE_BUCKET"], dataset_version_id + ".cxg/") - assert mock_delete_many_from_s3.call_count == 3 assert FAILED_ARTIFACT_CLEANUP_MESSAGE in caplog.text assert FAILED_CXG_CLEANUP_MESSAGE in caplog.text assert FAILED_DATASET_CLEANUP_MESSAGE in caplog.text + + +@pytest.fixture +def mock_dataset_version(): + dv = Mock() + dv.artifacts = [ + DatasetArtifact(id=DatasetArtifactId(), uri="s3://bucket/uri", type=DatasetArtifactType.ATAC_INDEX), + DatasetArtifact(id=DatasetArtifactId(), uri="s3://bucket/uri", type=DatasetArtifactType.ATAC_FRAGMENT), + ] + dv.status.atac_status = DatasetConversionStatus.UPLOADED + return dv + + +class TestDeleteAtacFragmentFiles: + @pytest.mark.parametrize( + "atac_status", [DatasetConversionStatus.COPIED, DatasetConversionStatus.SKIPPED, DatasetConversionStatus.NA] + ) + def test_delete_skipped(self, mock_business_logic, atac_status, mock_dataset_version): + # Arrange + dataset_version_id = "example_dataset" + mock_dataset_version.status.atac_status = atac_status + mock_business_logic.get_dataset_version.return_value = mock_dataset_version + + # Act + delete_atac_fragment_files(dataset_version_id) + + # Assert + mock_business_logic.get_atac_fragment_uris_from_dataset_version.assert_not_called() + + def test_delete_atac_fragment_files__OK( + self, mock_delete_many_from_s3, mock_env_vars, mock_business_logic, mock_dataset_version, dataset_version_id + ): + """Check that atac fragment files are deleted.""" + # Arrange + mock_business_logic.get_dataset_version.return_value = mock_dataset_version + test_uris = ["uri1", "uri2"] + mock_business_logic.get_atac_fragment_uris_from_dataset_version.return_value = ["uri1", "uri2"] + + # Act + delete_atac_fragment_files(dataset_version_id) + + # Assertions + for uri in test_uris: + mock_delete_many_from_s3.assert_any_call( + mock_env_vars["DATASETS_BUCKET"], os.path.join(os.environ.get("REMOTE_DEV_PREFIX", ""), uri) + ) + + def test_catch_errors( + self, + caplog, + mock_delete_many_from_s3, + mock_env_vars, + mock_business_logic, + mock_dataset_version, + dataset_version_id, + ): + # Arrange + mock_delete_many_from_s3.side_effect = Exception("Boom!") + mock_business_logic.get_atac_fragment_uris_from_dataset_version.return_value = ["uri1"] + # Act + delete_atac_fragment_files(dataset_version_id) + # Assert + assert FAILED_ATAC_DATASET_MESSAGE[:-3] in caplog.text diff --git a/tests/unit/processing/test_matrix_utils.py b/tests/unit/processing/test_matrix_utils.py index 58dcc74a2b202..3335b7da79646 100644 --- a/tests/unit/processing/test_matrix_utils.py +++ b/tests/unit/processing/test_matrix_utils.py @@ -4,6 +4,7 @@ import numpy as np import pytest from anndata import AnnData +from dask.array import from_array from scipy.sparse import coo_matrix from backend.layers.processing.utils.matrix_utils import enforce_canonical_format, is_matrix_sparse @@ -14,13 +15,13 @@ class TestMatrixUtils: def test__is_matrix_sparse__zero_and_one_hundred_percent_threshold(self): - matrix = np.array([1, 2, 3]) + matrix = from_array(np.array([1, 2, 3])) assert not is_matrix_sparse(matrix, 0) assert is_matrix_sparse(matrix, 100) def test__is_matrix_sparse__partially_populated_sparse_matrix_returns_true(self): - matrix = np.zeros([3, 4]) + matrix = from_array(np.zeros([3, 4])) matrix[2][3] = 1.0 matrix[1][1] = 2.2 @@ -31,29 +32,10 @@ def test__is_matrix_sparse__partially_populated_dense_matrix_returns_false(self) matrix[0][0] = 1.0 matrix[0][1] = 2.2 matrix[1][1] = 3.7 + matrix = from_array(matrix) assert not is_matrix_sparse(matrix, 50) - def test__is_matrix_sparse__giant_matrix_returns_false_early(self, caplog): - caplog.set_level(logging.INFO) - matrix = np.ones([20000, 20]) - - assert not is_matrix_sparse(matrix, 1) - - # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of - # non-zero elements in the matrix. - assert "Percentage of non-zero elements (estimate)" in caplog.text - - def test__is_matrix_sparse_with_column_shift_encoding__giant_matrix_returns_false_early(self, caplog): - caplog.set_level(logging.INFO) - matrix = np.random.rand(20000, 20) - - assert not is_matrix_sparse(matrix, 1) - - # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of - # non-zero elements in the matrix. - assert "Percentage of non-zero elements (estimate)" in caplog.text - @pytest.fixture def noncanonical_matrix(): diff --git a/tests/unit/processing/test_process_add_labels.py b/tests/unit/processing/test_process_add_labels.py new file mode 100644 index 0000000000000..7e711d53dded1 --- /dev/null +++ b/tests/unit/processing/test_process_add_labels.py @@ -0,0 +1,124 @@ +import tempfile +from unittest.mock import MagicMock, Mock, patch + +import anndata + +from backend.layers.common.entities import ( + DatasetArtifactType, + DatasetConversionStatus, + DatasetValidationStatus, + DatasetVersionId, + Link, +) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.process import ProcessMain +from backend.layers.processing.process_add_labels import ProcessAddLabels +from tests.unit.processing.base_processing_test import BaseProcessingTest + + +class ProcessingTest(BaseProcessingTest): + @patch("scanpy.read_h5ad") + def test_process_add_labels(self, mock_read_h5ad): + """ + ProcessDownloadValidate should: + 1. Download the h5ad artifact + 2. Add labels to h5ad + 2. Set upload status to UPLOADED + 3. set h5ad status to UPLOADED + 4. upload the labeled file to S3 + """ + dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=0" + self.crossref_provider.fetch_metadata = Mock(return_value=({}, "12.2345", 17169328.664)) + + collection = self.generate_unpublished_collection( + links=[Link(name=None, type="DOI", uri="http://doi.org/12.2345")] + ) + dataset_version_id, dataset_id = self.business_logic.ingest_dataset( + collection.version_id, dropbox_uri, None, None + ) + # This is where we're at when we start the SFN + + mock_read_h5ad.return_value = MagicMock(uns=dict()) + + # TODO: ideally use a real h5ad + processor = ProcessAddLabels(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + processor.extract_metadata = Mock() + processor.populate_dataset_citation = Mock() + processor.process(collection.version_id, dataset_version_id, "fake_bucket_name", "fake_datasets_bucket") + + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertEqual(status.validation_status, DatasetValidationStatus.VALID) + self.assertEqual(status.h5ad_status, DatasetConversionStatus.UPLOADED) + + # Verify that the labeled (local.h5ad) file is there + self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad")) + # Verify that the labeled file is uploaded to the datasets bucket + self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.h5ad")) + + artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) + self.assertEqual(1, len(artifacts)) + artifact = artifacts[0] + artifact.type = DatasetArtifactType.H5AD + artifact.uri = f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad" + + def test_populate_dataset_citation__with_publication_doi(self): + mock_adata = anndata.AnnData(X=None, obs=None, obsm=None, uns={}, var=None) + self.crossref_provider.fetch_metadata = Mock(return_value=({}, "12.2345", 17169328.664)) + collection = self.generate_unpublished_collection( + links=[Link(name=None, type="DOI", uri="https://doi.org/12.2345")] + ) + with tempfile.NamedTemporaryFile(suffix=".h5ad") as f: + mock_adata.write_h5ad(f.name) + pal = ProcessAddLabels(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + dataset_version_id = DatasetVersionId() + pal.populate_dataset_citation(collection.version_id, dataset_version_id, f.name) + citation_str = ( + f"Publication: https://doi.org/12.2345 " + f"Dataset Version: http://domain/{dataset_version_id}.h5ad curated and distributed by " + f"CZ CELLxGENE Discover in Collection: https://domain/collections/{collection.collection_id}" + ) + adata = anndata.read_h5ad(f.name) + self.assertEqual(adata.uns["citation"], citation_str) + + def test_populate_dataset_citation__no_publication_doi(self): + mock_adata = anndata.AnnData(X=None, obs=None, obsm=None, uns={}, var=None) + collection = self.generate_unpublished_collection() + with tempfile.NamedTemporaryFile(suffix=".h5ad") as f: + mock_adata.write_h5ad(f.name) + pal = ProcessAddLabels(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + dataset_version_id = DatasetVersionId() + pal.populate_dataset_citation(collection.version_id, dataset_version_id, f.name) + citation_str = ( + f"Dataset Version: http://domain/{dataset_version_id}.h5ad curated and distributed by " + f"CZ CELLxGENE Discover in Collection: https://domain/collections/{collection.collection_id}" + ) + adata = anndata.read_h5ad(f.name) + self.assertEqual(adata.uns["citation"], citation_str) + + def test_process_add_labels_fail(self): + """ + If the validation is not successful, the processing pipeline should: + 1. Set the processing status to INVALID + 2. Set a validation message accordingly + """ + dropbox_uri = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" + manifest = IngestionManifest(anndata=dropbox_uri) + + collection = self.generate_unpublished_collection() + dataset_version_id, dataset_id = self.business_logic.ingest_dataset( + collection.version_id, dropbox_uri, None, None + ) + self.schema_validator.add_labels = Mock(side_effect=ValueError("Add labels error")) + pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pm.process( + collection.version_id, + dataset_version_id, + "add_labels", + manifest, + "fake_bucket_name", + "fake_datasets_bucket", + "fake_cxg_bucket", + ) + + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertEqual(status.h5ad_status, DatasetConversionStatus.FAILED) diff --git a/tests/unit/processing/test_process_download.py b/tests/unit/processing/test_process_download.py deleted file mode 100644 index 6622cf71b01d4..0000000000000 --- a/tests/unit/processing/test_process_download.py +++ /dev/null @@ -1,223 +0,0 @@ -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import scanpy - -from backend.common.utils.math_utils import GB -from backend.layers.common.entities import DatasetArtifactType, DatasetUploadStatus -from backend.layers.processing.process_download import ProcessDownload -from tests.unit.processing.base_processing_test import BaseProcessingTest - -test_environment = {"REMOTE_DEV_PREFIX": "fake-stack", "DEPLOYMENT_STAGE": "test"} - - -class TestProcessDownload(BaseProcessingTest): - @patch("backend.common.utils.dl_sources.uri.downloader") - @patch("backend.common.utils.dl_sources.uri.DropBoxURL.file_info", return_value={"size": 100, "name": "fake_name"}) - @patch("os.environ", test_environment) - @patch("backend.layers.processing.process_download.StepFunctionProvider") - @patch("scanpy.read_h5ad") - def test_process_download_success(self, mock_read_h5ad, mock_sfn_provider, *args): - """ - ProcessValidate should: - 1. Download the h5ad artifact - 2. Set upload status to UPLOADED - 3. upload the original file to S3 - """ - dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=0" - bucket_name = "fake_bucket_name" - stack_name = test_environment["REMOTE_DEV_PREFIX"] - deployment_stage = test_environment["DEPLOYMENT_STAGE"] - collection = self.generate_unpublished_collection() - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, dropbox_uri, None, None - ) - # Mock anndata object - mock_anndata = Mock(spec=scanpy.AnnData) - mock_anndata.n_obs = 10000 - mock_anndata.n_vars = 10000 - mock_read_h5ad.return_value = mock_anndata - - # Mock SFN client - mock_sfn = Mock() - mock_sfn_provider.return_value = mock_sfn - - # This is where we're at when we start the SFN - pdv = ProcessDownload(self.business_logic, self.uri_provider, self.s3_provider) - pdv.process(dataset_version_id, dropbox_uri, bucket_name, "fake_sfn_task_token") - - status = self.business_logic.get_dataset_status(dataset_version_id) - self.assertEqual(status.upload_status, DatasetUploadStatus.UPLOADED) - - # Assert mocks - mock_read_h5ad.assert_called_with("raw.h5ad", backed="r") - mock_sfn.client.send_task_success.assert_called_with( - taskToken="fake_sfn_task_token", - output=json.dumps( - { - "JobDefinitionName": f"dp-{deployment_stage}-{stack_name}-ingest-process-{dataset_version_id.id}", - "Vcpus": 2, - "Memory": 16000, - "LinuxParameters": {"Swappiness": 60, "MaxSwap": 0}, - } - ), - ) - - # Verify that both the original (raw.h5ad) and the labeled (local.h5ad) files are there - self.assertTrue( - self.s3_provider.uri_exists(f"s3://{bucket_name}/{stack_name}/{dataset_version_id.id}/raw.h5ad") - ) - - artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(1, len(artifacts)) - artifact = artifacts[0] - artifact.type = DatasetArtifactType.RAW_H5AD - artifact.uri = f"s3://fake_bucket_name/{stack_name}/{dataset_version_id.id}/raw.h5ad" - - @patch("backend.common.utils.dl_sources.uri.S3Provider") - @patch("backend.common.utils.dl_sources.uri.S3URI.file_info", return_value={"size": 100, "name": "fake_name"}) - def test_download_from_s3_uri(self, *arg): - """ - Call process download using an s3 uri - """ - - s3_uri = "s3://fake_bucket_name/fake_key/fake_file.h5ad" - pdv = ProcessDownload(Mock(), self.uri_provider, Mock()) - pdv.download_from_s3 = Mock() - - assert pdv.download_from_source_uri(s3_uri, "fake_local_path") == "fake_local_path" - - @patch("backend.common.utils.dl_sources.uri.downloader") - @patch("backend.common.utils.dl_sources.uri.DropBoxURL.file_info", return_value={"size": 100, "name": "fake_name"}) - def test_download_from_dropbox_uri(self, *arg): - """ - Call process download using a dropbox uri - """ - - dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=1" - pdv = ProcessDownload(Mock(), self.uri_provider, Mock()) - pdv.download = Mock() - - assert pdv.download_from_source_uri(dropbox_uri, "fake_local_path") == "fake_local_path" - - def test_download_unknown_uri(self): - """ - Call process download using unknown - """ - - uri = "fake://fake_bucket_name/fake_key/fake_file.h5ad" - pdv = ProcessDownload(Mock(), self.uri_provider, Mock()) - pdv.download_from_s3 = Mock() - with pytest.raises(ValueError, match=f"Malformed source URI: {uri}"): - pdv.download_from_source_uri(uri, "fake_local_path") - - -@pytest.fixture -def mock_ProcessDownload(): - return ProcessDownload(Mock(), Mock(), Mock()) - - -def sample_adata(n_obs: int, n_vars: int): - # Create a sample AnnData object for testing - adata = MagicMock(spec=scanpy.AnnData, n_obs=n_obs, n_vars=n_vars) - return adata - - -@pytest.fixture -def mock_read_h5ad(): - with patch("scanpy.read_h5ad") as mock_read_h5ad: - mock_read_h5ad.return_value = sample_adata(1, 2 * GB) - yield mock_read_h5ad - - -def memory_settings( - memory_modifier=1, - memory_per_vcpu=4000, - min_vcpu=1, - max_vcpu=16, - max_swap_memory_mb=300000, - swap_modifier=5, -) -> dict: - return dict( - memory_modifier=memory_modifier, - memory_per_vcpu=memory_per_vcpu, - min_vcpu=min_vcpu, - max_vcpu=max_vcpu, - max_swap_memory_MB=max_swap_memory_mb, - swap_modifier=swap_modifier, - ) - - -# Arrange -@pytest.mark.parametrize( - "adata, memory_settings, expected", - [ - (sample_adata(1, 2 * GB), memory_settings(), {"Vcpus": 1, "Memory": 4000, "MaxSwap": 20000}), # minimum memory - ( - sample_adata(1, 5 * GB), - memory_settings(), - {"Vcpus": 2, "Memory": 8000, "MaxSwap": 40000}, - ), # above minimum memory - ( - sample_adata(1, 5 * GB), - memory_settings(1.5), - {"Vcpus": 2, "Memory": 8000, "MaxSwap": 40000}, - ), # modifier adjusted - ( - sample_adata(1, 64 * GB), - memory_settings(), - {"Vcpus": 16, "Memory": 64000, "MaxSwap": 300000}, - ), # maximum memory - ], -) -def test_estimate_resource_requirements_positive(mock_ProcessDownload, adata, memory_settings, expected): - # Act & Assert - assert expected == mock_ProcessDownload.estimate_resource_requirements(adata, **memory_settings) - - -@pytest.mark.parametrize( - "environ,expected", - [ - ({"REMOTE_DEV_PREFIX": "/stack/", "DEPLOYMENT_STAGE": "test"}, "dp-test-stack-ingest-process-fake_dataset_id"), - ({"REMOTE_DEV_PREFIX": "stack", "DEPLOYMENT_STAGE": "test"}, "dp-test-stack-ingest-process-fake_dataset_id"), - ({"DEPLOYMENT_STAGE": "test"}, "dp-test-ingest-process-fake_dataset_id"), - ], -) -def test_get_job_definion_name(mock_ProcessDownload, environ, expected): - # Arrange - with patch("os.environ", environ): - dataset_id = "fake_dataset_id" - - # Act - result = mock_ProcessDownload.get_job_definion_name(dataset_id) - - # Assert - assert result == expected - - -def test_remove_prefix(mock_ProcessDownload): - # Act & Assert - assert mock_ProcessDownload.remove_prefix("prefixfake", "prefix") == "fake" - - -def test_create_batch_job_definition_parameters(mock_ProcessDownload, mock_read_h5ad): - # Arrange - mock_ProcessDownload.get_job_definion_name = Mock(return_value="fake_job_definition_name") - mock_ProcessDownload.estimate_resource_requirements = Mock(return_value={"Vcpus": 1, "Memory": 4000, "MaxSwap": 0}) - - # Act - resp = mock_ProcessDownload.create_batch_job_definition_parameters("local_file.h5ad", "fake_dataset_id") - - # Assert - mock_ProcessDownload.estimate_resource_requirements.assert_called_once_with(mock_read_h5ad.return_value) - mock_ProcessDownload.get_job_definion_name.assert_called_once_with("fake_dataset_id") - assert resp == { - "JobDefinitionName": "fake_job_definition_name", - "Vcpus": 1, - "Memory": 4000, - "LinuxParameters": { - "Swappiness": 60, - "MaxSwap": 0, - }, - } diff --git a/tests/unit/processing/test_process_validate.py b/tests/unit/processing/test_process_validate.py deleted file mode 100644 index 79bc6fdcaafe6..0000000000000 --- a/tests/unit/processing/test_process_validate.py +++ /dev/null @@ -1,124 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch - -from backend.layers.common.entities import ( - DatasetArtifactType, - DatasetConversionStatus, - DatasetProcessingStatus, - DatasetUploadStatus, - DatasetValidationStatus, - DatasetVersionId, - Link, -) -from backend.layers.processing.process import ProcessMain -from backend.layers.processing.process_validate import ProcessValidate -from tests.unit.processing.base_processing_test import BaseProcessingTest - - -class ProcessingTest(BaseProcessingTest): - @patch("scanpy.read_h5ad") - def test_process_download_validate_success(self, mock_read_h5ad): - """ - ProcessDownloadValidate should: - 1. Download the h5ad artifact - 2. set validation status to VALID - 3. Set upload status to UPLOADED - 4. set h5ad status to UPLOADED - 5. upload the original file to S3 - 6. upload the labeled file to S3 - """ - dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=0" - self.crossref_provider.fetch_metadata = Mock(return_value=({}, "12.2345", 17169328.664)) - - collection = self.generate_unpublished_collection( - links=[Link(name=None, type="DOI", uri="http://doi.org/12.2345")] - ) - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, dropbox_uri, None, None - ) - # This is where we're at when we start the SFN - - status = self.business_logic.get_dataset_status(dataset_version_id) - # self.assertEqual(status.validation_status, DatasetValidationStatus.NA) - self.assertIsNone(status.validation_status) - self.assertEqual(status.processing_status, DatasetProcessingStatus.INITIALIZED) - self.assertEqual(status.upload_status, DatasetUploadStatus.WAITING) - - mock_read_h5ad.return_value = MagicMock(uns=dict()) - - # TODO: ideally use a real h5ad so that - with patch("backend.layers.processing.process_validate.ProcessValidate.extract_metadata"): - pdv = ProcessValidate(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) - pdv.process(collection.version_id, dataset_version_id, "fake_bucket_name", "fake_datasets_bucket") - citation_str = ( - f"Publication: https://doi.org/12.2345 " - f"Dataset Version: http://domain/{dataset_version_id}.h5ad curated and distributed by " - f"CZ CELLxGENE Discover in Collection: https://domain/collections/{collection.collection_id.id}" - ) - self.assertEqual(mock_read_h5ad.return_value.uns["citation"], citation_str) - status = self.business_logic.get_dataset_status(dataset_version_id) - self.assertEqual(status.validation_status, DatasetValidationStatus.VALID) - self.assertEqual(status.h5ad_status, DatasetConversionStatus.UPLOADED) - - # Verify that both the original (raw.h5ad) and the labeled (local.h5ad) files are there - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad")) - # Verify that the labeled file is uploaded to the datasets bucket - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.h5ad")) - - artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(1, len(artifacts)) - artifact = artifacts[0] - artifact.type = DatasetArtifactType.H5AD - artifact.uri = f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad" - - @patch("scanpy.read_h5ad") - def test_populate_dataset_citation__no_publication_doi(self, mock_read_h5ad): - mock_read_h5ad.return_value = MagicMock(uns=dict()) - collection = self.generate_unpublished_collection() - - pdv = ProcessValidate(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) - dataset_version_id = DatasetVersionId() - pdv.populate_dataset_citation(collection.version_id, dataset_version_id, "") - citation_str = ( - f"Dataset Version: http://domain/{dataset_version_id}.h5ad curated and distributed by " - f"CZ CELLxGENE Discover in Collection: https://domain/collections/{collection.collection_id.id}" - ) - self.assertEqual(mock_read_h5ad.return_value.uns["citation"], citation_str) - - def test_process_validate_fail(self): - """ - If the validation is not successful, the processing pipeline should: - 1. Set the processing status to INVALID - 2. Set a validation message accordingly - """ - dropbox_uri = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" - collection = self.generate_unpublished_collection() - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, dropbox_uri, None, None - ) - - # Set a mock failure for the schema validator - self.schema_validator.validate_and_save_labels = Mock( - return_value=(False, ["Validation error 1", "Validation error 2"], True) - ) - - collection = self.generate_unpublished_collection() - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, dropbox_uri, None, None - ) - - pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) - - for step_name in ["validate"]: - pm.process( - collection.version_id, - dataset_version_id, - step_name, - dropbox_uri, - "fake_bucket_name", - "fake_datasets_bucket", - "fake_cxg_bucket", - ) - - status = self.business_logic.get_dataset_status(dataset_version_id) - self.assertEqual(status.validation_status, DatasetValidationStatus.INVALID) - self.assertEqual(status.validation_message, "Validation error 1\nValidation error 2") diff --git a/tests/unit/processing/test_process_validate_atac.py b/tests/unit/processing/test_process_validate_atac.py new file mode 100644 index 0000000000000..241f99ea4e9ca --- /dev/null +++ b/tests/unit/processing/test_process_validate_atac.py @@ -0,0 +1,466 @@ +from typing import Tuple +from unittest.mock import Mock + +import pytest + +from backend.common.utils.corpora_constants import CorporaConstants +from backend.layers.common.entities import ( + CollectionVersion, + CollectionVersionWithDatasets, + DatasetArtifact, + DatasetArtifactId, + DatasetArtifactType, + DatasetConversionStatus, + DatasetId, + DatasetStatusKey, + DatasetValidationStatus, + DatasetVersionId, +) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.exceptions import ConversionFailed, ValidationAtacFailed +from backend.layers.processing.process_validate_atac import ProcessValidateATAC +from tests.unit.processing.base_processing_test import BaseProcessingTest + +fragment_uri_fmt = "http://domain/{artifact_id}-fragment.tsv.bgz" + + +@pytest.fixture +def setup(): + base_test = BaseProcessingTest() + base_test.setUpClass() + base_test.setUp() + base_test.schema_validator.check_anndata_requires_fragment = Mock(return_value=False) + base_test.schema_validator.validate_atac = Mock(return_value=([], "fragment.tsv.bgz", "fragment.tsv.bgz.tbi")) + return base_test + + +@pytest.fixture +def migration_set(monkeypatch): + monkeypatch.setenv("MIGRATION", "true") + + +@pytest.fixture +def unpublished_collection(setup) -> CollectionVersion: + return setup.generate_unpublished_collection() + + +@pytest.fixture +def unpublished_dataset(unpublished_collection, setup) -> Tuple[DatasetVersionId, DatasetId]: + new_dataset_version = setup.database_provider.create_canonical_dataset(unpublished_collection.version_id) + setup.database_provider.add_dataset_to_collection_version_mapping( + unpublished_collection.version_id, new_dataset_version.version_id + ) + return new_dataset_version.version_id, new_dataset_version.dataset_id + + +@pytest.fixture +def process_validate_atac(setup): + proc = ProcessValidateATAC(setup.business_logic, setup.uri_provider, setup.s3_provider, setup.schema_validator) + proc.download_from_source_uri = Mock( + side_effect=[CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME, CorporaConstants.ORIGINAL_ATAC_FRAGMENT_FILENAME] + ) + return proc + + +@pytest.fixture +def collection_revision_with_fragment( + setup, unpublished_collection, process_validate_atac +) -> CollectionVersionWithDatasets: + collection = setup.generate_published_collection(add_datasets=2) + fragment_dataset = collection.datasets[0] + artifact_id = process_validate_atac.create_atac_artifact( + "anything", + DatasetArtifactType.ATAC_FRAGMENT, + fragment_dataset.version_id, + "datasets", + ) + process_validate_atac.create_atac_artifact( + "anything", + DatasetArtifactType.ATAC_INDEX, + fragment_dataset.version_id, + "datasets", + artifact_id, + ) + revision = setup.business_logic.create_collection_version(collection.collection_id) + return setup.business_logic.get_collection_version(revision.version_id) + + +@pytest.fixture +def new_fragment_uri() -> str: + return "https://www.dropbox.com/s/fake_location/test.tsv.bgz?dl=0" + + +@pytest.fixture +def anndata_uri() -> str: + return "s3://fake_bucket_name/fake_key.h5ad" + + +@pytest.fixture +def manifest_with_fragment(anndata_uri, new_fragment_uri) -> IngestionManifest: + return IngestionManifest(anndata=anndata_uri, atac_fragment=new_fragment_uri) + + +@pytest.fixture +def manifest_without_fragment(anndata_uri) -> IngestionManifest: + return IngestionManifest(anndata=anndata_uri) + + +class TestProcessValidateAtac: + """These tests assume that the anndata is atac, and a fragment is provided.""" + + # collection revision + ## dataset revised with optional fragment, fragment is added + ## dataset revised without optional fragment, fragment is removed, still exists on old dataset version. + ## dataset with required fragment and revised anndata, new anndata and same fragment + ## dataset with revise required fragment, new fragment, old fragment still exists + ## dataset with deleted fragment, fragment is removed + + def assert_old_fragment_replaced(self, artifacts, old_artifact_id, old_artifact_index_id, setup): + atac_frag_index_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_INDEX][0] + assert setup.s3_provider.file_exists("datasets", atac_frag_index_artifact.uri.split("/")[-1]) + assert str(atac_frag_index_artifact.id) != str(old_artifact_index_id.id) + + atac_fragment_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + assert setup.s3_provider.file_exists("datasets", atac_fragment_artifact.uri.split("/")[-1]) + assert str(atac_fragment_artifact.id) != str(old_artifact_id.id) + + def assert_new_fragment_added(self, artifacts, setup): + atac_fragment_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + assert setup.s3_provider.file_exists("datasets", atac_fragment_artifact.uri.split("/")[-1]) + assert atac_fragment_artifact.uri == f"s3://datasets/{atac_fragment_artifact.id}-fragment.tsv.bgz" + + atac_frag_index_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_INDEX][0] + assert setup.s3_provider.file_exists("datasets", atac_frag_index_artifact.uri.split("/")[-1]) + assert atac_frag_index_artifact.uri == f"s3://datasets/{atac_fragment_artifact.id}-fragment.tsv.bgz.tbi" + + def assert_artifacts_uploaded(self, setup, dataset_version_id) -> list[DatasetArtifact]: + status = setup.business_logic.get_dataset_status(dataset_version_id) + assert status.atac_status == DatasetConversionStatus.UPLOADED + + artifacts = setup.business_logic.get_dataset_version(dataset_version_id).artifacts + assert len(artifacts) == 2 + + return artifacts + + @pytest.mark.parametrize( + "anndata_uri", + [ + "s3://fake_bucket_name/fake_key.h5ad", # existing anndata + "https://www.dropbox.com/s/fake_location/test.h5ad?dl=0", # new anndata + ], + ) + def test_new_fragment( + self, manifest_with_fragment, unpublished_collection, unpublished_dataset, process_validate_atac, setup + ): + """validation will succeed, status will be updated, and fragment artifacts will be uploaded + + This covers cases where the collection is unpublished, and the datset is new. + It also covers cases where the fragment is optional and present or required and present. + """ + # Arrange + dataset_version_id, _ = unpublished_dataset + + # Act + process_validate_atac.process( + unpublished_collection.version_id, + dataset_version_id, + manifest_with_fragment, + "datasets", + ) + + # Assert + artifacts = self.assert_artifacts_uploaded(setup, dataset_version_id) + self.assert_new_fragment_added(artifacts, setup) + + def test_old_fragment(self, anndata_uri, collection_revision_with_fragment, process_validate_atac, setup): + """A published fragment is used in the manifest, this will pass validation, the artifact will be copied to the + new dataset version.""" + # Arrange + process_validate_atac.hash_file = Mock( + return_value="fake_hash" + ) # mock the hash_file method to return the same value + dataset = collection_revision_with_fragment.datasets[0] + new_dataset_version = setup.database_provider.replace_dataset_in_collection_version( + collection_revision_with_fragment.version_id, dataset.version_id + ) + old_fragment_index_artifact_id = setup.get_artifact_type_from_dataset( + dataset, DatasetArtifactType.ATAC_INDEX + ).id + old_fragment_artifact_id = setup.get_artifact_type_from_dataset(dataset, DatasetArtifactType.ATAC_FRAGMENT).id + old_fragment_uri = fragment_uri_fmt.format(artifact_id=old_fragment_artifact_id) + + # Act + process_validate_atac.process( + collection_revision_with_fragment.version_id, + new_dataset_version.version_id, + IngestionManifest(anndata=anndata_uri, atac_fragment=old_fragment_uri), + "datasets", + ) + + # Assert + status = setup.business_logic.get_dataset_status(new_dataset_version.version_id) + assert status.atac_status == DatasetConversionStatus.COPIED + + artifacts = setup.business_logic.get_dataset_version(new_dataset_version.version_id).artifacts + assert len(artifacts) == 2 + + atac_frag_index_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_INDEX][0] + assert setup.s3_provider.file_exists("datasets", atac_frag_index_artifact.uri.split("/")[-1]) + assert str(atac_frag_index_artifact.id) == str(old_fragment_index_artifact_id.id) + + atac_fragment_artifact = [a for a in artifacts if a.type == DatasetArtifactType.ATAC_FRAGMENT][0] + assert setup.s3_provider.file_exists("datasets", atac_fragment_artifact.uri.split("/")[-1]) + assert str(atac_fragment_artifact.id) == str(old_fragment_artifact_id) + + def test_old_fragment_replaced_because_hash_difference( + self, anndata_uri, collection_revision_with_fragment, process_validate_atac, setup, migration_set + ): + """A published fragment is used in the manifest. This will pass validation, but the hash of the new file is + different, so a new artifact will be added to the dataset version.""" + # Arrange + process_validate_atac.hash_file = Mock( + side_effect=["abcd", "efgh"] + ) # mock the hash_file method to return different values + dataset = collection_revision_with_fragment.datasets[0] + new_dataset_version = setup.database_provider.replace_dataset_in_collection_version( + collection_revision_with_fragment.version_id, dataset.version_id + ) + old_fragment_index_artifact_id = setup.get_artifact_type_from_dataset( + dataset, DatasetArtifactType.ATAC_INDEX + ).id + old_fragment_artifact_id = setup.get_artifact_type_from_dataset(dataset, DatasetArtifactType.ATAC_FRAGMENT).id + old_fragment_uri = fragment_uri_fmt.format(artifact_id=old_fragment_artifact_id) + + # Act + process_validate_atac.process( + collection_revision_with_fragment.version_id, + new_dataset_version.version_id, + IngestionManifest(anndata=anndata_uri, atac_fragment=old_fragment_uri), + "datasets", + ) + + # Assert + artifacts = self.assert_artifacts_uploaded(setup, new_dataset_version.version_id) + self.assert_old_fragment_replaced(artifacts, old_fragment_artifact_id, old_fragment_index_artifact_id, setup) + + def test_replace_existing_fragment( + self, collection_revision_with_fragment, process_validate_atac, setup, manifest_with_fragment + ): + # Arrange + dataset = collection_revision_with_fragment.datasets[0] + new_dataset_version = setup.database_provider.replace_dataset_in_collection_version( + collection_revision_with_fragment.version_id, dataset.version_id + ) + old_fragment_index_artifact_id = setup.get_artifact_type_from_dataset( + dataset, DatasetArtifactType.ATAC_INDEX + ).id + old_fragment_artifact_id = setup.get_artifact_type_from_dataset(dataset, DatasetArtifactType.ATAC_FRAGMENT).id + + # Act + process_validate_atac.process( + collection_revision_with_fragment.version_id, + new_dataset_version.version_id, + manifest_with_fragment, + "datasets", + ) + + # Assert + artifacts = self.assert_artifacts_uploaded(setup, new_dataset_version.version_id) + self.assert_old_fragment_replaced(artifacts, old_fragment_artifact_id, old_fragment_index_artifact_id, setup) + + def test_existing_dataset_with_fragment_removed( + self, collection_revision_with_fragment, process_validate_atac, setup, manifest_without_fragment + ): + """Updating a dataset to remove the optional fragment.""" + # Arrange + process_validate_atac.schema_validator.check_anndata_requires_fragment = Mock(return_value=False) + dataset_version_id = collection_revision_with_fragment.datasets[0].version_id + new_dataset_version = setup.database_provider.replace_dataset_in_collection_version( + collection_revision_with_fragment.version_id, dataset_version_id + ) + + # Act + process_validate_atac.process( + collection_revision_with_fragment.version_id, + new_dataset_version.version_id, + manifest_without_fragment, + "datasets", + ) + + # Assert + status = setup.business_logic.get_dataset_status(new_dataset_version.version_id) + assert status.atac_status == DatasetConversionStatus.SKIPPED + + artifacts = setup.business_logic.get_dataset_version(new_dataset_version.version_id).artifacts + assert len(artifacts) == 0 + + def test_existing_dataset_with_fragment_added( + self, collection_revision_with_fragment, process_validate_atac, setup, manifest_with_fragment + ): + """Updating a dataset to add the optional fragment.""" + # Arrange + process_validate_atac.schema_validator.check_anndata_requires_fragment = Mock(return_value=False) + dataset_version_id = collection_revision_with_fragment.datasets[0].version_id + new_dataset_version = setup.database_provider.replace_dataset_in_collection_version( + collection_revision_with_fragment.version_id, dataset_version_id + ) + + # Act + process_validate_atac.process( + collection_revision_with_fragment.version_id, + new_dataset_version.version_id, + manifest_with_fragment, + "datasets", + ) + + # Assert + artifacts = self.assert_artifacts_uploaded(setup, new_dataset_version.version_id) + self.assert_new_fragment_added(artifacts, setup) + + +class TestSkipATACValidation: + def test_not_atac_and_no_fragment( + self, process_validate_atac, unpublished_dataset, setup, manifest_without_fragment + ): + """The anndata file is not ATAC, and no fragment file in the manifest, so validation should be skipped.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + process_validate_atac.schema_validator.check_anndata_requires_fragment = Mock(side_effect=ValueError("test")) + + # Act + assert process_validate_atac.skip_atac_validation("fake_path", manifest_without_fragment, dataset_version_id) + + # Assert + dataset_status = setup.business_logic.get_dataset_status(dataset_version_id) + assert setup.business_logic.get_dataset_status(dataset_version_id).atac_status == DatasetConversionStatus.NA + assert dataset_status.validation_message == "test" + + def test_not_atac_and_fragment(self, process_validate_atac, unpublished_dataset, setup, manifest_with_fragment): + """A manifest is provided with a fragment, and the anndata does not require one. This will fail + validation.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + process_validate_atac.schema_validator.check_anndata_requires_fragment = Mock(side_effect=ValueError("test")) + + # Act + with pytest.raises(ValidationAtacFailed) as e: + process_validate_atac.skip_atac_validation("fake_path", manifest_with_fragment, dataset_version_id) + # Assert + assert e.value.errors == ["test", "Fragment file not allowed for non atac anndata."] + dataset_status = setup.business_logic.get_dataset_status(dataset_version_id) + assert dataset_status.validation_status == DatasetValidationStatus.INVALID + + def test_optional_and_no_fragment( + self, process_validate_atac, unpublished_dataset, setup, manifest_without_fragment + ): + """A manifest is provided without a fragment, and the anndata does not require one. This will pass + validation.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + + # Act + assert process_validate_atac.skip_atac_validation("fake_path", manifest_without_fragment, dataset_version_id) + # Assert + dataset_status = setup.business_logic.get_dataset_status(dataset_version_id) + assert ( + setup.business_logic.get_dataset_status(dataset_version_id).atac_status == DatasetConversionStatus.SKIPPED + ) + assert dataset_status.validation_message == "Fragment is optional and not present." + + def test_optional_and_fragment(self, process_validate_atac, unpublished_dataset, setup, manifest_with_fragment): + """A manifest is provided without a fragment, and the anndata has an optional fragment. This will pass + validation.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + + # Act + assert not process_validate_atac.skip_atac_validation("fake_path", manifest_with_fragment, dataset_version_id) + + def test_required_and_missing_fragment( + self, process_validate_atac, unpublished_dataset, setup, manifest_without_fragment + ): + """A manifest is provided without a fragment, and the anndata requires one. This will fail validation.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + process_validate_atac.schema_validator.check_anndata_requires_fragment = Mock(return_value=True) + + # Act + with pytest.raises(ValidationAtacFailed) as e: + process_validate_atac.skip_atac_validation("fake_path", manifest_without_fragment, dataset_version_id) + + # Assert + assert e.value.errors == ["Anndata requires fragment file"] + dataset_status = setup.business_logic.get_dataset_status(dataset_version_id) + assert dataset_status.validation_status == DatasetValidationStatus.INVALID + + +class TestHashFile: + def test_hash_file(self, process_validate_atac, tmpdir): + """Test that the hash_file method returns the correct hash.""" + + # Arrange + file_path = tmpdir.join("test.txt") + with open(tmpdir.join("test.txt"), "w") as f: + f.write("test") + + # Act + assert isinstance(process_validate_atac.hash_file(file_path), str) + + +class TestCreateAtacArtifact: + def test_fragment(self, process_validate_atac, unpublished_dataset, setup): + """Test that the create_atac_artifact method creates an artifact for the fragment.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + artifact_id = process_validate_atac.create_atac_artifact( + "anything", + DatasetArtifactType.ATAC_FRAGMENT, + dataset_version_id, + "datasets", + ) + + # Assert + dataset = setup.business_logic.get_dataset_version(dataset_version_id) + artifacts = dataset.artifacts + assert len(artifacts) == 1 + assert str(artifact_id.id) == str(artifacts[0].id) + assert artifacts[0].type == DatasetArtifactType.ATAC_FRAGMENT + assert artifacts[0].uri == f"s3://datasets/{artifacts[0].id}-fragment.tsv.bgz" + + def test_fragment_index(self, process_validate_atac, unpublished_dataset, setup): + """Test that the create_atac_artifact method creates an artifact for the fragment index.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + fragment_artifact_id = DatasetArtifactId("deadbeef-36da-4643-b3d5-ee20853084ba") + artifact_id = process_validate_atac.create_atac_artifact( + "anything", + DatasetArtifactType.ATAC_INDEX, + dataset_version_id, + "datasets", + fragment_artifact_id=fragment_artifact_id, + ) + + # Assert + dataset = setup.business_logic.get_dataset_version(dataset_version_id) + artifacts = dataset.artifacts + assert len(artifacts) == 1 + assert str(artifact_id.id) == str(artifacts[0].id) + assert artifacts[0].type == DatasetArtifactType.ATAC_INDEX + assert artifacts[0].uri == f"s3://datasets/{fragment_artifact_id.id}-fragment.tsv.bgz.tbi" + + def test_exception(self, process_validate_atac, unpublished_dataset, setup): + """Test that the create_atac_artifact method raises an exception when the artifact cannot be created.""" + # Arrange + dataset_version_id, _ = unpublished_dataset + process_validate_atac.business_logic.add_dataset_artifact = Mock(side_effect=ValueError("test")) + + # Act + with pytest.raises(ConversionFailed) as e: + process_validate_atac.create_atac_artifact( + "anything", + DatasetArtifactType.ATAC_FRAGMENT, + dataset_version_id, + "datasets", + ) + + assert e.value.failed_status == DatasetStatusKey.ATAC diff --git a/tests/unit/processing/test_process_validate_h5ad.py b/tests/unit/processing/test_process_validate_h5ad.py new file mode 100644 index 0000000000000..4ace38ef0709a --- /dev/null +++ b/tests/unit/processing/test_process_validate_h5ad.py @@ -0,0 +1,129 @@ +from unittest.mock import Mock, patch + +import pytest + +from backend.common.utils.corpora_constants import CorporaConstants +from backend.layers.common.entities import ( + DatasetConversionStatus, + DatasetProcessingStatus, + DatasetUploadStatus, + DatasetValidationStatus, +) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.process import ProcessMain +from backend.layers.processing.process_validate_h5ad import ProcessValidateH5AD +from tests.unit.processing.base_processing_test import BaseProcessingTest + + +class TestProcessDownload(BaseProcessingTest): + @patch("backend.common.utils.dl_sources.uri.S3Provider") + @patch("backend.common.utils.dl_sources.uri.S3URI.file_info", return_value={"size": 100, "name": "fake_name"}) + def test_download_from_s3_uri(self, *arg): + """ + Call process download using an s3 uri + """ + + s3_uri = "s3://fake_bucket_name/fake_key/fake_file.h5ad" + pdv = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pdv.download_from_s3 = Mock() + + assert pdv.download_from_source_uri(s3_uri, "fake_local_path") == "fake_local_path" + + @patch("backend.common.utils.dl_sources.uri.downloader") + @patch("backend.common.utils.dl_sources.uri.DropBoxURL.file_info", return_value={"size": 100, "name": "fake_name"}) + def test_download_from_dropbox_uri(self, *arg): + """ + Call process download using a dropbox uri + """ + + dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=1" + pdv = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pdv.download = Mock() + + assert pdv.download_from_source_uri(dropbox_uri, "fake_local_path") == "fake_local_path" + + def test_download_unknown_uri(self): + """ + Call process download using unknown + """ + + uri = "fake://fake_bucket_name/fake_key/fake_file.h5ad" + pdv = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pdv.download_from_s3 = Mock() + with pytest.raises(ValueError, match=f"Malformed source URI: {uri}"): + pdv.download_from_source_uri(uri, "fake_local_path") + + +class TestProcessValidateH5AD(BaseProcessingTest): + def test_process_download_validate_success(self): + """ + ProcessDownloadValidate should: + 1. Download the h5ad artifact + 2. Set DatasetStatusKey.H5AD DatasetValidationStatus.VALIDATING + 3. Validate the h5ad + 4. Set DatasetStatusKey.H5AD DatasetValidationStatus.VALID + 5. Set the DatasetStatusKey.RDS DatasetConversionStatus.SKIPPED accordingly + 6. upload the original file to S3 + + """ + dropbox_uri = "https://www.dropbox.com/s/fake_location/test.h5ad?dl=0" + + collection = self.generate_unpublished_collection() + dataset_version_id, dataset_id = self.business_logic.ingest_dataset( + collection.version_id, dropbox_uri, None, None + ) + # This is where we're at when we start the SFN + + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertIsNone(status.validation_status) + self.assertEqual(status.processing_status, DatasetProcessingStatus.INITIALIZED) + self.assertEqual(status.upload_status, DatasetUploadStatus.WAITING) + + pdv = ProcessValidateH5AD(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pdv.download_from_source_uri = Mock(return_value=CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME) + pdv.process(dataset_version_id, IngestionManifest(anndata=dropbox_uri), "fake_bucket_name") + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertEqual(status.rds_status, DatasetConversionStatus.SKIPPED) + self.assertEqual(status.h5ad_status, DatasetConversionStatus.CONVERTING) + + raw_uri = f"s3://fake_bucket_name/{dataset_version_id.id}/raw.h5ad" + # Verify that the original (raw.h5ad) file is there + self.assertTrue(self.s3_provider.uri_exists(raw_uri)) + # Verify that the artifact uri is in the database + artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) + self.assertEqual(1, len(artifacts)) + self.assertEqual(artifacts[0].uri, raw_uri) + + def test_process_validate_fail(self): + """ + If the validation is not successful, the processing pipeline should: + 1. Set the processing status to INVALID + 2. Set a validation message accordingly + """ + dropbox_uri = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" + manifest = IngestionManifest(anndata=dropbox_uri) + collection = self.generate_unpublished_collection() + dataset_version_id, dataset_id = self.business_logic.ingest_dataset( + collection.version_id, dropbox_uri, None, None + ) + + # Set a mock failure for the schema validator + self.schema_validator.validate_anndata = Mock( + return_value=(False, ["Validation error 1", "Validation error 2"], True) + ) + pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pm.download_from_source_uri = Mock(return_value=CorporaConstants.ORIGINAL_H5AD_ARTIFACT_FILENAME) + for step_name in ["validate_anndata"]: + pm.process( + collection.version_id, + dataset_version_id, + step_name, + manifest, + "fake_bucket_name", + "fake_datasets_bucket", + "fake_cxg_bucket", + ) + + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertEqual(status.validation_status, DatasetValidationStatus.INVALID) + self.assertEqual(status.validation_message, "Validation error 1\nValidation error 2") diff --git a/tests/unit/processing/test_processing.py b/tests/unit/processing/test_processing.py index d733a8680713f..0a6661597dc2b 100644 --- a/tests/unit/processing/test_processing.py +++ b/tests/unit/processing/test_processing.py @@ -1,49 +1,24 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import Mock, patch from backend.layers.common.entities import ( + DatasetArtifactType, DatasetConversionStatus, DatasetProcessingStatus, DatasetUploadStatus, DatasetValidationStatus, ) +from backend.layers.common.ingestion_manifest import IngestionManifest +from backend.layers.processing.exceptions import ValidationAtacFailed from backend.layers.processing.process import ProcessMain from backend.layers.processing.process_cxg import ProcessCxg -from backend.layers.processing.process_seurat import ProcessSeurat from tests.unit.processing.base_processing_test import BaseProcessingTest class ProcessingTest(BaseProcessingTest): - @patch("anndata.read_h5ad") - @patch("backend.layers.processing.process_seurat.ProcessSeurat.make_seurat") - def test_process_seurat_success(self, mock_seurat, mock_anndata_read_h5ad): - collection = self.generate_unpublished_collection() - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, "nothing", None, None - ) - - mock_anndata = MagicMock(uns=dict(), n_obs=1000, n_vars=1000) - mock_anndata_read_h5ad.return_value = mock_anndata - - mock_seurat.return_value = "local.rds" - ps = ProcessSeurat(self.business_logic, self.uri_provider, self.s3_provider) - ps.process(dataset_version_id, "fake_bucket_name", "fake_datasets_bucket") - - status = self.business_logic.get_dataset_status(dataset_version_id) - self.assertEqual(status.rds_status, DatasetConversionStatus.UPLOADED) - - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) - - artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(1, len(artifacts)) - artifact = artifacts[0] - artifact.type = "RDS" - artifact.uri = f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds" - def test_process_cxg_success(self): collection = self.generate_unpublished_collection() dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, "nothing", None, None + collection.version_id, "http://fake.url", None, None ) with patch("backend.layers.processing.process_cxg.ProcessCxg.make_cxg") as mock: @@ -65,14 +40,14 @@ def test_process_cxg_success(self): def test_reprocess_cxg_success(self): collection = self.generate_unpublished_collection() dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, "nothing", None, None + collection.version_id, "http://fake.url", None, None ) with patch("backend.layers.processing.process_cxg.ProcessCxg.make_cxg") as mock: mock.return_value = "local.cxg" ps = ProcessCxg(self.business_logic, self.uri_provider, self.s3_provider) self.business_logic.add_dataset_artifact( - dataset_version_id, "h5ad", f"s3://fake_bucket_name/{dataset_id}/local.h5ad" + dataset_version_id, DatasetArtifactType.H5AD, f"s3://fake_bucket_name/{dataset_id}/local.h5ad" ) ps.process(dataset_version_id, "fake_bucket_name", "fake_cxg_bucket") @@ -92,56 +67,87 @@ def test_reprocess_cxg_success(self): cxg_artifact = [artifact for artifact in artifacts if artifact.type == "cxg"][0] self.assertTrue(cxg_artifact, f"s3://fake_cxg_bucket/{dataset_version_id.id}.cxg/") - @patch("backend.layers.processing.process_download.StepFunctionProvider") - @patch("scanpy.read_h5ad") - @patch("anndata.read_h5ad") - @patch("backend.layers.processing.process_validate.ProcessValidate.extract_metadata") - @patch("backend.layers.processing.process_seurat.ProcessSeurat.make_seurat") + @patch("backend.layers.processing.process_add_labels.ProcessAddLabels.populate_dataset_citation") + @patch("backend.layers.processing.process_add_labels.ProcessAddLabels.extract_metadata") @patch("backend.layers.processing.process_cxg.ProcessCxg.make_cxg") - def test_process_all( - self, mock_cxg, mock_seurat, mock_h5ad, mock_anndata_read_h5ad, mock_scanpy_read_h5ad, mock_sfn_provider - ): - mock_seurat.return_value = "local.rds" + def test_process_anndata(self, mock_cxg, mock_extract_h5ad, mock_dataset_citation): mock_cxg.return_value = "local.cxg" - # Mock anndata object - mock_anndata = MagicMock(uns=dict(), n_obs=1000, n_vars=1000) - mock_scanpy_read_h5ad.return_value = mock_anndata - mock_anndata_read_h5ad.return_value = mock_anndata - dropbox_uri = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" + manifest = IngestionManifest(anndata=dropbox_uri) collection = self.generate_unpublished_collection() dataset_version_id, dataset_id = self.business_logic.ingest_dataset( collection.version_id, dropbox_uri, None, None ) - + self.schema_validator.check_anndata_requires_fragment.return_value = False pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) - for step_name in ["download", "validate", "cxg", "seurat"]: + pm.download_from_source_uri = lambda x, y: y + + for step_name in ["validate_anndata", "validate_atac", "add_labels", "cxg"]: assert pm.process( collection.version_id, dataset_version_id, step_name, - dropbox_uri, + manifest, "fake_bucket_name", "fake_datasets_bucket", "fake_cxg_bucket", ) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/raw.h5ad")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad")) + dataset = self.business_logic.get_dataset_version(dataset_version_id) + + def assert_s3_matches_db(uri, artifact_type): + self.assertTrue(self.s3_provider.uri_exists(uri)) + artifact = [artifact for artifact in dataset.artifacts if artifact.type == artifact_type] + self.assertTrue(len(artifact) == 1) + self.assertEqual(artifact[0].uri, uri) + + assert_s3_matches_db(f"s3://fake_bucket_name/{dataset_version_id.id}/raw.h5ad", DatasetArtifactType.RAW_H5AD) + assert_s3_matches_db(f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad", DatasetArtifactType.H5AD) self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.h5ad")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_cxg_bucket/{dataset_version_id.id}.cxg/")) + self.assertFalse( + self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}-fragment.tsv.bgz") + ) + self.assertFalse( + self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}-fragment.tsv.bgz.tbi") + ) + self.assertFalse(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) + self.assertFalse(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) status = self.business_logic.get_dataset_status(dataset_version_id) self.assertEqual(status.cxg_status, DatasetConversionStatus.UPLOADED) - self.assertEqual(status.rds_status, DatasetConversionStatus.UPLOADED) + self.assertEqual(status.rds_status, DatasetConversionStatus.SKIPPED) self.assertEqual(status.h5ad_status, DatasetConversionStatus.UPLOADED) + self.assertEqual(status.atac_status, DatasetConversionStatus.SKIPPED) self.assertEqual(status.validation_status, DatasetValidationStatus.VALID) self.assertEqual(status.upload_status, DatasetUploadStatus.UPLOADED) self.assertEqual(status.processing_status, DatasetProcessingStatus.PENDING) # TODO: DatasetProcessingStatus.SUCCESS is set by a lambda that also needs to be modified. It should belong here artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(4, len(artifacts)) + self.assertEqual(3, len(artifacts)) + + def test_process_atac_ValidationAtacFailed(self): + dropbox_uri = "https://www.dropbox.com/s/ow84zm4h0wkl409/test.h5ad?dl=0" + manifest = IngestionManifest(anndata=dropbox_uri, atac_fragment=dropbox_uri) + collection = self.generate_unpublished_collection() + dataset_version_id, dataset_id = self.business_logic.ingest_dataset( + collection.version_id, dropbox_uri, None, None + ) + pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) + pm.process_validate_atac_seq.process = Mock(side_effect=ValidationAtacFailed(errors=["failure 1", "failure 2"])) + + assert not pm.process( + collection.version_id, + dataset_version_id, + "validate_atac", + manifest, + "fake_bucket_name", + "fake_datasets_bucket", + "fake_cxg_bucket", + ) + + status = self.business_logic.get_dataset_status(dataset_version_id) + self.assertEqual(status.validation_status, DatasetValidationStatus.INVALID) + self.assertEqual(status.validation_message, "\n".join(["failure 1", "failure 2"]))