Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/dioptra/client/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
PLUGINS: Final[str] = "plugins"
ARTIFACT_PLUGINS: Final[str] = "artifactPlugins"
BUNDLE: Final[str] = "bundle"
DYNAMIC_GLOBAL_PARAMETERS: Final[str] = "dynamicGlobalParameters"

T = TypeVar("T")

Expand Down Expand Up @@ -506,6 +507,26 @@ def get_artifact_plugins_bundle(
params=params,
)

def get_task_graph_global_params(
self, entrypoint_id: int, entrypoint_snapshot_id: int, swaps: dict[str, str]
) -> T:
"""Get the global parameters used by an entrypoint graph with specified swap tasks.

Args:
entrypoint_id: The entrypoint id, an integer.
entrypoint_snapshot_id: The entrypoint snapshot id, an integer.
swaps: The selected task for each swappable step in the entrypoint graph.

Returns:
The response from the Dioptra API.
"""
return self._session.get(
self.build_sub_collection_url(entrypoint_id),
str(entrypoint_snapshot_id),
DYNAMIC_GLOBAL_PARAMETERS,
params={"swaps": ",".join([f"{k}:{v}" for k, v in swaps.items()])},
)


class EntrypointsCollectionClient(CollectionClient[T]):
"""The client for managing Dioptra's /entrypoints collection.
Expand Down
12 changes: 12 additions & 0 deletions src/dioptra/restapi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ def __init__(self, message: str):
super().__init__(message)


class EntrypointSwapsRenderError(DioptraError):
"""Password Error."""

def __init__(self, message: str):
super().__init__(message)


class JobStoreError(DioptraError):
"""JobStoreError Error."""

Expand Down Expand Up @@ -894,6 +901,11 @@ def handle_user_password_error(error: UserPasswordError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.UNAUTHORIZED, {})

@api.errorhandler(EntrypointSwapsRenderError)
def handle_entrypoint_swaps_error(error: EntrypointSwapsRenderError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.BAD_REQUEST, {})

@api.errorhandler(JobStoreError)
def handle_mlflow_error(error: JobStoreError):
log.debug(error.to_message())
Expand Down
47 changes: 47 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from dioptra.restapi.v1.shared.task_engine_yaml.service import TaskEngineYamlService

from .schema import (
DynamicGlobalParametersRequestSchema,
DynamicGlobalParametersResponseSchema,
EntrypointArtifactPluginMutableFieldsSchema,
EntrypointDraftSchema,
EntrypointGetQueryParameters,
Expand All @@ -65,6 +67,7 @@
from .service import (
RESOURCE_TYPE,
SEARCHABLE_FIELDS,
DynamicGlobalParametersService,
EntrypointIdArtifactPluginsIdService,
EntrypointIdArtifactPluginsService,
EntrypointIdPluginsIdService,
Expand Down Expand Up @@ -708,6 +711,50 @@ def delete(self, id: int, queueId):
return self._entrypoint_id_queues_id_service.delete(id, queueId, log=log)


@api.route("/<int:id>/snapshots/<int:snapshotId>/dynamicGlobalParameters")
@api.param("id", "ID for the Entrypoint resource.")
@api.param("snapshotId", "Snapshot ID for the Entrypoint snapshot.")
class DynamicGlobalParametersEntrypoint(Resource):
@inject
def __init__(
self,
dynamic_global_parameters_service: DynamicGlobalParametersService,
*args,
**kwargs,
) -> None:
"""Initialize the workflow resource.

All arguments are provided via dependency injection.

