Skip to content

Commit 95443a8

Browse files
committed
refactor: add artifact and job repositories
This commit adds the artifact and job repositories, which are modeled after patterns in the existing repositories. It updates the service layers to use the new data access pattern. It adds a new test suite for the artifact and job repositories. It makes external services (mlflow and redis) accessible by injecting those services into UnitOfWork instead of injecting them into other services directly.
1 parent fb9fe9a commit 95443a8

23 files changed

Lines changed: 3605 additions & 947 deletions

File tree

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
# This Software (Dioptra) is being made available as a public service by the
2+
# National Institute of Standards and Technology (NIST), an Agency of the United
3+
# States Department of Commerce. This software was developed in part by employees of
4+
# NIST and in part by NIST contractors. Copyright in portions of this software that
5+
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
6+
# to Title 17 United States Code Section 105, works of NIST employees are not
7+
# subject to copyright protection in the United States. However, NIST may hold
8+
# international copyright in software created by its employees and domestic
9+
# copyright (or licensing rights) in portions of software that were assigned or
10+
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
11+
# being made available under the Creative Commons Attribution 4.0 International
12+
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
13+
# of the software developed or licensed by NIST.
14+
#
15+
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
16+
# https://creativecommons.org/licenses/by/4.0/legalcode
17+
"""
18+
The artifact repository: data operations related to artifacts
19+
"""
20+
21+
from collections.abc import Iterable, Sequence
22+
from typing import Any, Final, overload
23+
24+
from sqlalchemy import Select, func, select
25+
from sqlalchemy.orm import aliased
26+
27+
import dioptra.restapi.db.repository.utils as utils
28+
from dioptra.restapi.db.models import Artifact, Group, Resource, Tag
29+
from dioptra.restapi.db.models.plugins import (
30+
ArtifactTask,
31+
PluginTaskOutputParameter,
32+
PluginTaskParameterType,
33+
)
34+
from dioptra.restapi.v1.entity_types import EntityType
35+
36+
37+
class ArtifactRepository:
38+
SEARCHABLE_FIELDS: Final[dict[str, Any]] = {
39+
"artifactUri": lambda x: Artifact.uri.like(x, escape="/"),
40+
"description": lambda x: Artifact.description.like(x, escape="/"),
41+
"tag": lambda x: Artifact.tags.any(Tag.name.like(x, escape="/")),
42+
}
43+
44+
# Maps a general sort criterion name to a Artifact attribute
45+
SORTABLE_FIELDS: Final[dict[str, Any]] = {
46+
"uri": Artifact.uri,
47+
"createdOn": Artifact.created_on,
48+
"lastModifiedOn": Resource.last_modified_on,
49+
"description": Artifact.description,
50+
"job": Artifact.job_id,
51+
}
52+
53+
def __init__(self, session: utils.CompatibleSession[utils.S]):
54+
self.session = session
55+
56+
def create(self, artifact: Artifact) -> None:
57+
"""
58+
Create a new artifact resource. This creates both the resource and the
59+
initial snapshot.
60+
61+
Args:
62+
artifact: The artifact to create
63+
64+
Raises:
65+
EntityExistsError: if the artifact resource or snapshot already
66+
exists, or the artifact name collides with another artifact in the
67+
same group
68+
EntityDoesNotExistError: if the group owner or user creator does
69+
not exist
70+
EntityDeletedError: if the artifact, its creator, or its group owner
71+
is deleted
72+
UserNotInGroupError: if the user creator is not a member of the
73+
group who will own the resource
74+
MismatchedResourceTypeError: if the snapshot or resource's type is
75+
not "artifact"
76+
"""
77+
78+
# Consistency rules:
79+
# - Latest-snapshot artifact uris must be unique within the owning job
80+
# - Artifact snapshots must be of artifact resources
81+
# - For now, the snapshot creator must be a member of the group who
82+
# owns the resource. I think this will become more complicated when
83+
# we implement shares and permissions.
84+
85+
utils.assert_can_create_resource(self.session, artifact, EntityType.ARTIFACT)
86+
87+
self.session.add(artifact)
88+
89+
def create_snapshot(self, artifact: Artifact) -> None:
90+
"""
91+
Create a new artifact snapshot.
92+
93+
Args:
94+
artifact: A Artifact object with the desired snapshot settings
95+
96+
Raises:
97+
EntityDoesNotExistError: if the artifact resource or snapshot creator
98+
user does not exist
99+
EntityExistsError: if the snapshot already exists, or this new
100+
snapshot's artifact name collides with another artifact in the same
101+
group
102+
EntityDeletedError: if the artifact or snapshot creator user are
103+
deleted
104+
UserNotInGroupError: if the snapshot creator user is not a member
105+
of the group who owns the artifact
106+
MismatchedResourceTypeError: if the snapshot or resource's type is
107+
not "artifact"
108+
"""
109+
# Consistency rules:
110+
# - Latest-snapshot artifact uris must be unique within the owning job
111+
# - Artifact snapshots must be of artifact resources
112+
# - Snapshot timestamps must be monotonically increasing(?)
113+
# - For now, the snapshot creator must be a member of the group who
114+
# owns the resource. I think this will become more complicated when
115+
# we implement shares and permissions.
116+
117+
utils.assert_can_create_snapshot(self.session, artifact, EntityType.ARTIFACT)
118+
119+
# Assume that the new snapshot's created_on timestamp is later than the
120+
# current latest timestamp?
121+
122+
self.session.add(artifact)
123+
124+
def delete(self, artifact: Artifact | int) -> None:
125+
"""
126+
Delete a artifact. No-op if the artifact is already deleted.
127+
128+
Args:
129+
artifact: A Artifact object or resource_id primary key value identifying
130+
a artifact resource
131+
132+
Raises:
133+
EntityDoesNotExistError: if the artifact does not exist
134+
"""
135+
136+
utils.delete_resource(self.session, artifact)
137+
138+
@overload
139+
def get(
140+
self,
141+
resource_ids: int,
142+
deletion_policy: utils.DeletionPolicy,
143+
) -> Artifact | None: ...
144+
145+
@overload
146+
def get(
147+
self,
148+
resource_ids: Iterable[int],
149+
deletion_policy: utils.DeletionPolicy,
150+
) -> Sequence[Artifact]: ...
151+
152+
def get(
153+
self,
154+
resource_ids: int | Iterable[int],
155+
deletion_policy: utils.DeletionPolicy = utils.DeletionPolicy.NOT_DELETED,
156+
) -> Artifact | Sequence[Artifact] | None:
157+
"""
158+
Get the latest snapshot of the given artifact resource.
159+
160+
Args:
161+
resource_ids: A single or iterable of artifact resource IDs
162+
deletion_policy: Whether to look at deleted artifacts, non-deleted
163+
artifacts, or all artifacts
164+
165+
Returns:
166+
A Artifact/list of Artifact objects, or None/empty list if none were
167+
found with the given ID(s)
168+
"""
169+
170+
return utils.get_latest_snapshots(
171+
self.session, Artifact, resource_ids, deletion_policy
172+
)
173+
174+
def get_by_job(
175+
self, *job_ids: int, deletion_policy: utils.DeletionPolicy
176+
) -> Sequence[Artifact]:
177+
"""
178+
Get the latest Artifact snapshots associated with the given Job ID(s).
179+
180+
Args:
181+
job_ids: One or more Job resource IDs.
182+
deletion_policy: Whether to look at deleted artifacts, non-deleted
183+
artifacts, or all artifacts
184+
185+
Returns:
186+
A list of Artifact objects, or empty list if none were found with the given
187+
Job ID.
188+
"""
189+
return utils.get_latest_snapshots_where(
190+
self.session,
191+
Artifact,
192+
Artifact.job_id.in_(job_ids),
193+
deletion_policy=deletion_policy,
194+
)
195+
196+
def get_one(
197+
self,
198+
resource_id: int,
199+
deletion_policy: utils.DeletionPolicy,
200+
) -> Artifact:
201+
"""
202+
Get the latest snapshot of the given artifact resource; require that
203+
exactly one is found, or raise an exception.
204+
205+
Args:
206+
resource_id: A resource ID
207+
deletion_policy: Whether to look at deleted artifacts, non-deleted
208+
artifacts, or all artifacts
209+
210+
Returns:
211+
A Artifact object
212+
213+
Raises:
214+
EntityDoesNotExistError: if the artifact does not exist in the
215+
database (deleted or not)
216+
EntityExistsError: if the artifact exists and is not deleted, but
217+
policy was to find a deleted artifact
218+
EntityDeletedError: if the artifact is deleted, but policy was to find
219+
a non-deleted artifact
220+
"""
221+
return utils.get_one_latest_snapshot(
222+
self.session, Artifact, resource_id, deletion_policy
223+
)
224+
225+
def get_one_snapshot(
226+
self,
227+
resource_id: int,
228+
snapshot_id: int,
229+
deletion_policy: utils.DeletionPolicy,
230+
) -> Artifact:
231+
"""
232+
Get the a specific artifact snapshot given the resource snapshot ID; require
233+
that exactly one is found, or raise an exception.
234+
235+
Args:
236+
resource_id: A resource ID
237+
snapshot_id: A resource snapshot ID
238+
deletion_policy: Whether to look at deleted artifacts, non-deleted
239+
artifacts, or all artifacts
240+
241+
Returns:
242+
An Artifact object
243+
244+
Raises:
245+
EntityDoesNotExistError: if the artifact does not exist in the
246+
database (deleted or not)
247+
EntityExistsError: if the artifact exists and is not deleted, but
248+
policy was to find a deleted artifact
249+
EntityDeletedError: if the artifact is deleted, but policy was to
250+
find a non-deleted artifact
251+
"""
252+
return utils.get_one_snapshot(
253+
self.session, Artifact, resource_id, snapshot_id, deletion_policy
254+
)
255+
256+
def get_by_filters_paged(
257+
self,
258+
group: Group | int | None,
259+
filters: list[dict],
260+
output_params: list[int] | None,
261+
page_start: int,
262+
page_length: int,
263+
sort_by: str | None,
264+
descending: bool,
265+
deletion_policy: utils.DeletionPolicy = utils.DeletionPolicy.NOT_DELETED,
266+
) -> tuple[Sequence[Artifact], int]:
267+
"""
268+
Get some artifacts according to search criteria.
269+
270+
Args:
271+
group: Limit artifacts to those owned by this group; None to not limit
272+
the search
273+
filters: Search criteria, see parse_search_text()
274+
page_start: Zero-based row index where the page should start
275+
page_length: Maximum number of rows in the page; use <= 0 for
276+
unlimited length
277+
sort_by: Sort criterion; must be a key of SORTABLE_FIELDS. None
278+
to sort in an implementation-dependent way.
279+
descending: Whether to sort in descending order; only applicable
280+
if sort_by is given
281+
deletion_policy: Whether to look at deleted artifacts, non-deleted
282+
artifacts, or all artifacts
283+
284+
Returns:
285+
A 2-tuple including the page of artifacts and total count of matching
286+
artifacts which exist
287+
288+
Raises:
289+
SearchParseError: if filters includes a non-searchable field
290+
SortParameterValidationError: if sort_by is a non-sortable field
291+
EntityDoesNotExistError: if the given group does not exist
292+
EntityDeletedError: if the given group is deleted
293+
"""
294+
295+
output_params_filter = []
296+
if output_params is not None and output_params:
297+
output_params_filter.append(
298+
lambda stmt: self._apply_ouput_params_filter(stmt, output_params)
299+
)
300+
301+
return utils.get_by_filters_paged(
302+
self.session,
303+
Artifact,
304+
self.SORTABLE_FIELDS,
305+
self.SEARCHABLE_FIELDS,
306+
group,
307+
filters,
308+
page_start,
309+
page_length,
310+
sort_by,
311+
descending,
312+
deletion_policy,
313+
output_params_filter,
314+
)
315+
316+
def _apply_ouput_params_filter(
317+
self, stmt: Select, output_params: list[int]
318+
) -> Select:
319+
# creates a comparison for each outuput parameter and makes
320+
# sure the type is correct for that parameter_number
321+
for index, p in enumerate(output_params):
322+
task_alias = aliased(ArtifactTask)
323+
parameter_alias = aliased(PluginTaskOutputParameter)
324+
type_alias = aliased(PluginTaskParameterType)
325+
stmt = (
326+
stmt.join(Artifact.task.of_type(task_alias))
327+
.join(ArtifactTask.output_parameters.of_type(parameter_alias))
328+
.join(
329+
type_alias,
330+
type_alias.resource_snapshot_id
331+
== parameter_alias.plugin_task_parameter_type_resource_snapshot_id,
332+
)
333+
.where(
334+
type_alias.resource_id == p,
335+
parameter_alias.parameter_number == index,
336+
)
337+
)
338+
339+
# verifies that the number of parameters is what we are looking for
340+
# prevents picking up artifacts which match the ones we are looking
341+
# for, but have more parameters
342+
count_subquery = (
343+
select(
344+
PluginTaskOutputParameter.task_id,
345+
func.count().label("param_count"),
346+
)
347+
.group_by(PluginTaskOutputParameter.task_id)
348+
.subquery()
349+
)
350+
task_alias = aliased(ArtifactTask)
351+
stmt = (
352+
stmt.join(Artifact.task.of_type(task_alias))
353+
.join(count_subquery, task_alias.task_id == count_subquery.c.task_id)
354+
.where(
355+
count_subquery.c.param_count == len(output_params),
356+
)
357+
)
358+
return stmt

0 commit comments

Comments
 (0)