diff --git a/changelog.d/20250429_163951_maria_reusable_requests_functionality.md b/changelog.d/20250429_163951_maria_reusable_requests_functionality.md new file mode 100644 index 000000000000..d136d0f8928d --- /dev/null +++ b/changelog.d/20250429_163951_maria_reusable_requests_functionality.md @@ -0,0 +1,36 @@ +### Removed + +- The `POST /api/consensus/merges?rq_id=rq_id` endpoint no longer supports + process status checking + () +- The `GET /api/projects/id/dataset?action=import_status` endpoint no longer + supports process status checking + () +- The `POST /api/projects/backup?rq_id=rq_id` endpoint no longer supports + process status checking + () +- The `POST /api/tasks/backup?rq_id=rq_id` endpoint no longer supports + process status checking + () +- The `PUT /api/tasks/id/annotations?rq_id=rq_id&format=format` endpoint + no longer supports process status checking + () +- The `PUT /api/jobs/id/annotations?rq_id=rq_id&format=format` endpoint + no longer supports process status checking + () + +### Deprecated + +- The `GET /api/events` endpoint is deprecated in favor of the `POST /api/events/export`, + `GET /api/requests/rq_id`, and `GET result_url`, where `result_url` is obtained from + background request details + () +- The `POST /api/quality/reports/rq_id=rq_id` is deprecated in favor of + `GET /api/requests/rq_id` + () + +### Changed +- Cache files with exported events now are stored in `/data/cache/export/` instead of + `/data/tmp/`. These files are periodically deleted by the + `cleanup_export_cache_directory` cron job + () diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 241d09b2a8a9..bd0dc058776b 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -612,6 +612,57 @@ const defaultRequestConfig = { fetchAll: false, }; +async function getRequestsList(): Promise> { + const { backendAPI } = config; + const params = enableOrganization(); + + try { + const response = await fetchAll(`${backendAPI}/requests`, params); + + return response.results; + } catch (errorData) { + throw generateError(errorData); + } +} + +// Temporary solution for server availability problems +const retryTimeouts = [5000, 10000, 15000]; +async function getRequestStatus(rqID: string): Promise { + const { backendAPI } = config; + let retryCount = 0; + let lastError = null; + + while (retryCount < 3) { + try { + const response = await Axios.get(`${backendAPI}/requests/${rqID}`); + + return response.data; + } catch (errorData) { + lastError = generateError(errorData); + const { response } = errorData; + if (response && [502, 503, 504].includes(response.status)) { + const timeout = retryTimeouts[retryCount]; + await new Promise((resolve) => { setTimeout(resolve, timeout); }); + retryCount++; + } else { + throw generateError(errorData); + } + } + } + + throw lastError; +} + +async function cancelRequest(requestID): Promise { + const { backendAPI } = config; + + try { + await Axios.post(`${backendAPI}/requests/${requestID}/cancel`); + } catch (errorData) { + throw generateError(errorData); + } +} + async function serverRequest( url: string, data: object, requestConfig: ServerRequestConfig = defaultRequestConfig, @@ -768,30 +819,19 @@ async function deleteTask(id: number, organizationID: string | null = null): Pro } } -async function mergeConsensusJobs(id: number, instanceType: string): Promise { +async function mergeConsensusJobs(id: number, instanceType: string): Promise { const { backendAPI } = config; const url = `${backendAPI}/consensus/merges`; - const params = { - rq_id: null, - }; - const requestBody = { - task_id: undefined, - job_id: undefined, - }; + const requestBody = (instanceType === 'task') ? { task_id: id } : { job_id: id }; - if (instanceType === 'task') requestBody.task_id = id; - else requestBody.job_id = id; - - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { async function request() { try { - const response = await Axios.post(url, requestBody, { params }); - params.rq_id = response.data.rq_id; + const response = await Axios.post(url, requestBody); + const rqID = response.data.rq_id; const { status } = response; if (status === 202) { - setTimeout(request, 3000); - } else if (status === 201) { - resolve(); + resolve(rqID); } else { reject(generateError(response)); } @@ -2304,57 +2344,6 @@ async function getAnalyticsReports( } } -async function getRequestsList(): Promise> { - const { backendAPI } = config; - const params = enableOrganization(); - - try { - const response = await fetchAll(`${backendAPI}/requests`, params); - - return response.results; - } catch (errorData) { - throw generateError(errorData); - } -} - -// Temporary solution for server availability problems -const retryTimeouts = [5000, 10000, 15000]; -async function getRequestStatus(rqID: string): Promise { - const { backendAPI } = config; - let retryCount = 0; - let lastError = null; - - while (retryCount < 3) { - try { - const response = await Axios.get(`${backendAPI}/requests/${rqID}`); - - return response.data; - } catch (errorData) { - lastError = generateError(errorData); - const { response } = errorData; - if (response && [502, 503, 504].includes(response.status)) { - const timeout = retryTimeouts[retryCount]; - await new Promise((resolve) => { setTimeout(resolve, timeout); }); - retryCount++; - } else { - throw generateError(errorData); - } - } - } - - throw lastError; -} - -async function cancelRequest(requestID): Promise { - const { backendAPI } = config; - - try { - await Axios.post(`${backendAPI}/requests/${requestID}/cancel`); - } catch (errorData) { - throw generateError(errorData); - } -} - const listenToCreateAnalyticsReportCallbacks: { job: LongProcessListener; task: LongProcessListener; diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index 8deff9694ad4..cc5a19273d6b 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -738,7 +738,7 @@ export class Job extends Session { return result; } - async mergeConsensusJobs(): Promise { + async mergeConsensusJobs(): Promise { const result = await PluginRegistry.apiWrapper.call(this, Job.prototype.mergeConsensusJobs); return result; } @@ -1204,7 +1204,7 @@ export class Task extends Session { return result; } - async mergeConsensusJobs(): Promise { + async mergeConsensusJobs(): Promise { const result = await PluginRegistry.apiWrapper.call(this, Task.prototype.mergeConsensusJobs); return result; } diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index d4928b2c4781..8af3df28a3ec 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -574,7 +574,7 @@ def annotate_task( if clear_existing: client.tasks.api.update_annotations( - task_id, task_annotations_update_request=models.LabeledDataRequest(shapes=shapes) + task_id, labeled_data_request=models.LabeledDataRequest(shapes=shapes) ) else: client.tasks.api.partial_update_annotations( diff --git a/cvat-sdk/cvat_sdk/core/proxies/annotations.py b/cvat-sdk/cvat_sdk/core/proxies/annotations.py index 6ec7434b2c99..8d3c45153a10 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/annotations.py +++ b/cvat-sdk/cvat_sdk/core/proxies/annotations.py @@ -20,17 +20,12 @@ class AnnotationUpdateAction(Enum): class AnnotationCrudMixin(ABC): # TODO: refactor - @property - def _put_annotations_data_param(self) -> str: ... - def get_annotations(self: _EntityT) -> models.ILabeledData: (annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field)) return annotations def set_annotations(self: _EntityT, data: models.ILabeledDataRequest): - self.api.update_annotations( - getattr(self, self._model_id_field), **{self._put_annotations_data_param: data} - ) + self.api.update_annotations(getattr(self, self._model_id_field), labeled_data_request=data) def update_annotations( self: _EntityT, diff --git a/cvat-sdk/cvat_sdk/core/proxies/jobs.py b/cvat-sdk/cvat_sdk/core/proxies/jobs.py index fbf02c168b24..ad21a41de60c 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/jobs.py +++ b/cvat-sdk/cvat_sdk/core/proxies/jobs.py @@ -42,7 +42,6 @@ class Job( ExportDatasetMixin, ): _model_partial_update_arg = "patched_job_write_request" - _put_annotations_data_param = "job_annotations_update_request" def import_annotations( self, diff --git a/cvat-sdk/cvat_sdk/core/proxies/tasks.py b/cvat-sdk/cvat_sdk/core/proxies/tasks.py index 7502612342e6..4714910787a4 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/tasks.py +++ b/cvat-sdk/cvat_sdk/core/proxies/tasks.py @@ -67,7 +67,6 @@ class Task( DownloadBackupMixin, ): _model_partial_update_arg = "patched_task_write_request" - _put_annotations_data_param = "task_annotations_update_request" def upload_data( self, diff --git a/cvat-ui/src/actions/consensus-actions.ts b/cvat-ui/src/actions/consensus-actions.ts index a06de6ef3bdc..c42d5334b725 100644 --- a/cvat-ui/src/actions/consensus-actions.ts +++ b/cvat-ui/src/actions/consensus-actions.ts @@ -3,7 +3,11 @@ // SPDX-License-Identifier: MIT import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; -import { Project, ProjectOrTaskOrJob } from 'cvat-core-wrapper'; +import { Project, ProjectOrTaskOrJob, getCore } from 'cvat-core-wrapper'; + +import { updateRequestProgress } from './requests-actions'; + +const core = getCore(); export enum ConsensusActionTypes { MERGE_CONSENSUS_JOBS = 'MERGE_CONSENSUS_JOBS', @@ -28,7 +32,10 @@ export const mergeConsensusJobsAsync = ( ): ThunkAction => async (dispatch) => { try { dispatch(consensusActions.mergeConsensusJobs(instance)); - await instance.mergeConsensusJobs(); + const rqID = await instance.mergeConsensusJobs(); + await core.requests.listen(rqID, { + callback: (updatedRequest) => updateRequestProgress(updatedRequest, dispatch), + }); } catch (error) { dispatch(consensusActions.mergeConsensusJobsFailed(instance, error)); return; diff --git a/cvat/apps/consensus/merging_manager.py b/cvat/apps/consensus/merging_manager.py index b6fdf0eb75c5..98a304ef4943 100644 --- a/cvat/apps/consensus/merging_manager.py +++ b/cvat/apps/consensus/merging_manager.py @@ -6,32 +6,29 @@ from typing import Type import datumaro as dm -import django_rq from django.conf import settings from django.db import transaction -from django_rq.queues import DjangoRQ as RqQueue -from rq.job import Job as RqJob -from rq.job import JobStatus as RqJobStatus +from rest_framework import serializers from cvat.apps.consensus.intersect_merge import IntersectMerge from cvat.apps.consensus.models import ConsensusSettings +from cvat.apps.consensus.rq import ConsensusRequestId from cvat.apps.dataset_manager.bindings import import_dm_annotations from cvat.apps.dataset_manager.task import PatchAction, patch_job_data from cvat.apps.engine.models import ( DimensionType, Job, JobType, + RequestTarget, StageChoice, StateChoice, Task, User, clear_annotations_in_jobs, ) -from cvat.apps.engine.rq import BaseRQMeta, define_dependent_job -from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import get_rq_lock_by_user from cvat.apps.profiler import silk_profile from cvat.apps.quality_control.quality_reports import ComparisonParameters, JobDataProvider +from cvat.apps.redis_handler.background import AbstractRequestManager class _TaskMerger: @@ -159,83 +156,43 @@ class MergingNotAvailable(Exception): pass -class JobAlreadyExists(MergingNotAvailable): - def __init__(self, instance: Task | Job): - super().__init__() - self.instance = instance +class MergingManager(AbstractRequestManager): + QUEUE_NAME = settings.CVAT_QUEUES.CONSENSUS.value + SUPPORTED_TARGETS = {RequestTarget.TASK, RequestTarget.JOB} - def __str__(self): - return f"Merging for this {type(self.instance).__name__.lower()} already enqueued" + @property + def job_result_ttl(self): + return 300 + def build_request_id(self) -> str: + return ConsensusRequestId( + target=self.target, + target_id=self.db_instance.pk, + ).render() -class MergingManager: - _QUEUE_CUSTOM_JOB_PREFIX = "consensus-merge-" - _JOB_RESULT_TTL = 300 + def init_callback_with_params(self): + self.callback = self._merge + self.callback_kwargs = { + "target_type": type(self.db_instance), + "target_id": self.db_instance.pk, + } - def _get_queue(self) -> RqQueue: - return django_rq.get_queue(settings.CVAT_QUEUES.CONSENSUS.value) + def validate_request(self): + super().validate_request() + # FUTURE-FIXME: check that there is no indirectly dependent RQ jobs: + # e.g merge whole task and merge a particular job from the task + task, job = self._split_to_task_and_job() - def _make_job_id(self, task_id: int, job_id: int | None, user_id: int) -> str: - key = f"{self._QUEUE_CUSTOM_JOB_PREFIX}task-{task_id}" - if job_id: - key += f"-job-{job_id}" - key += f"-user-{user_id}" # TODO: remove user id, add support for non owners to get status - return key + try: + _TaskMerger(task=task).check_merging_available(parent_job_id=job.pk if job else None) + except MergingNotAvailable as ex: + raise serializers.ValidationError(str(ex)) from ex - def _check_merging_available(self, task: Task, job: Job | None): - _TaskMerger(task=task).check_merging_available(parent_job_id=job.id if job else None) + def _split_to_task_and_job(self) -> tuple[Task, Job | None]: + if isinstance(self.db_instance, Job): + return self.db_instance.segment.task, self.db_instance - def schedule_merge(self, target: Task | Job, *, request: ExtendedRequest) -> str: - if isinstance(target, Job): - target_task = target.segment.task - target_job = target - else: - target_task = target - target_job = None - - self._check_merging_available(target_task, target_job) - - queue = self._get_queue() - - user_id = request.user.id - with get_rq_lock_by_user(queue, user_id=user_id): - rq_id = self._make_job_id( - task_id=target_task.id, - job_id=target_job.id if target_job else None, - user_id=user_id, - ) - rq_job = queue.fetch_job(rq_id) - if rq_job: - if rq_job.get_status(refresh=False) in ( - RqJobStatus.QUEUED, - RqJobStatus.STARTED, - RqJobStatus.SCHEDULED, - RqJobStatus.DEFERRED, - ): - raise JobAlreadyExists(target) - - rq_job.delete() - - dependency = define_dependent_job( - queue, user_id=user_id, rq_id=rq_id, should_be_dependent=True - ) - - queue.enqueue( - self._merge, - target_type=type(target), - target_id=target.id, - job_id=rq_id, - meta=BaseRQMeta.build(request=request, db_obj=target), - result_ttl=self._JOB_RESULT_TTL, - failure_ttl=self._JOB_RESULT_TTL, - depends_on=dependency, - ) - - return rq_id - - def get_job(self, rq_id: str) -> RqJob | None: - queue = self._get_queue() - return queue.fetch_job(rq_id) + return self.db_instance, None @classmethod @silk_profile() diff --git a/cvat/apps/consensus/permissions.py b/cvat/apps/consensus/permissions.py index 3eb24f59cb09..9c160d4034fc 100644 --- a/cvat/apps/consensus/permissions.py +++ b/cvat/apps/consensus/permissions.py @@ -16,43 +16,25 @@ class ConsensusMergePermission(OpenPolicyAgentPermission): - rq_job_owner_id: int | None task_id: int | None class Scopes(StrEnum): CREATE = "create" - VIEW_STATUS = "view:status" @classmethod - def create_scope_check_status( - cls, request: ExtendedRequest, rq_job_owner_id: int, iam_context=None - ): - if not iam_context and request: - iam_context = get_iam_context(request, None) - return cls(**iam_context, scope=cls.Scopes.VIEW_STATUS, rq_job_owner_id=rq_job_owner_id) - - @classmethod - def create(cls, request, view, obj, iam_context): + def create(cls, request: ExtendedRequest, view, obj, iam_context): Scopes = __class__.Scopes permissions = [] if view.basename == "consensus_merges": for scope in cls.get_scopes(request, view, obj): if scope == Scopes.CREATE: - # Note: POST /api/consensus/merges is used to initiate report creation - # and to check the operation status - rq_id = request.query_params.get("rq_id") + # FUTURE-FIXME: use serializers for validation task_id = request.data.get("task_id") job_id = request.data.get("job_id") - if not (task_id or job_id or rq_id): - raise PermissionDenied( - "Either task_id or job_id or rq_id must be specified" - ) - - if rq_id: - # There will be another check for this case during request processing - continue + if not (task_id or job_id): + raise PermissionDenied("Either task_id or job_id must be specified") # merge is always at least at the task level, even for specific jobs if task_id is not None or job_id is not None: @@ -90,9 +72,6 @@ def create(cls, request, view, obj, iam_context): return permissions def __init__(self, **kwargs): - if "rq_job_owner_id" in kwargs: - self.rq_job_owner_id = int(kwargs.pop("rq_job_owner_id")) - super().__init__(**kwargs) self.url = settings.IAM_OPA_DATA_URL + "/consensus_merges/allow" @@ -143,8 +122,6 @@ def get_resource(self): else None ), } - elif self.scope == self.Scopes.VIEW_STATUS: - data = {"owner": {"id": self.rq_job_owner_id}} return data diff --git a/cvat/apps/consensus/rq.py b/cvat/apps/consensus/rq.py new file mode 100644 index 000000000000..8243aa9b61bf --- /dev/null +++ b/cvat/apps/consensus/rq.py @@ -0,0 +1,14 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import ClassVar + +from cvat.apps.redis_handler.rq import RequestId + + +class ConsensusRequestId(RequestId): + ACTION_DEFAULT_VALUE: ClassVar[str] = "merge" + ACTION_ALLOWED_VALUES: ClassVar[tuple[str]] = (ACTION_DEFAULT_VALUE,) + + QUEUE_SELECTORS: ClassVar[tuple[str]] = ACTION_ALLOWED_VALUES diff --git a/cvat/apps/consensus/rules/consensus_merges.rego b/cvat/apps/consensus/rules/consensus_merges.rego index 113ff7885595..ef0618ece2b7 100644 --- a/cvat/apps/consensus/rules/consensus_merges.rego +++ b/cvat/apps/consensus/rules/consensus_merges.rego @@ -7,7 +7,7 @@ import data.organizations import data.quality_utils # input: { -# "scope": <"create"|"view"|"view:status"|"list"> or null, +# "scope": <"create"|"view"|"list"> or null, # "auth": { # "user": { # "id": , @@ -57,11 +57,6 @@ allow if { organizations.is_member } -allow if { - input.scope == utils.VIEW_STATUS - utils.is_resource_owner -} - allow if { input.scope in {utils.CREATE, utils.VIEW} utils.is_sandbox diff --git a/cvat/apps/consensus/views.py b/cvat/apps/consensus/views.py index 4a7032215fdb..c1c4c8f00b54 100644 --- a/cvat/apps/consensus/views.py +++ b/cvat/apps/consensus/views.py @@ -11,24 +11,21 @@ extend_schema, extend_schema_view, ) -from rest_framework import mixins, status, viewsets -from rest_framework.exceptions import NotFound, ValidationError -from rest_framework.response import Response -from rq.job import JobStatus as RqJobStatus +from rest_framework import mixins, viewsets +from rest_framework.exceptions import NotFound from cvat.apps.consensus import merging_manager as merging from cvat.apps.consensus.models import ConsensusSettings -from cvat.apps.consensus.permissions import ConsensusMergePermission, ConsensusSettingPermission +from cvat.apps.consensus.permissions import ConsensusSettingPermission from cvat.apps.consensus.serializers import ( ConsensusMergeCreateSerializer, ConsensusSettingsSerializer, ) from cvat.apps.engine.mixins import PartialUpdateModelMixin from cvat.apps.engine.models import Job, Task -from cvat.apps.engine.rq import BaseRQMeta -from cvat.apps.engine.serializers import RqIdSerializer from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import process_failed_job +from cvat.apps.engine.view_utils import get_410_response_when_checking_process_status +from cvat.apps.redis_handler.serializers import RqIdSerializer @extend_schema(tags=["consensus"]) @@ -38,31 +35,15 @@ class ConsensusMergesViewSet(viewsets.GenericViewSet): @extend_schema( operation_id="consensus_create_merge", summary="Create a consensus merge", - parameters=[ - OpenApiParameter( - CREATE_MERGE_RQ_ID_PARAMETER, - type=str, - description=textwrap.dedent( - """\ - The consensus merge request id. Can be specified to check operation status. - """ - ), - ) - ], - request=ConsensusMergeCreateSerializer(required=False), + request=ConsensusMergeCreateSerializer, responses={ - "201": None, "202": OpenApiResponse( RqIdSerializer, description=textwrap.dedent( """\ A consensus merge request has been enqueued, the request id is returned. - The request status can be checked at this endpoint by passing the {} - as the query parameter. If the request id is specified, this response - means the consensus merge request is queued or is being processed. - """.format( - CREATE_MERGE_RQ_ID_PARAMETER - ) + The request status can be checked by using common requests API: GET /api/requests/ + """ ), ), "400": OpenApiResponse( @@ -73,72 +54,27 @@ class ConsensusMergesViewSet(viewsets.GenericViewSet): def create(self, request: ExtendedRequest, *args, **kwargs): rq_id = request.query_params.get(self.CREATE_MERGE_RQ_ID_PARAMETER, None) - if rq_id is None: - input_serializer = ConsensusMergeCreateSerializer(data=request.data) - input_serializer.is_valid(raise_exception=True) - - task_id = input_serializer.validated_data.get("task_id", None) - job_id = input_serializer.validated_data.get("job_id", None) - if task_id: - try: - instance = Task.objects.get(pk=task_id) - except Task.DoesNotExist as ex: - raise NotFound(f"Task {task_id} does not exist") from ex - elif job_id: - try: - instance = Job.objects.select_related("segment").get(pk=job_id) - except Job.DoesNotExist as ex: - raise NotFound(f"Jobs {job_id} do not exist") from ex + if rq_id: + return get_410_response_when_checking_process_status("merge") + input_serializer = ConsensusMergeCreateSerializer(data=request.data) + input_serializer.is_valid(raise_exception=True) + + task_id = input_serializer.validated_data.get("task_id", None) + job_id = input_serializer.validated_data.get("job_id", None) + if task_id: + try: + instance = Task.objects.get(pk=task_id) + except Task.DoesNotExist as ex: + raise NotFound(f"Task {task_id} does not exist") from ex + elif job_id: try: - manager = merging.MergingManager() - rq_id = manager.schedule_merge(instance, request=request) - serializer = RqIdSerializer({"rq_id": rq_id}) - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - except merging.MergingNotAvailable as ex: - raise ValidationError(str(ex)) - else: - serializer = RqIdSerializer(data={"rq_id": rq_id}) - serializer.is_valid(raise_exception=True) - rq_id = serializer.validated_data["rq_id"] - - manager = merging.MergingManager() - rq_job = manager.get_job(rq_id) - if ( - not rq_job - or not ConsensusMergePermission.create_scope_check_status( - request, rq_job_owner_id=BaseRQMeta.for_job(rq_job).user.id - ) - .check_access() - .allow - ): - # We should not provide job existence information to unauthorized users - raise NotFound("Unknown request id") - - rq_job_status = rq_job.get_status(refresh=False) - if rq_job_status == RqJobStatus.FAILED: - exc_info = process_failed_job(rq_job) - - exc_name_pattern = f"{merging.MergingNotAvailable.__name__}: " - if (exc_pos := exc_info.find(exc_name_pattern)) != -1: - return Response( - data=exc_info[exc_pos + len(exc_name_pattern) :].strip(), - status=status.HTTP_400_BAD_REQUEST, - ) - - return Response(data=str(exc_info), status=status.HTTP_500_INTERNAL_SERVER_ERROR) - elif rq_job_status in ( - RqJobStatus.QUEUED, - RqJobStatus.STARTED, - RqJobStatus.SCHEDULED, - RqJobStatus.DEFERRED, - ): - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - elif rq_job_status == RqJobStatus.FINISHED: - rq_job.delete() - return Response(status=status.HTTP_201_CREATED) - - raise AssertionError(f"Unexpected rq job '{rq_id}' status '{rq_job_status}'") + instance = Job.objects.select_related("segment").get(pk=job_id) + except Job.DoesNotExist as ex: + raise NotFound(f"Jobs {job_id} do not exist") from ex + + manager = merging.MergingManager(request=request, db_instance=instance) + return manager.enqueue_job() @extend_schema(tags=["consensus"]) diff --git a/cvat/apps/dataset_manager/cron.py b/cvat/apps/dataset_manager/cron.py index e2918de985ba..0bb336af93e4 100644 --- a/cvat/apps/dataset_manager/cron.py +++ b/cvat/apps/dataset_manager/cron.py @@ -11,11 +11,11 @@ from pathlib import Path from typing import ClassVar, Type -from django.conf import settings from django.utils import timezone from cvat.apps.dataset_manager.util import ( CacheFileOrDirPathParseError, + ConstructedFileId, ExportCacheManager, TmpDirManager, get_export_cache_lock, @@ -39,7 +39,10 @@ def clear_export_cache(file_path: Path) -> bool: ttl=EXPORT_CACHE_LOCK_TTL, ): parsed_filename = ExportCacheManager.parse_filename(file_path.name) - cache_ttl = get_export_cache_ttl(parsed_filename.instance_type) + if isinstance(parsed_filename.file_id, ConstructedFileId): + cache_ttl = get_export_cache_ttl(parsed_filename.file_id.instance_type) + else: + cache_ttl = get_export_cache_ttl() # use common default cache TTL if timezone.now().timestamp() <= file_path.stat().st_mtime + cache_ttl.total_seconds(): logger.debug(f"Export cache file {file_path.name!r} was recently accessed") @@ -100,7 +103,7 @@ class ExportCacheDirectoryCleaner(BaseCleaner): task_description: ClassVar[str] = "export cache directory cleanup" def do_cleanup(self) -> None: - export_cache_dir_path = settings.EXPORT_CACHE_ROOT + export_cache_dir_path = ExportCacheManager.ROOT assert os.path.exists(export_cache_dir_path) for child in os.scandir(export_cache_dir_path): diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index 9301e3652eaa..0b305bcf990f 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -20,7 +20,8 @@ from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.rq import ImportRQMeta from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer -from cvat.apps.engine.task import _create_thread as create_task +from cvat.apps.engine.task import create_thread as create_task +from cvat.apps.engine.utils import av_scan_paths from .annotation import AnnotationIR from .bindings import CvatDatasetNotFoundError, CvatImportError, ProjectData, load_dataset_data @@ -206,6 +207,8 @@ def import_dataset_as_project(src_file, project_id, format_name, conv_mask_to_po rq_job_meta.progress = 0. rq_job_meta.save() + av_scan_paths(src_file) + project = ProjectAnnotationAndData(project_id) importer = make_importer(format_name) diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index e45960010191..376f91c8d14f 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -30,7 +30,7 @@ from cvat.apps.engine.log import DatasetLogManager from cvat.apps.engine.model_utils import add_prefetch_fields, bulk_create, get_cached from cvat.apps.engine.plugins import plugin_decorator -from cvat.apps.engine.utils import take_by +from cvat.apps.engine.utils import av_scan_paths, take_by from cvat.apps.events.handlers import handle_annotations_change from cvat.apps.profiler import silk_profile @@ -1128,6 +1128,7 @@ def export_task( @transaction.atomic def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly): + av_scan_paths(src_file) task = TaskAnnotation(task_id, write_only=True) importer = make_importer(format_name) @@ -1140,6 +1141,7 @@ def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly): @transaction.atomic def import_job_annotations(src_file, job_id, format_name, conv_mask_to_poly): + av_scan_paths(src_file) job = JobAnnotation(job_id, prefetch_images=True) importer = make_importer(format_name) diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index 89450a7bd866..ff39e285acf3 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -44,7 +44,12 @@ from cvat.apps.dataset_manager.util import get_export_cache_lock from cvat.apps.dataset_manager.views import export from cvat.apps.engine.models import Task -from cvat.apps.engine.tests.utils import ExportApiTestBase, ForceLogin, get_paginated_collection +from cvat.apps.engine.tests.utils import ( + ExportApiTestBase, + ForceLogin, + ImportApiTestBase, + get_paginated_collection, +) projects_path = osp.join(osp.dirname(__file__), 'assets', 'projects.json') with open(projects_path) as file: @@ -143,7 +148,7 @@ def compare_datasets(expected: Dataset, actual: Dataset): ) -class _DbTestBase(ExportApiTestBase): +class _DbTestBase(ExportApiTestBase, ImportApiTestBase): @classmethod def setUpTestData(cls): cls.create_db_users() @@ -309,12 +314,6 @@ def _create_annotations_in_job(self, task, job_id, name_ann, key_get_values): response = self._put_api_v2_job_id_annotations(job_id, tmp_annotations) self.assertEqual(response.status_code, status.HTTP_200_OK, msg=response.json()) - def _upload_file(self, url, data, user): - response = self._put_request(url, user, data={"annotation_file": data}, format="multipart") - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - response = self._put_request(url, user) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - def _check_downloaded_file(self, file_name): if not osp.exists(file_name): raise FileNotFoundError(f"File '{file_name}' was not downloaded") @@ -322,15 +321,6 @@ def _check_downloaded_file(self, file_name): def _generate_url_remove_tasks_annotations(self, task_id): return f"/api/tasks/{task_id}/annotations" - def _generate_url_upload_tasks_annotations(self, task_id, upload_format_name): - return f"/api/tasks/{task_id}/annotations?format={upload_format_name}" - - def _generate_url_upload_job_annotations(self, job_id, upload_format_name): - return f"/api/jobs/{job_id}/annotations?format={upload_format_name}" - - def _generate_url_upload_project_dataset(self, project_id, format_name): - return f"/api/projects/{project_id}/dataset?format={format_name}" - def _remove_annotations(self, url, user): response = self._delete_request(url, user) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -357,12 +347,9 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): dump_formats = dm.views.get_export_formats() upload_formats = dm.views.get_import_formats() expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED,'file_exists': True, 'annotation_loaded': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True, 'annotation_loaded': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False, 'annotation_loaded': False}, + self.admin: {'name': 'admin', 'file_exists': True, 'annotation_loaded': True}, + self.user: {'name': 'user', 'file_exists': True, 'annotation_loaded': True}, + None: {'name': 'none', 'file_exists': False, 'annotation_loaded': False}, } with TestDir() as test_dir: @@ -438,18 +425,15 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): else: task = self._create_task(tasks["main"], images) task_id = task["id"] - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) + + expected_4xx_status_code = None if user else status.HTTP_401_UNAUTHORIZED with open(file_zip_name, 'rb') as binary_file: - response = self._put_request( - url, - user, - data={"annotation_file": binary_file}, - format="multipart", + self._import_task_annotations( + user, task_id, binary_file, + query_params={"format": upload_format_name}, + expected_4xx_status_code=expected_4xx_status_code ) - self.assertEqual(response.status_code, edata['accept code']) - response = self._put_request(url, user) - self.assertEqual(response.status_code, edata['create code']) def test_api_v2_dump_annotations_with_objects_type_is_track(self): test_name = self._testMethodName @@ -457,12 +441,9 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): dump_formats = dm.views.get_export_formats() upload_formats = dm.views.get_import_formats() expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True, 'annotation_loaded': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True, 'annotation_loaded': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False, 'annotation_loaded': False}, + self.admin: {'name': 'admin', 'file_exists': True, 'annotation_loaded': True}, + self.user: {'name': 'user', 'file_exists': True, 'annotation_loaded': True}, + None: {'name': 'none', 'file_exists': False, 'annotation_loaded': False}, } with TestDir() as test_dir: @@ -535,29 +516,21 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): else: task = self._create_task(tasks["main"], video) task_id = task["id"] - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) with open(file_zip_name, 'rb') as binary_file: - response = self._put_request( - url, - user, - data={"annotation_file": binary_file}, - format="multipart", + self._import_task_annotations( + user, task_id, binary_file, + query_params={"format": upload_format_name}, + expected_4xx_status_code=None if user else status.HTTP_401_UNAUTHORIZED ) - self.assertEqual(response.status_code, edata['accept code']) - response = self._put_request(url, user) - self.assertEqual(response.status_code, edata['create code']) def test_api_v2_dump_tag_annotations(self): dump_format_name = "CVAT for images 1.1" test_cases = ['all', 'first'] expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False}, + self.admin: {'name': 'admin', 'file_exists': True}, + self.user: {'name': 'user', 'file_exists': True}, + None: {'name': 'none', 'file_exists': False}, } export_params = { "format": dump_format_name, @@ -615,18 +588,24 @@ def test_api_v2_dump_and_upload_annotations_with_objects_are_different_images(se url = self._generate_url_remove_tasks_annotations(task_id) self._remove_annotations(url, self.admin) + if upload_type == "task": - url_upload = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") + with open(file_zip_name, 'rb') as binary_file: + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) else: jobs = self._get_jobs(task_id) - url_upload = self._generate_url_upload_job_annotations(jobs[0]["id"], "CVAT 1.1") - - with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url_upload, binary_file, self.admin) + with open(file_zip_name, 'rb') as binary_file: + self._import_job_annotations( + self.admin, jobs[0]["id"], binary_file, + query_params={"format": "CVAT 1.1"}, + ) - response = self._get_request(f"/api/tasks/{task_id}/annotations", self.admin) - self.assertEqual(len(response.data["shapes"]), 2) - self.assertEqual(len(response.data["tracks"]), 0) + response = self._get_request(f"/api/tasks/{task_id}/annotations", self.admin) + self.assertEqual(len(response.data["shapes"]), 2) + self.assertEqual(len(response.data["tracks"]), 0) def test_api_v2_dump_and_upload_annotations_with_objects_are_different_video(self): test_name = self._testMethodName @@ -657,18 +636,23 @@ def test_api_v2_dump_and_upload_annotations_with_objects_are_different_video(sel url = self._generate_url_remove_tasks_annotations(task_id) self._remove_annotations(url, self.admin) if upload_type == "task": - url_upload = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") + with open(file_zip_name, 'rb') as binary_file: + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) else: jobs = self._get_jobs(task_id) - url_upload = self._generate_url_upload_job_annotations(jobs[0]["id"], "CVAT 1.1") - - with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url_upload, binary_file, self.admin) - self.assertEqual(osp.exists(file_zip_name), True) + with open(file_zip_name, 'rb') as binary_file: + self._import_job_annotations( + self.admin, jobs[0]["id"], binary_file, + query_params={"format": "CVAT 1.1"}, + ) - response = self._get_request(f"/api/tasks/{task_id}/annotations", self.admin) - self.assertEqual(len(response.data["shapes"]), 0) - self.assertEqual(len(response.data["tracks"]), 2) + self.assertEqual(osp.exists(file_zip_name), True) + response = self._get_request(f"/api/tasks/{task_id}/annotations", self.admin) + self.assertEqual(len(response.data["shapes"]), 0) + self.assertEqual(len(response.data["tracks"]), 2) def test_api_v2_dump_and_upload_with_objects_type_is_track_and_outside_property(self): test_name = self._testMethodName @@ -687,8 +671,10 @@ def test_api_v2_dump_and_upload_with_objects_type_is_track_and_outside_property( self.assertEqual(osp.exists(file_zip_name), True) with open(file_zip_name, 'rb') as binary_file: - url = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) def test_api_v2_dump_and_upload_with_objects_type_is_track_and_keyframe_property(self): test_name = self._testMethodName @@ -709,8 +695,10 @@ def test_api_v2_dump_and_upload_with_objects_type_is_track_and_keyframe_property self.assertEqual(osp.exists(file_zip_name), True) with open(file_zip_name, 'rb') as binary_file: - url = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) def test_api_v2_dump_upload_annotations_from_several_jobs(self): test_name = self._testMethodName @@ -734,9 +722,11 @@ def test_api_v2_dump_upload_annotations_from_several_jobs(self): # remove annotations url = self._generate_url_remove_tasks_annotations(task_id) self._remove_annotations(url, self.admin) - url = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) def test_api_v2_dump_annotations_from_several_jobs(self): test_name = self._testMethodName @@ -768,21 +758,20 @@ def test_api_v2_dump_annotations_from_several_jobs(self): # remove annotations url = self._generate_url_remove_tasks_annotations(task_id) self._remove_annotations(url, self.admin) - url = self._generate_url_upload_tasks_annotations(task_id, "CVAT 1.1") with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": "CVAT 1.1"}, + ) def test_api_v2_export_dataset(self): test_name = self._testMethodName dump_formats = dm.views.get_export_formats() expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False}, + self.admin: {'name': 'admin', 'file_exists': True}, + self.user: {'name': 'user', 'file_exists': True}, + None: {'name': 'none', 'file_exists': False}, } with TestDir() as test_dir: @@ -858,19 +847,11 @@ def test_api_v2_dump_empty_frames(self): task = self._create_task(tasks["no attributes"], images) task_id = task["id"] - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) - with open(file_zip_name, 'rb') as binary_file: - response = self._put_request( - url, - self.admin, - data={"annotation_file": binary_file}, - format="multipart", + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": upload_format_name}, ) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - response = self._put_request(url, self.admin) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertIsNone(response.data) def test_api_v2_rewriting_annotations(self): test_name = self._testMethodName @@ -926,10 +907,12 @@ def test_api_v2_rewriting_annotations(self): dump_format_name = "CVAT 1.1" elif dump_format_name == "Ultralytics YOLO Detection Track 1.0": dump_format_name = "Ultralytics YOLO Detection 1.0" - url = self._generate_url_upload_tasks_annotations(task_id, dump_format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": dump_format_name}, + ) task_ann = TaskAnnotation(task_id) task_ann.init_from_db() @@ -967,9 +950,11 @@ def test_api_v2_tasks_annotations_dump_and_upload_many_jobs_with_datumaro(self): self._remove_annotations(url, self.admin) # upload annotations - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": upload_format_name}, + ) # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) @@ -1043,9 +1028,12 @@ def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self): upload_format_name = 'Ultralytics YOLO Detection 1.0' else: upload_format_name = dump_format_name - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) + with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": upload_format_name}, + ) # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) @@ -1108,9 +1096,11 @@ def test_api_v2_check_widerface_with_all_attributes(self): self._remove_annotations(url, self.admin) # upload annotations - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": upload_format_name}, + ) # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) @@ -1144,9 +1134,11 @@ def test_api_v2_check_mot_with_shapes_only(self): self._remove_annotations(url, self.admin) # upload annotations - url = self._generate_url_upload_tasks_annotations(task_id, format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": format_name}, + ) # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) @@ -1181,9 +1173,11 @@ def test_api_v2_check_attribute_import_in_tracks(self): self._remove_annotations(url, self.admin) # upload annotations - url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": upload_format_name}, + ) # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) @@ -1227,9 +1221,11 @@ def test_api_v2_check_skeleton_tracks_with_missing_shapes(self): self._remove_annotations(url, self.admin) # upload annotations - url = self._generate_url_upload_tasks_annotations(task_id, format_name) with open(file_zip_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, + query_params={"format": format_name}, + ) class ExportBehaviorTest(_DbTestBase): @@ -1605,10 +1601,10 @@ def patched_osp_exists(path: str): with ( patch( - "cvat.apps.engine.background.get_export_cache_lock", + "cvat.apps.redis_handler.background.get_export_cache_lock", new=self.patched_get_export_cache_lock, ), - patch("cvat.apps.engine.background.osp.exists") as mock_osp_exists, + patch("cvat.apps.redis_handler.background.osp.exists") as mock_osp_exists, TemporaryDirectory() as temp_dir, ): mock_osp_exists.side_effect = patched_osp_exists @@ -2053,12 +2049,9 @@ def test_api_v2_export_import_dataset(self): upload_formats = dm.views.get_import_formats() expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False}, + self.admin: {'name': 'admin', 'file_exists': True}, + self.user: {'name': 'user', 'file_exists': True}, + None: {'name': 'none', 'file_exists': False}, } with TestDir() as test_dir: @@ -2112,35 +2105,33 @@ def test_api_v2_export_import_dataset(self): ]: # TO-DO: fix bug for this formats continue + if upload_format_name == "Ultralytics YOLO Classification 1.0": + # FUTURE-FIXME: + # cvat.apps.dataset_manager.bindings.CvatImportError: + # Could not match item id: \'image_1\' with any task frame + continue for user, edata in list(expected.items()): project = copy.deepcopy(projects['main']) if upload_format_name in tasks: project['labels'] = tasks[upload_format_name]['labels'] project = self._create_project(project) file_zip_name = osp.join(test_dir, f"{test_name}_{edata['name']}_{upload_format_name}.zip") - url = self._generate_url_upload_project_dataset(project['id'], upload_format_name) if osp.exists(file_zip_name): with open(file_zip_name, 'rb') as binary_file: - response = self._post_request( - url, - user, - data={"dataset_file": binary_file}, - format="multipart", + self._import_project_dataset( + self.admin, project['id'], binary_file, + query_params={"format": upload_format_name}, ) - self.assertEqual(response.status_code, edata['accept code']) def test_api_v2_export_annotations(self): test_name = self._testMethodName dump_formats = dm.views.get_export_formats() expected = { - self.admin: {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - self.user: {'name': 'user', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED, - 'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True}, - None: {'name': 'none', 'code': status.HTTP_401_UNAUTHORIZED, 'create code': status.HTTP_401_UNAUTHORIZED, - 'accept code': status.HTTP_401_UNAUTHORIZED, 'file_exists': False}, + self.admin: {'name': 'admin', 'file_exists': True}, + self.user: {'name': 'user', 'file_exists': True}, + None: {'name': 'none', 'file_exists': False}, } with TestDir() as test_dir: @@ -2211,16 +2202,12 @@ def test_api_v2_dump_upload_annotations_with_objects_type_is_track(self): # Upload annotations with objects type is track project = self._create_project(project_dict) - url = self._generate_url_upload_project_dataset(project["id"], upload_format_name) with open(file_zip_name, 'rb') as binary_file: - response = self._post_request( - url, - user, - data={"dataset_file": binary_file}, - format="multipart", + self._import_project_dataset( + user, project["id"], binary_file, + query_params={"format": upload_format_name}, ) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) # equals annotations new_task = self._get_tasks(project["id"])[0] diff --git a/cvat/apps/dataset_manager/util.py b/cvat/apps/dataset_manager/util.py index 6ab9ed71ce08..49d55ba885b6 100644 --- a/cvat/apps/dataset_manager/util.py +++ b/cvat/apps/dataset_manager/util.py @@ -15,7 +15,8 @@ from datetime import timedelta from enum import Enum from threading import Lock -from typing import Any +from typing import Any, Protocol +from uuid import UUID import attrs import django_rq @@ -108,16 +109,23 @@ def get_export_cache_lock( class OperationType(str, Enum): EXPORT = "export" + def __str__(self): + return self.value + class ExportFileType(str, Enum): ANNOTATIONS = "annotations" BACKUP = "backup" DATASET = "dataset" + EVENTS = "events" @classmethod def values(cls) -> list[str]: return list(map(lambda x: x.value, cls)) + def __str__(self): + return self.value + class InstanceType(str, Enum): PROJECT = "project" TASK = "task" @@ -127,23 +135,32 @@ class InstanceType(str, Enum): def values(cls) -> list[str]: return list(map(lambda x: x.value, cls)) -@attrs.frozen -class _ParsedExportFilename: - file_type: ExportFileType - file_ext: str + def __str__(self): + return self.value + +class FileId(Protocol): + value: str + +@attrs.frozen(kw_only=True) +class SimpleFileId(FileId): + value: str = attrs.field() + +@attrs.frozen(kw_only=True) +class ConstructedFileId(FileId): instance_type: InstanceType = attrs.field(converter=InstanceType) - instance_id: int + instance_id: int = attrs.field(converter=int) instance_timestamp: float = attrs.field(converter=float) - -@attrs.frozen -class ParsedDatasetFilename(_ParsedExportFilename): - format_repr: str + @property + def value(self): + return "-".join(map(str, [self.instance_type, self.instance_id, self.instance_timestamp])) -@attrs.frozen -class ParsedBackupFilename(_ParsedExportFilename): - pass +@attrs.frozen(kw_only=True) +class ParsedExportFilename: + file_type: ExportFileType = attrs.field(converter=ExportFileType) + file_ext: str + file_id: FileId class TmpDirManager: @@ -191,9 +208,11 @@ def get_tmp_directory_for_export( class ExportCacheManager: + ROOT = settings.EXPORT_CACHE_ROOT + SPLITTER = "-" INSTANCE_PREFIX = "instance" - FILE_NAME_TEMPLATE = SPLITTER.join([ + FILE_NAME_TEMPLATE_WITH_INSTANCE = SPLITTER.join([ "{instance_type}", "{instance_id}", "{file_type}", INSTANCE_PREFIX + # store the instance timestamp in the file name to reliably get this information # ctime / mtime do not return file creation time on linux @@ -201,6 +220,15 @@ class ExportCacheManager: "{instance_timestamp}{optional_suffix}.{file_ext}" ]) + FILE_NAME_TEMPLATE_WITHOUT_INSTANCE = SPLITTER.join([ + "{file_type}", "{file_id}.{file_ext}" + ]) + + @classmethod + def file_types_with_general_template(cls): + return (ExportFileType.EVENTS,) + + @classmethod def make_dataset_file_path( cls, @@ -219,7 +247,7 @@ def make_dataset_file_path( file_type = ExportFileType.DATASET if save_images else ExportFileType.ANNOTATIONS normalized_format_name = make_file_name(to_snake_case(format_name)) - filename = cls.FILE_NAME_TEMPLATE.format_map( + filename = cls.FILE_NAME_TEMPLATE_WITH_INSTANCE.format_map( { "instance_type": instance_type, "instance_id": instance_id, @@ -230,7 +258,7 @@ def make_dataset_file_path( } ) - return osp.join(settings.EXPORT_CACHE_ROOT, filename) + return osp.join(cls.ROOT, filename) @classmethod def make_backup_file_path( @@ -241,7 +269,7 @@ def make_backup_file_path( instance_timestamp: float, ) -> str: instance_type = InstanceType(instance_type.lower()) - filename = cls.FILE_NAME_TEMPLATE.format_map( + filename = cls.FILE_NAME_TEMPLATE_WITH_INSTANCE.format_map( { "instance_type": instance_type, "instance_id": instance_id, @@ -251,53 +279,72 @@ def make_backup_file_path( "file_ext": "zip", } ) - return osp.join(settings.EXPORT_CACHE_ROOT, filename) + return osp.join(cls.ROOT, filename) + + @classmethod + def make_file_path( + cls, + *, + file_type: str, + file_id: UUID, + file_ext: str, + ) -> str: + filename = cls.FILE_NAME_TEMPLATE_WITHOUT_INSTANCE.format_map({ + "file_type": ExportFileType(file_type), # convert here to be sure only expected types are used + "file_id": file_id, + "file_ext": file_ext, + }) + return osp.join(cls.ROOT, filename) @classmethod def parse_filename( cls, filename: str, - ) -> ParsedDatasetFilename | ParsedBackupFilename: + ) -> ParsedExportFilename: basename, file_ext = osp.splitext(filename) file_ext = file_ext.strip(".").lower() - basename_match = re.fullmatch( - ( - rf"^(?P{'|'.join(InstanceType.values())})" - rf"{cls.SPLITTER}(?P\d+)" - rf"{cls.SPLITTER}(?P{'|'.join(ExportFileType.values())})" - rf"{cls.SPLITTER}(?P.+)$" - ), - basename, - ) - if not basename_match: - raise CacheFileOrDirPathParseError(f"Couldn't parse file name: {basename!r}") - - fragments = basename_match.groupdict() - fragments["instance_id"] = int(fragments["instance_id"]) - - unparsed = fragments.pop("unparsed")[len(cls.INSTANCE_PREFIX):] - specific_params = {} + try: + for exp_file_type in cls.file_types_with_general_template(): + if basename.startswith(exp_file_type): + file_type, file_id = basename.split(cls.SPLITTER, maxsplit=1) + + return ParsedExportFilename( + file_type=file_type, + file_id=SimpleFileId(value=file_id), + file_ext=file_ext + ) + + basename_match = re.fullmatch( + ( + rf"^(?P{'|'.join(InstanceType.values())})" + rf"{cls.SPLITTER}(?P\d+)" + rf"{cls.SPLITTER}(?P{'|'.join(ExportFileType.values())})" + rf"{cls.SPLITTER}(?P.+)$" + ), + basename, + ) - if fragments["file_type"] in (ExportFileType.DATASET, ExportFileType.ANNOTATIONS): - try: - instance_timestamp, format_repr = unparsed.split(cls.SPLITTER, maxsplit=1) - except ValueError: - raise CacheFileOrDirPathParseError(f"Couldn't parse file name: {basename!r}") + if not basename_match: + assert False # will be handled - specific_params["format_repr"] = format_repr - ParsedFileNameClass = ParsedDatasetFilename - else: + fragments = basename_match.groupdict() + unparsed = fragments.pop("unparsed")[len(cls.INSTANCE_PREFIX):] instance_timestamp = unparsed - ParsedFileNameClass = ParsedBackupFilename - try: - parsed_file_name = ParsedFileNameClass( + if fragments["file_type"] in (ExportFileType.DATASET, ExportFileType.ANNOTATIONS): + # The "format" is a part of file id, but there is actually + # no need to use it after filename parsing, so just drop it. + instance_timestamp, _ = unparsed.split(cls.SPLITTER, maxsplit=1) + + parsed_file_name = ParsedExportFilename( + file_type=fragments.pop("file_type"), + file_id=ConstructedFileId( + instance_timestamp=instance_timestamp, + **fragments, + ), file_ext=file_ext, - instance_timestamp=instance_timestamp, - **fragments, - **specific_params, ) - except ValueError as ex: + except Exception as ex: raise CacheFileOrDirPathParseError(f"Couldn't parse file name: {basename!r}") from ex return parsed_file_name diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py index 67ff2143e4e2..0204ce7de425 100644 --- a/cvat/apps/dataset_manager/views.py +++ b/cvat/apps/dataset_manager/views.py @@ -59,7 +59,10 @@ def log_exception(logger: logging.Logger | None = None, exc_info: bool = True): EXPORT_LOCKED_RETRY_INTERVAL = timedelta(seconds=settings.EXPORT_LOCKED_RETRY_INTERVAL) -def get_export_cache_ttl(db_instance: str | Project | Task | Job) -> timedelta: +def get_export_cache_ttl(db_instance: str | Project | Task | Job | None = None) -> timedelta: + if not db_instance: + return DEFAULT_CACHE_TTL + if isinstance(db_instance, (Project, Task, Job)): db_instance = db_instance.__class__.__name__ diff --git a/cvat/apps/engine/background.py b/cvat/apps/engine/background.py index 43daad8f7427..8d9e49057089 100644 --- a/cvat/apps/engine/background.py +++ b/cvat/apps/engine/background.py @@ -2,46 +2,63 @@ # # SPDX-License-Identifier: MIT -import os.path as osp -from abc import ABC, abstractmethod +from abc import abstractmethod +from dataclasses import asdict as dataclass_asdict from dataclasses import dataclass -from datetime import datetime -from typing import Any, ClassVar, Optional, Union -from urllib.parse import quote +from pathlib import Path +from tempfile import NamedTemporaryFile +from uuid import uuid4 -import django_rq +from attrs.converters import to_bool from django.conf import settings -from django.http.response import HttpResponseBadRequest -from django_rq.queues import DjangoRQ, DjangoScheduler -from rest_framework import serializers, status -from rest_framework.response import Response +from django.db.models import Model +from rest_framework import serializers +from rest_framework.exceptions import MethodNotAllowed from rest_framework.reverse import reverse from rq.job import Job as RQJob -from rq.job import JobStatus as RQJobStatus import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.formats.registry import EXPORT_FORMATS -from cvat.apps.dataset_manager.util import get_export_cache_lock -from cvat.apps.dataset_manager.views import get_export_cache_ttl, get_export_callback -from cvat.apps.engine import models -from cvat.apps.engine.backup import ProjectExporter, TaskExporter, create_backup -from cvat.apps.engine.cloud_provider import export_resource_to_cloud_storage -from cvat.apps.engine.location import StorageType, get_location_configuration +from cvat.apps.dataset_manager.util import TmpDirManager +from cvat.apps.dataset_manager.views import get_export_callback +from cvat.apps.engine.backup import ( + ProjectExporter, + TaskExporter, + create_backup, + import_project, + import_task, +) +from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage +from cvat.apps.engine.location import LocationConfig, StorageType, get_location_configuration from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.models import Location, RequestAction, RequestSubresource, RequestTarget, Task +from cvat.apps.engine.models import ( + Data, + Job, + Location, + Project, + RequestAction, + RequestSubresource, + RequestTarget, + Task, +) from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export -from cvat.apps.engine.rq import ExportRQMeta, RQId, define_dependent_job -from cvat.apps.engine.serializers import RqIdSerializer +from cvat.apps.engine.rq import ExportRequestId, ImportRequestId +from cvat.apps.engine.serializers import ( + AnnotationFileSerializer, + DatasetFileSerializer, + ProjectFileSerializer, + TaskFileSerializer, +) +from cvat.apps.engine.task import create_thread as create_task from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import ( build_annotations_file_name, build_backup_file_name, - get_rq_lock_by_user, - get_rq_lock_for_job, + import_resource_with_clean_up_after, is_dataset_export, - sendfile, ) -from cvat.apps.events.handlers import handle_dataset_export +from cvat.apps.events.handlers import handle_dataset_export, handle_dataset_import +from cvat.apps.redis_handler.background import AbstractExporter, AbstractRequestManager slogger = ServerLogManager(__name__) @@ -51,179 +68,6 @@ LOCK_ACQUIRE_TIMEOUT = LOCK_TTL - 5 -class ResourceExportManager(ABC): - QUEUE_NAME = settings.CVAT_QUEUES.EXPORT_DATA.value - SUPPORTED_RESOURCES: ClassVar[set[RequestSubresource]] - SUPPORTED_SUBRESOURCES: ClassVar[set[RequestSubresource]] - - def __init__( - self, - db_instance: Union[models.Project, models.Task, models.Job], - request: ExtendedRequest, - ) -> None: - """ - Args: - db_instance (Union[models.Project, models.Task, models.Job]): Model instance - request (ExtendedRequest): Incoming HTTP request - """ - self.db_instance = db_instance - self.request = request - self.resource = db_instance.__class__.__name__.lower() - if self.resource not in self.SUPPORTED_RESOURCES: - raise ValueError("Unexpected type of db_instance: {}".format(type(db_instance))) - - ### Initialization logic ### - - @abstractmethod - def initialize_export_args(self) -> None: ... - - @abstractmethod - def validate_export_args(self) -> Response | None: ... - - @abstractmethod - def build_rq_id(self) -> str: ... - - def handle_existing_rq_job( - self, rq_job: Optional[RQJob], queue: DjangoRQ - ) -> Optional[Response]: - if not rq_job: - return None - - rq_job_status = rq_job.get_status(refresh=False) - - if rq_job_status in {RQJobStatus.STARTED, RQJobStatus.QUEUED}: - return Response( - data="Export request is being processed", - status=status.HTTP_409_CONFLICT, - ) - - if rq_job_status == RQJobStatus.DEFERRED: - rq_job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) - - if rq_job_status == RQJobStatus.SCHEDULED: - scheduler: DjangoScheduler = django_rq.get_scheduler(queue.name, queue=queue) - # remove the job id from the set with scheduled keys - scheduler.cancel(rq_job) - rq_job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) - - rq_job.delete() - return None - - @abstractmethod - def get_download_api_endpoint_view_name(self) -> str: ... - - def make_result_url(self, *, rq_id: str) -> str: - view_name = self.get_download_api_endpoint_view_name() - result_url = reverse(view_name, args=[self.db_instance.pk], request=self.request) - - return result_url + f"?rq_id={quote(rq_id)}" - - def get_updated_date_timestamp(self) -> str: - # use only updated_date for the related resource, don't check children objects - # because every child update should touch the updated_date of the parent resource - return datetime.strftime(self.db_instance.updated_date, "%Y_%m_%d_%H_%M_%S") - - @abstractmethod - def get_result_filename(self) -> str: ... - - @abstractmethod - def send_events(self) -> None: ... - - @abstractmethod - def setup_background_job(self, queue: DjangoRQ, rq_id: str) -> None: ... - - def export(self) -> Response: - self.initialize_export_args() - - if invalid_response := self.validate_export_args(): - return invalid_response - - queue: DjangoRQ = django_rq.get_queue(self.QUEUE_NAME) - rq_id = self.build_rq_id() - - # ensure that there is no race condition when processing parallel requests - with get_rq_lock_for_job(queue, rq_id): - rq_job = queue.fetch_job(rq_id) - if response := self.handle_existing_rq_job(rq_job, queue): - return response - self.setup_background_job(queue, rq_id) - - self.send_events() - - serializer = RqIdSerializer({"rq_id": rq_id}) - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - - ### Logic related to prepared file downloading ### - - def validate_rq_id(self, rq_id: str) -> None: - parsed_rq_id = RQId.parse(rq_id) - - if ( - parsed_rq_id.action != RequestAction.EXPORT - or parsed_rq_id.target != RequestTarget(self.resource) - or parsed_rq_id.identifier != self.db_instance.pk - or parsed_rq_id.subresource not in self.SUPPORTED_SUBRESOURCES - ): - raise ValueError("The provided request id does not match exported target or resource") - - def download_file(self) -> Response: - queue: DjangoRQ = django_rq.get_queue(self.QUEUE_NAME) - rq_id = self.request.query_params.get("rq_id") - - if not rq_id: - return HttpResponseBadRequest("Missing request id in the query parameters") - - try: - self.validate_rq_id(rq_id) - except ValueError: - return HttpResponseBadRequest("Invalid export request id") - - # ensure that there is no race condition when processing parallel requests - with get_rq_lock_for_job(queue, rq_id): - rq_job = queue.fetch_job(rq_id) - - if not rq_job: - return HttpResponseBadRequest("Unknown export request id") - - # define status once to avoid refreshing it on each check - # FUTURE-TODO: get_status will raise InvalidJobOperation exception instead of returning None in one of the next releases - rq_job_status = rq_job.get_status(refresh=False) - - if rq_job_status != RQJobStatus.FINISHED: - return HttpResponseBadRequest("The export process is not finished") - - rq_job_meta = ExportRQMeta.for_job(rq_job) - file_path = rq_job.return_value() - - if not file_path: - return ( - Response( - "A result for exporting job was not found for finished RQ job", - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if rq_job_meta.result_url # user tries to download a final file locally while the export is made to cloud storage - else HttpResponseBadRequest( - "The export process has no result file to be downloaded locally" - ) - ) - - with get_export_cache_lock( - file_path, ttl=LOCK_TTL, acquire_timeout=LOCK_ACQUIRE_TIMEOUT - ): - if not osp.exists(file_path): - return Response( - "The exported file has expired, please retry exporting", - status=status.HTTP_404_NOT_FOUND, - ) - - return sendfile( - self.request, - file_path, - attachment=True, - attachment_filename=rq_job_meta.result_filename, - ) - - def cancel_and_delete(rq_job: RQJob) -> None: # In the case the server is configured with ONE_RUNNING_JOB_IN_QUEUE_PER_USER # we have to enqueue dependent jobs after canceling one. @@ -231,87 +75,73 @@ def cancel_and_delete(rq_job: RQJob) -> None: rq_job.delete() -class DatasetExportManager(ResourceExportManager): - SUPPORTED_RESOURCES = {RequestTarget.PROJECT, RequestTarget.TASK, RequestTarget.JOB} - SUPPORTED_SUBRESOURCES = {RequestSubresource.DATASET, RequestSubresource.ANNOTATIONS} +class DatasetExporter(AbstractExporter): + SUPPORTED_TARGETS = {RequestTarget.PROJECT, RequestTarget.TASK, RequestTarget.JOB} @dataclass - class ExportArgs: + class ExportArgs(AbstractExporter.ExportArgs): format: str - filename: str save_images: bool - location_config: dict[str, Any] - - @property - def location(self) -> Location: - return self.location_config["location"] - def initialize_export_args(self) -> None: + def init_request_args(self) -> None: + super().init_request_args() save_images = is_dataset_export(self.request) - self.export_callback = get_export_callback(self.db_instance, save_images=save_images) - format_name = self.request.query_params.get("format", "") - filename = self.request.query_params.get("filename", "") - try: - location_config = get_location_configuration( - db_instance=self.db_instance, - query_params=self.request.query_params, - field_name=StorageType.TARGET, - ) - except ValueError as ex: - raise serializers.ValidationError(str(ex)) from ex - - location = location_config["location"] - - if location not in Location.list(): - raise serializers.ValidationError( - f"Unexpected location {location} specified for the request" - ) - - self.export_args = self.ExportArgs( + self.export_args: DatasetExporter.ExportArgs = self.ExportArgs( + **self.export_args.to_dict(), format=format_name, - filename=filename, save_images=save_images, - location_config=location_config, ) - def validate_export_args(self): + def validate_request(self): + super().validate_request() + format_desc = {f.DISPLAY_NAME: f for f in dm.views.get_export_formats()}.get( self.export_args.format ) if format_desc is None: raise serializers.ValidationError("Unknown format specified for the request") elif not format_desc.ENABLED: - return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) - - def build_rq_id(self): - return RQId( - RequestAction.EXPORT, - RequestTarget(self.resource), - self.db_instance.pk, + raise MethodNotAllowed(self.request.method, detail="Format is disabled") + + def build_request_id(self): + return ExportRequestId( + action=RequestAction.EXPORT, + target=RequestTarget(self.target), + target_id=self.db_instance.pk, + user_id=self.user_id, subresource=( RequestSubresource.DATASET if self.export_args.save_images else RequestSubresource.ANNOTATIONS ), format=self.export_args.format, - user_id=self.request.user.id, ).render() - def send_events(self): - handle_dataset_export( - self.db_instance, - format_name=self.export_args.format, - cloud_storage_id=self.export_args.location_config.get("storage_id"), - save_images=self.export_args.save_images, + def validate_request_id(self, request_id, /) -> None: + # FUTURE-TODO: optimize, request_id is parsed 2 times (first one when checking permissions) + parsed_request_id: ExportRequestId = ExportRequestId.parse_and_validate_queue( + request_id, expected_queue=self.QUEUE_NAME, try_legacy_format=True ) - def setup_background_job( - self, - queue: DjangoRQ, - rq_id: str, - ) -> None: + if ( + parsed_request_id.action != RequestAction.EXPORT + or parsed_request_id.target != RequestTarget(self.target) + or parsed_request_id.target_id != self.db_instance.pk + or parsed_request_id.subresource + not in {RequestSubresource.DATASET, RequestSubresource.ANNOTATIONS} + ): + raise ValueError( + "The provided request id does not match exported target or subresource" + ) + + def _init_callback_with_params(self): + self.callback = get_export_callback( + self.db_instance, save_images=self.export_args.save_images + ) + self.callback_args = (self.db_instance.pk, self.export_args.format) + try: if self.request.scheme: server_address = self.request.scheme + "://" @@ -319,68 +149,27 @@ def setup_background_job( except Exception: server_address = None - cache_ttl = get_export_cache_ttl(self.db_instance) + self.callback_kwargs = { + "server_url": server_address, + } - user_id = self.request.user.id - - func = self.export_callback - func_args = (self.db_instance.id, self.export_args.format) - result_url = None - - if self.export_args.location == Location.CLOUD_STORAGE: - try: - storage_id = self.export_args.location_config["storage_id"] - except KeyError: - raise serializers.ValidationError( - "Cloud storage location was selected as the destination," - " but cloud storage id was not specified" - ) - - db_storage = get_cloud_storage_for_import_or_export( - storage_id=storage_id, - request=self.request, - is_default=self.export_args.location_config["is_default"], - ) - - func = export_resource_to_cloud_storage - func_args = ( - db_storage, - self.export_callback, - ) + func_args - else: - db_storage = None - result_url = self.make_result_url(rq_id=rq_id) - - with get_rq_lock_by_user(queue, user_id): - result_filename = self.get_result_filename() - meta = ExportRQMeta.build_for( - request=self.request, - db_obj=self.db_instance, - result_url=result_url, - result_filename=result_filename, - ) - queue.enqueue_call( - func=func, - args=func_args, - kwargs={ - "server_url": server_address, - }, - job_id=rq_id, - meta=meta, - depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), - result_ttl=cache_ttl.total_seconds(), - failure_ttl=cache_ttl.total_seconds(), - ) + def finalize_request(self): + handle_dataset_export( + self.db_instance, + format_name=self.export_args.format, + cloud_storage_id=self.export_args.location_config.cloud_storage_id, + save_images=self.export_args.save_images, + ) def get_result_filename(self) -> str: filename = self.export_args.filename if not filename: - instance_timestamp = self.get_updated_date_timestamp() + timestamp = self.get_file_timestamp() filename = build_annotations_file_name( - class_name=self.resource, - identifier=self.db_instance.id, - timestamp=instance_timestamp, + class_name=self.target, + identifier=self.db_instance.pk, + timestamp=timestamp, format_name=self.export_args.format, is_annotation_file=not self.export_args.save_images, extension=(EXPORT_FORMATS[self.export_args.format].EXT).lower(), @@ -388,130 +177,368 @@ def get_result_filename(self) -> str: return filename - def get_download_api_endpoint_view_name(self) -> str: - return f"{self.resource}-download-dataset" + def get_result_endpoint_url(self) -> str: + return reverse( + f"{self.target}-download-dataset", args=[self.db_instance.pk], request=self.request + ) -class BackupExportManager(ResourceExportManager): - SUPPORTED_RESOURCES = {RequestTarget.PROJECT, RequestTarget.TASK} - SUPPORTED_SUBRESOURCES = {RequestSubresource.BACKUP} +class BackupExporter(AbstractExporter): + SUPPORTED_TARGETS = {RequestTarget.PROJECT, RequestTarget.TASK} - @dataclass - class ExportArgs: - filename: str - location_config: dict[str, Any] - - @property - def location(self) -> Location: - return self.location_config["location"] - - def initialize_export_args(self) -> None: - self.export_callback = create_backup - filename = self.request.query_params.get("filename", "") - - location_config = get_location_configuration( - db_instance=self.db_instance, - query_params=self.request.query_params, - field_name=StorageType.TARGET, + def validate_request(self): + super().validate_request() + + # do not add this check when a project is backed up, as empty tasks are skipped + if isinstance(self.db_instance, Task) and not self.db_instance.data: + raise serializers.ValidationError("Backup of a task without data is not allowed") + + def validate_request_id(self, request_id, /) -> None: + # FUTURE-TODO: optimize, request_id is parsed 2 times (first one when checking permissions) + parsed_request_id: ExportRequestId = ExportRequestId.parse_and_validate_queue( + request_id, expected_queue=self.QUEUE_NAME, try_legacy_format=True ) - self.export_args = self.ExportArgs(filename, location_config) - def validate_export_args(self): - return + if ( + parsed_request_id.action != RequestAction.EXPORT + or parsed_request_id.target != RequestTarget(self.target) + or parsed_request_id.target_id != self.db_instance.pk + or parsed_request_id.subresource != RequestSubresource.BACKUP + ): + raise ValueError( + "The provided request id does not match exported target or subresource" + ) - def get_result_filename(self) -> str: + def _init_callback_with_params(self): + self.callback = create_backup + + if isinstance(self.db_instance, Task): + logger = slogger.task[self.db_instance.pk] + Exporter = TaskExporter + else: + logger = slogger.project[self.db_instance.pk] + Exporter = ProjectExporter + + self.callback_args = ( + self.db_instance.pk, + Exporter, + logger, + self.job_result_ttl, + ) + + def get_result_filename(self): filename = self.export_args.filename if not filename: - instance_timestamp = self.get_updated_date_timestamp() + instance_timestamp = self.get_file_timestamp() filename = build_backup_file_name( - class_name=self.resource, + class_name=self.target, identifier=self.db_instance.name, timestamp=instance_timestamp, ) return filename - def build_rq_id(self): - return RQId( - RequestAction.EXPORT, - RequestTarget(self.resource), - self.db_instance.pk, + def build_request_id(self): + return ExportRequestId( + action=RequestAction.EXPORT, + target=RequestTarget(self.target), + target_id=self.db_instance.pk, + user_id=self.user_id, subresource=RequestSubresource.BACKUP, - user_id=self.request.user.id, ).render() - # FUTURE-TODO: move into ResourceExportManager - def setup_background_job( + def get_result_endpoint_url(self) -> str: + return reverse( + f"{self.target}-download-backup", args=[self.db_instance.pk], request=self.request + ) + + def finalize_request(self): + # FUTURE-TODO: send events to event store + pass + + +class ResourceImporter(AbstractRequestManager): + QUEUE_NAME = settings.CVAT_QUEUES.IMPORT_DATA.value + + @dataclass + class ImportArgs: + location_config: LocationConfig + file_path: str | None + + def to_dict(self): + return dataclass_asdict(self) + + import_args: ImportArgs | None + + def __init__(self, *, request: ExtendedRequest, db_instance: Model | None, tmp_dir: Path): + super().__init__(request=request, db_instance=db_instance) + self.tmp_dir = tmp_dir + + @property + def job_result_ttl(self): + return int(settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds()) + + @property + def job_failed_ttl(self): + return int(settings.IMPORT_CACHE_FAILED_TTL.total_seconds()) + + def init_request_args(self): + file_path: str | None = None + + try: + location_config = get_location_configuration( + db_instance=self.db_instance, + query_params=self.request.query_params, + field_name=StorageType.SOURCE, + ) + except ValueError as ex: + raise serializers.ValidationError(str(ex)) from ex + + if filename := self.request.query_params.get("filename"): + file_path = ( + str(self.tmp_dir / filename) + if location_config.location != Location.CLOUD_STORAGE + else filename + ) + + self.import_args = ResourceImporter.ImportArgs( + location_config=location_config, + file_path=file_path, + ) + + def validate_request(self): + super().validate_request() + + if ( + self.import_args.location_config.location == Location.CLOUD_STORAGE + and not self.import_args.file_path + ): + raise serializers.ValidationError("The filename was not specified") + + def _handle_cloud_storage_file_upload(self): + storage_id = self.import_args.location_config.cloud_storage_id + db_storage = get_cloud_storage_for_import_or_export( + storage_id=storage_id, + request=self.request, + is_default=self.import_args.location_config.is_default, + ) + + key = self.import_args.file_path + with NamedTemporaryFile(prefix="cvat_", dir=TmpDirManager.TMP_ROOT, delete=False) as tf: + self.import_args.file_path = tf.name + return db_storage, key + + @abstractmethod + def _get_payload_file(self): ... + + def _handle_non_tus_file_upload(self): + payload_file = self._get_payload_file() + + with NamedTemporaryFile(prefix="cvat_", dir=TmpDirManager.TMP_ROOT, delete=False) as tf: + self.import_args.file_path = tf.name + for chunk in payload_file.chunks(): + tf.write(chunk) + + @abstractmethod + def _init_callback_with_params(self): ... + + def init_callback_with_params(self): + # Note: self.import_args is changed here + if self.import_args.location_config.location == Location.CLOUD_STORAGE: + db_storage, key = self._handle_cloud_storage_file_upload() + elif not self.import_args.file_path: + self._handle_non_tus_file_upload() + + self._init_callback_with_params() + + # redefine here callback and callback args in order to: + # - (optional) download file from cloud storage + # - remove uploaded file at the end + if self.import_args.location_config.location == Location.CLOUD_STORAGE: + self.callback_args = ( + self.callback_args[0], + db_storage, + key, + self.callback, + *self.callback_args[1:], + ) + self.callback = import_resource_from_cloud_storage + + self.callback_args = (self.callback, *self.callback_args) + self.callback = import_resource_with_clean_up_after + + +class DatasetImporter(ResourceImporter): + SUPPORTED_TARGETS = {RequestTarget.PROJECT, RequestTarget.TASK, RequestTarget.JOB} + + @dataclass + class ImportArgs(ResourceImporter.ImportArgs): + format: str + conv_mask_to_poly: bool + + def __init__( self, - queue: DjangoRQ, - rq_id: str, - ) -> None: - cache_ttl = get_export_cache_ttl(self.db_instance) - user_id = self.request.user.id + *, + request: ExtendedRequest, + db_instance: Project | Task | Job, + ): + super().__init__( + request=request, db_instance=db_instance, tmp_dir=Path(db_instance.get_tmp_dirname()) + ) - if isinstance(self.db_instance, Task): - logger = slogger.task[self.db_instance.pk] - Exporter = TaskExporter + def init_request_args(self) -> None: + super().init_request_args() + format_name = self.request.query_params.get("format", "") + conv_mask_to_poly = to_bool(self.request.query_params.get("conv_mask_to_poly", True)) + + self.import_args: DatasetImporter.ImportArgs = self.ImportArgs( + **self.import_args.to_dict(), + format=format_name, + conv_mask_to_poly=conv_mask_to_poly, + ) + + def _get_payload_file(self): + # Common serializer is not used to not break API + if isinstance(self.db_instance, Project): + serializer_class = DatasetFileSerializer + file_field = "dataset_file" else: - logger = slogger.project[self.db_instance.pk] - Exporter = ProjectExporter + serializer_class = AnnotationFileSerializer + file_field = "annotation_file" + + file_serializer = serializer_class(data=self.request.data) + file_serializer.is_valid(raise_exception=True) + return file_serializer.validated_data[file_field] + + def _init_callback_with_params(self): + if isinstance(self.db_instance, Project): + self.callback = dm.project.import_dataset_as_project + elif isinstance(self.db_instance, Task): + self.callback = dm.task.import_task_annotations + else: + assert isinstance(self.db_instance, Job) + self.callback = dm.task.import_job_annotations - func = self.export_callback - func_args = ( - self.db_instance.id, - Exporter, - logger, - cache_ttl, + self.callback_args = ( + self.import_args.file_path, + self.db_instance.pk, + self.import_args.format, + self.import_args.conv_mask_to_poly, ) - result_url = None - - if self.export_args.location == Location.CLOUD_STORAGE: - try: - storage_id = self.export_args.location_config["storage_id"] - except KeyError: - raise serializers.ValidationError( - "Cloud storage location was selected as the destination," - " but cloud storage id was not specified" - ) - - db_storage = get_cloud_storage_for_import_or_export( - storage_id=storage_id, - request=self.request, - is_default=self.export_args.location_config["is_default"], - ) - func = export_resource_to_cloud_storage - func_args = ( - db_storage, - self.export_callback, - ) + func_args + def validate_request(self): + super().validate_request() + + format_desc = {f.DISPLAY_NAME: f for f in dm.views.get_import_formats()}.get( + self.import_args.format + ) + if format_desc is None: + raise serializers.ValidationError(f"Unknown input format {self.import_args.format!r}") + elif not format_desc.ENABLED: + raise MethodNotAllowed(self.request.method, detail="Format is disabled") + + def build_request_id(self): + return ImportRequestId( + action=RequestAction.IMPORT, + target=RequestTarget(self.target), + target_id=self.db_instance.pk, + subresource=( + RequestSubresource.DATASET + if isinstance(self.db_instance, Project) + else RequestSubresource.ANNOTATIONS + ), + ).render() + + def finalize_request(self): + handle_dataset_import( + self.db_instance, + format_name=self.import_args.format, + cloud_storage_id=self.import_args.location_config.cloud_storage_id, + ) + + +class BackupImporter(ResourceImporter): + SUPPORTED_TARGETS = {RequestTarget.PROJECT, RequestTarget.TASK} + + @dataclass + class ImportArgs(ResourceImporter.ImportArgs): + org_id: int | None + + def __init__( + self, + *, + request: ExtendedRequest, + target: RequestTarget, + ): + super().__init__(request=request, db_instance=None, tmp_dir=Path(TmpDirManager.TMP_ROOT)) + assert target in self.SUPPORTED_TARGETS, f"Unsupported target: {target}" + self.target = target + + def init_request_args(self) -> None: + super().init_request_args() + + self.import_args: BackupImporter.ImportArgs = self.ImportArgs( + **self.import_args.to_dict(), + org_id=getattr(self.request.iam_context["organization"], "id", None), + ) + + def build_request_id(self): + return ImportRequestId( + action=RequestAction.IMPORT, + target=self.target, + id=uuid4(), + subresource=RequestSubresource.BACKUP, + ).render() + + def _get_payload_file(self): + # Common serializer is not used to not break API + if self.target == RequestTarget.PROJECT: + serializer_class = ProjectFileSerializer + file_field = "project_file" else: - result_url = self.make_result_url(rq_id=rq_id) - - with get_rq_lock_by_user(queue, user_id): - result_filename = self.get_result_filename() - meta = ExportRQMeta.build_for( - request=self.request, - db_obj=self.db_instance, - result_url=result_url, - result_filename=result_filename, - ) + serializer_class = TaskFileSerializer + file_field = "task_file" - queue.enqueue_call( - func=func, - args=func_args, - job_id=rq_id, - meta=meta, - depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), - result_ttl=cache_ttl.total_seconds(), - failure_ttl=cache_ttl.total_seconds(), - ) + file_serializer = serializer_class(data=self.request.data) + file_serializer.is_valid(raise_exception=True) + return file_serializer.validated_data[file_field] - def get_download_api_endpoint_view_name(self) -> str: - return f"{self.resource}-download-backup" + def _init_callback_with_params(self): + self.callback = import_project if self.target == RequestTarget.PROJECT else import_task + self.callback_args = (self.import_args.file_path, self.user_id, self.import_args.org_id) - def send_events(self): - # FUTURE-TODO: send events to event store + def finalize_request(self): + # FUTURE-TODO: send logs to event store pass + + +class TaskCreator(AbstractRequestManager): + QUEUE_NAME = settings.CVAT_QUEUES.IMPORT_DATA.value + SUPPORTED_TARGETS = {RequestTarget.TASK} + + def __init__( + self, + *, + request: ExtendedRequest, + db_instance: Task, + db_data: Data, + ): + super().__init__(request=request, db_instance=db_instance) + self.db_data = db_data + + @property + def job_failure_ttl(self): + return int(settings.IMPORT_CACHE_FAILED_TTL.total_seconds()) + + def build_request_id(self): + return ImportRequestId( + action=RequestAction.CREATE, + target=RequestTarget.TASK, + target_id=self.db_instance.pk, + ).render() + + def init_callback_with_params(self): + self.callback = create_task + self.callback_args = (self.db_instance.pk, self.db_data) diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 3218b9e5264a..1746c18ae5f7 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -9,31 +9,25 @@ import re import shutil import tempfile -import uuid from abc import ABCMeta, abstractmethod from collections.abc import Collection, Iterable from copy import deepcopy from datetime import timedelta from enum import Enum from logging import Logger -from tempfile import NamedTemporaryFile from typing import Any, ClassVar, Optional, Type, Union from zipfile import ZipFile -import django_rq import rapidjson from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.utils import timezone -from rest_framework import serializers, status from rest_framework.exceptions import ValidationError from rest_framework.parsers import JSONParser from rest_framework.renderers import JSONRenderer -from rest_framework.response import Response import cvat.apps.dataset_manager as dm -from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.util import ( ExportCacheManager, TmpDirManager, @@ -49,23 +43,9 @@ retry_current_rq_job, ) from cvat.apps.engine import models -from cvat.apps.engine.cloud_provider import ( - db_storage_to_storage_instance, - import_resource_from_cloud_storage, -) -from cvat.apps.engine.location import StorageType, get_location_configuration +from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.models import ( - DataChoice, - Location, - RequestAction, - RequestSubresource, - RequestTarget, - StorageChoice, - StorageMethodChoice, -) -from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export -from cvat.apps.engine.rq import ImportRQMeta, RQId, define_dependent_job +from cvat.apps.engine.models import DataChoice, StorageChoice, StorageMethodChoice from cvat.apps.engine.serializers import ( AnnotationGuideWriteSerializer, AssetWriteSerializer, @@ -74,23 +54,15 @@ JobWriteSerializer, LabeledDataSerializer, LabelSerializer, - ProjectFileSerializer, ProjectReadSerializer, - RqIdSerializer, SegmentSerializer, SimpleJobSerializer, - TaskFileSerializer, TaskReadSerializer, ValidationParamsSerializer, ) -from cvat.apps.engine.task import JobFileMapping, _create_thread -from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import ( - av_scan_paths, - get_rq_lock_by_user, - import_resource_with_clean_up_after, - process_failed_job, -) +from cvat.apps.engine.task import JobFileMapping +from cvat.apps.engine.task import create_thread as create_task +from cvat.apps.engine.utils import av_scan_paths slogger = ServerLogManager(__name__) @@ -914,7 +886,7 @@ def _import_task(self): if validation_params: data['validation_params'] = validation_params - _create_thread(self._db_task.pk, data.copy(), is_backup_restore=True) + create_task(self._db_task.pk, data.copy(), is_backup_restore=True) self._db_task.refresh_from_db() db_data.refresh_from_db() @@ -971,7 +943,7 @@ def import_task(self): return self._db_task @transaction.atomic -def _import_task(filename, user, org_id): +def import_task(filename, user, org_id): av_scan_paths(filename) task_importer = TaskImporter(filename, user, org_id) db_task = task_importer.import_task() @@ -1107,7 +1079,7 @@ def import_project(self): return self._db_project @transaction.atomic -def _import_project(filename, user, org_id): +def import_project(filename, user, org_id): av_scan_paths(filename) project_importer = ProjectImporter(filename, user, org_id) db_project = project_importer.import_project() @@ -1175,158 +1147,5 @@ def create_backup( log_exception(logger) raise - -def _import( - importer: TaskImporter | ProjectImporter, - request: ExtendedRequest, - queue: django_rq.queues.DjangoRQ, - rq_id: str, - Serializer: type[TaskFileSerializer] | type[ProjectFileSerializer], - file_field_name: str, - location_conf: dict, - filename: str | None = None, -): - rq_job = queue.fetch_job(rq_id) - - if not rq_job: - org_id = getattr(request.iam_context['organization'], 'id', None) - location = location_conf.get('location') - - if location == Location.LOCAL: - if not filename: - serializer = Serializer(data=request.data) - serializer.is_valid(raise_exception=True) - payload_file = serializer.validated_data[file_field_name] - with NamedTemporaryFile( - prefix='cvat_', - dir=settings.TMP_FILES_ROOT, - delete=False) as tf: - filename = tf.name - for chunk in payload_file.chunks(): - tf.write(chunk) - else: - file_name = request.query_params.get('filename') - assert file_name, "The filename wasn't specified" - try: - storage_id = location_conf['storage_id'] - except KeyError: - raise serializers.ValidationError( - 'Cloud storage location was selected as the source,' - ' but cloud storage id was not specified') - - db_storage = get_cloud_storage_for_import_or_export( - storage_id=storage_id, request=request, - is_default=location_conf['is_default']) - - key = filename - with NamedTemporaryFile(prefix='cvat_', dir=settings.TMP_FILES_ROOT, delete=False) as tf: - filename = tf.name - - func = import_resource_with_clean_up_after - func_args = (importer, filename, request.user.id, org_id) - - if location == Location.CLOUD_STORAGE: - func_args = (db_storage, key, func) + func_args - func = import_resource_from_cloud_storage - - user_id = request.user.id - - with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build_for( - request=request, - db_obj=None, - tmp_file=filename, - ) - rq_job = queue.enqueue_call( - func=func, - args=func_args, - job_id=rq_id, - meta=meta, - depends_on=define_dependent_job(queue, user_id), - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) - else: - rq_job_meta = ImportRQMeta.for_job(rq_job) - if rq_job_meta.user.id != request.user.id: - return Response(status=status.HTTP_403_FORBIDDEN) - - if rq_job.is_finished: - project_id = rq_job.return_value() - rq_job.delete() - return Response({'id': project_id}, status=status.HTTP_201_CREATED) - elif rq_job.is_failed: - exc_info = process_failed_job(rq_job) - # RQ adds a prefix with exception class name - import_error_prefix = '{}.{}'.format( - CvatImportError.__module__, CvatImportError.__name__) - if exc_info.startswith(import_error_prefix): - exc_info = exc_info.replace(import_error_prefix + ': ', '') - return Response(data=exc_info, - status=status.HTTP_400_BAD_REQUEST) - else: - return Response(data=exc_info, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - serializer = RqIdSerializer(data={'rq_id': rq_id}) - serializer.is_valid(raise_exception=True) - - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - def get_backup_dirname(): return TmpDirManager.TMP_ROOT - -def import_project(request: ExtendedRequest, queue_name: str, filename: str | None = None): - if 'rq_id' in request.data: - rq_id = request.data['rq_id'] - else: - rq_id = RQId( - RequestAction.IMPORT, RequestTarget.PROJECT, uuid.uuid4(), - subresource=RequestSubresource.BACKUP, - ).render() - Serializer = ProjectFileSerializer - file_field_name = 'project_file' - - location_conf = get_location_configuration( - query_params=request.query_params, - field_name=StorageType.SOURCE, - ) - - queue = django_rq.get_queue(queue_name) - - return _import( - importer=_import_project, - request=request, - queue=queue, - rq_id=rq_id, - Serializer=Serializer, - file_field_name=file_field_name, - location_conf=location_conf, - filename=filename - ) - -def import_task(request: ExtendedRequest, queue_name: str, filename: str | None = None): - rq_id = request.data.get('rq_id', RQId( - RequestAction.IMPORT, RequestTarget.TASK, uuid.uuid4(), - subresource=RequestSubresource.BACKUP, - ).render()) - Serializer = TaskFileSerializer - file_field_name = 'task_file' - - location_conf = get_location_configuration( - query_params=request.query_params, - field_name=StorageType.SOURCE - ) - - queue = django_rq.get_queue(queue_name) - - return _import( - importer=_import_task, - request=request, - queue=queue, - rq_id=rq_id, - Serializer=Serializer, - file_field_name=file_field_name, - location_conf=location_conf, - filename=filename - ) diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index 56fb43e573da..e55ff3b8786c 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -1005,18 +1005,17 @@ def db_storage_to_storage_instance(db_storage): T = TypeVar('T', Callable[[str, int, int], int], Callable[[str, int, str, bool], None]) def import_resource_from_cloud_storage( + filename: str, db_storage: Any, key: str, - cleanup_func: Callable[[T, str,], Any], import_func: T, - filename: str, *args, **kwargs, ) -> Any: storage = db_storage_to_storage_instance(db_storage) storage.download_file(key, filename) - return cleanup_func(import_func, filename, *args, **kwargs) + return import_func(filename, *args, **kwargs) def export_resource_to_cloud_storage( db_storage: Any, diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py index 35d903369186..ffc7ae39ea5f 100644 --- a/cvat/apps/engine/location.py +++ b/cvat/apps/engine/location.py @@ -5,7 +5,9 @@ from enum import Enum from typing import Any, Optional, Union -from cvat.apps.engine.models import Job, Location, Project, Task +import attrs + +from cvat.apps.engine.models import Job, Location, Project, Storage, Task class StorageType(str, Enum): @@ -16,47 +18,51 @@ def __str__(self): return self.value +@attrs.frozen(kw_only=True) +class LocationConfig: + is_default: bool = attrs.field(validator=attrs.validators.instance_of(bool)) + location: Location = attrs.field(converter=Location) + cloud_storage_id: int | None = attrs.field( + converter=lambda x: x if x is None else int(x), default=None + ) + + def __attrs_post_init__(self): + if self.location == Location.CLOUD_STORAGE and not self.cloud_storage_id: + raise ValueError( + "Trying to use undefined cloud storage (cloud_storage_id was not provided)" + ) + + def get_location_configuration( query_params: dict[str, Any], field_name: str, *, db_instance: Optional[Union[Project, Task, Job]] = None, -) -> dict[str, Any]: +) -> LocationConfig: location = query_params.get("location") - # handle resource import + # handle backup imports if not location and not db_instance: location = Location.LOCAL use_default_settings = location is None - location_conf = {"is_default": use_default_settings} - if use_default_settings: - storage = ( + storage: Storage = ( getattr(db_instance, field_name) if not isinstance(db_instance, Job) else getattr(db_instance.segment.task, field_name) ) - if storage is None: - location_conf["location"] = Location.LOCAL - else: - location_conf["location"] = storage.location - if cloud_storage_id := storage.cloud_storage_id: - location_conf["storage_id"] = cloud_storage_id - else: - if location not in Location.list(): - raise ValueError(f"The specified location {location} is not supported") - - cloud_storage_id = query_params.get("cloud_storage_id") - - if location == Location.CLOUD_STORAGE and not cloud_storage_id: - raise ValueError( - "Cloud storage was selected as location but cloud_storage_id was not specified" + return ( + LocationConfig(is_default=True, location=Location.LOCAL) + if storage is None + else LocationConfig( + is_default=True, + location=storage.location, + cloud_storage_id=storage.cloud_storage_id, ) + ) - location_conf["location"] = location - if cloud_storage_id: - location_conf["storage_id"] = int(cloud_storage_id) - - return location_conf + return LocationConfig( + is_default=False, location=location, cloud_storage_id=query_params.get("cloud_storage_id") + ) diff --git a/cvat/apps/engine/mixins.py b/cvat/apps/engine/mixins.py index 6d5c3addb68e..ab1cb10dee7d 100644 --- a/cvat/apps/engine/mixins.py +++ b/cvat/apps/engine/mixins.py @@ -12,12 +12,10 @@ from pathlib import Path from tempfile import NamedTemporaryFile from textwrap import dedent -from typing import Callable from unittest import mock from urllib.parse import urljoin import django_rq -from attr.converters import to_bool from django.conf import settings from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema @@ -25,22 +23,14 @@ from rest_framework.decorators import action from rest_framework.response import Response -from cvat.apps.engine.background import BackupExportManager, DatasetExportManager +from cvat.apps.engine.background import BackupExporter, DatasetExporter from cvat.apps.engine.handlers import clear_import_cache -from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.models import ( - Job, - Location, - Project, - RequestAction, - RequestSubresource, - RequestTarget, - Task, -) -from cvat.apps.engine.rq import RQId -from cvat.apps.engine.serializers import DataSerializer, RqIdSerializer +from cvat.apps.engine.models import Location, RequestAction, RequestTarget +from cvat.apps.engine.rq import RequestId +from cvat.apps.engine.serializers import DataSerializer from cvat.apps.engine.types import ExtendedRequest +from cvat.apps.redis_handler.serializers import RqIdSerializer slogger = ServerLogManager(__name__) @@ -279,9 +269,11 @@ def init_tus_upload(self, request: ExtendedRequest): if file_exists: # check whether the rq_job is in progress or has been finished/failed object_class_name = self._object.__class__.__name__.lower() - template = RQId( - RequestAction.IMPORT, RequestTarget(object_class_name), self._object.pk, - subresource=RequestSubresource(import_type) + template = RequestId( + action=RequestAction.IMPORT, + target=RequestTarget(object_class_name), + target_id=self._object.pk, + subresource=import_type, ).render() queue = django_rq.get_queue(settings.CVAT_QUEUES.IMPORT_DATA.value) finished_job_ids = queue.finished_job_registry.get_job_ids() @@ -453,8 +445,8 @@ class DatasetMixin: def initiate_dataset_export(self, request: ExtendedRequest, pk: int): self._object = self.get_object() # force call of check_object_permissions() - export_manager = DatasetExportManager(self._object, request) - return export_manager.export() + export_manager = DatasetExporter(request=request, db_instance=self._object) + return export_manager.enqueue_job() @extend_schema(summary='Download a prepared dataset file', parameters=[ @@ -469,60 +461,12 @@ def initiate_dataset_export(self, request: ExtendedRequest, pk: int): @action(methods=['GET'], detail=True, url_path='dataset/download') def download_dataset(self, request: ExtendedRequest, pk: int): obj = self.get_object() # force to call check_object_permissions - export_manager = DatasetExportManager(obj, request) - return export_manager.download_file() - - # FUTURE-TODO: migrate to new API - def import_annotations( - self, - request: ExtendedRequest, - db_obj: Project | Task | Job, - import_func: Callable[..., None], - rq_func: Callable[..., None], - rq_id_factory: Callable[..., RQId], - ): - is_tus_request = request.headers.get('Upload-Length', None) is not None or \ - request.method == 'OPTIONS' - if is_tus_request: - return self.init_tus_upload(request) - - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - location_conf = get_location_configuration( - db_instance=db_obj, - query_params=request.query_params, - field_name=StorageType.SOURCE, - ) - if location_conf['location'] == Location.CLOUD_STORAGE: - format_name = request.query_params.get('format') - file_name = request.query_params.get('filename') - - return import_func( - request=request, - rq_id_factory=rq_id_factory, - rq_func=rq_func, - db_obj=self._object, - format_name=format_name, - location_conf=location_conf, - filename=file_name, - conv_mask_to_poly=conv_mask_to_poly, - ) - - return self.upload_data(request) + downloader = DatasetExporter(request=request, db_instance=obj).get_downloader() + return downloader.download_file() class BackupMixin: - # FUTURE-TODO: migrate to new API - def import_backup_v1(self, request: ExtendedRequest, import_func: Callable) -> Response: - location = request.query_params.get("location", Location.LOCAL) - if location == Location.CLOUD_STORAGE: - file_name = request.query_params.get("filename", "") - return import_func( - request, - queue_name=settings.CVAT_QUEUES.IMPORT_DATA.value, - filename=file_name, - ) - return self.upload_data(request) @extend_schema(summary='Initiate process to backup resource', description=dedent("""\ @@ -549,8 +493,8 @@ def import_backup_v1(self, request: ExtendedRequest, import_func: Callable) -> R @action(detail=True, methods=['POST'], serializer_class=None, url_path='backup/export') def initiate_backup_export(self, request: ExtendedRequest, pk: int): db_object = self.get_object() # force to call check_object_permissions - export_manager = BackupExportManager(db_object, request) - return export_manager.export() + export_manager = BackupExporter(request=request, db_instance=db_object) + return export_manager.enqueue_job() @extend_schema(summary='Download a prepared backup file', @@ -566,5 +510,6 @@ def initiate_backup_export(self, request: ExtendedRequest, pk: int): @action(methods=['GET'], detail=True, url_path='backup/download') def download_backup(self, request: ExtendedRequest, pk: int): obj = self.get_object() # force to call check_object_permissions - export_manager = BackupExportManager(obj, request) - return export_manager.download_file() + + downloader = BackupExporter(request=request, db_instance=obj).get_downloader() + return downloader.download_file() diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index c915c0a47de2..b7c043b1b458 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -1260,6 +1260,10 @@ def __str__(self): def list(cls): return [i.value for i in cls] + @classmethod + def _missing_(cls, value): + raise ValueError(f"The specified location {value!r} is not supported") + class CloudStorage(TimestampedModel): # restrictions: # AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots @@ -1346,12 +1350,6 @@ def organization_id(self): def get_asset_dir(self): return os.path.join(settings.ASSETS_ROOT, str(self.uuid)) -class RequestStatus(TextChoices): - QUEUED = "queued" - STARTED = "started" - FAILED = "failed" - FINISHED = "finished" - class RequestAction(TextChoices): AUTOANNOTATE = "autoannotate" CREATE = "create" diff --git a/cvat/apps/engine/permissions.py b/cvat/apps/engine/permissions.py index 596381b60915..a39b6a8cab7c 100644 --- a/cvat/apps/engine/permissions.py +++ b/cvat/apps/engine/permissions.py @@ -12,9 +12,8 @@ from django.conf import settings from django.shortcuts import get_object_or_404 from rest_framework.exceptions import PermissionDenied, ValidationError -from rq.job import Job as RQJob -from cvat.apps.engine.rq import RQId, is_rq_job_owner +from cvat.apps.engine.rq import ExportRequestId from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import is_dataset_export from cvat.apps.iam.permissions import ( @@ -48,16 +47,19 @@ def _get_key(d: dict[str, Any], key_path: Union[str, Sequence[str]]) -> Optional return d class DownloadExportedExtension: - rq_job_id: RQId | None + rq_job_id: ExportRequestId | None class Scopes(StrEnum): DOWNLOAD_EXPORTED_FILE = 'download:exported_file' @staticmethod def extend_params_with_rq_job_details(*, request: ExtendedRequest, params: dict[str, Any]) -> None: + # prevent importing from partially initialized module + from cvat.apps.redis_handler.background import AbstractExporter + if rq_id := request.query_params.get("rq_id"): try: - params["rq_job_id"] = RQId.parse(rq_id) + params["rq_job_id"] = ExportRequestId.parse_and_validate_queue(rq_id, expected_queue=AbstractExporter.QUEUE_NAME, try_legacy_format=True) return except Exception: raise ValidationError("Unexpected request id format") @@ -1251,41 +1253,6 @@ def get_scopes(request: ExtendedRequest, view: ViewSet, obj: AnnotationGuide | N }[view.action]] -class RequestPermission(OpenPolicyAgentPermission): - class Scopes(StrEnum): - LIST = 'list' - VIEW = 'view' - CANCEL = 'cancel' - - @classmethod - def create(cls, request: ExtendedRequest, view: ViewSet, obj: RQJob | None, iam_context: dict) -> list[OpenPolicyAgentPermission]: - permissions = [] - if view.basename == 'request': - for scope in cls.get_scopes(request, view, obj): - if scope != cls.Scopes.LIST: - user_id = request.user.id - if not is_rq_job_owner(obj, user_id): - raise PermissionDenied('You don\'t have permission to perform this action') - - return permissions - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/requests/allow' - - @staticmethod - def get_scopes(request: ExtendedRequest, view: ViewSet, obj: RQJob | None) -> list[Scopes]: - Scopes = __class__.Scopes - return [{ - ('list', 'GET'): Scopes.LIST, - ('retrieve', 'GET'): Scopes.VIEW, - ('cancel', 'POST'): Scopes.CANCEL, - }[(view.action, request.method)]] - - - def get_resource(self): - return None - def get_cloud_storage_for_import_or_export( storage_id: int, *, request: ExtendedRequest, is_default: bool = False ) -> CloudStorage: diff --git a/cvat/apps/engine/rq.py b/cvat/apps/engine/rq.py index a7d490d0b9b0..d070a0e51355 100644 --- a/cvat/apps/engine/rq.py +++ b/cvat/apps/engine/rq.py @@ -5,8 +5,8 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union -from uuid import UUID +from types import NoneType +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Protocol import attrs from django.conf import settings @@ -18,8 +18,7 @@ from rq.registry import BaseRegistry as RQBaseRegistry from cvat.apps.engine.types import ExtendedRequest - -from .models import RequestAction, RequestSubresource, RequestTarget +from cvat.apps.redis_handler.rq import RequestId, RequestIdWithOptionalSubresource if TYPE_CHECKING: from django.contrib.auth.models import User @@ -39,6 +38,7 @@ class RequestField: FORMATTED_EXCEPTION = "formatted_exception" EXCEPTION_TYPE = "exc_type" EXCEPTION_ARGS = "exc_args" + # common fields REQUEST = "request" USER = "user" @@ -49,12 +49,14 @@ class RequestField: JOB_ID = "job_id" STATUS = "status" PROGRESS = "progress" - # import fields - TMP_FILE = "tmp_file" + + # import-specific fields TASK_PROGRESS = "task_progress" - # export fields + + # export specific fields RESULT_URL = "result_url" RESULT_FILENAME = "result_filename" + # lambda fields LAMBDA = "lambda" FUNCTION_ID = "function_id" @@ -307,11 +309,6 @@ def build_for( class ImportRQMeta(BaseRQMeta): - # immutable && optional fields - tmp_file: str | None = ImmutableRQMetaAttribute( - RQJobMetaField.TMP_FILE, optional=True - ) # used only when importing annotations|datasets|backups - # mutable fields task_progress: float | None = MutableRQMetaAttribute( RQJobMetaField.TASK_PROGRESS, validator=lambda x: isinstance(x, float), optional=True @@ -323,18 +320,6 @@ def _get_resettable_fields() -> list[str]: return base_fields + [RQJobMetaField.TASK_PROGRESS] - @classmethod - def build_for( - cls, - *, - request: ExtendedRequest, - db_obj: Model | None, - tmp_file: str | None = None, - ): - base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) - - return {**base_meta, RQJobMetaField.TMP_FILE: tmp_file} - def is_rq_job_owner(rq_job: RQJob, user_id: int) -> bool: if user := BaseRQMeta.for_job(rq_job).user: @@ -343,118 +328,47 @@ def is_rq_job_owner(rq_job: RQJob, user_id: int) -> bool: return False -@attrs.frozen() -class RQId: - action: RequestAction = attrs.field(validator=attrs.validators.instance_of(RequestAction)) - target: RequestTarget = attrs.field(validator=attrs.validators.instance_of(RequestTarget)) - identifier: Union[int, UUID] = attrs.field(validator=attrs.validators.instance_of((int, UUID))) - subresource: Optional[RequestSubresource] = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(RequestSubresource)), - kw_only=True, - default=None, - ) - user_id: Optional[int] = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(int)), - kw_only=True, - default=None, +@attrs.frozen(kw_only=True, slots=False) +class RequestIdWithOptionalFormat(RequestId): + format: str | None = attrs.field( + validator=attrs.validators.instance_of((str, NoneType)), default=None ) - format: Optional[str] = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(str)), - kw_only=True, - default=None, + + +@attrs.frozen(kw_only=True, slots=False) +class ExportRequestId( + RequestIdWithOptionalSubresource, # subresource is optional because export queue works also with events + RequestIdWithOptionalFormat, +): + ACTION_DEFAULT_VALUE: ClassVar[str] = "export" + ACTION_ALLOWED_VALUES: ClassVar[tuple[str]] = (ACTION_DEFAULT_VALUE,) + + SUBRESOURCE_ALLOWED_VALUES: ClassVar[tuple[str]] = ("backup", "dataset", "annotations") + QUEUE_SELECTORS: ClassVar[tuple[str]] = ACTION_ALLOWED_VALUES + + # will be deleted after several releases + LEGACY_FORMAT_PATTERNS: ClassVar[tuple[str]] = ( + r"export:(?P(task|project))-(?P\d+)-(?Pbackup)-by-(?P\d+)", + r"export:(?P(project|task|job))-(?P\d+)-(?P(annotations|dataset))" + + r"-in-(?P[\w@]+)-format-by-(?P\d+)", ) - _OPTIONAL_FIELD_REQUIREMENTS = { - RequestAction.AUTOANNOTATE: {"subresource": False, "format": False, "user_id": False}, - RequestAction.CREATE: {"subresource": False, "format": False, "user_id": False}, - RequestAction.EXPORT: {"subresource": True, "user_id": True}, - RequestAction.IMPORT: {"subresource": True, "format": False, "user_id": False}, - } - - def __attrs_post_init__(self) -> None: - for field, req in self._OPTIONAL_FIELD_REQUIREMENTS[self.action].items(): - if req: - if getattr(self, field) is None: - raise ValueError(f"{field} is required for the {self.action} action") - else: - if getattr(self, field) is not None: - raise ValueError(f"{field} is not allowed for the {self.action} action") - - # RQ ID templates: - # autoannotate:task- - # import:-- - # create:task- - # export:---in--format-by- - # export:--backup-by- - - def render( - self, - ) -> str: - common_prefix = f"{self.action}:{self.target}-{self.identifier}" - - if RequestAction.IMPORT == self.action: - return f"{common_prefix}-{self.subresource}" - elif RequestAction.EXPORT == self.action: - if self.format is None: - return f"{common_prefix}-{self.subresource}-by-{self.user_id}" - - format_to_be_used_in_urls = self.format.replace(" ", "_").replace(".", "@") - return f"{common_prefix}-{self.subresource}-in-{format_to_be_used_in_urls}-format-by-{self.user_id}" - elif self.action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: - return common_prefix - else: - assert False, f"Unsupported action {self.action!r} was found" - @staticmethod - def parse(rq_id: str) -> RQId: - identifier: Optional[Union[UUID, int]] = None - subresource: Optional[RequestSubresource] = None - user_id: Optional[int] = None - anno_format: Optional[str] = None - - try: - action_and_resource, unparsed = rq_id.split("-", maxsplit=1) - action_str, target_str = action_and_resource.split(":") - action = RequestAction(action_str) - target = RequestTarget(target_str) - - if action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: - identifier = unparsed - elif RequestAction.IMPORT == action: - identifier, subresource_str = unparsed.rsplit("-", maxsplit=1) - subresource = RequestSubresource(subresource_str) - else: # action == export - identifier, subresource_str, unparsed = unparsed.split("-", maxsplit=2) - subresource = RequestSubresource(subresource_str) - - if RequestSubresource.BACKUP == subresource: - _, user_id = unparsed.split("-") - else: - unparsed, _, user_id = unparsed.rsplit("-", maxsplit=2) - # remove prefix(in-), suffix(-format) and restore original format name - # by replacing special symbols: "_" -> " ", "@" -> "." - anno_format = unparsed[3:-7].replace("_", " ").replace("@", ".") - - if identifier is not None: - if identifier.isdigit(): - identifier = int(identifier) - else: - identifier = UUID(identifier) - - if user_id is not None: - user_id = int(user_id) - - return RQId( - action=action, - target=target, - identifier=identifier, - subresource=subresource, - user_id=user_id, - format=anno_format, - ) - - except Exception as ex: - raise ValueError(f"The {rq_id!r} RQ ID cannot be parsed: {str(ex)}") from ex +@attrs.frozen(kw_only=True, slots=False) +class ImportRequestId( + RequestIdWithOptionalSubresource, # subresource is optional because import queue works also with task creation jobs + RequestIdWithOptionalFormat, +): + ACTION_ALLOWED_VALUES: ClassVar[tuple[str]] = ("create", "import") + SUBRESOURCE_ALLOWED_VALUES: ClassVar[tuple[str]] = ("backup", "dataset", "annotations") + QUEUE_SELECTORS: ClassVar[tuple[str]] = ACTION_ALLOWED_VALUES + + # will be deleted after several releases + LEGACY_FORMAT_PATTERNS = ( + r"(?Pcreate):(?Ptask)-(?P\d+)", + r"(?Pimport):(?P(task|project|job))-(?P\d+)-(?P(annotations|dataset))", + r"(?Pimport):(?P(task|project))-(?P[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})-(?Pbackup)", + ) def define_dependent_job( diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 3c9960a384be..b66dff9edae1 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -15,14 +15,11 @@ from collections.abc import Iterable, Sequence from contextlib import closing from copy import copy -from datetime import timedelta -from decimal import Decimal from inspect import isclass from tempfile import NamedTemporaryFile from typing import Any, Optional, Union import django_rq -import rq.defaults as rq_defaults from django.conf import settings from django.contrib.auth.models import Group, User from django.db import transaction @@ -31,8 +28,6 @@ from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer from numpy import random from rest_framework import exceptions, serializers -from rq.job import Job as RQJob -from rq.job import JobStatus as RQJobStatus from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine import field_validation, models @@ -41,7 +36,6 @@ from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.permissions import TaskPermission -from cvat.apps.engine.rq import BaseRQMeta, ExportRQMeta, ImportRQMeta, RequestAction, RQId from cvat.apps.engine.task_validation import HoneypotFrameSelector from cvat.apps.engine.utils import ( CvatChunkTimestampMismatchError, @@ -49,12 +43,10 @@ format_list, get_list_view_name, grouped, - parse_exception_message, parse_specific_attributes, reverse, take_by, ) -from cvat.apps.lambda_manager.rq import LambdaRQMeta from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) @@ -1814,12 +1806,9 @@ class RqStatusSerializer(serializers.Serializer): def __init__(self, instance=None, data=..., **kwargs): warnings.warn("RqStatusSerializer is deprecated, " - "use cvat.apps.engine.serializers.RequestSerializer instead", DeprecationWarning) + "use cvat.apps.redis_handler.serializers.RequestSerializer instead", DeprecationWarning) super().__init__(instance, data, **kwargs) -class RqIdSerializer(serializers.Serializer): - rq_id = serializers.CharField(help_text="Request id") - class JobFiles(serializers.ListField): """ @@ -2963,6 +2952,7 @@ class TaskFileSerializer(serializers.Serializer): class ProjectFileSerializer(serializers.Serializer): project_file = serializers.FileField() + class CommentReadSerializer(serializers.ModelSerializer): owner = BasicUserSerializer(allow_null=True, required=False) @@ -3488,130 +3478,3 @@ def create(self, validated_data): class Meta: model = models.AnnotationGuide fields = ('id', 'task_id', 'project_id', 'markdown', ) - -class UserIdentifiersSerializer(BasicUserSerializer): - class Meta(BasicUserSerializer.Meta): - fields = ( - "id", - "username", - ) - - -class RequestDataOperationSerializer(serializers.Serializer): - type = serializers.CharField() - target = serializers.ChoiceField(choices=models.RequestTarget.choices) - project_id = serializers.IntegerField(required=False, allow_null=True) - task_id = serializers.IntegerField(required=False, allow_null=True) - job_id = serializers.IntegerField(required=False, allow_null=True) - format = serializers.CharField(required=False, allow_null=True) - function_id = serializers.CharField(required=False, allow_null=True) - - def to_representation(self, rq_job: RQJob) -> dict[str, Any]: - parsed_rq_id: RQId = rq_job.parsed_rq_id - - base_rq_job_meta = BaseRQMeta.for_job(rq_job) - representation = { - "type": ":".join( - [ - parsed_rq_id.action, - parsed_rq_id.subresource or parsed_rq_id.target, - ] - ), - "target": parsed_rq_id.target, - "project_id": base_rq_job_meta.project_id, - "task_id": base_rq_job_meta.task_id, - "job_id": base_rq_job_meta.job_id, - } - if parsed_rq_id.action == RequestAction.AUTOANNOTATE: - representation["function_id"] = LambdaRQMeta.for_job(rq_job).function_id - elif parsed_rq_id.action in (RequestAction.IMPORT, RequestAction.EXPORT): - representation["format"] = parsed_rq_id.format - - return representation - -class RequestSerializer(serializers.Serializer): - # SerializerMethodField is not used here to mark "status" field as required and fix schema generation. - # Marking them as read_only leads to generating type as allOf with one reference to RequestStatus component. - # The client generated using openapi-generator from such a schema contains wrong type like: - # status (bool, date, datetime, dict, float, int, list, str, none_type): [optional] - status = serializers.ChoiceField(source="get_status", choices=models.RequestStatus.choices) - message = serializers.SerializerMethodField() - id = serializers.CharField() - operation = RequestDataOperationSerializer(source="*") - progress = serializers.SerializerMethodField() - created_date = serializers.DateTimeField(source="created_at") - started_date = serializers.DateTimeField( - required=False, allow_null=True, source="started_at", - ) - finished_date = serializers.DateTimeField( - required=False, allow_null=True, source="ended_at", - ) - expiry_date = serializers.SerializerMethodField() - owner = serializers.SerializerMethodField() - result_url = serializers.URLField(required=False, allow_null=True) - result_id = serializers.IntegerField(required=False, allow_null=True) - - def __init__(self, *args, **kwargs): - self._base_rq_job_meta: BaseRQMeta | None = None - super().__init__(*args, **kwargs) - - @extend_schema_field(UserIdentifiersSerializer()) - def get_owner(self, rq_job: RQJob) -> dict[str, Any]: - assert self._base_rq_job_meta - return UserIdentifiersSerializer(self._base_rq_job_meta.user).data - - @extend_schema_field( - serializers.FloatField(min_value=0, max_value=1, required=False, allow_null=True) - ) - def get_progress(self, rq_job: RQJob) -> Decimal: - rq_job_meta = ImportRQMeta.for_job(rq_job) - # progress of task creation is stored in "task_progress" field - # progress of project import is stored in "progress" field - return Decimal(rq_job_meta.progress or rq_job_meta.task_progress or 0.) - - @extend_schema_field(serializers.DateTimeField(required=False, allow_null=True)) - def get_expiry_date(self, rq_job: RQJob) -> Optional[str]: - delta = None - if rq_job.is_finished: - delta = rq_job.result_ttl or rq_defaults.DEFAULT_RESULT_TTL - elif rq_job.is_failed: - delta = rq_job.failure_ttl or rq_defaults.DEFAULT_FAILURE_TTL - - if rq_job.ended_at and delta: - expiry_date = rq_job.ended_at + timedelta(seconds=delta) - return expiry_date.replace(tzinfo=timezone.utc) - - return None - - @extend_schema_field(serializers.CharField(allow_blank=True)) - def get_message(self, rq_job: RQJob) -> str: - assert self._base_rq_job_meta - rq_job_status = rq_job.get_status() - message = '' - - if RQJobStatus.STARTED == rq_job_status: - message = self._base_rq_job_meta.status or message - elif RQJobStatus.FAILED == rq_job_status: - message = self._base_rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error")) - - return message - - def to_representation(self, rq_job: RQJob) -> dict[str, Any]: - self._base_rq_job_meta = BaseRQMeta.for_job(rq_job) - representation = super().to_representation(rq_job) - - # FUTURE-TODO: support such statuses on UI - if representation["status"] in (RQJobStatus.DEFERRED, RQJobStatus.SCHEDULED): - representation["status"] = RQJobStatus.QUEUED - - if representation["status"] == RQJobStatus.FINISHED: - if rq_job.parsed_rq_id.action == models.RequestAction.EXPORT: - representation["result_url"] = ExportRQMeta.for_job(rq_job).result_url - - if ( - rq_job.parsed_rq_id.action == models.RequestAction.IMPORT - and rq_job.parsed_rq_id.subresource == models.RequestSubresource.BACKUP - ): - representation["result_id"] = rq_job.return_value() - - return representation diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index ca6819304626..ee13c8fc8365 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -20,7 +20,6 @@ import attrs import av -import django_rq import rq from django.conf import settings from django.db import transaction @@ -46,11 +45,9 @@ sort, ) from cvat.apps.engine.model_utils import bulk_create -from cvat.apps.engine.models import RequestAction, RequestTarget -from cvat.apps.engine.rq import ImportRQMeta, RQId, define_dependent_job +from cvat.apps.engine.rq import ImportRQMeta from cvat.apps.engine.task_validation import HoneypotFrameSelector -from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import av_scan_paths, format_list, get_rq_lock_by_user, take_by +from cvat.apps.engine.utils import av_scan_paths, format_list, take_by from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session from utils.dataset_manifest import ImageManifestManager, VideoManifestManager, is_manifest from utils.dataset_manifest.core import VideoManifestValidator, is_dataset_manifest @@ -60,32 +57,6 @@ slogger = ServerLogManager(__name__) -############################# Low Level server API - -def create( - db_task: models.Task, - data: models.Data, - request: ExtendedRequest, -) -> str: - """Schedule a background job to create a task and return that job's identifier""" - q = django_rq.get_queue(settings.CVAT_QUEUES.IMPORT_DATA.value) - user_id = request.user.id - rq_id = RQId(RequestAction.CREATE, RequestTarget.TASK, db_task.pk).render() - - with get_rq_lock_by_user(q, user_id): - q.enqueue_call( - func=_create_thread, - args=(db_task.pk, data), - job_id=rq_id, - meta=ImportRQMeta.build_for(request=request, db_obj=db_task), - depends_on=define_dependent_job(q, user_id), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds(), - ) - - return rq_id - -############################# Internal implementation for server API - JobFileMapping = list[list[str]] class SegmentParams(NamedTuple): @@ -574,7 +545,7 @@ def _create_task_manifest_from_cloud_data( manifest.create() @transaction.atomic -def _create_thread( +def create_thread( db_task: Union[int, models.Task], data: dict[str, Any], *, diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index f2c9df7d8f77..0bb2f2297355 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -66,6 +66,7 @@ ApiTestBase, ExportApiTestBase, ForceLogin, + ImportApiTestBase, generate_image_file, generate_video_file, get_paginated_collection, @@ -1322,7 +1323,7 @@ def test_api_v2_projects_id_tasks_no_auth(self): self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) -class ProjectBackupAPITestCase(ExportApiTestBase): +class ProjectBackupAPITestCase(ExportApiTestBase, ImportApiTestBase): @classmethod def setUpTestData(cls): create_db_users(cls) @@ -1620,12 +1621,6 @@ def _create_projects(cls): cls._create_tasks(db_project) cls.projects.append(db_project) - def _run_api_v2_projects_import(self, user, data): - with ForceLogin(user, self.client): - response = self.client.post('/api/projects/backup', data=data, format="multipart") - - return response - def _run_api_v2_projects_id(self, pid, user): with ForceLogin(user, self.client): response = self.client.get('/api/projects/{}'.format(pid), format="json") @@ -1654,18 +1649,13 @@ def _run_api_v2_projects_id_export_import(self, user): self.assertTrue(response.streaming) content = io.BytesIO(b"".join(response.streaming_content)) content.seek(0) + content.name = "file.zip" - uploaded_data = { - "project_file": content, - } - response = self._run_api_v2_projects_import(user, uploaded_data) - self.assertEqual(response.status_code, expected_4xx_status_code or status.HTTP_202_ACCEPTED) - if response.status_code == status.HTTP_202_ACCEPTED: - rq_id = response.data["rq_id"] - response = self._run_api_v2_projects_import(user, {"rq_id": rq_id}) - self.assertEqual(response.status_code, expected_4xx_status_code or status.HTTP_201_CREATED) + created_project_id = self._import_project_backup(user, content, expected_4xx_status_code=expected_4xx_status_code) + + if not expected_4xx_status_code: original_project = self._run_api_v2_projects_id(pid, user) - imported_project = self._run_api_v2_projects_id(response.data["id"], user) + imported_project = self._run_api_v2_projects_id(created_project_id, user) compare_objects( self=self, obj1=original_project, @@ -1882,7 +1872,7 @@ def test_api_v2_projects_remove_task_export(self): self._check_xml(pid, user, 3) -class ProjectImportExportAPITestCase(ExportApiTestBase): +class ProjectImportExportAPITestCase(ExportApiTestBase, ImportApiTestBase): def setUp(self) -> None: super().setUp() self.tasks = [] @@ -1999,16 +1989,6 @@ def _create_project(project_data): for data in project_data: _create_project(data) - def _run_api_v2_projects_id_dataset_import(self, pid, user, data, f): - with ForceLogin(user, self.client): - response = self.client.post("/api/projects/{}/dataset?format={}".format(pid, f), data=data, format="multipart") - return response - - def _run_api_v2_projects_id_dataset_import_status(self, pid, user, rq_id): - with ForceLogin(user, self.client): - response = self.client.get("/api/projects/{}/dataset?action=import_status&rq_id={}".format(pid, rq_id), format="json") - return response - def test_api_v2_projects_id_export_import(self): self._create_projects() self._create_tasks() @@ -2025,16 +2005,8 @@ def test_api_v2_projects_id_export_import(self): tmp_file.write(b"".join(response.streaming_content)) tmp_file.seek(0) - import_data = { - "dataset_file": tmp_file, - } - - response = self._run_api_v2_projects_id_dataset_import(pid_import, self.owner, import_data, "CVAT 1.1") - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self._import_project_dataset(self.owner, pid_import, tmp_file, query_params={"format": "CVAT 1.1"}) - rq_id = response.data.get('rq_id') - response = self._run_api_v2_projects_id_dataset_import_status(pid_import, self.owner, rq_id) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) def tearDown(self): for task in self.tasks: @@ -2797,7 +2769,7 @@ def test_api_v2_tasks_no_auth(self): } self._check_api_v2_tasks(None, data) -class TaskImportExportAPITestCase(ExportApiTestBase): +class TaskImportExportAPITestCase(ExportApiTestBase, ImportApiTestBase): def setUp(self): super().setUp() self.tasks = [] @@ -3110,11 +3082,6 @@ def _create_task(task_data, media_data): for media in self.media_data: _create_task(data, media) - def _run_api_v2_tasks_id_import(self, user, data): - with ForceLogin(user, self.client): - response = self.client.post('/api/tasks/backup', data=data, format="multipart") - - return response def _run_api_v2_tasks_id(self, tid, user): with ForceLogin(user, self.client): @@ -3140,18 +3107,13 @@ def _run_api_v2_tasks_id_export_import(self, user): self.assertTrue(response.streaming) content = io.BytesIO(b"".join(response.streaming_content)) content.seek(0) + content.name = "file.zip" + + created_task_id = self._import_task_backup(user, content, expected_4xx_status_code=expected_4xx_status_code) - uploaded_data = { - "task_file": content, - } - response = self._run_api_v2_tasks_id_import(user, uploaded_data) - self.assertEqual(response.status_code, expected_4xx_status_code or status.HTTP_202_ACCEPTED) if user is not self.somebody and user is not self.user and user is not self.annotator: - rq_id = response.data["rq_id"] - response = self._run_api_v2_tasks_id_import(user, {"rq_id": rq_id}) - self.assertEqual(response.status_code, expected_4xx_status_code or status.HTTP_201_CREATED) original_task = self._run_api_v2_tasks_id(tid, user) - imported_task = self._run_api_v2_tasks_id(response.data["id"], user) + imported_task = self._run_api_v2_tasks_id(created_task_id, user) compare_objects( self=self, obj1=original_task, @@ -5495,7 +5457,7 @@ def test_api_v2_jobs_id_annotations_somebody(self): def test_api_v2_jobs_id_annotations_no_auth(self): self._run_api_v2_jobs_id_annotations(self.user, self.user, None) -class TaskAnnotationAPITestCase(ExportApiTestBase, JobAnnotationAPITestCase): +class TaskAnnotationAPITestCase(ExportApiTestBase, ImportApiTestBase, JobAnnotationAPITestCase): def _put_api_v2_tasks_id_annotations(self, pk, user, data): with ForceLogin(user, self.client): response = self.client.put("/api/tasks/{}/annotations".format(pk), @@ -5524,16 +5486,6 @@ def _patch_api_v2_tasks_id_annotations(self, pk, user, action, data): return response - def _upload_api_v2_tasks_id_annotations(self, pk, user, data, query_params=""): - with ForceLogin(user, self.client): - response = self.client.put( - path="/api/tasks/{0}/annotations?{1}".format(pk, query_params), - data=data, - format="multipart", - ) - - return response - def _get_formats(self, user): with ForceLogin(user, self.client): response = self.client.get( @@ -5943,13 +5895,9 @@ def _run_api_v2_tasks_id_annotations_dump_load(self, owner): if owner: HTTP_200_OK = status.HTTP_200_OK HTTP_204_NO_CONTENT = status.HTTP_204_NO_CONTENT - HTTP_202_ACCEPTED = status.HTTP_202_ACCEPTED - HTTP_201_CREATED = status.HTTP_201_CREATED else: HTTP_200_OK = status.HTTP_401_UNAUTHORIZED HTTP_204_NO_CONTENT = status.HTTP_401_UNAUTHORIZED - HTTP_202_ACCEPTED = status.HTTP_401_UNAUTHORIZED - HTTP_201_CREATED = status.HTTP_401_UNAUTHORIZED def _get_initial_annotation(annotation_format): if annotation_format not in ["Market-1501 1.0", "ICDAR Recognition 1.0", @@ -6539,18 +6487,9 @@ def _get_initial_annotation(annotation_format): if not import_format: continue - uploaded_data = { - "annotation_file": content, - } - response = self._upload_api_v2_tasks_id_annotations( - task["id"], owner, uploaded_data, - "format={}".format(import_format)) - self.assertEqual(response.status_code, HTTP_202_ACCEPTED) - - response = self._upload_api_v2_tasks_id_annotations( - task["id"], owner, {}, - "format={}".format(import_format)) - self.assertEqual(response.status_code, HTTP_201_CREATED) + self._import_task_annotations( + owner, task["id"], content, query_params={"format": import_format} + ) # 7. check annotation if export_format in {"Segmentation mask 1.1", "MOTS PNG 1.0", @@ -6669,18 +6608,9 @@ def generate_coco_anno(): content = io.BytesIO(generate_coco_anno()) content.seek(0) - format_name = "COCO 1.0" - uploaded_data = { - "annotation_file": content, - } - response = self._upload_api_v2_tasks_id_annotations( - task["id"], user, uploaded_data, - "format={}".format(format_name)) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - response = self._upload_api_v2_tasks_id_annotations( - task["id"], user, {}, "format={}".format(format_name)) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self._import_task_annotations( + user, task["id"], content, query_params={"format": "COCO 1.0"} + ) response = self._get_api_v2_tasks_id_annotations(task["id"], user) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py index 556187904550..fa6b2a66c16b 100644 --- a/cvat/apps/engine/tests/test_rest_api_3D.py +++ b/cvat/apps/engine/tests/test_rest_api_3D.py @@ -21,14 +21,19 @@ from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.tests.utils import TestDir from cvat.apps.engine.media_extractors import ValidateDimension -from cvat.apps.engine.tests.utils import ExportApiTestBase, ForceLogin, get_paginated_collection +from cvat.apps.engine.tests.utils import ( + ExportApiTestBase, + ForceLogin, + ImportApiTestBase, + get_paginated_collection, +) CREATE_ACTION = "create" UPDATE_ACTION = "update" DELETE_ACTION = "delete" -class _DbTestBase(ExportApiTestBase): +class _DbTestBase(ExportApiTestBase, ImportApiTestBase): @classmethod def setUpTestData(cls): cls.create_db_users() @@ -136,18 +141,6 @@ def _get_jobs(self, task_id): ) return values - def _upload_file(self, url, data, user): - response = self._put_request(url, user, data={"annotation_file": data}, format="multipart") - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - response = self._put_request(url, user) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - def _generate_url_upload_tasks_annotations(self, task_id, upload_format_name): - return f"/api/tasks/{task_id}/annotations?format={upload_format_name}" - - def _generate_url_upload_job_annotations(self, job_id, upload_format_name): - return f"/api/jobs/{job_id}/annotations?format={upload_format_name}" - def _remove_annotations(self, tid): response = self._delete_request(f"/api/tasks/{tid}/annotations", self.admin) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -499,10 +492,11 @@ def test_api_v2_dump_and_upload_annotation(self): self._remove_annotations(task_id) with self.subTest(format=f"{format_name}_upload"): file_name = osp.join(test_dir, f"{format_name}_admin.zip") - url = self._generate_url_upload_tasks_annotations(task_id, format_name) with open(file_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task["id"], binary_file, query_params={"format": format_name} + ) task_ann = TaskAnnotation(task_id) task_ann.init_from_db() @@ -538,10 +532,10 @@ def test_api_v2_rewrite_annotation(self): self.assertEqual(response.status_code, status.HTTP_200_OK) file_name = osp.join(test_dir, f"{format_name}.zip") - url = self._generate_url_upload_tasks_annotations(task_id, format_name) - with open(file_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task["id"], binary_file, query_params={"format": format_name} + ) task_ann = TaskAnnotation(task_id) task_ann.init_from_db() @@ -568,10 +562,11 @@ def test_api_v2_dump_and_upload_empty_annotation(self): self.assertTrue(osp.exists(file_name)) file_name = osp.join(test_dir, f"{format_name}.zip") - url = self._generate_url_upload_tasks_annotations(task_id, format_name) with open(file_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, query_params={"format": format_name} + ) task_ann = TaskAnnotation(task_id) task_ann.init_from_db() @@ -631,10 +626,11 @@ def test_api_v2_upload_annotation_with_attributes(self): self._remove_annotations(task_id) with self.subTest(format=f"{format_name}_upload"): file_name = osp.join(test_dir, f"{format_name}.zip") - url = self._generate_url_upload_tasks_annotations(task_id, format_name) with open(file_name, 'rb') as binary_file: - self._upload_file(url, binary_file, self.admin) + self._import_task_annotations( + self.admin, task_id, binary_file, query_params={"format": format_name} + ) task_ann = TaskAnnotation(task_id) task_ann.init_from_db() diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py index 820a815afbd4..0c22feadf1fa 100644 --- a/cvat/apps/engine/tests/utils.py +++ b/cvat/apps/engine/tests/utils.py @@ -189,6 +189,84 @@ def _check_request_status( assert request_status == "finished", f"The last request status was {request_status}" return response +class ImportApiTestBase(ApiTestBase): + def _import( + self, + user: str, + api_path: str, + file_content: BytesIO, + *, + through_field: str, + query_params: dict[str, Any] | None = None, + expected_4xx_status_code: int | None = None, + ): + response = self._post_request( + api_path, user, + data={through_field: file_content}, + format="multipart", + query_params=query_params, + ) + self.assertEqual(response.status_code, expected_4xx_status_code or status.HTTP_202_ACCEPTED) + + if not expected_4xx_status_code: + rq_id = response.json().get("rq_id") + assert rq_id, "The rq_id param was not found in the server response" + response = self._check_request_status(user, rq_id) + + return response + + def _import_project_dataset( + self, user: str, projetc_id: int, file_content: BytesIO, query_params: str = None, + expected_4xx_status_code: int | None = None + ): + return self._import( + user, f"/api/projects/{projetc_id}/dataset", file_content, through_field="dataset_file", + query_params=query_params, expected_4xx_status_code=expected_4xx_status_code + ) + + def _import_task_annotations( + self, user: str, task_id: int, file_content: BytesIO, query_params: str = None, + expected_4xx_status_code: int | None = None + ): + return self._import( + user, f"/api/tasks/{task_id}/annotations", file_content, through_field="annotation_file", + query_params=query_params, expected_4xx_status_code=expected_4xx_status_code + ) + + def _import_job_annotations( + self, user: str, job_id: int, file_content: BytesIO, query_params: str = None, + expected_4xx_status_code: int | None = None + ): + return self._import( + user, f"/api/jobs/{job_id}/annotations", file_content, through_field="annotation_file", + query_params=query_params, expected_4xx_status_code=expected_4xx_status_code + ) + + def _import_project_backup( + self, user: str, file_content: BytesIO, query_params: str = None, + expected_4xx_status_code: int | None = None + ) -> int | None: + response = self._import( + user, "/api/projects/backup", file_content, through_field="project_file", + query_params=query_params, expected_4xx_status_code=expected_4xx_status_code + ) + if expected_4xx_status_code: + return None + + return response.json()["result_id"] + + def _import_task_backup( + self, user: str, file_content: BytesIO, query_params: str = None, + expected_4xx_status_code: int | None = None + ) -> int | None: + response = self._import( + user, "/api/tasks/backup", file_content, through_field="task_file", + query_params=query_params, expected_4xx_status_code=expected_4xx_status_code + ) + if expected_4xx_status_code: + return None + + return response.json()["result_id"] class ExportApiTestBase(ApiTestBase): def _export( diff --git a/cvat/apps/engine/urls.py b/cvat/apps/engine/urls.py index 8b8c38a4e5ca..0c294aa5ab8f 100644 --- a/cvat/apps/engine/urls.py +++ b/cvat/apps/engine/urls.py @@ -23,7 +23,6 @@ router.register("cloudstorages", views.CloudStorageViewSet) router.register("assets", views.AssetsViewSet) router.register("guides", views.AnnotationGuidesViewSet) -router.register("requests", views.RequestViewSet, basename="request") urlpatterns = [ # Entry point for a client diff --git a/cvat/apps/engine/view_utils.py b/cvat/apps/engine/view_utils.py index 3d27deedd039..2503c0f0b226 100644 --- a/cvat/apps/engine/view_utils.py +++ b/cvat/apps/engine/view_utils.py @@ -4,9 +4,12 @@ # NOTE: importing in the utils.py header leads to circular importing +import textwrap +from datetime import datetime from typing import Optional from django.db.models.query import QuerySet +from django.http import HttpResponseGone from django.http.response import HttpResponse from drf_spectacular.utils import extend_schema from rest_framework.decorators import action @@ -92,3 +95,24 @@ def decorator(f): return f return decorator + +def get_410_response_for_export_api(path: str) -> HttpResponseGone: + return HttpResponseGone(textwrap.dedent(f"""\ + This endpoint is no longer supported. + To initiate the export process, use POST {path}. + To check the process status, use GET /api/requests/rq_id, + where rq_id is obtained from the response of the previous request. + To download the prepared file, use the result_url obtained from the response of the previous request. + """)) + +def get_410_response_when_checking_process_status(process_type: str, /) -> HttpResponseGone: + return HttpResponseGone(textwrap.dedent(f"""\ + This endpoint no longer supports checking the status of the {process_type} process. + The common requests API should be used instead: GET /api/requests/rq_id, + where rq_id is obtained from the response of the initializing request. + """)) + +def deprecate_response(response: Response, *, deprecation_date: datetime) -> None: + # https://www.rfc-editor.org/rfc/rfc9745 + deprecation_timestamp = int(deprecation_date.timestamp()) + response.headers["Deprecation"] = f"@{deprecation_timestamp}" diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 9d4bdd1c4f58..3e5f011b579f 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -3,7 +3,6 @@ # # SPDX-License-Identifier: MIT -import functools import itertools import os import os.path as osp @@ -13,15 +12,12 @@ import traceback import zlib from abc import ABCMeta, abstractmethod -from collections import namedtuple -from collections.abc import Iterable from contextlib import suppress from copy import copy from datetime import datetime from pathlib import Path -from tempfile import NamedTemporaryFile from types import SimpleNamespace -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Optional, Union, cast import django_rq from attr.converters import to_bool @@ -32,11 +28,8 @@ from django.db import models as django_models from django.db import transaction from django.db.models.query import Prefetch -from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseGone, HttpResponseNotFound +from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotFound from django.utils import timezone -from django.utils.decorators import method_decorator -from django.views.decorators.cache import never_cache -from django_rq.queues import DjangoRQ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( OpenApiExample, @@ -47,7 +40,6 @@ extend_schema_view, ) from PIL import Image -from redis.exceptions import ConnectionError as RedisConnectionError from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.exceptions import APIException, NotFound, PermissionDenied, ValidationError @@ -56,23 +48,14 @@ from rest_framework.response import Response from rest_framework.settings import api_settings from rq.job import Job as RQJob -from rq.job import JobStatus as RQJobStatus import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager.views # pylint: disable=unused-import -from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer from cvat.apps.engine import backup +from cvat.apps.engine.background import BackupImporter, DatasetImporter, TaskCreator from cvat.apps.engine.cache import CvatChunkTimestampMismatchError, LockError, MediaCache -from cvat.apps.engine.cloud_provider import ( - db_storage_to_storage_instance, - import_resource_from_cloud_storage, -) -from cvat.apps.engine.filters import ( - NonModelJsonLogicFilter, - NonModelOrderingFilter, - NonModelSimpleFilter, -) +from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance from cvat.apps.engine.frame_provider import ( DataWithMeta, FrameQuality, @@ -80,7 +63,6 @@ JobFrameProvider, TaskFrameProvider, ) -from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.media_extractors import get_mime from cvat.apps.engine.mixins import BackupMixin, DatasetMixin, PartialUpdateModelMixin, UploadMixin from cvat.apps.engine.model_utils import bulk_create @@ -96,8 +78,6 @@ Location, Project, RequestAction, - RequestStatus, - RequestSubresource, RequestTarget, StorageChoice, StorageMethodChoice, @@ -113,16 +93,9 @@ ProjectPermission, TaskPermission, UserPermission, - get_cloud_storage_for_import_or_export, get_iam_context, ) -from cvat.apps.engine.rq import ( - ImportRQMeta, - RQId, - RQMetaWithFailureInfo, - define_dependent_job, - is_rq_job_owner, -) +from cvat.apps.engine.rq import ImportRequestId, ImportRQMeta, RQMetaWithFailureInfo from cvat.apps.engine.serializers import ( AboutSerializer, AnnotationFileSerializer, @@ -154,8 +127,6 @@ ProjectFileSerializer, ProjectReadSerializer, ProjectWriteSerializer, - RequestSerializer, - RqIdSerializer, RqStatusSerializer, TaskFileSerializer, TaskReadSerializer, @@ -165,22 +136,18 @@ UserSerializer, ) from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import ( - av_scan_paths, - get_rq_lock_by_user, - get_rq_lock_for_job, - import_resource_with_clean_up_after, - parse_exception_message, - process_failed_job, - sendfile, +from cvat.apps.engine.utils import parse_exception_message, sendfile +from cvat.apps.engine.view_utils import ( + get_410_response_for_export_api, + get_410_response_when_checking_process_status, + tus_chunk_action, ) -from cvat.apps.engine.view_utils import tus_chunk_action -from cvat.apps.events.handlers import handle_dataset_import from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.iam.permissions import IsAuthenticatedOrReadPublicResource, PolicyEnforcer +from cvat.apps.redis_handler.serializers import RqIdSerializer from utils.dataset_manifest import ImageManifestManager -from . import models, task +from . import models from .log import ServerLogManager slogger = ServerLogManager(__name__) @@ -191,14 +158,6 @@ _DATA_UPDATED_DATE_HEADER_NAME = 'X-Updated-Date' _RETRY_AFTER_TIMEOUT = 10 -def get_410_response_for_export_api(path: str) -> HttpResponseGone: - return HttpResponseGone(textwrap.dedent(f"""\ - This endpoint is no longer supported. - To initiate the export process, use POST {path}. - To check the process status, use GET /api/requests/rq_id, - where rq_id is obtained from the response of the previous request. - To download the prepared file, use the result_url obtained from the response of the previous request. - """)) @extend_schema(tags=['server']) class ServerViewSet(viewsets.ViewSet): @@ -364,9 +323,6 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, ordering = "-id" lookup_fields = {'owner': 'owner__username', 'assignee': 'assignee__username'} iam_organization_field = 'organization' - IMPORT_RQ_ID_FACTORY = functools.partial(RQId, - RequestAction.IMPORT, RequestTarget.PROJECT, subresource=RequestSubresource.DATASET - ) def get_serializer_class(self): if self.request.method in SAFE_METHODS: @@ -395,50 +351,7 @@ def perform_create(self, serializer, **kwargs): # Required for the extra summary information added in the queryset serializer.instance = self.get_queryset().get(pk=serializer.instance.pk) - @extend_schema(methods=['GET'], summary='Check dataset import status', - description=textwrap.dedent(""" - Utilizing this endpoint to check the status of the process - of importing a project dataset from a file is deprecated. - In addition, this endpoint no longer handles the project dataset export process. - - Consider using new API: - - `POST /api/projects//dataset/export/?save_images=True` to initiate export process - - `GET /api/requests/` to check process status - - `GET result_url` to download a prepared file - - Where: - - `rq_id` can be found in the response on initializing request - - `result_url` can be found in the response on checking status request - """), - parameters=[ - OpenApiParameter('format', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - deprecated=True - ), - OpenApiParameter('filename', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - deprecated=True - ), - OpenApiParameter('action', description='Used to check the import status', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['import_status'], - deprecated=True - ), - OpenApiParameter('location', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list(), - deprecated=True - ), - OpenApiParameter('cloud_storage_id', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - deprecated=True - ), - OpenApiParameter('rq_id', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=True), - ], - deprecated=True, - responses={ - '410': OpenApiResponse(description='API endpoint no longer supports exporting datasets'), - }) + @extend_schema(methods=['GET'], exclude=True) @extend_schema(methods=['POST'], summary='Import a dataset into a project', description=textwrap.dedent(""" @@ -455,17 +368,10 @@ def perform_create(self, serializer, **kwargs): enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False), - OpenApiParameter('use_default_location', description='Use the location that was configured in the project to import annotations', - location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, - default=True, deprecated=True), OpenApiParameter('filename', description='Dataset file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), ], - request=PolymorphicProxySerializer('DatasetWrite', - # TODO: refactor to use required=False when possible - serializers=[DatasetFileSerializer, OpenApiTypes.NONE], - resource_type_field_name=None - ), + request=DatasetFileSerializer(required=False), responses={ '202': OpenApiResponse(RqIdSerializer, description='Importing has been started'), '400': OpenApiResponse(description='Failed to import dataset'), @@ -485,61 +391,13 @@ def dataset(self, request: ExtendedRequest, pk: int): # depends on rq job status (like 201 - finished), # while GET /api/requests/rq_id returns a 200 status code # if such a request exists regardless of job status. - - deprecation_timestamp = int(datetime(2025, 2, 27, tzinfo=timezone.utc).timestamp()) - response_headers = { - "Deprecation": f"@{deprecation_timestamp}" - } - - queue = django_rq.get_queue(settings.CVAT_QUEUES.IMPORT_DATA.value) - rq_id = request.query_params.get('rq_id') - if not rq_id: - return Response( - 'The rq_id param should be specified in the query parameters', - status=status.HTTP_400_BAD_REQUEST, - headers=response_headers, - ) - - rq_job = queue.fetch_job(rq_id) - - if rq_job is None: - return Response(status=status.HTTP_404_NOT_FOUND, headers=response_headers) - # check that the user has access to the current rq_job - elif not is_rq_job_owner(rq_job, request.user.id): - return Response(status=status.HTTP_403_FORBIDDEN, headers=response_headers) - - if rq_job.is_finished: - rq_job.delete() - return Response(status=status.HTTP_201_CREATED, headers=response_headers) - elif rq_job.is_failed: - exc_info = process_failed_job(rq_job) - - return Response( - data=str(exc_info), - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - headers=response_headers - ) - else: - return Response( - data=self._get_rq_response( - settings.CVAT_QUEUES.IMPORT_DATA.value, - rq_id, - ), - status=status.HTTP_202_ACCEPTED, - headers=response_headers - ) + return get_410_response_when_checking_process_status("import") # we cannot redirect to the new API here since this endpoint used not only to check the status # of exporting process|download a result file, but also to initiate export process return get_410_response_for_export_api("/api/projects/id/dataset/export?save_images=True") - return self.import_annotations( - request=request, - db_obj=self._object, - import_func=_import_project_dataset, - rq_func=dm.project.import_dataset_as_project, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - ) + return self.upload_data(request) @tus_chunk_action(detail=True, suffix_base="dataset") @@ -556,37 +414,13 @@ def get_upload_dir(self): def upload_finished(self, request: ExtendedRequest): if self.action == 'dataset': - format_name = request.query_params.get("format", "") - filename = request.query_params.get("filename", "") - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - tmp_dir = self._object.get_tmp_dirname() - uploaded_file = os.path.join(tmp_dir, filename) - if not os.path.isfile(uploaded_file): - uploaded_file = None - - return _import_project_dataset( - request=request, - filename=uploaded_file, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - rq_func=dm.project.import_dataset_as_project, - db_obj=self._object, - format_name=format_name, - conv_mask_to_poly=conv_mask_to_poly - ) + importer = DatasetImporter(request=request, db_instance=self._object) + return importer.enqueue_job() + elif self.action == 'import_backup': - filename = request.query_params.get("filename", "") - if filename: - tmp_dir = backup.get_backup_dirname() - backup_file = os.path.join(tmp_dir, filename) - if os.path.isfile(backup_file): - return backup.import_project( - request, - settings.CVAT_QUEUES.IMPORT_DATA.value, - filename=backup_file, - ) - return Response(data='No such file were uploaded', - status=status.HTTP_400_BAD_REQUEST) - return backup.import_project(request, settings.CVAT_QUEUES.IMPORT_DATA.value) + importer = BackupImporter(request=request, target=RequestTarget.PROJECT) + return importer.enqueue_job() + return Response(data='Unknown upload was finished', status=status.HTTP_400_BAD_REQUEST) @@ -604,14 +438,14 @@ def export_backup(self, request: ExtendedRequest, pk: int): description=textwrap.dedent(""" The backup import process is as follows: - The first request POST /api/projects/backup will initiate file upload and will create - the rq job on the server in which the process of a project creating from an uploaded backup - will be carried out. + The first request POST /api/projects/backup schedules a background job on the server + in which the process of creating a project from the uploaded backup is carried out. + + To check the status of the import process, use GET /api/requests/rq_id, + where rq_id is the request ID obtained from the response to the previous request. - After initiating the backup upload, you will receive an rq_id parameter. - Make sure to include this parameter as a query parameter in your subsequent requests - to track the status of the project creation. - Once the project has been successfully created, the server will return the id of the newly created project. + Once the import completes successfully, the response will contain the ID + of the newly created project in the result_id field. """), parameters=[ *ORGANIZATION_OPEN_API_PARAMETERS, @@ -622,27 +456,19 @@ def export_backup(self, request: ExtendedRequest, pk: int): location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False), OpenApiParameter('filename', description='Backup file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), - OpenApiParameter('rq_id', description='rq id', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), ], - request=PolymorphicProxySerializer('BackupWrite', - # TODO: refactor to use required=False when possible - serializers=[ProjectFileSerializer, OpenApiTypes.NONE], - resource_type_field_name=None - ), - # TODO: for some reason the code generated by the openapi generator from schema with different serializers - # contains only one serializer, need to fix that. - # https://github.com/OpenAPITools/openapi-generator/issues/6126 - responses={ - # 201: OpenApiResponse(inline_serializer("ImportedProjectIdSerializer", fields={"id": serializers.IntegerField(required=True)}) - '201': OpenApiResponse(description='The project has been imported'), - '202': OpenApiResponse(RqIdSerializer, description='Importing a backup file has been started'), + request=ProjectFileSerializer(required=False), + responses={ + '202': OpenApiResponse(RqIdSerializer, description='Import of the backup file has started'), }) @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=None, parser_classes=_UPLOAD_PARSER_CLASSES) def import_backup(self, request: ExtendedRequest): - return self.import_backup_v1(request, backup.import_project) + if "rq_id" in request.query_params: + return get_410_response_when_checking_process_status("import") + + return self.upload_data(request) @tus_chunk_action(detail=False, suffix_base="backup") def append_backup_chunk(self, request: ExtendedRequest, file_id: str): @@ -938,9 +764,6 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, ordering_fields = list(filter_fields) ordering = "-id" iam_organization_field = 'organization' - IMPORT_RQ_ID_FACTORY = functools.partial(RQId, - RequestAction.IMPORT, RequestTarget.TASK, subresource=RequestSubresource.ANNOTATIONS, - ) def get_serializer_class(self): if self.request.method in SAFE_METHODS: @@ -963,14 +786,14 @@ def get_queryset(self): description=textwrap.dedent(""" The backup import process is as follows: - The first request POST /api/tasks/backup will initiate file upload and will create - the rq job on the server in which the process of a task creating from an uploaded backup - will be carried out. + The first request POST /api/tasks/backup creates a background job on the server + in which the process of a task creating from an uploaded backup is carried out. - After initiating the backup upload, you will receive an rq_id parameter. - Make sure to include this parameter as a query parameter in your subsequent requests - to track the status of the task creation. - Once the task has been successfully created, the server will return the id of the newly created task. + To check the status of the import process, use GET /api/requests/rq_id, + where rq_id is the request ID obtained from the response to the previous request. + + Once the import completes successfully, the response will contain the ID + of the newly created task in the result_id field. """), parameters=[ *ORGANIZATION_OPEN_API_PARAMETERS, @@ -981,24 +804,20 @@ def get_queryset(self): location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False), OpenApiParameter('filename', description='Backup file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), - OpenApiParameter('rq_id', description='rq id', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), ], request=TaskFileSerializer(required=False), - # TODO: for some reason the code generated by the openapi generator from schema with different serializers - # contains only one serializer, need to fix that. - # https://github.com/OpenAPITools/openapi-generator/issues/6126 responses={ - # 201: OpenApiResponse(inline_serializer("ImportedTaskIdSerializer", fields={"id": serializers.IntegerField(required=True)}) - '201': OpenApiResponse(description='The task has been imported'), - '202': OpenApiResponse(RqIdSerializer, description='Importing a backup file has been started'), + '202': OpenApiResponse(RqIdSerializer, description='Import of the backup file has started'), }) @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=None, parser_classes=_UPLOAD_PARSER_CLASSES) def import_backup(self, request: ExtendedRequest): - return self.import_backup_v1(request, backup.import_task) + if "rq_id" in request.query_params: + return get_410_response_when_checking_process_status("import") + + return self.upload_data(request) @tus_chunk_action(detail=False, suffix_base="backup") def append_backup_chunk(self, request: ExtendedRequest, file_id: str): @@ -1125,23 +944,8 @@ def append_files(self, request): def upload_finished(self, request: ExtendedRequest): @transaction.atomic def _handle_upload_annotations(request: ExtendedRequest): - format_name = request.query_params.get("format", "") - filename = request.query_params.get("filename", "") - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - tmp_dir = self._object.get_tmp_dirname() - annotation_file = os.path.join(tmp_dir, filename) - if os.path.isfile(annotation_file): - return _import_annotations( - request=request, - filename=annotation_file, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - rq_func=dm.task.import_task_annotations, - db_obj=self._object, - format_name=format_name, - conv_mask_to_poly=conv_mask_to_poly, - ) - return Response(data='No such file were uploaded', - status=status.HTTP_400_BAD_REQUEST) + importer = DatasetImporter(request=request, db_instance=self._object) + return importer.enqueue_job() def _handle_upload_data(request: ExtendedRequest): with transaction.atomic(): @@ -1204,27 +1008,13 @@ def _handle_upload_data(request: ExtendedRequest): data['stop_frame'] = None # Need to process task data when the transaction is committed - rq_id = task.create(self._object, data, request) - rq_id_serializer = RqIdSerializer(data={'rq_id': rq_id}) - rq_id_serializer.is_valid(raise_exception=True) - - return Response(rq_id_serializer.data, status=status.HTTP_202_ACCEPTED) + creator = TaskCreator(request=request, db_instance=self._object, db_data=data) + return creator.enqueue_job() @transaction.atomic def _handle_upload_backup(request: ExtendedRequest): - filename = request.query_params.get("filename", "") - if filename: - tmp_dir = backup.get_backup_dirname() - backup_file = os.path.join(tmp_dir, filename) - if os.path.isfile(backup_file): - return backup.import_task( - request, - settings.CVAT_QUEUES.IMPORT_DATA.value, - filename=backup_file, - ) - return Response(data='No such file were uploaded', - status=status.HTTP_400_BAD_REQUEST) - return backup.import_task(request, settings.CVAT_QUEUES.IMPORT_DATA.value) + importer = BackupImporter(request=request, target=RequestTarget.TASK) + return importer.enqueue_job() if self.action == 'annotations': return _handle_upload_annotations(request) @@ -1387,80 +1177,12 @@ def append_data_chunk(self, request: ExtendedRequest, pk: int, file_id: str): return self.append_tus_chunk(request, file_id) @extend_schema(methods=['GET'], summary='Get task annotations', - description=textwrap.dedent("""\ - Deprecation warning: - - Utilizing this endpoint to export annotations as a dataset in - a specific format is no longer possible. - - Consider using new API: - - `POST /api/tasks//dataset/export?save_images=False` to initiate export process - - `GET /api/requests/` to check process status, - where `rq_id` is request id returned on initializing request - - `GET result_url` to download a prepared file, - where `result_url` can be found in the response on checking status request - """), - parameters=[ - # FUTURE-TODO: the following parameters should be removed after a few releases - OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description="This parameter is no longer supported", - deprecated=True - ), - OpenApiParameter('filename', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - deprecated=True - ), - OpenApiParameter('action', location=OpenApiParameter.QUERY, - description='This parameter is no longer supported', - type=OpenApiTypes.STR, required=False, enum=['download'], - deprecated=True - ), - OpenApiParameter('location', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list(), - deprecated=True - ), - OpenApiParameter('cloud_storage_id', description='This parameter is no longer supported', - location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - deprecated=True - ), - ], responses={ '200': OpenApiResponse(LabeledDataSerializer), '400': OpenApiResponse(description="Exporting without data is not allowed"), '410': OpenApiResponse(description="API endpoint no longer handles exporting process"), }) - @extend_schema(methods=['PUT'], summary='Replace task annotations / Get annotation import status', - description=textwrap.dedent(""" - Utilizing this endpoint to check status of the import process is deprecated - in favor of the new requests API: - GET /api/requests/, where `rq_id` parameter is returned in the response - on initializing request. - """), - parameters=[ - # deprecated parameters - OpenApiParameter( - 'format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats', - deprecated=True, - ), - OpenApiParameter( - 'rq_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description='rq id', - deprecated=True, - ), - ], - request=PolymorphicProxySerializer('TaskAnnotationsUpdate', - # TODO: refactor to use required=False when possible - serializers=[LabeledDataSerializer, AnnotationFileSerializer, OpenApiTypes.NONE], - resource_type_field_name=None - ), - responses={ - '201': OpenApiResponse(description='Import has finished'), - '202': OpenApiResponse(description='Import is in progress'), - '405': OpenApiResponse(description='Format is not available'), - }) @extend_schema(methods=['POST'], summary="Import annotations into a task", description=textwrap.dedent(""" @@ -1482,16 +1204,17 @@ def append_data_chunk(self, request: ExtendedRequest, pk: int, file_id: str): OpenApiParameter('filename', description='Annotation file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), ], - request=PolymorphicProxySerializer('TaskAnnotationsWrite', - # TODO: refactor to use required=False when possible - serializers=[AnnotationFileSerializer, OpenApiTypes.NONE], - resource_type_field_name=None - ), + request=AnnotationFileSerializer(required=False), responses={ '201': OpenApiResponse(description='Uploading has finished'), '202': OpenApiResponse(RqIdSerializer, description='Uploading has been started'), '405': OpenApiResponse(description='Format is not available'), }) + @extend_schema(methods=['PUT'], summary='Replace task annotations', + request=LabeledDataSerializer, + responses={ + '200': OpenApiResponse(description='Annotations have been replaced'), + }) @extend_schema(methods=['PATCH'], summary='Update task annotations', parameters=[ OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True, @@ -1523,38 +1246,17 @@ def annotations(self, request: ExtendedRequest, pk: int): return Response(data) elif request.method == 'POST' or request.method == 'OPTIONS': - # NOTE: initialization process of annotations import - format_name = request.query_params.get('format', '') - return self.import_annotations( - request=request, - db_obj=self._object, - import_func=_import_annotations, - rq_func=dm.task.import_task_annotations, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - ) + return self.upload_data(request) + elif request.method == 'PUT': - format_name = request.query_params.get('format', '') - # deprecated logic, will be removed in one of the next releases - if format_name: - # NOTE: continue process of import annotations - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - location_conf = get_location_configuration( - db_instance=self._object, query_params=request.query_params, field_name=StorageType.SOURCE - ) - return _import_annotations( - request=request, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - rq_func=dm.task.import_task_annotations, - db_obj=self._object, - format_name=format_name, - location_conf=location_conf, - conv_mask_to_poly=conv_mask_to_poly - ) - else: - serializer = LabeledDataSerializer(data=request.data) - if serializer.is_valid(raise_exception=True): - data = dm.task.put_task_data(pk, serializer.validated_data) - return Response(data) + if {"format", "rq_id"} & set(request.query_params.keys()): + return get_410_response_when_checking_process_status("import") + + serializer = LabeledDataSerializer(data=request.data) + if serializer.is_valid(raise_exception=True): + data = dm.task.put_task_data(pk, serializer.validated_data) + return Response(data) + elif request.method == 'DELETE': dm.task.delete_task_data(pk) return Response(status=status.HTTP_204_NO_CONTENT) @@ -1592,7 +1294,11 @@ def status(self, request, pk): task = self.get_object() # force call of check_object_permissions() response = self._get_rq_response( queue=settings.CVAT_QUEUES.IMPORT_DATA.value, - job_id=RQId(RequestAction.CREATE, RequestTarget.TASK, task.id).render() + job_id=ImportRequestId( + action=RequestAction.CREATE, + target=RequestTarget.TASK, + target_id=task.id + ).render() ) serializer = RqStatusSerializer(data=response) @@ -1604,7 +1310,6 @@ def status(self, request, pk): def _get_rq_response(queue, job_id): queue = django_rq.get_queue(queue) job = queue.fetch_job(job_id) - rq_job_meta = ImportRQMeta.for_job(job) response = {} if job is None or job.is_finished: response = { "state": "Finished" } @@ -1617,6 +1322,7 @@ def _get_rq_response(queue, job_id): # https://github.com/cvat-ai/cvat/issues/5215 response = { "state": "Failed", "message": parse_exception_message(job.exc_info or "Unknown error") } else: + rq_job_meta = ImportRQMeta.for_job(job) response = { "state": "Started" } if rq_job_meta.status: response['message'] = rq_job_meta.status @@ -1849,9 +1555,6 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, mixins.CreateMo 'project_name': 'segment__task__project__name', 'assignee': 'assignee__username' } - IMPORT_RQ_ID_FACTORY = functools.partial(RQId, - RequestAction.IMPORT, RequestTarget.JOB, subresource=RequestSubresource.ANNOTATIONS - ) def get_queryset(self): queryset = super().get_queryset() @@ -1902,24 +1605,9 @@ def get_upload_dir(self): # UploadMixin method def upload_finished(self, request: ExtendedRequest): if self.action == 'annotations': - format_name = request.query_params.get("format", "") - filename = request.query_params.get("filename", "") - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - tmp_dir = self.get_upload_dir() - annotation_file = os.path.join(tmp_dir, filename) - if os.path.isfile(annotation_file): - return _import_annotations( - request=request, - filename=annotation_file, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - rq_func=dm.task.import_job_annotations, - db_obj=self._object, - format_name=format_name, - conv_mask_to_poly=conv_mask_to_poly, - ) - else: - return Response(data='No such file were uploaded', - status=status.HTTP_400_BAD_REQUEST) + importer = DatasetImporter(request=request, db_instance=self._object) + return importer.enqueue_job() + return Response(data='Unknown upload was finished', status=status.HTTP_400_BAD_REQUEST) @@ -1994,47 +1682,12 @@ def upload_finished(self, request: ExtendedRequest): '202': OpenApiResponse(RqIdSerializer, description='Uploading has been started'), '405': OpenApiResponse(description='Format is not available'), }) - @extend_schema(methods=['PUT'], - summary='Replace job annotations / Get annotation import status', - description=textwrap.dedent(""" - Utilizing this endpoint to check status of the import process is deprecated - in favor of the new requests API: - GET /api/requests/, where `rq_id` parameter is returned in the response - on initializing request. - """), - parameters=[ - - OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats', - deprecated=True, - ), - OpenApiParameter('location', description='where to import the annotation from', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list(), - deprecated=True, - ), - OpenApiParameter('cloud_storage_id', description='Storage id', - location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - deprecated=True, - ), - OpenApiParameter('filename', description='Annotation file name', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - deprecated=True, - ), - OpenApiParameter('rq_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description='rq id', - deprecated=True, - ), - ], - request=PolymorphicProxySerializer( - component_name='JobAnnotationsUpdate', - serializers=[LabeledDataSerializer, AnnotationFileSerializer(required=False)], - resource_type_field_name=None - ), - responses={ - '201': OpenApiResponse(description='Import has finished'), - '202': OpenApiResponse(description='Import is in progress'), - '405': OpenApiResponse(description='Format is not available'), + @extend_schema( + methods=['PUT'], + summary='Replace job annotations', + request=LabeledDataSerializer, + responses={ + '200': OpenApiResponse(description='Annotations have been replaced'), }) @extend_schema(methods=['PATCH'], summary='Update job annotations', parameters=[ @@ -2065,40 +1718,19 @@ def annotations(self, request: ExtendedRequest, pk: int): return Response(annotations) elif request.method == 'POST' or request.method == 'OPTIONS': - format_name = request.query_params.get('format', '') - return self.import_annotations( - request=request, - db_obj=self._object, - import_func=_import_annotations, - rq_func=dm.task.import_job_annotations, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - ) + return self.upload_data(request) elif request.method == 'PUT': - format_name = request.query_params.get('format', '') - if format_name: - # deprecated logic, will be removed in one of the next releases - conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True)) - location_conf = get_location_configuration( - db_instance=self._object, query_params=request.query_params, field_name=StorageType.SOURCE - ) - return _import_annotations( - request=request, - rq_id_factory=self.IMPORT_RQ_ID_FACTORY, - rq_func=dm.task.import_job_annotations, - db_obj=self._object, - format_name=format_name, - location_conf=location_conf, - conv_mask_to_poly=conv_mask_to_poly - ) - else: - serializer = LabeledDataSerializer(data=request.data) - if serializer.is_valid(raise_exception=True): - try: - data = dm.task.put_job_data(pk, serializer.validated_data) - except (AttributeError, IntegrityError) as e: - return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST) - return Response(data) + if {"format", "rq_id"} & set(request.query_params.keys()): + return get_410_response_when_checking_process_status("import") + + serializer = LabeledDataSerializer(data=request.data) + if serializer.is_valid(raise_exception=True): + try: + data = dm.task.put_job_data(pk, serializer.validated_data) + except (AttributeError, IntegrityError) as e: + return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST) + return Response(data) elif request.method == 'DELETE': dm.task.delete_job_data(pk) return Response(status=status.HTTP_204_NO_CONTENT) @@ -3208,450 +2840,3 @@ def rq_exception_handler(rq_job: RQJob, exc_type: type[Exception], exc_value: Ex rq_job_meta.save() return True - -def _import_annotations( - request: ExtendedRequest, - rq_id_factory: Callable[..., RQId], - rq_func: Callable[..., None], - db_obj: Task | Job, - format_name: str, - filename: str = None, - location_conf: dict[str, Any] | None = None, - conv_mask_to_poly: bool = True, -): - - format_desc = {f.DISPLAY_NAME: f - for f in dm.views.get_import_formats()}.get(format_name) - if format_desc is None: - raise serializers.ValidationError( - "Unknown input format '{}'".format(format_name)) - elif not format_desc.ENABLED: - return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) - - rq_id = request.query_params.get('rq_id') - rq_id_should_be_checked = bool(rq_id) - if not rq_id: - rq_id = rq_id_factory(db_obj.pk).render() - - queue = django_rq.get_queue(settings.CVAT_QUEUES.IMPORT_DATA.value) - - # ensure that there is no race condition when processing parallel requests - with get_rq_lock_for_job(queue, rq_id): - rq_job = queue.fetch_job(rq_id) - - if rq_job: - if rq_id_should_be_checked and not is_rq_job_owner(rq_job, request.user.id): - return Response(status=status.HTTP_403_FORBIDDEN) - - if request.method == 'POST': - if rq_job.get_status(refresh=False) not in (RQJobStatus.FINISHED, RQJobStatus.FAILED): - return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists') - - rq_job.delete() - rq_job = None - - if not rq_job: - # If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem - # Then we dont need to create temporary file - # Or filename specify key in cloud storage so we need to download file - location = location_conf.get('location') if location_conf else Location.LOCAL - db_storage = None - - if not filename or location == Location.CLOUD_STORAGE: - if location != Location.CLOUD_STORAGE: - serializer = AnnotationFileSerializer(data=request.data) - if serializer.is_valid(raise_exception=True): - anno_file = serializer.validated_data['annotation_file'] - with NamedTemporaryFile( - prefix='cvat_{}'.format(db_obj.pk), - dir=settings.TMP_FILES_ROOT, - delete=False) as tf: - filename = tf.name - for chunk in anno_file.chunks(): - tf.write(chunk) - else: - assert filename, 'The filename was not specified' - - try: - storage_id = location_conf['storage_id'] - except KeyError: - raise serializers.ValidationError( - 'Cloud storage location was selected as the source,' - ' but cloud storage id was not specified') - db_storage = get_cloud_storage_for_import_or_export( - storage_id=storage_id, request=request, - is_default=location_conf['is_default']) - - key = filename - with NamedTemporaryFile( - prefix='cvat_{}'.format(db_obj.pk), - dir=settings.TMP_FILES_ROOT, - delete=False) as tf: - filename = tf.name - - func = import_resource_with_clean_up_after - func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly) - - if location == Location.CLOUD_STORAGE: - func_args = (db_storage, key, func) + func_args - func = import_resource_from_cloud_storage - - av_scan_paths(filename) - user_id = request.user.id - - with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) - queue.enqueue_call( - func=func, - args=func_args, - job_id=rq_id, - depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), - meta=meta, - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) - - # log events after releasing Redis lock - if not rq_job: - handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None) - - serializer = RqIdSerializer(data={'rq_id': rq_id}) - serializer.is_valid(raise_exception=True) - - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - - # Deprecated logic, /api/requests API should be used instead - # https://greenbytes.de/tech/webdav/draft-ietf-httpapi-deprecation-header-latest.html#the-deprecation-http-response-header-field - deprecation_timestamp = int(datetime(2025, 2, 14, tzinfo=timezone.utc).timestamp()) - response_headers = { - "Deprecation": f"@{deprecation_timestamp}" - } - - rq_job_status = rq_job.get_status(refresh=False) - if RQJobStatus.FINISHED == rq_job_status: - rq_job.delete() - return Response(status=status.HTTP_201_CREATED, headers=response_headers) - elif RQJobStatus.FAILED == rq_job_status: - exc_info = process_failed_job(rq_job) - - import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:' - if exc_info.startswith("Traceback") and import_error_prefix in exc_info: - exc_message = exc_info.split(import_error_prefix)[-1].strip() - return Response(data=exc_message, status=status.HTTP_400_BAD_REQUEST, headers=response_headers) - else: - return Response(data=exc_info, - status=status.HTTP_500_INTERNAL_SERVER_ERROR, headers=response_headers) - - return Response(status=status.HTTP_202_ACCEPTED, headers=response_headers) - -def _import_project_dataset( - request: ExtendedRequest, - rq_id_factory: Callable[..., RQId], - rq_func: Callable[..., None], - db_obj: Project, - format_name: str, - filename: str | None = None, - conv_mask_to_poly: bool = True, - location_conf: dict[str, Any] | None = None -): - format_desc = {f.DISPLAY_NAME: f - for f in dm.views.get_import_formats()}.get(format_name) - if format_desc is None: - raise serializers.ValidationError( - "Unknown input format '{}'".format(format_name)) - elif not format_desc.ENABLED: - return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) - - rq_id = rq_id_factory(db_obj.pk).render() - - queue: DjangoRQ = django_rq.get_queue(settings.CVAT_QUEUES.IMPORT_DATA.value) - - # ensure that there is no race condition when processing parallel requests - with get_rq_lock_for_job(queue, rq_id): - rq_job = queue.fetch_job(rq_id) - - if rq_job: - rq_job_status = rq_job.get_status(refresh=False) - if rq_job_status not in (RQJobStatus.FINISHED, RQJobStatus.FAILED): - return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists') - - # for some reason the previous job has not been deleted - # (e.g the user closed the browser tab when job has been created - # but no one requests for checking status were not made) - rq_job.delete() - rq_job = None - - location = location_conf.get('location') if location_conf else None - db_storage = None - - if not filename and location != Location.CLOUD_STORAGE: - serializer = DatasetFileSerializer(data=request.data) - if serializer.is_valid(raise_exception=True): - dataset_file = serializer.validated_data['dataset_file'] - with NamedTemporaryFile( - prefix='cvat_{}'.format(db_obj.pk), - dir=settings.TMP_FILES_ROOT, - delete=False) as tf: - filename = tf.name - for chunk in dataset_file.chunks(): - tf.write(chunk) - - elif location == Location.CLOUD_STORAGE: - assert filename, 'The filename was not specified' - try: - storage_id = location_conf['storage_id'] - except KeyError: - raise serializers.ValidationError( - 'Cloud storage location was selected as the source,' - ' but cloud storage id was not specified') - db_storage = get_cloud_storage_for_import_or_export( - storage_id=storage_id, request=request, - is_default=location_conf['is_default']) - - key = filename - with NamedTemporaryFile( - prefix='cvat_{}'.format(db_obj.pk), - dir=settings.TMP_FILES_ROOT, - delete=False) as tf: - filename = tf.name - - func = import_resource_with_clean_up_after - func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly) - - if location == Location.CLOUD_STORAGE: - func_args = (db_storage, key, func) + func_args - func = import_resource_from_cloud_storage - - user_id = request.user.id - - with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) - queue.enqueue_call( - func=func, - args=func_args, - job_id=rq_id, - meta=meta, - depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) - - - handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None) - - serializer = RqIdSerializer(data={'rq_id': rq_id}) - serializer.is_valid(raise_exception=True) - - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - -@extend_schema(tags=['requests']) -@extend_schema_view( - list=extend_schema( - summary='List requests', - responses={ - '200': RequestSerializer(many=True), - } - ), - retrieve=extend_schema( - summary='Get request details', - responses={ - '200': RequestSerializer, - } - ), -) -class RequestViewSet(viewsets.GenericViewSet): - # FUTURE-TODO: support re-enqueue action - SUPPORTED_QUEUES = ( - settings.CVAT_QUEUES.IMPORT_DATA.value, - settings.CVAT_QUEUES.EXPORT_DATA.value, - ) - - serializer_class = RequestSerializer - iam_organization_field = None - filter_backends = [ - NonModelSimpleFilter, - NonModelJsonLogicFilter, - NonModelOrderingFilter, - ] - - ordering_fields = ['created_date', 'status', 'action'] - ordering = '-created_date' - - filter_fields = [ - # RQ job fields - 'status', - # derivatives fields (from meta) - 'project_id', - 'task_id', - 'job_id', - # derivatives fields (from parsed rq_id) - 'action', - 'target', - 'subresource', - 'format', - ] - - simple_filters = filter_fields + ['org'] - - lookup_fields = { - 'created_date': 'created_at', - 'action': 'parsed_rq_id.action', - 'target': 'parsed_rq_id.target', - 'subresource': 'parsed_rq_id.subresource', - 'format': 'parsed_rq_id.format', - 'status': 'get_status', - 'project_id': 'meta.project_id', - 'task_id': 'meta.task_id', - 'job_id': 'meta.job_id', - 'org': 'meta.org_slug', - } - - SchemaField = namedtuple('SchemaField', ['type', 'choices'], defaults=(None,)) - - simple_filters_schema = { - 'status': SchemaField('string', RequestStatus.choices), - 'project_id': SchemaField('integer'), - 'task_id': SchemaField('integer'), - 'job_id': SchemaField('integer'), - 'action': SchemaField('string', RequestAction.choices), - 'target': SchemaField('string', RequestTarget.choices), - 'subresource': SchemaField('string', RequestSubresource.choices), - 'format': SchemaField('string'), - 'org': SchemaField('string'), - } - - def get_queryset(self): - return None - - @property - def queues(self) -> Iterable[DjangoRQ]: - return (django_rq.get_queue(queue_name) for queue_name in self.SUPPORTED_QUEUES) - - def _get_rq_jobs_from_queue(self, queue: DjangoRQ, user_id: int) -> list[RQJob]: - job_ids = set(queue.get_job_ids() + - queue.started_job_registry.get_job_ids() + - queue.finished_job_registry.get_job_ids() + - queue.failed_job_registry.get_job_ids() + - queue.deferred_job_registry.get_job_ids() - ) - jobs = [] - for job in queue.job_class.fetch_many(job_ids, queue.connection): - if job and is_rq_job_owner(job, user_id): - try: - parsed_rq_id = RQId.parse(job.id) - except Exception: # nosec B112 - continue - job.parsed_rq_id = parsed_rq_id - jobs.append(job) - - return jobs - - - def _get_rq_jobs(self, user_id: int) -> list[RQJob]: - """ - Get all RQ jobs for a specific user and return them as a list of RQJob objects. - - Parameters: - user_id (int): The ID of the user for whom to retrieve jobs. - - Returns: - List[RQJob]: A list of RQJob objects representing all jobs for the specified user. - """ - all_jobs = [] - for queue in self.queues: - jobs = self._get_rq_jobs_from_queue(queue, user_id) - all_jobs.extend(jobs) - - return all_jobs - - def _get_rq_job_by_id(self, rq_id: str) -> Optional[RQJob]: - """ - Get a RQJob by its ID from the queues. - - Args: - rq_id (str): The ID of the RQJob to retrieve. - - Returns: - Optional[RQJob]: The retrieved RQJob, or None if not found. - """ - try: - parsed_rq_id = RQId.parse(rq_id) - except Exception: - return None - - job: Optional[RQJob] = None - - for queue in self.queues: - job = queue.fetch_job(rq_id) - if job: - job.parsed_rq_id = parsed_rq_id - break - - return job - - def _handle_redis_exceptions(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except RedisConnectionError as ex: - msg = 'Redis service is not available' - slogger.glob.exception(f'{msg}: {str(ex)}') - return Response(msg, status=status.HTTP_503_SERVICE_UNAVAILABLE) - return wrapper - - @method_decorator(never_cache) - @_handle_redis_exceptions - def retrieve(self, request: ExtendedRequest, pk: str): - job = self._get_rq_job_by_id(pk) - - if not job: - return HttpResponseNotFound("There is no request with specified id") - - self.check_object_permissions(request, job) - - serializer = self.get_serializer(job, context={'request': request}) - return Response(data=serializer.data, status=status.HTTP_200_OK) - - @method_decorator(never_cache) - @_handle_redis_exceptions - def list(self, request: ExtendedRequest): - user_id = request.user.id - user_jobs = self._get_rq_jobs(user_id) - - filtered_jobs = self.filter_queryset(user_jobs) - - page = self.paginate_queryset(filtered_jobs) - if page is not None: - serializer = self.get_serializer(page, many=True, context={'request': request}) - return self.get_paginated_response(serializer.data) - - serializer = self.get_serializer(filtered_jobs, many=True, context={'request': request}) - return Response(data=serializer.data, status=status.HTTP_200_OK) - - @extend_schema( - summary='Cancel request', - request=None, - responses={ - '200': OpenApiResponse(description='The request has been cancelled'), - }, - ) - @method_decorator(never_cache) - @action(detail=True, methods=['POST'], url_path='cancel') - @_handle_redis_exceptions - def cancel(self, request: ExtendedRequest, pk: str): - rq_job = self._get_rq_job_by_id(pk) - - if not rq_job: - return HttpResponseNotFound("There is no request with specified id") - - self.check_object_permissions(request, rq_job) - - if rq_job.get_status(refresh=False) not in {RQJobStatus.QUEUED, RQJobStatus.DEFERRED}: - return HttpResponseBadRequest("Only requests that have not yet been started can be cancelled") - - # FUTURE-TODO: race condition is possible here - rq_job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) - rq_job.delete() - - return Response(status=status.HTTP_200_OK) diff --git a/cvat/apps/events/export.py b/cvat/apps/events/export.py index 382949640270..9ac4d383476d 100644 --- a/cvat/apps/events/export.py +++ b/cvat/apps/events/export.py @@ -5,27 +5,34 @@ import csv import os import uuid -from datetime import datetime, timedelta, timezone -from logging import Logger +from datetime import datetime, timedelta import clickhouse_connect -import django_rq from dateutil import parser from django.conf import settings +from django.utils import timezone from rest_framework import serializers, status from rest_framework.response import Response +from rest_framework.reverse import reverse +from cvat.apps.dataset_manager.util import ExportCacheManager from cvat.apps.dataset_manager.views import log_exception from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.rq import BaseRQMeta, RQMetaWithFailureInfo +from cvat.apps.engine.models import RequestAction +from cvat.apps.engine.rq import ExportRequestId, RQMetaWithFailureInfo +from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import sendfile +from cvat.apps.engine.view_utils import deprecate_response +from cvat.apps.events.permissions import EventsPermission +from cvat.apps.redis_handler.background import AbstractExporter slogger = ServerLogManager(__name__) DEFAULT_CACHE_TTL = timedelta(hours=1) +TARGET = "events" -def _create_csv(query_params, output_filename, cache_ttl): +def _create_csv(query_params: dict, output_filename: str): try: clickhouse_settings = settings.CLICKHOUSE["events"] @@ -70,74 +77,120 @@ def _create_csv(query_params, output_filename, cache_ttl): writer.writerow(result.column_names) writer.writerows(result.result_rows) - archive_ctime = os.path.getctime(output_filename) - scheduler = django_rq.get_scheduler(settings.CVAT_QUEUES.EXPORT_DATA.value) - cleaning_job = scheduler.enqueue_in( - time_delta=cache_ttl, - func=_clear_export_cache, - file_path=output_filename, - file_ctime=archive_ctime, - logger=slogger.glob, - ) - slogger.glob.info( - f"The {output_filename} is created " - f"and available for downloading for the next {cache_ttl}. " - f"Export cache cleaning job is enqueued, id '{cleaning_job.id}'" - ) return output_filename except Exception: log_exception(slogger.glob) raise -def export(request, filter_query, queue_name): - action = request.query_params.get("action", None) - filename = request.query_params.get("filename", None) +class EventsExporter(AbstractExporter): - query_params = { - "org_id": filter_query.get("org_id", None), - "project_id": filter_query.get("project_id", None), - "task_id": filter_query.get("task_id", None), - "job_id": filter_query.get("job_id", None), - "user_id": filter_query.get("user_id", None), - "from": filter_query.get("from", None), - "to": filter_query.get("to", None), - } + def __init__( + self, + *, + request: ExtendedRequest, + ) -> None: + super().__init__(request=request) - try: - if query_params["from"]: - query_params["from"] = parser.parse(query_params["from"]).timestamp() - except parser.ParserError: - raise serializers.ValidationError( - f"Cannot parse 'from' datetime parameter: {query_params['from']}" + # temporary arg + if query_id := self.request.query_params.get("query_id"): + self.query_id = uuid.UUID(query_id) + else: + self.query_id = uuid.uuid4() + + def build_request_id(self): + return ExportRequestId( + target=TARGET, + id=self.query_id, + user_id=self.user_id, + ).render() + + def validate_request_id(self, request_id, /) -> None: + parsed_request_id: ExportRequestId = ExportRequestId.parse_and_validate_queue( + request_id, + expected_queue=self.QUEUE_NAME, # try_legacy_format is not set here since deprecated API accepts query_id, not the whole Request ID ) - try: - if query_params["to"]: - query_params["to"] = parser.parse(query_params["to"]).timestamp() - except parser.ParserError: - raise serializers.ValidationError( - f"Cannot parse 'to' datetime parameter: {query_params['to']}" + + if parsed_request_id.action != RequestAction.EXPORT or parsed_request_id.target != TARGET: + raise ValueError("The provided request id does not match exported target") + + def init_request_args(self): + super().init_request_args() + perm = EventsPermission.create_scope_list(self.request) + self.filter_query = perm.filter(self.request.query_params) + + def _init_callback_with_params(self): + self.callback = _create_csv + + query_params = { + "org_id": self.filter_query.get("org_id", None), + "project_id": self.filter_query.get("project_id", None), + "task_id": self.filter_query.get("task_id", None), + "job_id": self.filter_query.get("job_id", None), + "user_id": self.filter_query.get("user_id", None), + "from": self.filter_query.get("from", None), + "to": self.filter_query.get("to", None), + } + + try: + if query_params["from"]: + query_params["from"] = parser.parse(query_params["from"]).timestamp() + except parser.ParserError: + raise serializers.ValidationError( + f"Cannot parse 'from' datetime parameter: {query_params['from']}" + ) + try: + if query_params["to"]: + query_params["to"] = parser.parse(query_params["to"]).timestamp() + except parser.ParserError: + raise serializers.ValidationError( + f"Cannot parse 'to' datetime parameter: {query_params['to']}" + ) + + if ( + query_params["from"] + and query_params["to"] + and query_params["from"] > query_params["to"] + ): + raise serializers.ValidationError("'from' must be before than 'to'") + + # Set the default time interval to last 30 days + if not query_params["from"] and not query_params["to"]: + query_params["to"] = datetime.now(timezone.utc) + query_params["from"] = query_params["to"] - timedelta(days=30) + + output_filename = ExportCacheManager.make_file_path( + file_type="events", file_id=self.query_id, file_ext="csv" ) + self.callback_args = (query_params, output_filename) + + def get_result_endpoint_url(self) -> str: + return reverse("events-download-file", request=self.request) + + def get_result_filename(self): + if self.export_args.filename: + return self.export_args.filename - if query_params["from"] and query_params["to"] and query_params["from"] > query_params["to"]: - raise serializers.ValidationError("'from' must be before than 'to'") + timestamp = self.get_file_timestamp() + return f"logs_{timestamp}.csv" - # Set the default time interval to last 30 days - if not query_params["from"] and not query_params["to"]: - query_params["to"] = datetime.now(timezone.utc) - query_params["from"] = query_params["to"] - timedelta(days=30) +# FUTURE-TODO: delete deprecated function after several releases +def export(request: ExtendedRequest): + action = request.query_params.get("action") if action not in (None, "download"): raise serializers.ValidationError("Unexpected action specified for the request") - query_id = request.query_params.get("query_id", None) or uuid.uuid4() - rq_id = f"export:csv-logs-{query_id}-by-{request.user}" + filename = request.query_params.get("filename") + manager = EventsExporter(request=request) + request_id = manager.build_request_id() + queue = manager.get_queue() + response_data = { - "query_id": query_id, + "query_id": manager.query_id, } - - queue: django_rq.queues.DjangoRQ = django_rq.get_queue(queue_name) - rq_job = queue.fetch_job(rq_id) + deprecation_date = datetime(2025, 3, 17, tzinfo=timezone.utc) + rq_job = queue.fetch_job(request_id) if rq_job: if rq_job.is_finished: @@ -150,35 +203,34 @@ def export(request, filter_query, queue_name): return sendfile(request, file_path, attachment=True, attachment_filename=filename) else: if os.path.exists(file_path): - return Response(status=status.HTTP_201_CREATED) + response = Response(status=status.HTTP_201_CREATED) + deprecate_response(response, deprecation_date=deprecation_date) + return response + elif rq_job.is_failed: rq_job_meta = RQMetaWithFailureInfo.for_job(rq_job) exc_info = rq_job_meta.formatted_exception or str(rq_job.exc_info) rq_job.delete() - return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + response = Response( + exc_info, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response else: - return Response(data=response_data, status=status.HTTP_202_ACCEPTED) - - ttl = DEFAULT_CACHE_TTL.total_seconds() - output_filename = os.path.join(settings.TMP_FILES_ROOT, f"{query_id}.csv") - queue.enqueue_call( - func=_create_csv, - args=(query_params, output_filename, DEFAULT_CACHE_TTL), - job_id=rq_id, - meta=BaseRQMeta.build(request=request, db_obj=None), - result_ttl=ttl, - failure_ttl=ttl, - ) - - return Response(data=response_data, status=status.HTTP_202_ACCEPTED) - - -def _clear_export_cache(file_path: str, file_ctime: float, logger: Logger) -> None: - try: - if os.path.exists(file_path) and os.path.getctime(file_path) == file_ctime: - os.remove(file_path) - - logger.info("Export cache file '{}' successfully removed".format(file_path)) - except Exception: - log_exception(logger) - raise + response = Response( + data=response_data, + status=status.HTTP_202_ACCEPTED, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response + + manager.init_request_args() + # request validation is missed here since exporting to a cloud_storage is disabled + manager._set_default_callback_params() + manager.init_callback_with_params() + manager.setup_new_job(queue, request_id) + + response = Response(data=response_data, status=status.HTTP_202_ACCEPTED) + deprecate_response(response, deprecation_date=deprecation_date) + return response diff --git a/cvat/apps/events/permissions.py b/cvat/apps/events/permissions.py index c5fa706e7f56..e23f35224470 100644 --- a/cvat/apps/events/permissions.py +++ b/cvat/apps/events/permissions.py @@ -3,24 +3,34 @@ # # SPDX-License-Identifier: MIT +from typing import Any + from django.conf import settings from rest_framework.exceptions import PermissionDenied +from cvat.apps.engine.permissions import DownloadExportedExtension +from cvat.apps.engine.types import ExtendedRequest from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum from cvat.utils.http import make_requests_session -class EventsPermission(OpenPolicyAgentPermission): +class EventsPermission(OpenPolicyAgentPermission, DownloadExportedExtension): class Scopes(StrEnum): SEND_EVENTS = "send:events" DUMP_EVENTS = "dump:events" @classmethod - def create(cls, request, view, obj, iam_context): + def create( + cls, request: ExtendedRequest, view, obj: None, iam_context: dict[str, Any] + ) -> list[OpenPolicyAgentPermission]: permissions = [] if view.basename == "events": for scope in cls.get_scopes(request, view, obj): - self = cls.create_base_perm(request, view, scope, iam_context, obj) + scope_params = {} + if DownloadExportedExtension.Scopes.DOWNLOAD_EXPORTED_FILE == scope: + cls.extend_params_with_rq_job_details(request=request, params=scope_params) + + self = cls.create_base_perm(request, view, scope, iam_context, obj, **scope_params) permissions.append(self) return permissions @@ -29,7 +39,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.url = settings.IAM_OPA_DATA_URL + "/events/allow" - def filter(self, query_params): + def filter(self, query_params: dict[str, Any]): url = self.url.replace("/allow", "/filter") with make_requests_session() as session: @@ -47,14 +57,23 @@ def filter(self, query_params): return filter_params @staticmethod - def get_scopes(request, view, obj): + def get_scopes(request: ExtendedRequest, view, obj: None): Scopes = __class__.Scopes return [ { ("create", "POST"): Scopes.SEND_EVENTS, + ("initiate_export", "POST"): Scopes.DUMP_EVENTS, + ("download_file", "GET"): DownloadExportedExtension.Scopes.DOWNLOAD_EXPORTED_FILE, + # deprecated permissions: ("list", "GET"): Scopes.DUMP_EVENTS, }[(view.action, request.method)] ] def get_resource(self): - return None + data = None + + if DownloadExportedExtension.Scopes.DOWNLOAD_EXPORTED_FILE == self.scope: + data = {} + self.extend_resource_with_rq_job_details(data) + + return data diff --git a/cvat/apps/events/rules/events.rego b/cvat/apps/events/rules/events.rego index 58ec43763b2f..dcef46c7c251 100644 --- a/cvat/apps/events/rules/events.rego +++ b/cvat/apps/events/rules/events.rego @@ -6,7 +6,7 @@ import data.utils import data.organizations # input: { -# "scope": <"send:events","dump:events"> or null, +# "scope": <"send:events","dump:events","download:exported_file"> or null, # "auth": { # "user": { # "id": , @@ -22,6 +22,9 @@ import data.organizations # } # } or null, # } +# "resource": { +# "rq_job": { "owner": { "id": } } or null, +# } or null, # } default allow := false @@ -46,6 +49,12 @@ allow if { organizations.has_perm(organizations.WORKER) } +allow if { + input.scope == utils.DOWNLOAD_EXPORTED_FILE + input.auth.user.id == input.resource.rq_job.owner.id +} + + filter := [] if { utils.is_admin utils.is_sandbox diff --git a/cvat/apps/events/views.py b/cvat/apps/events/views.py index 9ed25d71e1fd..22de4cfb76eb 100644 --- a/cvat/apps/events/views.py +++ b/cvat/apps/events/views.py @@ -2,22 +2,84 @@ # # SPDX-License-Identifier: MIT -from django.conf import settings from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema -from rest_framework import status, viewsets +from rest_framework import serializers, status, viewsets +from rest_framework.decorators import action from rest_framework.renderers import JSONRenderer from rest_framework.response import Response +from cvat.apps.engine.location import Location from cvat.apps.engine.log import vlogger -from cvat.apps.events.permissions import EventsPermission +from cvat.apps.engine.types import ExtendedRequest +from cvat.apps.events.export import EventsExporter from cvat.apps.events.serializers import ClientEventsSerializer from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS +from cvat.apps.redis_handler.serializers import RqIdSerializer from .const import USER_ACTIVITY_SCOPE from .export import export from .handlers import handle_client_events_push +api_filter_parameters = ( + OpenApiParameter( + "org_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by organization ID", + ), + OpenApiParameter( + "project_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by project ID", + ), + OpenApiParameter( + "task_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by task ID", + ), + OpenApiParameter( + "job_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by job ID", + ), + OpenApiParameter( + "user_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by user ID", + ), + OpenApiParameter( + "from", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.DATETIME, + required=False, + description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + ), + OpenApiParameter( + "to", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.DATETIME, + required=False, + description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + ), + OpenApiParameter( + "filename", + description="Desired output file name", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + required=False, + ), +) + class EventsViewSet(viewsets.ViewSet): serializer_class = None @@ -51,94 +113,100 @@ def create(self, request): return Response(serializer.validated_data, status=status.HTTP_201_CREATED) + # FUTURE-TODO: remove deprecated API endpoint after several releases @extend_schema( summary="Get an event log", methods=["GET"], description="The log is returned in the CSV format.", parameters=[ + *api_filter_parameters, OpenApiParameter( - "org_id", - location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, - required=False, - description="Filter events by organization ID", - ), - OpenApiParameter( - "project_id", - location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, - required=False, - description="Filter events by project ID", - ), - OpenApiParameter( - "task_id", - location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, - required=False, - description="Filter events by task ID", - ), - OpenApiParameter( - "job_id", - location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, - required=False, - description="Filter events by job ID", - ), - OpenApiParameter( - "user_id", - location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, - required=False, - description="Filter events by user ID", - ), - OpenApiParameter( - "from", + "action", location=OpenApiParameter.QUERY, - type=OpenApiTypes.DATETIME, + description="Used to start downloading process after annotation file had been created", + type=OpenApiTypes.STR, required=False, - description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + enum=["download"], ), OpenApiParameter( - "to", + "query_id", location=OpenApiParameter.QUERY, - type=OpenApiTypes.DATETIME, + type=OpenApiTypes.STR, required=False, - description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + description="ID of query request that need to check or download", ), + ], + responses={ + "200": OpenApiResponse(description="Download of file started"), + "201": OpenApiResponse(description="CSV log file is ready for downloading"), + "202": OpenApiResponse(description="Creating a CSV log file has been started"), + }, + deprecated=True, + ) + def list(self, request: ExtendedRequest): + self.check_permissions(request) + + if ( + request.query_params.get("cloud_storage_id") + or request.query_params.get("location") == Location.CLOUD_STORAGE + ): + raise serializers.ValidationError( + "This endpoint does not support exporting events to cloud storage" + ) + + return export(request=request) + + @extend_schema( + summary="Initiate a process to export events", + request=None, + parameters=[ + *api_filter_parameters, OpenApiParameter( - "filename", - description="Desired output file name", + "location", + description="Where need to save events file", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list(), ), OpenApiParameter( - "action", + "cloud_storage_id", + description="Storage id", location=OpenApiParameter.QUERY, - description="Used to start downloading process after annotation file had been created", - type=OpenApiTypes.STR, + type=OpenApiTypes.INT, required=False, - enum=["download"], ), + ], + responses={ + "202": OpenApiResponse(RqIdSerializer), + }, + ) + @action(detail=False, methods=["POST"], url_path="export") + def initiate_export(self, request: ExtendedRequest): + self.check_permissions(request) + exporter = EventsExporter(request=request) + return exporter.enqueue_job() + + @extend_schema( + summary="Download a prepared file with events", + request=None, + parameters=[ OpenApiParameter( - "query_id", + "rq_id", + description="Request ID", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, - required=False, - description="ID of query request that need to check or download", + required=True, ), ], responses={ "200": OpenApiResponse(description="Download of file started"), - "201": OpenApiResponse(description="CSV log file is ready for downloading"), - "202": OpenApiResponse(description="Creating a CSV log file has been started"), }, + exclude=True, # private API endpoint that should be used only as result_url ) - def list(self, request): - perm = EventsPermission.create_scope_list(request) - filter_query = perm.filter(request.query_params) - return export( - request=request, - filter_query=filter_query, - queue_name=settings.CVAT_QUEUES.EXPORT_DATA.value, - ) + @action(detail=False, methods=["GET"], url_path="download") + def download_file(self, request: ExtendedRequest): + self.check_permissions(request) + + downloader = EventsExporter(request=request).get_downloader() + return downloader.download_file() diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index ff9ad5c9df16..809e907123bb 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -46,7 +46,7 @@ SourceType, Task, ) -from cvat.apps.engine.rq import RQId, define_dependent_job +from cvat.apps.engine.rq import RequestId, define_dependent_job from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import get_rq_lock_by_user, get_rq_lock_for_job @@ -620,7 +620,9 @@ def enqueue( job: Optional[int] = None, ) -> LambdaJob: queue = self._get_queue() - rq_id = RQId(RequestAction.AUTOANNOTATE, RequestTarget.TASK, task).render() + rq_id = RequestId( + action=RequestAction.AUTOANNOTATE, target=RequestTarget.TASK, target_id=task + ).render() # Ensure that there is no race condition when processing parallel requests. # Enqueuing an RQ job with (queue, user) lock but without (queue, rq_id) lock diff --git a/cvat/apps/quality_control/permissions.py b/cvat/apps/quality_control/permissions.py index 25677b2d0480..c95de30d37c2 100644 --- a/cvat/apps/quality_control/permissions.py +++ b/cvat/apps/quality_control/permissions.py @@ -24,6 +24,7 @@ class Scopes(StrEnum): LIST = "list" CREATE = "create" VIEW = "view" + # FUTURE-TODO: deprecated scope, should be removed when related API is removed VIEW_STATUS = "view:status" @classmethod @@ -61,7 +62,9 @@ def create(cls, request, view, obj, iam_context): permissions.append(TaskPermission.create_scope_view(request, task=obj)) elif scope == Scopes.CREATE: # Note: POST /api/quality/reports is used to initiate report creation and to check the process status + # FUTURE-TODO: delete after several releases rq_id = request.query_params.get("rq_id") + # FUTURE-FIXME: use serializers for validation task_id = request.data.get("task_id") if not (task_id or rq_id): diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index fa2030410dcd..c047e7bc660e 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -10,22 +10,19 @@ from collections.abc import Hashable, Sequence from copy import deepcopy from functools import cached_property, partial -from typing import Any, Callable, Optional, TypeVar, Union, cast +from typing import Any, Callable, ClassVar, Optional, TypeVar, Union, cast import datumaro as dm import datumaro.components.annotations.matcher import datumaro.components.comparator import datumaro.util.annotation_util import datumaro.util.mask_tools -import django_rq import numpy as np -import rq from attrs import asdict, define, fields_dict from datumaro.util import dump_json, parse_json from django.conf import settings from django.db import transaction -from django_rq.queues import DjangoRQ as RqQueue -from rq.job import Job as RqJob +from rest_framework import serializers from scipy.optimize import linear_sum_assignment from cvat.apps.dataset_manager.bindings import ( @@ -45,6 +42,7 @@ Image, Job, JobType, + RequestTarget, ShapeType, StageChoice, StatusChoice, @@ -52,9 +50,6 @@ User, ValidationMode, ) -from cvat.apps.engine.rq import BaseRQMeta, define_dependent_job -from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import get_rq_lock_by_user, get_rq_lock_for_job from cvat.apps.profiler import silk_profile from cvat.apps.quality_control import models from cvat.apps.quality_control.models import ( @@ -62,6 +57,8 @@ AnnotationConflictType, AnnotationType, ) +from cvat.apps.quality_control.rq import QualityRequestId +from cvat.apps.redis_handler.background import AbstractRequestManager class Serializable: @@ -2264,96 +2261,54 @@ def generate_report(self) -> ComparisonReport: ) -class QualityReportUpdateManager: - _QUEUE_CUSTOM_JOB_PREFIX = "quality-check-" - _RQ_CUSTOM_QUALITY_CHECK_JOB_TYPE = "custom_quality_check" - _JOB_RESULT_TTL = 120 +class QualityReportRQJobManager(AbstractRequestManager): + QUEUE_NAME = settings.CVAT_QUEUES.QUALITY_REPORTS.value + SUPPORTED_TARGETS: ClassVar[set[RequestTarget]] = {RequestTarget.TASK} + + @property + def job_result_ttl(self): + return 120 - def _get_queue(self) -> RqQueue: - return django_rq.get_queue(settings.CVAT_QUEUES.QUALITY_REPORTS.value) + def get_job_by_id(self, id_, /): + try: + id_ = QualityRequestId.parse_and_validate_queue( + id_, expected_queue=self.QUEUE_NAME, try_legacy_format=True + ).render() + except ValueError: + raise serializers.ValidationError("Provided request ID is invalid") - def _make_custom_quality_check_job_id(self, task_id: int, user_id: int) -> str: - # FUTURE-TODO: it looks like job ID template should not include user_id because: - # 1. There is no need to compute quality reports several times for different users - # 2. Each user (not only rq job owner) that has permission to access a task should - # be able to check the status of the computation process - return f"{self._QUEUE_CUSTOM_JOB_PREFIX}task-{task_id}-user-{user_id}" + return super().get_job_by_id(id_) - class QualityReportsNotAvailable(Exception): - pass + def build_request_id(self): + return QualityRequestId( + target=self.target, + target_id=self.db_instance.pk, + ).render() - def _check_quality_reporting_available(self, task: Task): - if task.dimension != DimensionType.DIM_2D: - raise self.QualityReportsNotAvailable("Quality reports are only supported in 2d tasks") + def validate_request(self): + super().validate_request() - gt_job = task.gt_job + if self.db_instance.dimension != DimensionType.DIM_2D: + raise serializers.ValidationError("Quality reports are only supported in 2d tasks") + + gt_job = self.db_instance.gt_job if gt_job is None or not ( gt_job.stage == StageChoice.ACCEPTANCE and gt_job.state == StatusChoice.COMPLETED ): - raise self.QualityReportsNotAvailable( + raise serializers.ValidationError( "Quality reports require a Ground Truth job in the task " f"at the {StageChoice.ACCEPTANCE} stage " f"and in the {StatusChoice.COMPLETED} state" ) - class JobAlreadyExists(QualityReportsNotAvailable): - def __str__(self): - return "Quality computation job for this task already enqueued" - - def schedule_custom_quality_check_job( - self, request: ExtendedRequest, task: Task, *, user_id: int - ) -> str: - """ - Schedules a quality report computation job, supposed for updates by a request. - """ - - self._check_quality_reporting_available(task) - - queue = self._get_queue() - rq_id = self._make_custom_quality_check_job_id(task_id=task.id, user_id=user_id) - - # ensure that there is no race condition when processing parallel requests - with get_rq_lock_for_job(queue, rq_id): - if rq_job := queue.fetch_job(rq_id): - if rq_job.get_status(refresh=False) in ( - rq.job.JobStatus.QUEUED, - rq.job.JobStatus.STARTED, - rq.job.JobStatus.SCHEDULED, - rq.job.JobStatus.DEFERRED, - ): - raise self.JobAlreadyExists() - - rq_job.delete() - - with get_rq_lock_by_user(queue, user_id=user_id): - dependency = define_dependent_job( - queue, user_id=user_id, rq_id=rq_id, should_be_dependent=True - ) - - queue.enqueue( - self._check_task_quality, - task_id=task.id, - job_id=rq_id, - meta=BaseRQMeta.build(request=request, db_obj=task), - result_ttl=self._JOB_RESULT_TTL, - failure_ttl=self._JOB_RESULT_TTL, - depends_on=dependency, - ) - - return rq_id - - def get_quality_check_job(self, rq_id: str) -> Optional[RqJob]: - queue = self._get_queue() - rq_job = queue.fetch_job(rq_id) - - if rq_job and not self.is_custom_quality_check_job(rq_job): - rq_job = None - - return rq_job + def init_callback_with_params(self): + self.callback = QualityReportUpdateManager._check_task_quality + self.callback_kwargs = { + "task_id": self.db_instance.pk, + } - def is_custom_quality_check_job(self, rq_job: RqJob) -> bool: - return isinstance(rq_job.id, str) and rq_job.id.startswith(self._QUEUE_CUSTOM_JOB_PREFIX) +class QualityReportUpdateManager: @classmethod @silk_profile() def _check_task_quality(cls, *, task_id: int) -> int: diff --git a/cvat/apps/quality_control/rq.py b/cvat/apps/quality_control/rq.py new file mode 100644 index 000000000000..9242db551205 --- /dev/null +++ b/cvat/apps/quality_control/rq.py @@ -0,0 +1,27 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import ClassVar + +import attrs + +from cvat.apps.redis_handler.rq import RequestIdWithSubresource + + +@attrs.frozen(kw_only=True, slots=False) +class QualityRequestId(RequestIdWithSubresource): + ACTION_DEFAULT_VALUE: ClassVar[str] = "calculate" + ACTION_ALLOWED_VALUES: ClassVar[tuple[str]] = (ACTION_DEFAULT_VALUE,) + + SUBRESOURCE_DEFAULT_VALUE: ClassVar[str] = "quality" + SUBRESOURCE_ALLOWED_VALUES: ClassVar[tuple[str]] = (SUBRESOURCE_DEFAULT_VALUE,) + + QUEUE_SELECTORS: ClassVar[tuple[tuple[str, str]]] = ( + (ACTION_DEFAULT_VALUE, SUBRESOURCE_DEFAULT_VALUE), + ) + + # will be deleted after several releases + LEGACY_FORMAT_PATTERNS = ( + r"quality-check-(?Ptask)-(?P\d+)-user-(\d+)", # user id is excluded in the new format + ) diff --git a/cvat/apps/quality_control/views.py b/cvat/apps/quality_control/views.py index 239ff6de8144..8f60bf5ee088 100644 --- a/cvat/apps/quality_control/views.py +++ b/cvat/apps/quality_control/views.py @@ -3,9 +3,11 @@ # SPDX-License-Identifier: MIT import textwrap +from datetime import datetime from django.db.models import Q from django.http import HttpResponse +from django.utils import timezone from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( OpenApiParameter, @@ -22,8 +24,9 @@ from cvat.apps.engine.mixins import PartialUpdateModelMixin from cvat.apps.engine.models import Task from cvat.apps.engine.rq import BaseRQMeta -from cvat.apps.engine.serializers import RqIdSerializer +from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import get_server_url +from cvat.apps.engine.view_utils import deprecate_response from cvat.apps.quality_control import quality_reports as qc from cvat.apps.quality_control.models import ( AnnotationConflict, @@ -42,6 +45,7 @@ QualityReportSerializer, QualitySettingsSerializer, ) +from cvat.apps.redis_handler.serializers import RqIdSerializer @extend_schema(tags=["quality"]) @@ -224,6 +228,12 @@ def get_queryset(self): @extend_schema( operation_id="quality_create_report", summary="Create a quality report", + description=textwrap.dedent( + """\ + Deprecation warning: Utilizing this endpoint to check the computation status is no longer possible. + Consider using common requests API: GET /api/requests/ + """ + ), parameters=[ OpenApiParameter( CREATE_REPORT_RQ_ID_PARAMETER, @@ -234,6 +244,7 @@ def get_queryset(self): creation status. """ ), + deprecated=True, ) ], request=QualityReportCreateSerializer(required=False), @@ -257,7 +268,7 @@ def get_queryset(self): ), }, ) - def create(self, request, *args, **kwargs): + def create(self, request: ExtendedRequest, *args, **kwargs): self.check_permissions(request) rq_id = request.query_params.get(self.CREATE_REPORT_RQ_ID_PARAMETER, None) @@ -273,22 +284,16 @@ def create(self, request, *args, **kwargs): except Task.DoesNotExist as ex: raise NotFound(f"Task {task_id} does not exist") from ex - try: - rq_id = qc.QualityReportUpdateManager().schedule_custom_quality_check_job( - request=request, task=task, user_id=request.user.id - ) - serializer = RqIdSerializer({"rq_id": rq_id}) - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) - except qc.QualityReportUpdateManager.QualityReportsNotAvailable as ex: - raise ValidationError(str(ex)) + manager = qc.QualityReportRQJobManager(request=request, db_instance=task) + return manager.enqueue_job() else: + deprecation_date = datetime(2025, 3, 17, tzinfo=timezone.utc) serializer = RqIdSerializer(data={"rq_id": rq_id}) serializer.is_valid(raise_exception=True) rq_id = serializer.validated_data["rq_id"] + rq_job = qc.QualityReportRQJobManager(request=request).get_job_by_id(rq_id) - report_manager = qc.QualityReportUpdateManager() - rq_job = report_manager.get_quality_check_job(rq_id) # FUTURE-TODO: move into permissions # and allow not only rq job owner to check the status if ( @@ -300,36 +305,60 @@ def create(self, request, *args, **kwargs): .allow ): # We should not provide job existence information to unauthorized users - raise NotFound("Unknown request id") + response = Response( + "Unknown request id", + status=status.HTTP_404_NOT_FOUND, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response rq_job_status = rq_job.get_status(refresh=False) if rq_job_status == RqJobStatus.FAILED: message = str(rq_job.exc_info) rq_job.delete() - raise ValidationError(message) + response = Response( + message, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response + elif rq_job_status in ( RqJobStatus.QUEUED, RqJobStatus.STARTED, RqJobStatus.SCHEDULED, RqJobStatus.DEFERRED, ): - return Response(serializer.data, status=status.HTTP_202_ACCEPTED) + response = Response( + serializer.data, + status=status.HTTP_202_ACCEPTED, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response + elif rq_job_status == RqJobStatus.FINISHED: return_value = rq_job.return_value() rq_job.delete() if not return_value: - raise ValidationError("No report has been computed") + response = Response( + "No report has been computed", + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + deprecate_response(response, deprecation_date=deprecation_date) + return response report = self.get_queryset().get(pk=return_value) report_serializer = QualityReportSerializer( instance=report, context={"request": request} ) - return Response( + response = Response( data=report_serializer.data, status=status.HTTP_201_CREATED, headers=self.get_success_headers(report_serializer.data), ) + deprecate_response(response, deprecation_date=deprecation_date) + return response raise AssertionError(f"Unexpected rq job '{rq_id}' status '{rq_job_status}'") diff --git a/cvat/apps/redis_handler/apps.py b/cvat/apps/redis_handler/apps.py index a00543165e7f..f6edc97d9311 100644 --- a/cvat/apps/redis_handler/apps.py +++ b/cvat/apps/redis_handler/apps.py @@ -3,8 +3,76 @@ # SPDX-License-Identifier: MIT +from contextlib import suppress +from typing import cast + from django.apps import AppConfig +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured +from django.utils.module_loading import import_string + + +class LayeredKeyDict(dict): + def __getitem__(self, key: str | tuple) -> str: + if isinstance(key, tuple) and (len(key) == 3): # action, target, subresource + with suppress(KeyError): + return self.__getitem__(key[0]) + return self.__getitem__((key[0], key[2])) # (action, subresource) + return super().__getitem__(key) + + +SELECTOR_TO_QUEUE = LayeredKeyDict() +QUEUE_TO_PARSED_JOB_ID_CLS = {} + +REQUEST_ID_SUBCLASSES = set() + + +def initialize_mappings(): + from cvat.apps.redis_handler.rq import RequestId + + def init_subclasses(cur_cls: type[RequestId] = RequestId): + for subclass in cur_cls.__subclasses__(): + REQUEST_ID_SUBCLASSES.add(subclass) + init_subclasses(subclass) + + for queue_name, queue_conf in settings.RQ_QUEUES.items(): + if path_to_parsed_job_id_cls := queue_conf.get("PARSED_JOB_ID_CLASS"): + parsed_job_id_cls = import_string(path_to_parsed_job_id_cls) + + if not issubclass(parsed_job_id_cls, RequestId): + raise ImproperlyConfigured( + f"The {path_to_parsed_job_id_cls!r} must be inherited from the RequestId class" + ) + + for queue_selector in parsed_job_id_cls.QUEUE_SELECTORS: + if not isinstance(queue_selector, (tuple, str)): + raise ImproperlyConfigured("Wrong queue selector, must be either tuple or str") + SELECTOR_TO_QUEUE[queue_selector] = queue_name + + QUEUE_TO_PARSED_JOB_ID_CLS[queue_name] = parsed_job_id_cls + + init_subclasses() + # check that each subclass that has QUEUE_SELECTORS can be used to determine the queue + for subclass in REQUEST_ID_SUBCLASSES: + subclass = cast(RequestId, subclass) + if subclass.LEGACY_FORMAT_PATTERNS and not subclass.QUEUE_SELECTORS: + raise ImproperlyConfigured( + f"Subclass {subclass.__name__} has LEGACY_FORMAT_PATTERNS - QUEUE_SELECTORS must be defined" + ) + + if subclass.QUEUE_SELECTORS: + for queue_selector in subclass.QUEUE_SELECTORS: + if not SELECTOR_TO_QUEUE.get(queue_selector): + raise ImproperlyConfigured( + f"Queue selector {queue_selector!r} for the class {subclass.__name__!r} is missed in the queue configuration" + ) class RedisHandlerConfig(AppConfig): name = "cvat.apps.redis_handler" + + def ready(self) -> None: + from cvat.apps.iam.permissions import load_app_permissions + + load_app_permissions(self) + initialize_mappings() diff --git a/cvat/apps/redis_handler/background.py b/cvat/apps/redis_handler/background.py new file mode 100644 index 000000000000..29a717263cf3 --- /dev/null +++ b/cvat/apps/redis_handler/background.py @@ -0,0 +1,371 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import os.path as osp +from abc import ABCMeta, abstractmethod +from dataclasses import asdict as dataclass_asdict +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Callable, ClassVar +from urllib.parse import quote + +import django_rq +from django.conf import settings +from django.db.models import Model +from django.http.response import HttpResponseBadRequest +from django.utils import timezone +from django_rq.queues import DjangoRQ, DjangoScheduler +from rest_framework import serializers, status +from rest_framework.response import Response +from rq.job import Job as RQJob +from rq.job import JobStatus as RQJobStatus + +from cvat.apps.dataset_manager.util import get_export_cache_lock +from cvat.apps.engine.cloud_provider import export_resource_to_cloud_storage +from cvat.apps.engine.location import ( + Location, + LocationConfig, + StorageType, + get_location_configuration, +) +from cvat.apps.engine.log import ServerLogManager +from cvat.apps.engine.models import RequestTarget +from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export +from cvat.apps.engine.rq import BaseRQMeta, ExportRQMeta, define_dependent_job +from cvat.apps.engine.types import ExtendedRequest +from cvat.apps.engine.utils import get_rq_lock_by_user, get_rq_lock_for_job, sendfile +from cvat.apps.redis_handler.serializers import RqIdSerializer + +slogger = ServerLogManager(__name__) + +REQUEST_TIMEOUT = 60 +# it's better to return LockNotAvailableError instead of response with 504 status +LOCK_TTL = REQUEST_TIMEOUT - 5 +LOCK_ACQUIRE_TIMEOUT = LOCK_TTL - 5 + + +class AbstractRequestManager(metaclass=ABCMeta): + SUPPORTED_TARGETS: ClassVar[set[RequestTarget] | None] = None + QUEUE_NAME: ClassVar[str] + REQUEST_ID_KEY = "rq_id" + + callback: Callable + callback_args: tuple | None + callback_kwargs: dict[str, Any] | None + + def __init__( + self, + *, + request: ExtendedRequest, + db_instance: Model | None = None, + ) -> None: + self.request = request + self.user_id = request.user.id + self.db_instance = db_instance + + if db_instance: + assert self.SUPPORTED_TARGETS, "Should be defined" + self.target = RequestTarget(db_instance.__class__.__name__.lower()) + assert self.target in self.SUPPORTED_TARGETS, f"Unsupported target: {self.target}" + + @classmethod + def get_queue(cls) -> DjangoRQ: + return django_rq.get_queue(cls.QUEUE_NAME) + + @property + def job_result_ttl(self) -> int | None: + """ + Time to live for successful job result in seconds, + if not set, the default result TTL will be used + """ + return None + + @property + def job_failed_ttl(self) -> int | None: + """ + Time to live for failures in seconds, + if not set, the default failure TTL will be used + """ + return None + + @abstractmethod + def build_request_id(self): ... + + def validate_request_id(self, request_id: str, /) -> None: ... + + def get_job_by_id(self, id_: str, /) -> RQJob | None: + try: + self.validate_request_id(id_) + except Exception: + return None + + queue = self.get_queue() + return queue.fetch_job(id_) + + def init_request_args(self): + """ + Hook to initialize operation args based on the request + """ + + @abstractmethod + def init_callback_with_params(self) -> None: + """ + Method should initialize callback function with its args/kwargs: + + self.callback = ... + (optional) self.callback_args = ... + (optional) self.callback_kwargs = ... + """ + + def _set_default_callback_params(self): + self.callback_args = None + self.callback_kwargs = None + + def validate_request(self) -> Response | None: + """Hook to run some validations before processing a request""" + + # prevent architecture bugs + assert ( + "POST" == self.request.method + ), "Only POST requests can be used to initiate a background process" + + def handle_existing_job(self, job: RQJob | None, queue: DjangoRQ) -> Response | None: + if not job: + return None + + job_status = job.get_status(refresh=False) + + if job_status in {RQJobStatus.STARTED, RQJobStatus.QUEUED}: + return Response( + data="Request is being processed", + status=status.HTTP_409_CONFLICT, + ) + + if job_status == RQJobStatus.DEFERRED: + job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) + + if job_status == RQJobStatus.SCHEDULED: + scheduler: DjangoScheduler = django_rq.get_scheduler(queue.name, queue=queue) + # remove the job id from the set with scheduled keys + scheduler.cancel(job) + job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) + + job.delete() + return None + + def build_meta(self, *, request_id: str) -> dict[str, Any]: + return BaseRQMeta.build(request=self.request, db_obj=self.db_instance) + + def setup_new_job(self, queue: DjangoRQ, request_id: str, /, **kwargs): + with get_rq_lock_by_user(queue, self.user_id): + queue.enqueue_call( + func=self.callback, + args=self.callback_args, + kwargs=self.callback_kwargs, + job_id=request_id, + meta=self.build_meta(request_id=request_id), + depends_on=define_dependent_job(queue, self.user_id, rq_id=request_id), + result_ttl=self.job_result_ttl, + failure_ttl=self.job_failed_ttl, + **kwargs, + ) + + def finalize_request(self) -> None: + """Hook to run some actions (e.g. collect events) after processing a request""" + + def get_response(self, request_id: str) -> Response: + serializer = RqIdSerializer({"rq_id": request_id}) + return Response(serializer.data, status=status.HTTP_202_ACCEPTED) + + def enqueue_job(self) -> Response: + self.init_request_args() + self.validate_request() + self._set_default_callback_params() + self.init_callback_with_params() + + queue: DjangoRQ = django_rq.get_queue(self.QUEUE_NAME) + request_id = self.build_request_id() + + # ensure that there is no race condition when processing parallel requests + with get_rq_lock_for_job(queue, request_id): + job = queue.fetch_job(request_id) + + if response := self.handle_existing_job(job, queue): + return response + + self.setup_new_job(queue, request_id) + + self.finalize_request() + return self.get_response(request_id) + + +class AbstractExporter(AbstractRequestManager): + + class Downloader: + def __init__( + self, + *, + request: ExtendedRequest, + queue: DjangoRQ, + request_id: str, + ): + self.request = request + self.queue = queue + self.request_id = request_id + + def validate_request(self): + # prevent architecture bugs + assert self.request.method in ( + "GET", + "HEAD", + ), "Only GET/HEAD requests can be used to download a file" + + def download_file(self) -> Response: + self.validate_request() + + # ensure that there is no race condition when processing parallel requests + with get_rq_lock_for_job(self.queue, self.request_id): + job = self.queue.fetch_job(self.request_id) + + if not job: + return HttpResponseBadRequest("Unknown export request id") + + # define status once to avoid refreshing it on each check + # FUTURE-TODO: get_status will raise InvalidJobOperation exception instead of returning None in one of the next releases + job_status = job.get_status(refresh=False) + + if job_status != RQJobStatus.FINISHED: + return HttpResponseBadRequest("The export process is not finished") + + job_meta = ExportRQMeta.for_job(job) + file_path = job.return_value() + + if not file_path: + return ( + Response( + "A result for exporting job was not found for finished RQ job", + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if job_meta.result_url + # user tries to download a final file locally while the export is made to cloud storage + else HttpResponseBadRequest( + "The export process has no result file to be downloaded locally" + ) + ) + + with get_export_cache_lock( + file_path, ttl=LOCK_TTL, acquire_timeout=LOCK_ACQUIRE_TIMEOUT + ): + if not osp.exists(file_path): + return Response( + "The exported file has expired, please retry exporting", + status=status.HTTP_404_NOT_FOUND, + ) + + return sendfile( + self.request, + file_path, + attachment=True, + attachment_filename=job_meta.result_filename, + ) + + @dataclass + class ExportArgs: + filename: str | None + location_config: LocationConfig + + def to_dict(self): + return dataclass_asdict(self) + + QUEUE_NAME = settings.CVAT_QUEUES.EXPORT_DATA.value + + export_args: ExportArgs | None + + @property + def job_result_ttl(self): + from cvat.apps.dataset_manager.views import get_export_cache_ttl + + return int(get_export_cache_ttl(self.db_instance).total_seconds()) + + @property + def job_failed_ttl(self): + return self.job_result_ttl + + @abstractmethod + def get_result_filename(self) -> str: ... + + @abstractmethod + def get_result_endpoint_url(self) -> str: ... + + def make_result_url(self, *, request_id: str) -> str: + return self.get_result_endpoint_url() + f"?{self.REQUEST_ID_KEY}={quote(request_id)}" + + def get_file_timestamp(self) -> str: + # use only updated_date for the related resource, don't check children objects + # because every child update should touch the updated_date of the parent resource + date = self.db_instance.updated_date if self.db_instance else timezone.now() + return datetime.strftime(date, "%Y_%m_%d_%H_%M_%S") + + def init_request_args(self) -> None: + try: + location_config = get_location_configuration( + db_instance=self.db_instance, + query_params=self.request.query_params, + field_name=StorageType.TARGET, + ) + except ValueError as ex: + raise serializers.ValidationError(str(ex)) from ex + + self.export_args = AbstractExporter.ExportArgs( + location_config=location_config, filename=self.request.query_params.get("filename") + ) + + @abstractmethod + def _init_callback_with_params(self): + """ + Private method that should initialize callback function with its args/kwargs + like the init_callback_with_params method in the parent class. + """ + + def init_callback_with_params(self): + """ + Method should not be overridden + """ + self._init_callback_with_params() + + if self.export_args.location_config.location == Location.CLOUD_STORAGE: + storage_id = self.export_args.location_config.cloud_storage_id + db_storage = get_cloud_storage_for_import_or_export( + storage_id=storage_id, + request=self.request, + is_default=self.export_args.location_config.is_default, + ) + + self.callback_args = (db_storage, self.callback) + self.callback_args + self.callback = export_resource_to_cloud_storage + + def build_meta(self, *, request_id): + return ExportRQMeta.build_for( + request=self.request, + db_obj=self.db_instance, + result_url=( + self.make_result_url(request_id=request_id) + if self.export_args.location_config.location != Location.CLOUD_STORAGE + else None + ), + result_filename=self.get_result_filename(), + ) + + def get_downloader(self): + request_id = self.request.query_params.get(self.REQUEST_ID_KEY) + + if not request_id: + raise serializers.ValidationError("Missing request id in the query parameters") + + try: + self.validate_request_id(request_id) + except ValueError: + raise serializers.ValidationError("Invalid export request id") + + return self.Downloader(request=self.request, queue=self.get_queue(), request_id=request_id) diff --git a/cvat/apps/redis_handler/permissions.py b/cvat/apps/redis_handler/permissions.py new file mode 100644 index 000000000000..51b3d047a37a --- /dev/null +++ b/cvat/apps/redis_handler/permissions.py @@ -0,0 +1,99 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.conf import settings +from rq.job import Job as RQJob + +from cvat.apps.engine.types import ExtendedRequest +from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum + +if TYPE_CHECKING: + from rest_framework.viewsets import ViewSet + +from cvat.apps.engine.models import RequestTarget +from cvat.apps.engine.permissions import JobPermission, ProjectPermission, TaskPermission +from cvat.apps.engine.rq import BaseRQMeta +from cvat.apps.redis_handler.rq import CustomRQJob + + +class RequestPermission(OpenPolicyAgentPermission): + + class Scopes(StrEnum): + LIST = "list" + VIEW = "view" + DELETE = "delete" + + @classmethod + def create( + cls, request: ExtendedRequest, view: ViewSet, obj: CustomRQJob | None, iam_context: dict + ) -> list[OpenPolicyAgentPermission]: + permissions = [] + if view.basename == "request": + for scope in cls.get_scopes(request, view, obj): + if scope == cls.Scopes.LIST: + continue + elif scope == cls.Scopes.VIEW: + parsed_request_id = obj.parsed_id + + # In case when background job is unique for a user, status check should be available only for this user/admin + # In other cases, status check should be available for all users that have target resource VIEW permission + if parsed_request_id.user_id: + job_owner = BaseRQMeta.for_job(obj).user + assert job_owner and job_owner.id == parsed_request_id.user_id + + elif parsed_request_id.target_id is not None: + if parsed_request_id.target == RequestTarget.PROJECT.value: + permissions.append( + ProjectPermission.create_scope_view( + request, parsed_request_id.target_id + ) + ) + continue + elif parsed_request_id.target == RequestTarget.TASK.value: + permissions.append( + TaskPermission.create_scope_view( + request, parsed_request_id.target_id + ) + ) + continue + elif parsed_request_id.target == RequestTarget.JOB.value: + permissions.append( + JobPermission.create_scope_view( + request, parsed_request_id.target_id + ) + ) + continue + assert False, "Unsupported operation on resource" + + self = cls.create_base_perm(request, view, scope, iam_context, obj) + permissions.append(self) + + return permissions + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.url = settings.IAM_OPA_DATA_URL + "/requests/allow" + + @staticmethod + def get_scopes(request: ExtendedRequest, view: ViewSet, obj: RQJob | None) -> list[Scopes]: + return [ + { + ("list", "GET"): __class__.Scopes.LIST, + ("retrieve", "GET"): __class__.Scopes.VIEW, + ("cancel", "POST"): __class__.Scopes.DELETE, + }[(view.action, request.method)] + ] + + def get_resource(self): + if self.obj and (owner := BaseRQMeta.for_job(self.obj).user): + return { + "owner": { + "id": owner.id, + }, + } + return None diff --git a/cvat/apps/redis_handler/rq.py b/cvat/apps/redis_handler/rq.py new file mode 100644 index 000000000000..15b82d50b4d1 --- /dev/null +++ b/cvat/apps/redis_handler/rq.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import re +import urllib.parse +from types import NoneType +from typing import Any, ClassVar, Protocol +from uuid import UUID + +import attrs +from django.utils.html import escape +from rq.job import Job as RQJob + +from cvat.apps.redis_handler.apps import ( + QUEUE_TO_PARSED_JOB_ID_CLS, + REQUEST_ID_SUBCLASSES, + SELECTOR_TO_QUEUE, +) + + +class IncorrectRequestIdError(ValueError): + pass + + +def _default_from_class_attr(attr_name: str): + def factory(self): + cls = type(self) + if attrs_value := getattr(cls, attr_name, None): + return attrs_value + raise AttributeError( + f"[{cls.__name__}] Unable to set default value for the {attr_name} attribute" + ) + + return attrs.Factory(factory, takes_self=True) + + +@attrs.frozen(kw_only=True, slots=False) # to be able to inherit from RequestId +class RequestId: + # https://datatracker.ietf.org/doc/html/rfc3986#section-2.3 - ALPHA / DIGIT / "-" / "." / "_" / "~" + UNRESERVED_BY_RFC3986_SPECIAL_CHARACTERS: ClassVar[tuple[str]] = ("-", ".", "_", "~") + ENCODE_MAPPING = { + ".": "~", # dot is a default DRF path parameter pattern + " ": "_", + } + + # "&" and "=" characters are reserved sub-delims symbols, but request ID is going to be used only as path parameter + FIELD_SEP: ClassVar[str] = "&" + KEY_VAL_SEP: ClassVar[str] = "=" + TYPE_SEP: ClassVar[str] = ":" # used in serialization logic + + STR_WITH_UNRESERVED_SPECIAL_CHARACTERS: ClassVar[str] = "".join( + re.escape(c) + for c in ( + set(UNRESERVED_BY_RFC3986_SPECIAL_CHARACTERS) + - {FIELD_SEP, KEY_VAL_SEP, *ENCODE_MAPPING.values()} + ) + ) + VALIDATION_PATTERN: ClassVar[str] = rf"[\w{STR_WITH_UNRESERVED_SPECIAL_CHARACTERS}]+" + + action: str = attrs.field( + validator=attrs.validators.instance_of(str), + default=_default_from_class_attr("ACTION_DEFAULT_VALUE"), + ) + ACTION_ALLOWED_VALUES: ClassVar[tuple[str]] + QUEUE_SELECTORS: ClassVar[tuple] = () + + @action.validator + def validate_action(self, attribute: attrs.Attribute, value: Any): + if hasattr(self, "ACTION_ALLOWED_VALUES") and value not in self.ACTION_ALLOWED_VALUES: + raise ValueError(f"Action must be one of {self.ACTION_ALLOWED_VALUES!r}") + + target: str = attrs.field(validator=attrs.validators.instance_of(str)) + target_id: int | None = attrs.field( + converter=lambda x: x if x is None else int(x), default=None + ) + + id: UUID | None = attrs.field( + converter=lambda x: x if isinstance(x, (NoneType, UUID)) else UUID(x), + default=None, + ) # operation id + user_id: int | None = attrs.field(converter=lambda x: x if x is None else int(x), default=None) + + # FUTURE-TODO: remove after several releases + # backward compatibility with previous ID formats + LEGACY_FORMAT_PATTERNS: ClassVar[tuple[str]] = () + + def __attrs_post_init__(self): + assert ( + sum(1 for i in (self.target_id, self.id) if i) == 1 + ), "Only one of target_id or id should be set" + + @property + def type(self) -> str: + return self.TYPE_SEP.join([self.action, self.target]) + + def to_dict(self) -> dict[str, Any]: + return attrs.asdict(self, filter=lambda _, v: bool(v)) + + @classmethod + def normalize(cls, repr_: dict[str, Any]) -> None: + for key, value in repr_.items(): + str_value = str(value) + if not re.match(cls.VALIDATION_PATTERN, str_value): + raise IncorrectRequestIdError( + f"{key} does not match allowed format: {cls.VALIDATION_PATTERN}" + ) + + for from_char, to_char in cls.ENCODE_MAPPING.items(): + str_value = str_value.replace(from_char, to_char) + + repr_[key] = str_value + + def render(self) -> str: + rq_id_repr = self.to_dict() + + # rq_id is going to be used in urls as path parameter, so it should be URL safe. + self.normalize(rq_id_repr) + # urllib.parse.quote/urllib.parse.urlencode are not used here because: + # - it's client logic to encode request ID + # - return value is used as RQ job ID and should be + # a. in a decoded state + # b. readable + return self.FIELD_SEP.join([f"{k}{self.KEY_VAL_SEP}{v}" for k, v in rq_id_repr.items()]) + + @classmethod + def parse( + cls, + request_id: str, + /, + *, + try_legacy_format: bool = False, + ) -> tuple[RequestId, str]: + + actual_cls = cls + queue: str | None = None + dict_repr = {} + fragments = {} + + try: + # try to parse ID as key=value pairs (newly introduced format) + fragments = dict(urllib.parse.parse_qsl(request_id)) + + if not fragments: + # try to use legacy format + if not try_legacy_format: + raise IncorrectRequestIdError( + f"Unable to parse request ID: {escape(request_id)!r}" + ) + + match: re.Match | None = None + + for subclass in REQUEST_ID_SUBCLASSES if cls is RequestId else (cls,): + for pattern in subclass.LEGACY_FORMAT_PATTERNS: + match = re.match(pattern, request_id) + if match: + actual_cls = subclass + break + if match: + break + else: + raise IncorrectRequestIdError( + f"Unable to parse request ID: {escape(request_id)!r}" + ) + + queue = SELECTOR_TO_QUEUE[ + actual_cls.QUEUE_SELECTORS[0] + ] # each selector match the same queue + fragments = match.groupdict() + # "." was replaced with "@" in previous format + if "format" in fragments: + fragments["format"] = fragments["format"].replace("@", cls.ENCODE_MAPPING["."]) + + # init dict representation for request ID + for key, value in fragments.items(): + for to_char, from_char in cls.ENCODE_MAPPING.items(): + value = value.replace(from_char, to_char) + + dict_repr[key] = value + + if not queue: + # try to define queue dynamically based on action/target/subresource + queue = SELECTOR_TO_QUEUE[ + (dict_repr["action"], dict_repr["target"], dict_repr.get("subresource")) + ] + + # queue that could be determined using SELECTOR_TO_QUEUE + # must also be included into QUEUE_TO_PARSED_JOB_ID_CLS + assert queue in QUEUE_TO_PARSED_JOB_ID_CLS + actual_cls = QUEUE_TO_PARSED_JOB_ID_CLS[queue] + + assert issubclass(actual_cls, cls) + result = actual_cls(**dict_repr) + + return (result, queue) + except AssertionError: + raise + except Exception as ex: + raise IncorrectRequestIdError from ex + + @classmethod + def parse_and_validate_queue( + cls, + request_id: str, + /, + *, + expected_queue: str, + try_legacy_format: bool = False, + ) -> RequestId: + parsed_request_id, queue = cls.parse(request_id, try_legacy_format=try_legacy_format) + assert queue == expected_queue + return parsed_request_id + + +@attrs.frozen(kw_only=True, slots=False) +class RequestIdWithSubresource(RequestId): + SUBRESOURCE_ALLOWED_VALUES: ClassVar[tuple[str]] + + subresource: str = attrs.field( + validator=attrs.validators.instance_of(str), + default=_default_from_class_attr("SUBRESOURCE_DEFAULT_VALUE"), + ) + + @subresource.validator + def validate_subresource(self, attribute: attrs.Attribute, value: Any): + if value not in self.SUBRESOURCE_ALLOWED_VALUES: + raise ValueError(f"Subresource must be one of {self.SUBRESOURCE_ALLOWED_VALUES!r}") + + @property + def type(self) -> str: + return self.TYPE_SEP.join([self.action, self.subresource]) + + +@attrs.frozen(kw_only=True, slots=False) +class RequestIdWithOptionalSubresource(RequestIdWithSubresource): + subresource: str | None = attrs.field( + validator=attrs.validators.instance_of((str, NoneType)), default=None + ) + + @subresource.validator + def validate_subresource(self, attribute: attrs.Attribute, value: Any): + if value is not None: + super().validate_subresource(attribute, value) + + @property + def type(self) -> str: + return self.TYPE_SEP.join([self.action, self.subresource or self.target]) + + +class _WithParsedId(Protocol): + parsed_id: RequestId + + +class CustomRQJob(RQJob, _WithParsedId): + pass diff --git a/cvat/apps/redis_handler/rules/requests.rego b/cvat/apps/redis_handler/rules/requests.rego new file mode 100644 index 000000000000..eaf7d8114052 --- /dev/null +++ b/cvat/apps/redis_handler/rules/requests.rego @@ -0,0 +1,39 @@ +package requests + +import rego.v1 + +import data.utils +import data.organizations + +# input: { +# "scope": <"view"|"delete"> or null, +# "auth": { +# "user": { +# "id": , +# "privilege": <"admin"|"user"|"worker"> or null +# }, +# "organization": { +# "id": , +# "owner": { +# "id": +# }, +# "user": { +# "role": <"owner"|"maintainer"|"supervisor"|"worker"> or null +# } +# } or null, +# }, +# "resource": { +# "owner": { "id": } or null, +# } +# } + +default allow := false + +allow if { + utils.is_admin +} + +allow if { + input.scope in {utils.VIEW, utils.DELETE} + input.auth.user.id == input.resource.owner.id +} diff --git a/cvat/apps/redis_handler/serializers.py b/cvat/apps/redis_handler/serializers.py new file mode 100644 index 000000000000..f7ebd9b1fcf6 --- /dev/null +++ b/cvat/apps/redis_handler/serializers.py @@ -0,0 +1,168 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from datetime import timedelta +from decimal import Decimal +from typing import Any +from uuid import UUID + +import rq.defaults as rq_defaults +from django.db.models import TextChoices +from django.utils import timezone +from drf_spectacular.utils import extend_schema_field +from rest_framework import serializers +from rq.job import JobStatus as RQJobStatus + +from cvat.apps.engine import models +from cvat.apps.engine.log import ServerLogManager +from cvat.apps.engine.models import RequestAction +from cvat.apps.engine.rq import BaseRQMeta, ExportRQMeta, ImportRQMeta, RequestIdWithOptionalFormat +from cvat.apps.engine.serializers import BasicUserSerializer +from cvat.apps.engine.utils import parse_exception_message +from cvat.apps.lambda_manager.rq import LambdaRQMeta +from cvat.apps.redis_handler.rq import CustomRQJob, RequestId + +slogger = ServerLogManager(__name__) + + +class RequestStatus(TextChoices): + QUEUED = "queued" + STARTED = "started" + FAILED = "failed" + FINISHED = "finished" + + +class RqIdSerializer(serializers.Serializer): + rq_id = serializers.CharField(help_text="Request id") + + +class UserIdentifiersSerializer(BasicUserSerializer): + class Meta(BasicUserSerializer.Meta): + fields = ( + "id", + "username", + ) + + +class RequestDataOperationSerializer(serializers.Serializer): + type = serializers.CharField() + target = serializers.CharField() + project_id = serializers.IntegerField(required=False, allow_null=True) + task_id = serializers.IntegerField(required=False, allow_null=True) + job_id = serializers.IntegerField(required=False, allow_null=True) + format = serializers.CharField(required=False, allow_null=True) + function_id = serializers.CharField(required=False, allow_null=True) + + def to_representation(self, rq_job: CustomRQJob) -> dict[str, Any]: + parsed_request_id: RequestId = rq_job.parsed_id + + base_rq_job_meta = BaseRQMeta.for_job(rq_job) + representation = { + "type": parsed_request_id.type, + "target": parsed_request_id.target, + "project_id": base_rq_job_meta.project_id, + "task_id": base_rq_job_meta.task_id, + "job_id": base_rq_job_meta.job_id, + } + if parsed_request_id.action == RequestAction.AUTOANNOTATE: + representation["function_id"] = LambdaRQMeta.for_job(rq_job).function_id + elif isinstance(parsed_request_id, RequestIdWithOptionalFormat): + representation["format"] = parsed_request_id.format + + return representation + + +class RequestSerializer(serializers.Serializer): + # SerializerMethodField is not used here to mark "status" field as required and fix schema generation. + # Marking them as read_only leads to generating type as allOf with one reference to RequestStatus component. + # The client generated using openapi-generator from such a schema contains wrong type like: + # status (bool, date, datetime, dict, float, int, list, str, none_type): [optional] + status = serializers.ChoiceField(source="get_status", choices=RequestStatus.choices) + message = serializers.SerializerMethodField() + id = serializers.CharField() + operation = RequestDataOperationSerializer(source="*") + progress = serializers.SerializerMethodField() + created_date = serializers.DateTimeField(source="created_at") + started_date = serializers.DateTimeField( + required=False, + allow_null=True, + source="started_at", + ) + finished_date = serializers.DateTimeField( + required=False, + allow_null=True, + source="ended_at", + ) + expiry_date = serializers.SerializerMethodField() + owner = serializers.SerializerMethodField() + result_url = serializers.URLField(required=False, allow_null=True) + result_id = serializers.IntegerField(required=False, allow_null=True) + + def __init__(self, *args, **kwargs): + self._base_rq_job_meta: BaseRQMeta | None = None + super().__init__(*args, **kwargs) + + @extend_schema_field(UserIdentifiersSerializer()) + def get_owner(self, rq_job: CustomRQJob) -> dict[str, Any]: + assert self._base_rq_job_meta + return UserIdentifiersSerializer(self._base_rq_job_meta.user).data + + @extend_schema_field( + serializers.FloatField(min_value=0, max_value=1, required=False, allow_null=True) + ) + def get_progress(self, rq_job: CustomRQJob) -> Decimal: + rq_job_meta = ImportRQMeta.for_job(rq_job) + # progress of task creation is stored in "task_progress" field + # progress of project import is stored in "progress" field + return Decimal(rq_job_meta.progress or rq_job_meta.task_progress or 0.0) + + @extend_schema_field(serializers.DateTimeField(required=False, allow_null=True)) + def get_expiry_date(self, rq_job: CustomRQJob) -> str | None: + delta = None + if rq_job.is_finished: + delta = rq_job.result_ttl or rq_defaults.DEFAULT_RESULT_TTL + elif rq_job.is_failed: + delta = rq_job.failure_ttl or rq_defaults.DEFAULT_FAILURE_TTL + + if rq_job.ended_at and delta: + expiry_date = rq_job.ended_at + timedelta(seconds=delta) + return expiry_date.replace(tzinfo=timezone.utc) + + return None + + @extend_schema_field(serializers.CharField(allow_blank=True)) + def get_message(self, rq_job: CustomRQJob) -> str: + assert self._base_rq_job_meta + rq_job_status = rq_job.get_status() + message = "" + + if RQJobStatus.STARTED == rq_job_status: + message = self._base_rq_job_meta.status or message + elif RQJobStatus.FAILED == rq_job_status: + + message = self._base_rq_job_meta.formatted_exception or parse_exception_message( + str(rq_job.exc_info or "Unknown error") + ) + + return message + + def to_representation(self, rq_job: CustomRQJob) -> dict[str, Any]: + self._base_rq_job_meta = BaseRQMeta.for_job(rq_job) + representation = super().to_representation(rq_job) + + # FUTURE-TODO: support such statuses on UI + if representation["status"] in (RQJobStatus.DEFERRED, RQJobStatus.SCHEDULED): + representation["status"] = RQJobStatus.QUEUED + + if representation["status"] == RQJobStatus.FINISHED: + if rq_job.parsed_id.action == models.RequestAction.EXPORT: + representation["result_url"] = ExportRQMeta.for_job(rq_job).result_url + else: + return_value = rq_job.return_value() + if isinstance(return_value, (int, UUID)): + representation["result_id"] = return_value + + return representation diff --git a/cvat/apps/redis_handler/urls.py b/cvat/apps/redis_handler/urls.py new file mode 100644 index 000000000000..7f742986978c --- /dev/null +++ b/cvat/apps/redis_handler/urls.py @@ -0,0 +1,15 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from django.urls import include, path +from rest_framework import routers + +from . import views + +router = routers.DefaultRouter(trailing_slash=False) +router.register("requests", views.RequestViewSet, basename="request") + +urlpatterns = [ + path("api/", include(router.urls)), +] diff --git a/cvat/apps/redis_handler/views.py b/cvat/apps/redis_handler/views.py new file mode 100644 index 000000000000..d73bab153db4 --- /dev/null +++ b/cvat/apps/redis_handler/views.py @@ -0,0 +1,246 @@ +import functools +from collections import namedtuple +from collections.abc import Iterable +from typing import cast + +import django_rq +from django.conf import settings +from django.http import HttpResponseBadRequest, HttpResponseNotFound +from django.utils.decorators import method_decorator +from django.views.decorators.cache import never_cache +from django_rq.queues import DjangoRQ +from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view +from redis.exceptions import ConnectionError as RedisConnectionError +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.response import Response +from rq.job import Job as RQJob +from rq.job import JobStatus as RQJobStatus + +from cvat.apps.engine.filters import ( + NonModelJsonLogicFilter, + NonModelOrderingFilter, + NonModelSimpleFilter, +) +from cvat.apps.engine.log import ServerLogManager +from cvat.apps.engine.rq import is_rq_job_owner +from cvat.apps.engine.types import ExtendedRequest +from cvat.apps.redis_handler.apps import SELECTOR_TO_QUEUE +from cvat.apps.redis_handler.rq import CustomRQJob, RequestId +from cvat.apps.redis_handler.serializers import RequestSerializer, RequestStatus + +slogger = ServerLogManager(__name__) + + +@extend_schema(tags=["requests"]) +@extend_schema_view( + list=extend_schema( + summary="List requests", + responses={ + "200": RequestSerializer(many=True), + }, + ), + retrieve=extend_schema( + summary="Get request details", + responses={ + "200": RequestSerializer, + }, + ), +) +class RequestViewSet(viewsets.GenericViewSet): + serializer_class = RequestSerializer + iam_organization_field = None + filter_backends = [ + NonModelSimpleFilter, + NonModelJsonLogicFilter, + NonModelOrderingFilter, + ] + + ordering_fields = ["created_date", "status", "action"] + ordering = "-created_date" + + filter_fields = [ + # RQ job fields + "status", + # derivatives fields (from meta) + "project_id", + "task_id", + "job_id", + # derivatives fields (from parsed rq_id) + "action", + "target", + "subresource", + "format", + ] + + simple_filters = filter_fields + ["org"] + + lookup_fields = { + "created_date": "created_at", + "action": "parsed_id.action", + "target": "parsed_id.target", + "subresource": "parsed_id.subresource", + "format": "parsed_id.format", + "status": "get_status", + "project_id": "meta.project_id", + "task_id": "meta.task_id", + "job_id": "meta.job_id", + "org": "meta.org_slug", + } + + SchemaField = namedtuple("SchemaField", ["type", "choices"], defaults=(None,)) + + simple_filters_schema = { + "status": SchemaField("string", RequestStatus.choices), + "project_id": SchemaField("integer"), + "task_id": SchemaField("integer"), + "job_id": SchemaField("integer"), + "action": SchemaField("string"), + "target": SchemaField("string"), + "subresource": SchemaField("string"), + "format": SchemaField("string"), + "org": SchemaField("string"), + } + + def get_queryset(self): + return None + + @property + def queues(self) -> Iterable[DjangoRQ]: + return (django_rq.get_queue(queue_name) for queue_name in set(SELECTOR_TO_QUEUE.values())) + + def _get_rq_jobs_from_queue(self, queue: DjangoRQ, user_id: int) -> list[RQJob]: + job_ids = set( + queue.get_job_ids() + + queue.started_job_registry.get_job_ids() + + queue.finished_job_registry.get_job_ids() + + queue.failed_job_registry.get_job_ids() + + queue.deferred_job_registry.get_job_ids() + ) + jobs = [] + + for job in queue.job_class.fetch_many(job_ids, queue.connection): + if job and is_rq_job_owner(job, user_id): + job = cast(CustomRQJob, job) + try: + parsed_request_id = RequestId.parse_and_validate_queue( + job.id, expected_queue=queue.name + ) + except Exception: # nosec B112 + continue + + job.parsed_id = parsed_request_id + jobs.append(job) + + return jobs + + def _get_rq_jobs(self, user_id: int) -> list[RQJob]: + """ + Get all RQ jobs for a specific user and return them as a list of RQJob objects. + + Parameters: + user_id (int): The ID of the user for whom to retrieve jobs. + + Returns: + List[RQJob]: A list of RQJob objects representing all jobs for the specified user. + """ + all_jobs = [] + for queue in self.queues: + jobs = self._get_rq_jobs_from_queue(queue, user_id) + all_jobs.extend(jobs) + + return all_jobs + + def _get_rq_job_by_id(self, rq_id: str) -> RQJob | None: + """ + Get a RQJob by its ID from the queues. + + Args: + rq_id (str): The ID of the RQJob to retrieve. + + Returns: + Optional[RQJob]: The retrieved RQJob, or None if not found. + """ + try: + parsed_request_id, queue_name = RequestId.parse(rq_id, try_legacy_format=True) + rq_id = parsed_request_id.render() + except Exception: + return None + + queue: DjangoRQ = django_rq.get_queue(queue_name) + job: CustomRQJob | None = queue.fetch_job(rq_id) + + if job: + job.parsed_id = parsed_request_id + + return job + + def _handle_redis_exceptions(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RedisConnectionError as ex: + msg = "Redis service is not available" + slogger.glob.exception(f"{msg}: {str(ex)}") + return Response(msg, status=status.HTTP_503_SERVICE_UNAVAILABLE) + + return wrapper + + @method_decorator(never_cache) + @_handle_redis_exceptions + def retrieve(self, request: ExtendedRequest, pk: str): + job = self._get_rq_job_by_id(pk) + + if not job: + return HttpResponseNotFound("There is no request with specified id") + + self.check_object_permissions(request, job) + + serializer = self.get_serializer(job, context={"request": request}) + return Response(data=serializer.data, status=status.HTTP_200_OK) + + @method_decorator(never_cache) + @_handle_redis_exceptions + def list(self, request: ExtendedRequest): + user_id = request.user.id + user_jobs = self._get_rq_jobs(user_id) + + filtered_jobs = self.filter_queryset(user_jobs) + + page = self.paginate_queryset(filtered_jobs) + if page is not None: + serializer = self.get_serializer(page, many=True, context={"request": request}) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(filtered_jobs, many=True, context={"request": request}) + return Response(data=serializer.data, status=status.HTTP_200_OK) + + @extend_schema( + summary="Cancel request", + request=None, + responses={ + "200": OpenApiResponse(description="The request has been cancelled"), + }, + ) + @method_decorator(never_cache) + @action(detail=True, methods=["POST"], url_path="cancel") + @_handle_redis_exceptions + def cancel(self, request: ExtendedRequest, pk: str): + rq_job = self._get_rq_job_by_id(pk) + + if not rq_job: + return HttpResponseNotFound("There is no request with specified id") + + self.check_object_permissions(request, rq_job) + + if rq_job.get_status(refresh=False) not in {RQJobStatus.QUEUED, RQJobStatus.DEFERRED}: + return HttpResponseBadRequest( + "Only requests that have not yet been started can be cancelled" + ) + + # FUTURE-TODO: race condition is possible here + rq_job.cancel(enqueue_dependents=settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER) + rq_job.delete() + + return Response(status=status.HTTP_200_OK) diff --git a/cvat/schema.yml b/cvat/schema.yml index 5805fa989f8e..581e80eb9d02 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -964,13 +964,6 @@ paths: post: operationId: consensus_create_merge summary: Create a consensus merge - parameters: - - in: query - name: rq_id - schema: - type: string - description: | - The consensus merge request id. Can be specified to check operation status. tags: - consensus requestBody: @@ -985,8 +978,6 @@ paths: - signatureAuth: [] - basicAuth: [] responses: - '201': - description: No response body '202': content: application/vnd.cvat+json: @@ -994,9 +985,7 @@ paths: $ref: '#/components/schemas/RqId' description: | A consensus merge request has been enqueued, the request id is returned. - The request status can be checked at this endpoint by passing the rq_id - as the query parameter. If the request id is specified, this response - means the consensus merge request is queued or is being processed. + The request status can be checked by using common requests API: GET /api/requests/ '400': description: Invalid or failed request, check the response data for details /api/consensus/settings: @@ -1201,6 +1190,7 @@ paths: tokenAuth: [] - signatureAuth: [] - basicAuth: [] + deprecated: true responses: '200': description: Download of file started @@ -1249,6 +1239,83 @@ paths: schema: $ref: '#/components/schemas/ClientEvents' description: '' + /api/events/export: + post: + operationId: events_create_export + summary: Initiate a process to export events + parameters: + - in: query + name: cloud_storage_id + schema: + type: integer + description: Storage id + - in: query + name: filename + schema: + type: string + description: Desired output file name + - in: query + name: from + schema: + type: string + format: date-time + description: Filter events after the datetime. If no 'from' or 'to' parameters + are passed, the last 30 days will be set. + - in: query + name: job_id + schema: + type: integer + description: Filter events by job ID + - in: query + name: location + schema: + type: string + enum: + - cloud_storage + - local + description: Where need to save events file + - in: query + name: org_id + schema: + type: integer + description: Filter events by organization ID + - in: query + name: project_id + schema: + type: integer + description: Filter events by project ID + - in: query + name: task_id + schema: + type: integer + description: Filter events by task ID + - in: query + name: to + schema: + type: string + format: date-time + description: Filter events before the datetime. If no 'from' or 'to' parameters + are passed, the last 30 days will be set. + - in: query + name: user_id + schema: + type: integer + description: Filter events by user ID + tags: + - events + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '202': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/RqId' + description: '' /api/guides: post: operationId: guides_create @@ -2294,66 +2361,24 @@ paths: description: Format is not available put: operationId: jobs_update_annotations - description: |2 - - Utilizing this endpoint to check status of the import process is deprecated - in favor of the new requests API: - GET /api/requests/, where `rq_id` parameter is returned in the response - on initializing request. - summary: Replace job annotations / Get annotation import status + summary: Replace job annotations parameters: - - in: query - name: cloud_storage_id - schema: - type: integer - description: Storage id - deprecated: true - - in: query - name: filename - schema: - type: string - description: Annotation file name - deprecated: true - - in: query - name: format - schema: - type: string - description: |- - Input format name - You can get the list of supported formats at: - /server/annotation/formats - deprecated: true - in: path name: id schema: type: integer description: A unique integer value identifying this job. required: true - - in: query - name: location - schema: - type: string - enum: - - cloud_storage - - local - description: where to import the annotation from - deprecated: true - - in: query - name: rq_id - schema: - type: string - description: rq id - deprecated: true tags: - jobs requestBody: content: application/json: schema: - $ref: '#/components/schemas/JobAnnotationsUpdateRequest' + $ref: '#/components/schemas/LabeledDataRequest' multipart/form-data: schema: - $ref: '#/components/schemas/JobAnnotationsUpdateRequest' + $ref: '#/components/schemas/LabeledDataRequest' security: - sessionAuth: [] csrfAuth: [] @@ -2361,12 +2386,8 @@ paths: - signatureAuth: [] - basicAuth: [] responses: - '201': - description: Import has finished - '202': - description: Import is in progress - '405': - description: Format is not available + '200': + description: Annotations have been replaced patch: operationId: jobs_partial_update_annotations summary: Update job annotations @@ -3709,83 +3730,6 @@ paths: description: The backup process has already been initiated and is not yet finished /api/projects/{id}/dataset/: - get: - operationId: projects_retrieve_dataset - description: |2 - - Utilizing this endpoint to check the status of the process - of importing a project dataset from a file is deprecated. - In addition, this endpoint no longer handles the project dataset export process. - - Consider using new API: - - `POST /api/projects//dataset/export/?save_images=True` to initiate export process - - `GET /api/requests/` to check process status - - `GET result_url` to download a prepared file - - Where: - - `rq_id` can be found in the response on initializing request - - `result_url` can be found in the response on checking status request - summary: Check dataset import status - parameters: - - in: query - name: action - schema: - type: string - enum: - - import_status - description: Used to check the import status - deprecated: true - - in: query - name: cloud_storage_id - schema: - type: integer - description: This parameter is no longer supported - deprecated: true - - in: query - name: filename - schema: - type: string - description: This parameter is no longer supported - deprecated: true - - in: query - name: format - schema: - type: string - description: This parameter is no longer supported - deprecated: true - - in: path - name: id - schema: - type: integer - description: A unique integer value identifying this project. - required: true - - in: query - name: location - schema: - type: string - enum: - - cloud_storage - - local - description: This parameter is no longer supported - deprecated: true - - in: query - name: rq_id - schema: - type: string - description: This parameter is no longer supported - required: true - tags: - - projects - security: - - sessionAuth: [] - csrfAuth: [] - tokenAuth: [] - - signatureAuth: [] - - basicAuth: [] - deprecated: true - responses: - '410': - description: API endpoint no longer supports exporting datasets post: operationId: projects_create_dataset description: |2 @@ -3827,24 +3771,16 @@ paths: - cloud_storage - local description: Where to import the dataset from - - in: query - name: use_default_location - schema: - type: boolean - default: true - description: Use the location that was configured in the project to import - annotations - deprecated: true tags: - projects requestBody: content: application/json: schema: - $ref: '#/components/schemas/DatasetWriteRequest' + $ref: '#/components/schemas/DatasetFileRequest' multipart/form-data: schema: - $ref: '#/components/schemas/DatasetWriteRequest' + $ref: '#/components/schemas/DatasetFileRequest' security: - sessionAuth: [] csrfAuth: [] @@ -3960,14 +3896,14 @@ paths: The backup import process is as follows: - The first request POST /api/projects/backup will initiate file upload and will create - the rq job on the server in which the process of a project creating from an uploaded backup - will be carried out. + The first request POST /api/projects/backup schedules a background job on the server + in which the process of creating a project from the uploaded backup is carried out. + + To check the status of the import process, use GET /api/requests/rq_id, + where rq_id is the request ID obtained from the response to the previous request. - After initiating the backup upload, you will receive an rq_id parameter. - Make sure to include this parameter as a query parameter in your subsequent requests - to track the status of the project creation. - Once the project has been successfully created, the server will return the id of the newly created project. + Once the import completes successfully, the response will contain the ID + of the newly created project in the result_id field. summary: Recreate a project from a backup parameters: - in: header @@ -4004,21 +3940,16 @@ paths: schema: type: integer description: Organization identifier - - in: query - name: rq_id - schema: - type: string - description: rq id tags: - projects requestBody: content: application/json: schema: - $ref: '#/components/schemas/BackupWriteRequest' + $ref: '#/components/schemas/ProjectFileRequest' multipart/form-data: schema: - $ref: '#/components/schemas/BackupWriteRequest' + $ref: '#/components/schemas/ProjectFileRequest' security: - sessionAuth: [] csrfAuth: [] @@ -4026,14 +3957,12 @@ paths: - signatureAuth: [] - basicAuth: [] responses: - '201': - description: The project has been imported '202': content: application/vnd.cvat+json: schema: $ref: '#/components/schemas/RqId' - description: Importing a backup file has been started + description: Import of the backup file has started /api/quality/conflicts: get: operationId: quality_list_conflicts @@ -4239,6 +4168,9 @@ paths: description: '' post: operationId: quality_create_report + description: | + Deprecation warning: Utilizing this endpoint to check the computation status is no longer possible. + Consider using common requests API: GET /api/requests/ summary: Create a quality report parameters: - in: query @@ -4248,6 +4180,7 @@ paths: description: | The report creation request id. Can be specified to check the report creation status. + deprecated: true tags: - quality requestBody: @@ -4473,11 +4406,6 @@ paths: description: A simple equality filter for the action field schema: type: string - enum: - - autoannotate - - create - - import - - export - name: filter required: false in: query @@ -4544,19 +4472,11 @@ paths: description: A simple equality filter for the subresource field schema: type: string - enum: - - annotations - - dataset - - backup - name: target in: query description: A simple equality filter for the target field schema: type: string - enum: - - project - - task - - job - name: task_id in: query description: A simple equality filter for the task_id field @@ -5131,61 +5051,14 @@ paths: /api/tasks/{id}/annotations/: get: operationId: tasks_retrieve_annotations - description: | - Deprecation warning: - - Utilizing this endpoint to export annotations as a dataset in - a specific format is no longer possible. - - Consider using new API: - - `POST /api/tasks//dataset/export?save_images=False` to initiate export process - - `GET /api/requests/` to check process status, - where `rq_id` is request id returned on initializing request - - `GET result_url` to download a prepared file, - where `result_url` can be found in the response on checking status request summary: Get task annotations parameters: - - in: query - name: action - schema: - type: string - enum: - - download - description: This parameter is no longer supported - deprecated: true - - in: query - name: cloud_storage_id - schema: - type: integer - description: This parameter is no longer supported - deprecated: true - - in: query - name: filename - schema: - type: string - description: This parameter is no longer supported - deprecated: true - - in: query - name: format - schema: - type: string - description: This parameter is no longer supported - deprecated: true - in: path name: id schema: type: integer description: A unique integer value identifying this task. required: true - - in: query - name: location - schema: - type: string - enum: - - cloud_storage - - local - description: This parameter is no longer supported - deprecated: true tags: - tasks security: @@ -5259,10 +5132,10 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/TaskAnnotationsWriteRequest' + $ref: '#/components/schemas/AnnotationFileRequest' multipart/form-data: schema: - $ref: '#/components/schemas/TaskAnnotationsWriteRequest' + $ref: '#/components/schemas/AnnotationFileRequest' security: - sessionAuth: [] csrfAuth: [] @@ -5282,46 +5155,24 @@ paths: description: Format is not available put: operationId: tasks_update_annotations - description: |2 - - Utilizing this endpoint to check status of the import process is deprecated - in favor of the new requests API: - - GET /api/requests/, where `rq_id` parameter is returned in the response - on initializing request. - summary: Replace task annotations / Get annotation import status + summary: Replace task annotations parameters: - - in: query - name: format - schema: - type: string - description: |- - Input format name - You can get the list of supported formats at: - /server/annotation/formats - deprecated: true - in: path name: id schema: type: integer description: A unique integer value identifying this task. required: true - - in: query - name: rq_id - schema: - type: string - description: rq id - deprecated: true tags: - tasks requestBody: content: application/json: schema: - $ref: '#/components/schemas/TaskAnnotationsUpdateRequest' + $ref: '#/components/schemas/LabeledDataRequest' multipart/form-data: schema: - $ref: '#/components/schemas/TaskAnnotationsUpdateRequest' + $ref: '#/components/schemas/LabeledDataRequest' security: - sessionAuth: [] csrfAuth: [] @@ -5329,12 +5180,8 @@ paths: - signatureAuth: [] - basicAuth: [] responses: - '201': - description: Import has finished - '202': - description: Import is in progress - '405': - description: Format is not available + '200': + description: Annotations have been replaced patch: operationId: tasks_partial_update_annotations summary: Update task annotations @@ -5877,14 +5724,14 @@ paths: The backup import process is as follows: - The first request POST /api/tasks/backup will initiate file upload and will create - the rq job on the server in which the process of a task creating from an uploaded backup - will be carried out. + The first request POST /api/tasks/backup creates a background job on the server + in which the process of a task creating from an uploaded backup is carried out. - After initiating the backup upload, you will receive an rq_id parameter. - Make sure to include this parameter as a query parameter in your subsequent requests - to track the status of the task creation. - Once the task has been successfully created, the server will return the id of the newly created task. + To check the status of the import process, use GET /api/requests/rq_id, + where rq_id is the request ID obtained from the response to the previous request. + + Once the import completes successfully, the response will contain the ID + of the newly created task in the result_id field. summary: Recreate a task from a backup parameters: - in: header @@ -5921,11 +5768,6 @@ paths: schema: type: integer description: Organization identifier - - in: query - name: rq_id - schema: - type: string - description: rq id tags: - tasks requestBody: @@ -5943,14 +5785,12 @@ paths: - signatureAuth: [] - basicAuth: [] responses: - '201': - description: The task has been imported '202': content: application/vnd.cvat+json: schema: $ref: '#/components/schemas/RqId' - description: Importing a backup file has been started + description: Import of the backup file has started /api/users: get: operationId: users_list @@ -6788,10 +6628,6 @@ components: required: - spec_id - value - BackupWriteRequest: - oneOf: - - $ref: '#/components/schemas/ProjectFileRequest' - nullable: true BasicOrganization: type: object properties: @@ -7364,10 +7200,6 @@ components: required: - exporters - importers - DatasetWriteRequest: - oneOf: - - $ref: '#/components/schemas/DatasetFileRequest' - nullable: true Event: type: object properties: @@ -7835,10 +7667,6 @@ components: count: type: integer readOnly: true - JobAnnotationsUpdateRequest: - oneOf: - - $ref: '#/components/schemas/LabeledDataRequest' - - $ref: '#/components/schemas/AnnotationFileRequest' JobRead: type: object properties: @@ -10098,7 +9926,7 @@ components: type: type: string target: - $ref: '#/components/schemas/RequestDataOperationTargetEnum' + type: string project_id: type: integer nullable: true @@ -10117,16 +9945,6 @@ components: required: - target - type - RequestDataOperationTargetEnum: - enum: - - project - - task - - job - type: string - description: |- - * `project` - Project - * `task` - Task - * `job` - Job RequestStatus: enum: - queued @@ -10545,15 +10363,6 @@ components: * `accuracy` - ACCURACY * `precision` - PRECISION * `recall` - RECALL - TaskAnnotationsUpdateRequest: - oneOf: - - $ref: '#/components/schemas/LabeledDataRequest' - - $ref: '#/components/schemas/AnnotationFileRequest' - nullable: true - TaskAnnotationsWriteRequest: - oneOf: - - $ref: '#/components/schemas/AnnotationFileRequest' - nullable: true TaskFileRequest: type: object properties: diff --git a/cvat/settings/base.py b/cvat/settings/base.py index b09aa07cec63..976fc58d629f 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -295,10 +295,14 @@ class CVAT_QUEUES(Enum): CVAT_QUEUES.IMPORT_DATA.value: { **REDIS_INMEM_SETTINGS, "DEFAULT_TIMEOUT": "4h", + # custom fields + "PARSED_JOB_ID_CLASS": "cvat.apps.engine.rq.ImportRequestId", }, CVAT_QUEUES.EXPORT_DATA.value: { **REDIS_INMEM_SETTINGS, "DEFAULT_TIMEOUT": "4h", + # custom fields + "PARSED_JOB_ID_CLASS": "cvat.apps.engine.rq.ExportRequestId", }, CVAT_QUEUES.AUTO_ANNOTATION.value: { **REDIS_INMEM_SETTINGS, @@ -315,6 +319,8 @@ class CVAT_QUEUES(Enum): CVAT_QUEUES.QUALITY_REPORTS.value: { **REDIS_INMEM_SETTINGS, "DEFAULT_TIMEOUT": "1h", + # custom fields + "PARSED_JOB_ID_CLASS": "cvat.apps.quality_control.rq.QualityRequestId", }, CVAT_QUEUES.CLEANING.value: { **REDIS_INMEM_SETTINGS, @@ -327,6 +333,8 @@ class CVAT_QUEUES(Enum): CVAT_QUEUES.CONSENSUS.value: { **REDIS_INMEM_SETTINGS, "DEFAULT_TIMEOUT": "1h", + # custom fields + "PARSED_JOB_ID_CLASS": "cvat.apps.consensus.rq.ConsensusRequestId", }, } @@ -665,7 +673,7 @@ class CVAT_QUEUES(Enum): "SortingMethod": "cvat.apps.engine.models.SortingMethod", "WebhookType": "cvat.apps.webhooks.models.WebhookTypeChoice", "WebhookContentType": "cvat.apps.webhooks.models.WebhookContentTypeChoice", - "RequestStatus": "cvat.apps.engine.models.RequestStatus", + "RequestStatus": "cvat.apps.redis_handler.serializers.RequestStatus", "ValidationMode": "cvat.apps.engine.models.ValidationMode", "FrameSelectionMethod": "cvat.apps.engine.models.JobFrameSelectionMethod", }, diff --git a/cvat/urls.py b/cvat/urls.py index cdc73b5e8568..a72331a5b7a5 100644 --- a/cvat/urls.py +++ b/cvat/urls.py @@ -25,6 +25,7 @@ urlpatterns = [ path("admin/", admin.site.urls), path("", include("cvat.apps.engine.urls")), + path("", include("cvat.apps.redis_handler.urls")), path("django-rq/", include("django_rq.urls")), ] diff --git a/tests/python/rest_api/test_analytics.py b/tests/python/rest_api/test_analytics.py index 22d3520362e0..f1381091b160 100644 --- a/tests/python/rest_api/test_analytics.py +++ b/tests/python/rest_api/test_analytics.py @@ -11,14 +11,17 @@ from http import HTTPStatus from io import StringIO from time import sleep +from typing import Optional import pytest +from cvat_sdk.api_client import ApiClient from dateutil import parser as datetime_parser -from shared.utils.config import delete_method, make_api_client, server_get +import shared.utils.s3 as s3 +from shared.utils.config import delete_method, get_method, make_api_client, server_get from shared.utils.helpers import generate_image_files -from .utils import create_task +from .utils import create_task, wait_and_download_v2, wait_background_request class TestGetAnalytics: @@ -148,27 +151,53 @@ def _wait_for_request_ids(self, event_filters): assert False, "Could not wait for expected request IDs" @staticmethod - def _export_events(endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs): - query_id = "" - for _ in range(max_retries): + def _export_events( + api_client: ApiClient, + *, + api_version: int, + max_retries: int = 20, + interval: float = 0.1, + **kwargs, + ) -> Optional[bytes]: + if api_version == 1: + endpoint = api_client.events_api.list_endpoint + query_id = "" + for _ in range(max_retries): + (_, response) = endpoint.call_with_http_info( + **kwargs, query_id=query_id, _parse_response=False + ) + if response.status == HTTPStatus.CREATED: + break + assert response.status == HTTPStatus.ACCEPTED + if not query_id: + response_json = json.loads(response.data) + query_id = response_json["query_id"] + sleep(interval) + + assert response.status == HTTPStatus.CREATED + (_, response) = endpoint.call_with_http_info( - **kwargs, query_id=query_id, _parse_response=False + **kwargs, query_id=query_id, action="download", _parse_response=False ) - if response.status == HTTPStatus.CREATED: - break - assert response.status == HTTPStatus.ACCEPTED - if not query_id: - response_json = json.loads(response.data) - query_id = response_json["query_id"] - sleep(interval) - assert response.status == HTTPStatus.CREATED - - (_, response) = endpoint.call_with_http_info( - **kwargs, query_id=query_id, action="download", _parse_response=False - ) - assert response.status == HTTPStatus.OK + assert response.status == HTTPStatus.OK + + return response.data + + assert api_version == 2 + + request_id, response = api_client.events_api.create_export(**kwargs, _check_status=False) + assert response.status == HTTPStatus.ACCEPTED + + if "location" in kwargs and "cloud_storage_id" in kwargs: + background_request, response = wait_background_request( + api_client, rq_id=request_id.rq_id, max_retries=max_retries, interval=interval + ) + assert background_request.result_url is None + return None - return response.data + return wait_and_download_v2( + api_client, rq_id=request_id.rq_id, max_retries=max_retries, interval=interval + ) @staticmethod def _csv_to_dict(csv_data): @@ -190,11 +219,12 @@ def _filter_events(events, filters): return res - def _test_get_audit_logs_as_csv(self, **kwargs): + def _test_get_audit_logs_as_csv(self, *, api_version: int = 2, **kwargs): with make_api_client(self._USERNAME) as api_client: - return self._export_events(api_client.events_api.list_endpoint, **kwargs) + return self._export_events(api_client, api_version=api_version, **kwargs) - def test_entry_to_time_interval(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_entry_to_time_interval(self, api_version: int): now = datetime.now(timezone.utc) to_datetime = now from_datetime = now - timedelta(minutes=3) @@ -204,7 +234,7 @@ def test_entry_to_time_interval(self): "to": to_datetime.isoformat(), } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) assert len(events) @@ -212,12 +242,13 @@ def test_entry_to_time_interval(self): event_timestamp = datetime_parser.isoparse(event["timestamp"]) assert from_datetime <= event_timestamp <= to_datetime - def test_filter_by_project(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_filter_by_project(self, api_version: int): query_params = { "project_id": self.project_id, } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) filtered_events = self._filter_events(events, [("project_id", [str(self.project_id)])]) @@ -229,13 +260,14 @@ def test_filter_by_project(self): assert event_count["create:task"] == 2 assert event_count["create:job"] == 4 - def test_filter_by_task(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_filter_by_task(self, api_version: int): for task_id in self.task_ids: query_params = { "task_id": task_id, } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) filtered_events = self._filter_events(events, [("task_id", [str(task_id)])]) @@ -246,20 +278,22 @@ def test_filter_by_task(self): assert event_count["create:task"] == 1 assert event_count["create:job"] == 2 - def test_filter_by_non_existent_project(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_filter_by_non_existent_project(self, api_version: int): query_params = { "project_id": self.project_id + 100, } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) assert len(events) == 0 - def test_user_and_request_id_not_empty(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_user_and_request_id_not_empty(self, api_version: int): query_params = { "project_id": self.project_id, } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) for event in events: @@ -272,7 +306,8 @@ def test_user_and_request_id_not_empty(self): assert request_id uuid.UUID(request_id) - def test_delete_project(self): + @pytest.mark.parametrize("api_version", [1, 2]) + def test_delete_project(self, api_version: int): response = delete_method("admin1", f"projects/{self.project_id}") assert response.status_code == HTTPStatus.NO_CONTENT @@ -299,7 +334,7 @@ def test_delete_project(self): "project_id": self.project_id, } - data = self._test_get_audit_logs_as_csv(**query_params) + data = self._test_get_audit_logs_as_csv(api_version=api_version, **query_params) events = self._csv_to_dict(data) filtered_events = self._filter_events(events, [("project_id", [str(self.project_id)])]) @@ -310,3 +345,31 @@ def test_delete_project(self): assert event_count["delete:project"] == 1 assert event_count["delete:task"] == 2 assert event_count["delete:job"] == 4 + + @pytest.mark.with_external_services + @pytest.mark.parametrize("api_version, allowed", [(1, False), (2, True)]) + @pytest.mark.parametrize("cloud_storage_id", [3]) # import/export bucket + def test_export_to_cloud( + self, api_version: int, allowed: bool, cloud_storage_id: int, cloud_storages + ): + query_params = { + "api_version": api_version, + "location": "cloud_storage", + "cloud_storage_id": cloud_storage_id, + "filename": "test.csv", + "task_id": self.task_ids[0], + } + if allowed: + data = self._test_get_audit_logs_as_csv(**query_params) + assert data is None + s3_client = s3.make_client(bucket=cloud_storages[cloud_storage_id]["resource"]) + data = s3_client.download_fileobj(query_params["filename"]) + events = self._csv_to_dict(data) + assert len(events) + else: + response = get_method(self._USERNAME, "events", **query_params) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert ( + response.json()[0] + == "This endpoint does not support exporting events to cloud storage" + ) diff --git a/tests/python/rest_api/test_consensus.py b/tests/python/rest_api/test_consensus.py index 2294990ee601..81c6bbfa937c 100644 --- a/tests/python/rest_api/test_consensus.py +++ b/tests/python/rest_api/test_consensus.py @@ -6,6 +6,7 @@ from copy import deepcopy from functools import partial from http import HTTPStatus +from itertools import product from typing import Any, Dict, Optional, Tuple import pytest @@ -17,7 +18,13 @@ from shared.utils.config import make_api_client -from .utils import CollectionSimpleFilterTestBase, compare_annotations +from .utils import ( + CollectionSimpleFilterTestBase, + compare_annotations, + invite_user_to_org, + register_new_user, + wait_background_request, +) class _PermissionTestBase: @@ -49,16 +56,13 @@ def merge( return response assert response.status == HTTPStatus.ACCEPTED - rq_id = json.loads(response.data)["rq_id"] - - while wait_result: - (_, response) = api_client.consensus_api.create_merge( - rq_id=rq_id, _parse_response=False + if wait_result: + rq_id = json.loads(response.data)["rq_id"] + background_request, _ = wait_background_request(api_client, rq_id) + assert ( + background_request.status.value + == models.RequestStatus.allowed_values[("value",)]["FINISHED"] ) - assert response.status in [HTTPStatus.CREATED, HTTPStatus.ACCEPTED] - - if response.status == HTTPStatus.CREATED: - break return response @@ -194,14 +198,14 @@ class TestPostConsensusMerge(_PermissionTestBase): def test_can_merge_task_with_consensus_jobs(self, admin_user, tasks): task_id = next(t["id"] for t in tasks if t["consensus_enabled"]) - assert self.merge(user=admin_user, task_id=task_id).status == HTTPStatus.CREATED + self.merge(user=admin_user, task_id=task_id) def test_can_merge_consensus_job(self, admin_user, jobs): job_id = next( j["id"] for j in jobs if j["type"] == "annotation" and j["consensus_replicas"] > 0 ) - assert self.merge(user=admin_user, job_id=job_id).status == HTTPStatus.CREATED + self.merge(user=admin_user, job_id=job_id) def test_cannot_merge_task_without_consensus_jobs(self, admin_user, tasks): task_id = next(t["id"] for t in tasks if not t["consensus_enabled"]) @@ -280,70 +284,111 @@ def test_user_merge_in_org_task( else: self._test_merge_403(user["username"], task_id=task["id"]) - # only rq job owner or admin now has the right to check status of report creation - def _test_check_merge_status_by_non_rq_job_owner( + # users with task:view rights can check status of report creation + def _test_check_merge_status( self, rq_id: str, *, staff_user: str, - other_user: str, + another_user: str, + another_user_status: int = HTTPStatus.FORBIDDEN, ): - with make_api_client(other_user) as api_client: - (_, response) = api_client.consensus_api.create_merge( - rq_id=rq_id, _parse_response=False, _check_status=False + with make_api_client(another_user) as api_client: + (_, response) = api_client.requests_api.retrieve( + rq_id, _parse_response=False, _check_status=False ) - assert response.status == HTTPStatus.NOT_FOUND - assert json.loads(response.data)["detail"] == "Unknown request id" + assert response.status == another_user_status with make_api_client(staff_user) as api_client: - (_, response) = api_client.consensus_api.create_merge( - rq_id=rq_id, _parse_response=False, _check_status=False - ) - assert response.status in {HTTPStatus.ACCEPTED, HTTPStatus.CREATED} + wait_background_request(api_client, rq_id) - def test_non_rq_job_owner_cannot_check_status_of_merge_in_sandbox( + def test_user_without_rights_cannot_check_status_of_merge_in_sandbox( self, find_sandbox_task_with_consensus, users, ): task, task_staff = find_sandbox_task_with_consensus(is_staff=True) - other_user = next( + another_user = next( u for u in users if ( u["id"] != task_staff["id"] and not u["is_superuser"] and u["id"] != task["owner"]["id"] + and u["id"] != (task["assignee"] or {}).get("id") ) ) rq_id = self.request_merge(task_id=task["id"], user=task_staff["username"]) - self._test_check_merge_status_by_non_rq_job_owner( - rq_id, staff_user=task_staff["username"], other_user=other_user["username"] + self._test_check_merge_status( + rq_id, staff_user=task_staff["username"], another_user=another_user["username"] ) - @pytest.mark.parametrize("role", _PermissionTestBase._default_org_roles) - def test_non_rq_job_owner_cannot_check_status_of_merge_in_org( + @pytest.mark.parametrize( + "same_org, role", + [ + pair + for pair in product([True, False], _PermissionTestBase._default_org_roles) + if not (pair[0] and pair[1] in ["owner", "maintainer"]) + ], + ) + def test_user_without_rights_cannot_check_status_of_merge_in_org( self, find_org_task_with_consensus, - find_users, + same_org: bool, role: str, + organizations, ): task, task_staff = find_org_task_with_consensus(is_staff=True, user_org_role="supervisor") - other_user = next( - u - for u in find_users(role=role, org=task["organization"]) - if ( - u["id"] != task_staff["id"] - and not u["is_superuser"] - and u["id"] != task["owner"]["id"] - ) + # create a new user that passes the requirements + another_user = register_new_user(f"{same_org}{role}") + org_id = ( + task["organization"] + if same_org + else next(o for o in organizations if o["id"] != task["organization"])["id"] ) + invite_user_to_org(another_user["email"], org_id, role) + rq_id = self.request_merge(task_id=task["id"], user=task_staff["username"]) - self._test_check_merge_status_by_non_rq_job_owner( - rq_id, staff_user=task_staff["username"], other_user=other_user["username"] + self._test_check_merge_status( + rq_id, staff_user=task_staff["username"], another_user=another_user["username"] + ) + + @pytest.mark.parametrize( + "role", + # owner and maintainer have rights even without being assigned to a task + ("supervisor", "worker"), + ) + def test_task_assignee_can_check_status_of_merge_in_org( + self, + find_org_task_with_consensus, + role: str, + ): + task, another_user = find_org_task_with_consensus(is_staff=False, user_org_role=role) + task_owner = task["owner"] + + rq_id = self.request_merge(task_id=task["id"], user=task_owner["username"]) + self._test_check_merge_status( + rq_id, + staff_user=task_owner["username"], + another_user=another_user["username"], + ) + + with make_api_client(task_owner["username"]) as api_client: + api_client.tasks_api.partial_update( + task["id"], + patched_task_write_request=models.PatchedTaskWriteRequest( + assignee_id=another_user["id"] + ), + ) + + self._test_check_merge_status( + rq_id, + staff_user=task_owner["username"], + another_user=another_user["username"], + another_user_status=HTTPStatus.OK, ) @pytest.mark.parametrize("is_sandbox", (True, False)) @@ -373,10 +418,7 @@ def test_admin_can_check_status_of_merge( rq_id = self.request_merge(task_id=task["id"], user=task_staff["username"]) with make_api_client(admin["username"]) as api_client: - (_, response) = api_client.consensus_api.create_merge( - rq_id=rq_id, _parse_response=False - ) - assert response.status in {HTTPStatus.ACCEPTED, HTTPStatus.CREATED} + wait_background_request(api_client, rq_id) class TestSimpleConsensusSettingsFilters(CollectionSimpleFilterTestBase): @@ -667,13 +709,11 @@ def test_quorum_is_applied(self, admin_user, jobs, labels, consensus_settings, t api_client.jobs_api.update_annotations( replicas[0]["id"], - job_annotations_update_request=models.JobAnnotationsUpdateRequest(shapes=[bbox1]), + labeled_data_request=models.LabeledDataRequest(shapes=[bbox1]), ) api_client.jobs_api.update_annotations( replicas[1]["id"], - job_annotations_update_request=models.JobAnnotationsUpdateRequest( - shapes=[bbox1, bbox2] - ), + labeled_data_request=models.LabeledDataRequest(shapes=[bbox1, bbox2]), ) self.merge(job_id=parent_job["id"], user=admin_user) diff --git a/tests/python/rest_api/test_projects.py b/tests/python/rest_api/test_projects.py index d049f46a8ef1..5bb5f2b1723c 100644 --- a/tests/python/rest_api/test_projects.py +++ b/tests/python/rest_api/test_projects.py @@ -43,6 +43,7 @@ export_dataset, export_project_backup, export_project_dataset, + import_project_backup, ) @@ -389,15 +390,7 @@ def test_admin_can_get_project_backup_and_create_project_by_backup(self, admin_u tmp_file = io.BytesIO(backup) tmp_file.name = "dataset.zip" - import_data = { - "project_file": tmp_file, - } - - with make_api_client(admin_user) as api_client: - (_, response) = api_client.projects_api.create_backup( - backup_write_request=deepcopy(import_data), _content_type="multipart/form-data" - ) - assert response.status == HTTPStatus.ACCEPTED + import_project_backup(admin_user, tmp_file) @pytest.mark.usefixtures("restore_db_per_function") @@ -630,7 +623,7 @@ def _test_import_project(self, username, project_id, format_name, data): (_, response) = api_client.projects_api.create_dataset( id=project_id, format=format_name, - dataset_write_request=deepcopy(data), + dataset_file_request={"dataset_file": data}, _content_type="multipart/form-data", ) assert response.status == HTTPStatus.ACCEPTED @@ -670,11 +663,7 @@ def test_can_import_dataset_in_org(self, admin_user: str): tmp_file = io.BytesIO(dataset) tmp_file.name = "dataset.zip" - import_data = { - "dataset_file": tmp_file, - } - - self._test_import_project(admin_user, project_id, "CVAT 1.1", import_data) + self._test_import_project(admin_user, project_id, "CVAT 1.1", tmp_file) @pytest.mark.parametrize( "export_format, import_format", @@ -712,11 +701,8 @@ def test_can_export_and_import_dataset_with_skeletons( tmp_file = io.BytesIO(dataset) tmp_file.name = "dataset.zip" - import_data = { - "dataset_file": tmp_file, - } - self._test_import_project(admin_user, project_id, import_format, import_data) + self._test_import_project(admin_user, project_id, import_format, tmp_file) @pytest.mark.parametrize("format_name", ("Datumaro 1.0", "ImageNet 1.0", "PASCAL VOC 1.1")) def test_can_import_export_dataset_with_some_format(self, format_name: str): @@ -735,11 +721,7 @@ def test_can_import_export_dataset_with_some_format(self, format_name: str): tmp_file = io.BytesIO(dataset) tmp_file.name = "dataset.zip" - import_data = { - "dataset_file": tmp_file, - } - - self._test_import_project(username, project_id, format_name, import_data) + self._test_import_project(username, project_id, format_name, tmp_file) @pytest.mark.parametrize("username, pid", [("admin1", 8)]) @pytest.mark.parametrize( @@ -801,11 +783,7 @@ def test_can_import_export_annotations_with_rotation(self): tmp_file = io.BytesIO(dataset) tmp_file.name = "dataset.zip" - import_data = { - "dataset_file": tmp_file, - } - - self._test_import_project(username, project_id, "CVAT 1.1", import_data) + self._test_import_project(username, project_id, "CVAT 1.1", tmp_file) response = get_method(username, f"tasks", project_id=project_id) assert response.status_code == HTTPStatus.OK @@ -915,11 +893,7 @@ def test_can_export_and_import_dataset_after_deleting_related_storage( with io.BytesIO(dataset) as tmp_file: tmp_file.name = "dataset.zip" - import_data = { - "dataset_file": tmp_file, - } - - self._test_import_project(admin_user, project_id, "CVAT 1.1", import_data) + self._test_import_project(admin_user, project_id, "CVAT 1.1", tmp_file) @pytest.mark.parametrize( "dimension, format_name", @@ -972,14 +946,12 @@ def _export_task(task_id: int, format_name: str) -> io.BytesIO: ) ) - import_data = {"dataset_file": dataset_file} - with pytest.raises(exceptions.ApiException, match="Dataset file should be zip archive"): self._test_import_project( admin_user, project.id, format_name=format_name, - data=import_data, + data=dataset_file, ) @pytest.mark.parametrize( diff --git a/tests/python/rest_api/test_quality_control.py b/tests/python/rest_api/test_quality_control.py index aab71846c7f0..9da1381ee56f 100644 --- a/tests/python/rest_api/test_quality_control.py +++ b/tests/python/rest_api/test_quality_control.py @@ -7,7 +7,7 @@ from copy import deepcopy from functools import partial from http import HTTPStatus -from itertools import groupby +from itertools import groupby, product from typing import Any, Callable, Optional import pytest @@ -18,7 +18,13 @@ from shared.utils.config import make_api_client -from .utils import CollectionSimpleFilterTestBase, parse_frame_step +from .utils import ( + CollectionSimpleFilterTestBase, + invite_user_to_org, + parse_frame_step, + register_new_user, + wait_background_request, +) class _PermissionTestBase: @@ -31,14 +37,14 @@ def create_quality_report(self, user: str, task_id: int): assert response.status == HTTPStatus.ACCEPTED rq_id = json.loads(response.data)["rq_id"] - while True: - (_, response) = api_client.quality_api.create_report( - rq_id=rq_id, _parse_response=False - ) - assert response.status in [HTTPStatus.CREATED, HTTPStatus.ACCEPTED] + background_request, _ = wait_background_request(api_client, rq_id) + assert ( + background_request.status.value + == models.RequestStatus.allowed_values[("value",)]["FINISHED"] + ) + report_id = background_request.result_id - if response.status == HTTPStatus.CREATED: - break + _, response = api_client.quality_api.retrieve_report(report_id, _parse_response=False) return json.loads(response.data) @@ -59,7 +65,7 @@ def create_gt_job(self, user, task_id): (labels, _) = api_client.labels_api.list(task_id=task_id) api_client.jobs_api.update_annotations( job.id, - job_annotations_update_request=dict( + labeled_data_request=dict( shapes=[ dict( frame=start_frame, @@ -169,6 +175,7 @@ def find_org_task_without_gt(self, find_org_task): ("worker", False, False), ], ) + _default_org_roles = ("owner", "maintainer", "supervisor", "worker") @pytest.mark.usefixtures("restore_db_per_class") @@ -581,28 +588,63 @@ def _initialize_report_creation(task_id: int, user: str) -> str: return rq_id - # only rq job owner or admin now has the right to check status of report creation - def _test_check_status_of_report_creation_by_non_rq_job_owner( + # users with task:view rights can check status of report creation + def _test_check_status_of_report_creation( self, rq_id: str, *, task_staff: str, another_user: str, + another_user_status: int = HTTPStatus.FORBIDDEN, ): with make_api_client(another_user) as api_client: - (_, response) = api_client.quality_api.create_report( - rq_id=rq_id, _parse_response=False, _check_status=False + (_, response) = api_client.requests_api.retrieve( + rq_id, _parse_response=False, _check_status=False ) - assert response.status == HTTPStatus.NOT_FOUND - assert json.loads(response.data)["detail"] == "Unknown request id" + assert response.status == another_user_status with make_api_client(task_staff) as api_client: - (_, response) = api_client.quality_api.create_report( - rq_id=rq_id, _parse_response=False, _check_status=False + wait_background_request(api_client, rq_id) + + @pytest.mark.parametrize( + "role", + # owner and maintainer have rights even without being assigned to a task + ("supervisor", "worker"), + ) + def test_task_assignee_can_check_status_of_report_creation_in_org( + self, + find_org_task_without_gt: Callable[[bool, str], tuple[dict[str, Any], dict[str, Any]]], + role: str, + admin_user: str, + ): + task, another_user = find_org_task_without_gt(is_staff=False, user_org_role=role) + self.create_gt_job(admin_user, task["id"]) + + task_owner = task["owner"] + + rq_id = self._initialize_report_creation(task_id=task["id"], user=task_owner["username"]) + self._test_check_status_of_report_creation( + rq_id, + task_staff=task_owner["username"], + another_user=another_user["username"], + ) + + with make_api_client(task_owner["username"]) as api_client: + api_client.tasks_api.partial_update( + task["id"], + patched_task_write_request=models.PatchedTaskWriteRequest( + assignee_id=another_user["id"] + ), ) - assert response.status in {HTTPStatus.ACCEPTED, HTTPStatus.CREATED} - def test_non_rq_job_owner_cannot_check_status_of_report_creation_in_sandbox( + self._test_check_status_of_report_creation( + rq_id, + task_staff=task_owner["username"], + another_user=another_user["username"], + another_user_status=HTTPStatus.OK, + ) + + def test_user_without_rights_cannot_check_status_of_report_creation_in_sandbox( self, find_sandbox_task_without_gt: Callable[[bool], tuple[dict[str, Any], dict[str, Any]]], admin_user: str, @@ -619,36 +661,45 @@ def test_non_rq_job_owner_cannot_check_status_of_report_creation_in_sandbox( u["id"] != task_staff["id"] and not u["is_superuser"] and u["id"] != task["owner"]["id"] + and u["id"] != (task["assignee"] or {}).get("id") ) ) rq_id = self._initialize_report_creation(task["id"], task_staff["username"]) - self._test_check_status_of_report_creation_by_non_rq_job_owner( + self._test_check_status_of_report_creation( rq_id, task_staff=task_staff["username"], another_user=another_user["username"] ) - @pytest.mark.parametrize("role", ("owner", "maintainer", "supervisor", "worker")) - def test_non_rq_job_owner_cannot_check_status_of_report_creation_in_org( + @pytest.mark.parametrize( + "same_org, role", + [ + pair + for pair in product([True, False], _PermissionTestBase._default_org_roles) + if not (pair[0] and pair[1] in ["owner", "maintainer"]) + ], + ) + def test_user_without_rights_cannot_check_status_of_report_creation_in_org( self, + same_org: bool, role: str, admin_user: str, find_org_task_without_gt: Callable[[bool, str], tuple[dict[str, Any], dict[str, Any]]], - find_users: Callable[..., list[dict[str, Any]]], + organizations, ): task, task_staff = find_org_task_without_gt(is_staff=True, user_org_role="supervisor") self.create_gt_job(admin_user, task["id"]) - another_user = next( - u - for u in find_users(role=role, org=task["organization"]) - if ( - u["id"] != task_staff["id"] - and not u["is_superuser"] - and u["id"] != task["owner"]["id"] - ) + # create another user that passes the requirements + another_user = register_new_user(f"{same_org}{role}") + org_id = ( + task["organization"] + if same_org + else next(o for o in organizations if o["id"] != task["organization"])["id"] ) + invite_user_to_org(another_user["email"], org_id, role) + rq_id = self._initialize_report_creation(task["id"], task_staff["username"]) - self._test_check_status_of_report_creation_by_non_rq_job_owner( + self._test_check_status_of_report_creation( rq_id, task_staff=task_staff["username"], another_user=another_user["username"] ) @@ -682,8 +733,7 @@ def test_admin_can_check_status_of_report_creation( rq_id = self._initialize_report_creation(task["id"], task_staff["username"]) with make_api_client(admin["username"]) as api_client: - (_, response) = api_client.quality_api.create_report(rq_id=rq_id, _parse_response=False) - assert response.status in {HTTPStatus.ACCEPTED, HTTPStatus.CREATED} + wait_background_request(api_client, rq_id) class TestSimpleQualityReportsFilters(CollectionSimpleFilterTestBase): @@ -1308,7 +1358,7 @@ def test_can_compute_quality_if_non_skeleton_label_follows_skeleton_label( new_label_id = new_label_obj.results[0].id api_client.tasks_api.update_annotations( task_id, - task_annotations_update_request={ + labeled_data_request={ "shapes": [ models.LabeledShapeRequest( type="rectangle", @@ -1518,12 +1568,10 @@ def test_quality_metrics_in_task_with_gt_and_tracks( ] } - api_client.jobs_api.update_annotations( - gt_job.id, job_annotations_update_request=gt_annotations - ) + api_client.jobs_api.update_annotations(gt_job.id, labeled_data_request=gt_annotations) api_client.tasks_api.update_annotations( - task_id, task_annotations_update_request=normal_annotations + task_id, labeled_data_request=normal_annotations ) api_client.jobs_api.partial_update( diff --git a/tests/python/rest_api/test_requests.py b/tests/python/rest_api/test_requests.py index b9bff93eece6..797307e3c60e 100644 --- a/tests/python/rest_api/test_requests.py +++ b/tests/python/rest_api/test_requests.py @@ -5,7 +5,8 @@ import io import json from http import HTTPStatus -from urllib.parse import urlparse +from typing import Optional +from urllib.parse import parse_qsl, urlparse import pytest from cvat_sdk.api_client import ApiClient, models @@ -25,7 +26,12 @@ export_project_dataset, export_task_backup, export_task_dataset, + import_job_annotations, + import_project_backup, + import_project_dataset, + import_task_annotations, import_task_backup, + wait_background_request, ) @@ -112,9 +118,7 @@ def _make_requests(project_ids: list[int], task_ids: list[int], job_ids: list[in if resource_type == "task" and subresource == "backup": import_task_backup( self.user, - data={ - "task_file": tmp_file, - }, + file_content=tmp_file, ) empty_file = io.BytesIO(b"empty_file") @@ -123,9 +127,7 @@ def _make_requests(project_ids: list[int], task_ids: list[int], job_ids: list[in # import corrupted backup import_task_backup( self.user, - data={ - "task_file": empty_file, - }, + file_content=empty_file, ) return _make_requests @@ -259,8 +261,8 @@ def test_list_requests_when_there_is_job_with_non_regular_or_corrupted_meta( assert 2 == background_requests.count corrupted_job, normal_job = background_requests.results - - remove_meta_command = f'redis-cli -e HDEL rq:job:{corrupted_job["id"]} meta' + corrupted_job_key = f"rq:job:{corrupted_job['id']}" + remove_meta_command = f'redis-cli -e HDEL "{corrupted_job_key}" meta' if request.config.getoption("--platform") == "local": stdout, _ = docker_exec_redis_inmem(["sh", "-c", remove_meta_command]) @@ -286,10 +288,14 @@ def test_list_requests_when_there_is_job_with_non_regular_or_corrupted_meta( @pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestGetRequests: - def _test_get_request_200(self, api_client: ApiClient, rq_id: str, **kwargs) -> models.Request: + def _test_get_request_200( + self, api_client: ApiClient, rq_id: str, validate_rq_id: bool = True, **kwargs + ) -> models.Request: (background_request, response) = api_client.requests_api.retrieve(rq_id, **kwargs) assert response.status == HTTPStatus.OK - assert background_request.id == rq_id + + if validate_rq_id: + assert background_request.id == rq_id return background_request @@ -312,16 +318,16 @@ def test_owner_can_retrieve_request(self, format_name: str, save_images: bool, p owner = project["owner"] subresource = "dataset" if save_images else "annotations" - export_project_dataset( + request_id = export_project_dataset( owner["username"], save_images=save_images, id=project["id"], download_result=False, + format=format_name, ) - rq_id = f'export:project-{project["id"]}-{subresource}-in-{format_name.replace(" ", "_").replace(".", "@")}-format-by-{owner["id"]}' with make_api_client(owner["username"]) as owner_client: - bg_request = self._test_get_request_200(owner_client, rq_id) + bg_request = self._test_get_request_200(owner_client, request_id) assert ( bg_request.created_date @@ -331,7 +337,7 @@ def test_owner_can_retrieve_request(self, format_name: str, save_images: bool, p ) assert bg_request.operation.format == format_name assert bg_request.operation.project_id == project["id"] - assert bg_request.operation.target.value == "project" + assert bg_request.operation.target == "project" assert bg_request.operation.task_id is None assert bg_request.operation.job_id is None assert bg_request.operation.type == f"export:{subresource}" @@ -341,8 +347,7 @@ def test_owner_can_retrieve_request(self, format_name: str, save_images: bool, p parsed_url = urlparse(bg_request.result_url) assert all([parsed_url.scheme, parsed_url.netloc, parsed_url.path, parsed_url.query]) - @pytest.mark.parametrize("format_name", ("CVAT for images 1.1",)) - def test_non_owner_cannot_retrieve_request(self, find_users, projects, format_name: str): + def test_non_owner_cannot_retrieve_request(self, find_users, projects): project = next( ( p @@ -353,13 +358,254 @@ def test_non_owner_cannot_retrieve_request(self, find_users, projects, format_na owner = project["owner"] malefactor = find_users(exclude_username=owner["username"])[0] - export_project_dataset( + request_id = export_project_dataset( owner["username"], save_images=True, id=project["id"], download_result=False, ) - rq_id = f'export:project-{project["id"]}-dataset-in-{format_name.replace(" ", "_").replace(".", "@")}-format-by-{owner["id"]}' - with make_api_client(malefactor["username"]) as malefactor_client: - self._test_get_request_403(malefactor_client, rq_id) + self._test_get_request_403(malefactor_client, request_id) + + def _test_get_request_using_legacy_id( + self, + legacy_request_id: str, + username: str, + *, + action: str, + target_type: str, + subresource: Optional[str] = None, + ): + with make_api_client(username) as api_client: + bg_requests, _ = api_client.requests_api.list( + target=target_type, + action=action, + **({"subresource": subresource} if subresource else {}), + ) + assert len(bg_requests.results) == 1 + request_id = bg_requests.results[0].id + bg_request = self._test_get_request_200( + api_client, legacy_request_id, validate_rq_id=False + ) + assert bg_request.id == request_id + + @pytest.mark.parametrize("target_type", ("project", "task", "job")) + @pytest.mark.parametrize("save_images", (True, False)) + @pytest.mark.parametrize("export_format", ("CVAT for images 1.1",)) + @pytest.mark.parametrize("import_format", ("CVAT 1.1",)) + def test_can_retrieve_dataset_import_export_requests_using_legacy_ids( + self, + target_type: str, + save_images: bool, + export_format: str, + import_format: str, + projects, + tasks, + jobs, + ): + def build_legacy_id_for_export_request( + *, + target_type: str, + target_id: int, + subresource: str, + format_name: str, + user_id: int, + ): + return f"export:{target_type}-{target_id}-{subresource}-in-{format_name.replace(' ', '_').replace('.', '@')}-format-by-{user_id}" + + def build_legacy_id_for_import_request( + *, + target_type: str, + target_id: int, + subresource: str, + ): + return f"import:{target_type}-{target_id}-{subresource}" + + if target_type == "project": + export_func, import_func = export_project_dataset, import_project_dataset + target = next(iter(projects)) + owner = target["owner"] + elif target_type == "task": + export_func, import_func = export_task_dataset, import_task_annotations + target = next(iter(tasks)) + owner = target["owner"] + else: + assert target_type == "job" + export_func, import_func = export_job_dataset, import_job_annotations + target = next(iter(jobs)) + owner = tasks[target["task_id"]]["owner"] + + target_id = target["id"] + subresource = "dataset" if save_images else "annotations" + file_content = io.BytesIO( + export_func( + owner["username"], + save_images=save_images, + format=export_format, + id=target_id, + ) + ) + file_content.name = "file.zip" + + legacy_request_id = build_legacy_id_for_export_request( + target_type=target_type, + target_id=target["id"], + subresource=subresource, + format_name=export_format, + user_id=owner["id"], + ) + + self._test_get_request_using_legacy_id( + legacy_request_id, + owner["username"], + action="export", + target_type=target_type, + subresource=subresource, + ) + + # check import requests + if not save_images and target_type == "project" or save_images and target_type != "project": + # skip: + # importing annotations into a project + # importing datasets into a task or job + return + + import_func( + owner["username"], + file_content=file_content, + id=target_id, + format=import_format, + ) + + legacy_request_id = build_legacy_id_for_import_request( + target_type=target_type, target_id=target_id, subresource=subresource + ) + self._test_get_request_using_legacy_id( + legacy_request_id, + owner["username"], + action="import", + target_type=target_type, + subresource=subresource, + ) + + @pytest.mark.parametrize("target_type", ("project", "task")) + def test_can_retrieve_backup_import_export_requests_using_legacy_ids( + self, + target_type: str, + projects, + tasks, + ): + def build_legacy_id_for_export_request( + *, + target_type: str, + target_id: int, + user_id: int, + ): + return f"export:{target_type}-{target_id}-backup-by-{user_id}" + + def build_legacy_id_for_import_request( + *, + target_type: str, + uuid_: str, + ): + return f"import:{target_type}-{uuid_}-backup" + + if target_type == "project": + export_func, import_func = export_project_backup, import_project_backup + target = next(iter(projects)) + else: + assert target_type == "task" + export_func, import_func = export_task_backup, import_task_backup + target = next(iter(tasks)) + + owner = target["owner"] + + # check export requests + backup_file = io.BytesIO( + export_func( + owner["username"], + id=target["id"], + ) + ) + backup_file.name = "file.zip" + + legacy_request_id = build_legacy_id_for_export_request( + target_type=target_type, target_id=target["id"], user_id=owner["id"] + ) + self._test_get_request_using_legacy_id( + legacy_request_id, + owner["username"], + action="export", + target_type=target_type, + subresource="backup", + ) + + # check import requests + result_id = import_func( + owner["username"], + file_content=backup_file, + ).id + legacy_request_id = build_legacy_id_for_import_request( + target_type=target_type, uuid_=dict(parse_qsl(result_id))["id"] + ) + + self._test_get_request_using_legacy_id( + legacy_request_id, + owner["username"], + action="import", + target_type=target_type, + subresource="backup", + ) + + def test_can_retrieve_task_creation_requests_using_legacy_ids(self, admin_user: str): + task_id = create_task( + admin_user, + spec={"name": "Test task", "labels": [{"name": "car"}]}, + data={ + "image_quality": 75, + "client_files": generate_image_files(2), + "segment_size": 1, + }, + )[0] + + legacy_request_id = f"create:task-{task_id}" + self._test_get_request_using_legacy_id( + legacy_request_id, admin_user, action="create", target_type="task" + ) + + def test_can_retrieve_quality_calculation_requests_using_legacy_ids(self, jobs, tasks): + gt_job = next( + j + for j in jobs + if ( + j["type"] == "ground_truth" + and j["stage"] == "acceptance" + and j["state"] == "completed" + ) + ) + task_id = gt_job["task_id"] + owner = tasks[task_id]["owner"] + + legacy_request_id = f"quality-check-task-{task_id}-user-{owner['id']}" + + with make_api_client(owner["username"]) as api_client: + # initiate quality report calculation + (_, response) = api_client.quality_api.create_report( + quality_report_create_request=models.QualityReportCreateRequest(task_id=task_id), + _parse_response=False, + ) + assert response.status == HTTPStatus.ACCEPTED + request_id = json.loads(response.data)["rq_id"] + + # get background request details using common request API + bg_request = self._test_get_request_200( + api_client, legacy_request_id, validate_rq_id=False + ) + assert bg_request.id == request_id + + # get quality report by legacy request ID using the deprecated API endpoint + wait_background_request(api_client, request_id) + api_client.quality_api.create_report( + quality_report_create_request=models.QualityReportCreateRequest(task_id=task_id), + rq_id=request_id, + ) diff --git a/tests/python/rest_api/test_resource_import_export.py b/tests/python/rest_api/test_resource_import_export.py index 84582871c2cc..13f0151563cc 100644 --- a/tests/python/rest_api/test_resource_import_export.py +++ b/tests/python/rest_api/test_resource_import_export.py @@ -202,7 +202,19 @@ def test_import_resource_from_cloud_storage_with_specific_location( resource, is_default=False, obj=obj, cloud_storage_id=cloud_storage_id ) self._export_resource(cloud_storage, obj_id, obj, resource, **export_kwargs) - self._import_resource(cloud_storage, resource, obj_id, obj, **kwargs) + self._import_resource( + cloud_storage, + resource, + *( + [ + obj_id, + ] + if resource != "backup" + else [] + ), + obj, + **kwargs, + ) @pytest.mark.usefixtures("restore_redis_inmem_per_function") @pytest.mark.parametrize( @@ -326,7 +338,13 @@ def test_user_cannot_import_from_cloud_storage_with_specific_location_without_ac self._import_resource( cloud_storage, resource, - obj_id, + *( + [ + obj_id, + ] + if resource != "backup" + else [] + ), obj, user=user, _expect_status=HTTPStatus.FORBIDDEN, diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index 3f2339774589..337e3769c6d8 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -4116,8 +4116,8 @@ def test_cannot_export_backup_for_task_without_data(self, tasks): with pytest.raises(ApiException) as exc: self._test_can_export_backup(task_id) - assert exc.status == HTTPStatus.BAD_REQUEST - assert "Backup of a task without data is not allowed" == exc.body.encode() + assert exc.value.status == HTTPStatus.BAD_REQUEST + assert "Backup of a task without data is not allowed" in exc.value.body.decode() @pytest.mark.with_external_services def test_can_export_and_import_backup_task_with_cloud_storage(self, tasks): diff --git a/tests/python/rest_api/utils.py b/tests/python/rest_api/utils.py index 460695f8a887..822dd97dfc52 100644 --- a/tests/python/rest_api/utils.py +++ b/tests/python/rest_api/utils.py @@ -7,6 +7,7 @@ from collections.abc import Hashable, Iterator, Sequence from copy import deepcopy from http import HTTPStatus +from io import BytesIO from time import sleep from typing import Any, Callable, Iterable, Optional, TypeVar, Union @@ -19,8 +20,9 @@ from cvat_sdk.api_client.exceptions import ForbiddenException from cvat_sdk.core.helpers import get_paginated_collection from deepdiff import DeepDiff +from urllib3 import HTTPResponse -from shared.utils.config import make_api_client +from shared.utils.config import USER_PASS, make_api_client, post_method def initialize_export(endpoint: Endpoint, *, expect_forbidden: bool = False, **kwargs) -> str: @@ -41,14 +43,13 @@ def initialize_export(endpoint: Endpoint, *, expect_forbidden: bool = False, **k return rq_id -def wait_and_download_v2( +def wait_background_request( api_client: ApiClient, rq_id: str, *, max_retries: int = 50, interval: float = 0.1, - download_result: bool = True, -) -> Optional[bytes]: +) -> tuple[models.Request, HTTPResponse]: for _ in range(max_retries): (background_request, response) = api_client.requests_api.retrieve(rq_id) assert response.status == HTTPStatus.OK @@ -56,28 +57,34 @@ def wait_and_download_v2( background_request.status.value == models.RequestStatus.allowed_values[("value",)]["FINISHED"] ): - break + return background_request, response sleep(interval) - else: - assert False, ( - f"Export process was not finished within allowed time ({interval * max_retries}, sec). " - + f"Last status was: {background_request.status.value}" - ) - if not download_result: - return None + assert False, ( + f"Export process was not finished within allowed time ({interval * max_retries}, sec). " + + f"Last status was: {background_request.status.value}" + ) - # return downloaded file in case of local downloading or None otherwise - if background_request.result_url: - response = requests.get( - background_request.result_url, - auth=(api_client.configuration.username, api_client.configuration.password), - ) - assert response.status_code == HTTPStatus.OK, f"Status: {response.status_code}" - return response.content +def wait_and_download_v2( + api_client: ApiClient, + rq_id: str, + *, + max_retries: int = 50, + interval: float = 0.1, +) -> bytes: + background_request, _ = wait_background_request( + api_client, rq_id, max_retries=max_retries, interval=interval + ) - return None + # return downloaded file in case of local downloading + assert background_request.result_url + response = requests.get( + background_request.result_url, + auth=(api_client.configuration.username, api_client.configuration.password), + ) + assert response.status_code == HTTPStatus.OK, f"Status: {response.status_code}" + return response.content def export_v2( @@ -89,7 +96,7 @@ def export_v2( wait_result: bool = True, download_result: bool = True, **kwargs, -) -> Optional[bytes]: +) -> Union[bytes, str]: """Export datasets|annotations|backups using the second version of export API Args: @@ -101,22 +108,27 @@ def export_v2( Returns: bytes: The content of the file if downloaded locally. - None: If `wait_result` or `download_result` were False or the file is downloaded to cloud storage. + str: If `wait_result` or `download_result` were False. """ # initialize background process and ensure that the first request returns 403 code if request should be forbidden rq_id = initialize_export(endpoint, expect_forbidden=expect_forbidden, **kwargs) if not wait_result: - return None + return rq_id # check status of background process - return wait_and_download_v2( - endpoint.api_client, - rq_id, - max_retries=max_retries, - interval=interval, - download_result=download_result, + if download_result: + return wait_and_download_v2( + endpoint.api_client, + rq_id, + max_retries=max_retries, + interval=interval, + ) + + background_request, _ = wait_background_request( + endpoint.api_client, rq_id, max_retries=max_retries, interval=interval ) + return background_request.id def export_dataset( @@ -184,7 +196,7 @@ def import_resource( expect_forbidden: bool = False, wait_result: bool = True, **kwargs, -) -> None: +) -> Optional[models.Request]: # initialize background process and ensure that the first request returns 403 code if request should be forbidden (_, response) = endpoint.call_with_http_info( **kwargs, @@ -220,6 +232,7 @@ def import_resource( f"Import process was not finished within allowed time ({interval * max_retries}, sec). " + f"Last status was: {background_request.status.value}" ) + return background_request def import_backup( @@ -228,19 +241,50 @@ def import_backup( max_retries: int = 50, interval: float = 0.1, **kwargs, -) -> None: +): endpoint = api.create_backup_endpoint return import_resource(endpoint, max_retries=max_retries, interval=interval, **kwargs) -def import_project_backup(username: str, data: dict, **kwargs) -> None: +def import_project_backup(username: str, file_content: BytesIO, **kwargs): + with make_api_client(username) as api_client: + return import_backup( + api_client.projects_api, project_file_request={"project_file": file_content}, **kwargs + ) + + +def import_task_backup(username: str, file_content: BytesIO, **kwargs): + with make_api_client(username) as api_client: + return import_backup( + api_client.tasks_api, task_file_request={"task_file": file_content}, **kwargs + ) + + +def import_project_dataset(username: str, file_content: BytesIO, **kwargs): with make_api_client(username) as api_client: - return import_backup(api_client.projects_api, project_file_request=deepcopy(data), **kwargs) + return import_resource( + api_client.projects_api.create_dataset_endpoint, + dataset_file_request={"dataset_file": file_content}, + **kwargs, + ) + + +def import_task_annotations(username: str, file_content: BytesIO, **kwargs): + with make_api_client(username) as api_client: + return import_resource( + api_client.tasks_api.create_annotations_endpoint, + annotation_file_request={"annotation_file": file_content}, + **kwargs, + ) -def import_task_backup(username: str, data: dict, **kwargs) -> None: +def import_job_annotations(username: str, file_content: BytesIO, **kwargs): with make_api_client(username) as api_client: - return import_backup(api_client.tasks_api, task_file_request=deepcopy(data), **kwargs) + return import_resource( + api_client.jobs_api.create_annotations_endpoint, + annotation_file_request={"annotation_file": file_content}, + **kwargs, + ) FieldPath = Sequence[Union[str, Callable]] @@ -467,3 +511,35 @@ def unique( it: Union[Iterator[_T], Iterable[_T]], *, key: Callable[[_T], Hashable] = None ) -> Iterable[_T]: return {key(v): v for v in it}.values() + + +def register_new_user(username: str) -> dict[str, Any]: + response = post_method( + "admin1", + "auth/register", + data={ + "username": username, + "password1": USER_PASS, + "password2": USER_PASS, + "email": f"{username}@email.com", + }, + ) + + assert response.status_code == HTTPStatus.CREATED + return response.json() + + +def invite_user_to_org( + user_email: str, + org_id: int, + role: str, +): + with make_api_client("admin1") as api_client: + invitation, _ = api_client.invitations_api.create( + models.InvitationWriteRequest( + role=role, + email=user_email, + ), + org_id=org_id, + ) + return invitation diff --git a/tests/python/shared/fixtures/data.py b/tests/python/shared/fixtures/data.py index eb8f7393cd8c..544ed74a14ff 100644 --- a/tests/python/shared/fixtures/data.py +++ b/tests/python/shared/fixtures/data.py @@ -394,6 +394,7 @@ def add_row(**kwargs): id=user["id"], privilege=group, has_analytics_access=user["has_analytics_access"], + is_superuser=user["is_superuser"], ) for membership in memberships: @@ -407,6 +408,7 @@ def add_row(**kwargs): org=membership["organization"], membership_id=membership["id"], has_analytics_access=users_by_name[username]["has_analytics_access"], + is_superuser=users_by_name[username]["is_superuser"], ) return data diff --git a/tests/python/shared/utils/resource_import_export.py b/tests/python/shared/utils/resource_import_export.py index c61d5874b3fe..e9ce58555013 100644 --- a/tests/python/shared/utils/resource_import_export.py +++ b/tests/python/shared/utils/resource_import_export.py @@ -4,13 +4,13 @@ from contextlib import ExitStack from http import HTTPStatus from time import sleep -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar import pytest T = TypeVar("T") -from shared.utils.config import get_method, post_method, put_method +from shared.utils.config import get_method, post_method FILENAME_TEMPLATE = "cvat/{}/{}.zip" EXPORT_FORMAT = "CVAT for images 1.1" @@ -53,6 +53,26 @@ def _make_import_resource_params( return params +# FUTURE-TODO: reuse common logic from rest_api/utils +def _wait_request( + user: str, + request_id: str, + *, + sleep_interval: float = 0.1, + number_of_checks: int = 100, +): + for _ in range(number_of_checks): + sleep(sleep_interval) + response = get_method(user, f"requests/{request_id}") + assert response.status_code == HTTPStatus.OK + + request_details = json.loads(response.content) + status = request_details["status"] + assert status in {"started", "queued", "finished", "failed"} + if status in {"finished", "failed"}: + return + + class _CloudStorageResourceTest(ABC): @staticmethod @abstractmethod @@ -90,13 +110,9 @@ def _export_resource_to_cloud_storage( resource: str, *, user: str, - _expect_status: Optional[int] = None, + _expect_status: HTTPStatus = HTTPStatus.ACCEPTED, **kwargs, ): - _expect_status = _expect_status or HTTPStatus.ACCEPTED - - sleep_interval = 0.1 - number_of_checks = 100 # initialize the export process response = post_method( @@ -113,47 +129,22 @@ def _export_resource_to_cloud_storage( rq_id = json.loads(response.content).get("rq_id") assert rq_id, "The rq_id was not found in server request" - for _ in range(number_of_checks): - sleep(sleep_interval) - # use new requests API for checking the status of the operation - response = get_method(user, f"requests/{rq_id}") - assert response.status_code == HTTPStatus.OK - - request_details = json.loads(response.content) - status = request_details["status"] - assert status in {"started", "queued", "finished", "failed"} - if status in {"finished", "failed"}: - break + _wait_request(user, rq_id) def _import_resource_from_cloud_storage( - self, url: str, *, user: str, _expect_status: Optional[int] = None, **kwargs + self, url: str, *, user: str, _expect_status: HTTPStatus = HTTPStatus.ACCEPTED, **kwargs ) -> None: - _expect_status = _expect_status or HTTPStatus.ACCEPTED - response = post_method(user, url, data=None, **kwargs) status = response.status_code - assert status == _expect_status + assert status == _expect_status, status if status == HTTPStatus.FORBIDDEN: return rq_id = response.json().get("rq_id") assert rq_id, "The rq_id parameter was not found in the server response" - number_of_checks = 100 - sleep_interval = 0.1 - - for _ in range(number_of_checks): - sleep(sleep_interval) - # use new requests API for checking the status of the operation - response = get_method(user, f"requests/{rq_id}") - assert response.status_code == HTTPStatus.OK - - request_details = json.loads(response.content) - status = request_details["status"] - assert status in {"started", "queued", "finished", "failed"} - if status in {"finished", "failed"}: - break + _wait_request(user, rq_id) def _import_annotations_from_cloud_storage( self, @@ -161,27 +152,17 @@ def _import_annotations_from_cloud_storage( obj, *, user, - _expect_status: Optional[int] = None, + _expect_status: HTTPStatus = HTTPStatus.ACCEPTED, _check_uploaded: bool = True, **kwargs, ): - _expect_status = _expect_status or HTTPStatus.CREATED - url = f"{obj}/{obj_id}/annotations" - response = post_method(user, url, data=None, **kwargs) - status = response.status_code - - # Only the first POST request contains rq_id in response. - # Exclude cases with 403 expected status. - rq_id = None - if status == HTTPStatus.ACCEPTED: - rq_id = response.json().get("rq_id") - assert rq_id, "The rq_id was not found in the response" + self._import_resource_from_cloud_storage( + url, user=user, _expect_status=_expect_status, **kwargs + ) - while status != _expect_status: - assert status == HTTPStatus.ACCEPTED - response = put_method(user, url, data=None, rq_id=rq_id, **kwargs) - status = response.status_code + if _expect_status == HTTPStatus.FORBIDDEN: + return if _check_uploaded: response = get_method(user, url) @@ -192,40 +173,18 @@ def _import_annotations_from_cloud_storage( assert len(annotations["shapes"]) def _import_backup_from_cloud_storage( - self, obj_id, obj, *, user, _expect_status: Optional[int] = None, **kwargs + self, obj, *, user, _expect_status: HTTPStatus = HTTPStatus.ACCEPTED, **kwargs ): - _expect_status = _expect_status or HTTPStatus.CREATED - - url = f"{obj}/backup" - response = post_method(user, url, data=None, **kwargs) - status = response.status_code - - while status != _expect_status: - assert status == HTTPStatus.ACCEPTED - data = json.loads(response.content.decode("utf8")) - response = post_method(user, url, data=data, **kwargs) - status = response.status_code + self._import_resource_from_cloud_storage( + f"{obj}/backup", user=user, _expect_status=_expect_status, **kwargs + ) def _import_dataset_from_cloud_storage( - self, obj_id, obj, *, user, _expect_status: Optional[int] = None, **kwargs + self, obj_id, obj, *, user, _expect_status: HTTPStatus = HTTPStatus.ACCEPTED, **kwargs ): - _expect_status = _expect_status or HTTPStatus.CREATED - - url = f"{obj}/{obj_id}/dataset" - response = post_method(user, url, data=None, **kwargs) - status = response.status_code - - # Only the first POST request contains rq_id in response. - # Exclude cases with 403 expected status. - rq_id = None - if status == HTTPStatus.ACCEPTED: - rq_id = response.json().get("rq_id") - assert rq_id, "The rq_id was not found in the response" - - while status != _expect_status: - assert status == HTTPStatus.ACCEPTED - response = get_method(user, url, action="import_status", rq_id=rq_id) - status = response.status_code + self._import_resource_from_cloud_storage( + f"{obj}/{obj_id}/dataset", user=user, _expect_status=_expect_status, **kwargs + ) def _import_resource(self, cloud_storage: dict[str, Any], resource_type: str, *args, **kwargs): methods = {