Args:
entrypoint_validate_service: An EntrypointValidateService object.
"""
self._dynamic_global_parameters_service = dynamic_global_parameters_service
super().__init__(*args, **kwargs)

@login_required
@accepts(query_params_schema=DynamicGlobalParametersRequestSchema, api=api)
@responds(schema=DynamicGlobalParametersResponseSchema, api=api)
def get(self, id: int, snapshotId: int):
"""Finds the global parameters for the given entrypoint + swap choice dictionary."""
log = LOGGER.new(
request_id=str(uuid.uuid4()),
resource="DynamicGlobalParameters",
request_type="GET",
)

entrypoint_id = id
entrypoint_snapshot_id = snapshotId
swap_choices = request.parsed_query_params["swaps"] # type: ignore

return self._dynamic_global_parameters_service.get_params(
entrypoint_id=entrypoint_id,
entrypoint_snapshot_id=entrypoint_snapshot_id,
swaps=swap_choices,
logger=log,
)


EntrypointDraftResource = generate_resource_drafts_endpoint(
api,
resource_name=RESOURCE_TYPE,
Expand Down
68 changes: 68 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"""The schemas for serializing/deserializing Entrypoint resources."""

from marshmallow import Schema, fields, validate
from marshmallow.exceptions import ValidationError

from dioptra.restapi.v1.plugins.schema import (
ALLOWED_PLUGIN_TASK_PARAMETER_REGEX,
PluginSnapshotRefSchema,
PluginTaskContainerSchema,
PluginTaskParameterSchema,
)
Expand Down Expand Up @@ -322,3 +324,69 @@ class EntrypointGetQueryParameters(
SortByGetQueryParametersSchema,
):
"""The query parameters for the GET method of the /entrypoints endpoint."""


class DelimitedKeyValuePairs(fields.Field):
def __init__(
self,
*,
delimiter: str = ",",
equality: str = ":",
**additional_metadata,
) -> None:
super().__init__(**additional_metadata)
self.delimiter = delimiter
self.equality = equality

def _deserialize(self, value, attr, data, **kwargs) -> dict[str, str]:
try:
if value == "":
return {}
return {
str(pair.split(self.equality)[0]): str(pair.split(self.equality)[1])
for pair in value.split(self.delimiter)
}
except Exception as e:
raise ValidationError(
f"{attr} is not a delimited list {value}. List format should be key{self.equality}value{self.delimiter}key2{self.equality}value2{self.delimiter}key3{self.equality}value3."
) from e


class DynamicGlobalParametersRequestSchema(Schema):
swaps = DelimitedKeyValuePairs(
attribute="swaps",
data_key="swaps",
metadata={
"description": (
"A list of swap choices to be applied to the entrypoint task graph."
)
},
)


class DynamicGlobalParametersResponseSchema(Schema):
globalParameters = fields.List(
fields.String(),
attribute="entrypoint_params",
data_key="entrypointParams",
metadata={
"description": (
"A list of global parameters used in the entrypoint task graph."
)
},
)
topologicalSort = fields.List(
fields.String(),
attribute="topological_sort",
data_key="topologicalSort",
metadata={
"description": ("A list of task names topologically sorted by dependency.")
},
)
activePlugins = fields.Nested(
PluginSnapshotRefSchema,
attribute="active_plugins",
data_key="activePlugins",
metadata={"description": ("A list of plugin objects used in the entrypoint.")},
many=True,
)
97 changes: 97 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Final, Iterable

import structlog
import yaml
from flask_login import current_user
from injector import inject
from sqlalchemy import Integer, func, select
Expand All @@ -30,6 +31,7 @@
BackendDatabaseError,
EntityDoesNotExistError,
EntityExistsError,
EntrypointSwapsRenderError,
QueryParameterNotUniqueError,
SortParameterValidationError,
)
Expand All @@ -44,8 +46,11 @@
from dioptra.restapi.v1.queues.service import QueueIdsService
from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters
from dioptra.restapi.v1.shared.task_engine_yaml.service import (
TaskEngineYamlService,
coerce_entrypoint_default_param_types,
)
from dioptra.sdk.utilities.entrypoint_swaps import render_swaps_graph
from dioptra.task_engine import util

LOGGER: BoundLogger = structlog.stdlib.get_logger()
PLUGIN_RESOURCE_TYPE: Final[str] = "entry_point_plugin"
Expand Down Expand Up @@ -1587,6 +1592,98 @@ def get(
return entrypoint


class DynamicGlobalParametersService(object):
@inject
def __init__(
self,
entrypoint_snapshot_id_service: EntrypointSnapshotIdService,
task_engine_yaml_service: TaskEngineYamlService,
) -> None:
"""Initialize the entrypoint service.

All arguments are provided via dependency injection.

Args:
task_engine_yaml_service: A TaskEngineYamlService object.
"""
self._entrypoint_snapshot_id_service = entrypoint_snapshot_id_service
self._task_engine_yaml_service = task_engine_yaml_service

def get_params(
self,
entrypoint_id: int,
entrypoint_snapshot_id: int,
swaps: dict[str, str],
logger: BoundLogger | None = None,
) -> dict[str, Any]:
entry_point = self._entrypoint_snapshot_id_service.get(
entrypoint_id=entrypoint_id, entrypoint_snapshot_id=entrypoint_snapshot_id
)

task_graph = entry_point.task_graph

graph = yaml.safe_load(task_graph)

try:
rendered = render_swaps_graph(graph, swaps)
except Exception as e:
raise EntrypointSwapsRenderError(str(e)) from e

vars = rendered.keys()
needed_vars = set()
used_tasks = set()

for step in rendered:
for task in rendered[step]:
used_tasks.add(task)
for ref in util.get_references(rendered[step][task]):
potential_step_name = ref.split(".")[0]
if potential_step_name not in vars:
needed_vars.add(
ref
) # if it is not a step output, it must be a global param

topsorted = util.get_sorted_steps(rendered)

plugin_files = [
plugin_plugin_file
for entry_point_plugin in entry_point.entry_point_plugins
for plugin_plugin_file in entry_point_plugin.plugin.plugin_plugin_files
]

types = self._entrypoint_snapshot_id_service.get_group_plugin_parameter_types(
entry_point.resource.group_id, log=logger
)

task_engine_yaml = self._task_engine_yaml_service.build_dict(
entry_point=entry_point,
plugin_plugin_files=plugin_files,
plugin_parameter_types=types,
logger=logger,
)

active_plugin_names = set()

for task in task_engine_yaml["tasks"]:
if task in used_tasks:
active_plugin_names.add(
task_engine_yaml["tasks"][task]["plugin"].split(".")[0]
)

active_plugins = []

for epp in entry_point.entry_point_plugins:
# print(epp)
if epp.plugin.name in active_plugin_names:
active_plugins.append(epp.plugin)

return {
"entrypoint_params": list(needed_vars),
"topological_sort": topsorted,
"active_plugins": active_plugins,
}


def _get_entrypoint_plugin_snapshots(
entrypoint: models.EntryPoint,
) -> list[utils.PluginWithFilesDict]:
Expand Down
Loading
Loading