diff --git a/.gitignore b/.gitignore index 9e53722e0..be75df848 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,6 @@ node_modules/ # Ignore mock database **/*.sqlite + +# Ignore virtual envs +*.venv diff --git a/CHANGELOG.md b/CHANGELOG.md index 61af4125d..e0a133e57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Added + +- Allow registering custom file transfer strategies + +### Changed + +- Improved automatic file transfer strategy selection +- HTTP strategy can now upload files too +- Adjusted sublattice logic. The sublattice builder now attempts to +link the sublattice with its parent electron. +- Replaced json sublattice flow with new tarball importer to allow future memory +footprint enhancements + ## [0.239.0-rc.0] - 2025-04-16 ### Authors diff --git a/covalent/_api/apiclient.py b/covalent/_api/apiclient.py index d3be6bd4a..4c6d0a77e 100644 --- a/covalent/_api/apiclient.py +++ b/covalent/_api/apiclient.py @@ -48,6 +48,7 @@ def get(self, endpoint: str, **kwargs): with requests.Session() as session: if self.adapter: session.mount("http://", self.adapter) + session.mount("https://", self.adapter) r = session.get(url, headers=headers, **kwargs) @@ -61,6 +62,26 @@ def get(self, endpoint: str, **kwargs): return r + def patch(self, endpoint: str, **kwargs): + headers = self.prepare_headers(kwargs) + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + session.mount("https://", self.adapter) + + r = session.patch(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + def put(self, endpoint: str, **kwargs): headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint @@ -68,6 +89,7 @@ def put(self, endpoint: str, **kwargs): with requests.Session() as session: if self.adapter: session.mount("http://", self.adapter) + session.mount("https://", self.adapter) r = session.put(url, headers=headers, **kwargs) @@ -87,6 +109,7 @@ def post(self, endpoint: str, **kwargs): with requests.Session() as session: if self.adapter: session.mount("http://", self.adapter) + session.mount("https://", self.adapter) r = session.post(url, headers=headers, **kwargs) @@ -106,6 +129,7 @@ def delete(self, endpoint: str, **kwargs): with requests.Session() as session: if self.adapter: session.mount("http://", self.adapter) + session.mount("https://", self.adapter) r = session.delete(url, headers=headers, **kwargs) diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index bc29af30c..17e080a1e 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -14,16 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import os +import tarfile import tempfile from copy import deepcopy from functools import wraps from pathlib import Path -from typing import Callable, Dict, List, Optional, Union - -from furl import furl +from typing import Callable, Dict, List, Optional, Tuple, Union from .._api.apiclient import CovalentAPIClient as APIClient +from .._file_transfer import FileTransfer from .._results_manager.result import Result from .._results_manager.results_manager import get_result, get_result_manager from .._serialize.result import ( @@ -36,7 +37,7 @@ from .._shared_files.config import get_config from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema -from .._shared_files.utils import copy_file_locally, format_server_url +from .._shared_files.utils import format_server_url from .._workflow.lattice import Lattice from ..triggers import BaseTrigger from .base import BaseDispatcher @@ -251,7 +252,7 @@ def start( if dispatcher_addr is None: dispatcher_addr = format_server_url() - endpoint = f"/api/v2/dispatches/{dispatch_id}/status" + endpoint = f"{BASE_ENDPOINT}/{dispatch_id}/status" body = {"status": "RUNNING"} r = APIClient(dispatcher_addr).put(endpoint, json=body) r.raise_for_status() @@ -463,7 +464,6 @@ def prepare_manifest(lattice, storage_path) -> ResultSchema: def register_manifest( manifest: ResultSchema, dispatcher_addr: Optional[str] = None, - parent_dispatch_id: Optional[str] = None, push_assets: bool = True, ) -> ResultSchema: """Submits a manifest for registration. @@ -482,9 +482,6 @@ def register_manifest( stripped = strip_local_uris(manifest) if push_assets else manifest endpoint = BASE_ENDPOINT - if parent_dispatch_id: - endpoint = f"{BASE_ENDPOINT}/{parent_dispatch_id}/sublattices" - r = APIClient(dispatcher_addr).post(endpoint, data=stripped.model_dump_json()) r.raise_for_status() @@ -512,7 +509,7 @@ def register_derived_manifest( # We don't yet support pulling assets for redispatch stripped = strip_local_uris(manifest) - endpoint = f"/api/v2/dispatches/{dispatch_id}/redispatches" + endpoint = f"{BASE_ENDPOINT}/{dispatch_id}/redispatches" params = {"reuse_previous_results": reuse_previous_results} r = APIClient(dispatcher_addr).post( @@ -531,45 +528,89 @@ def upload_assets(manifest: ResultSchema): @staticmethod def _upload(assets: List[AssetSchema]): - local_scheme_prefix = "file://" total = len(assets) number_uploaded = 0 for i, asset in enumerate(assets): if not asset.remote_uri or not asset.uri: app_log.debug(f"Skipping asset {i + 1} out of {total}") continue - if asset.remote_uri.startswith(local_scheme_prefix): - copy_file_locally(asset.uri, asset.remote_uri) - number_uploaded += 1 - else: - _upload_asset(asset.uri, asset.remote_uri) - number_uploaded += 1 - app_log.debug(f"Uploaded asset {i + 1} out of {total}.") + + _upload_asset(asset.uri, asset.remote_uri) + number_uploaded += 1 + app_log.debug(f"Uploaded asset {i + 1} out of {total}.") app_log.debug(f"uploaded {number_uploaded} assets.") def _upload_asset(local_uri, remote_uri): + _, ft = FileTransfer(local_uri, remote_uri).cp() + ft() + + +# Archive staging directory and manifest +# Used for sublattice dispatch when the executor cannot directly +# submit the sublattice to the control plane +def pack_staging_dir(staging_dir, manifest: ResultSchema) -> str: + # save manifest json to staging root + with open(os.path.join(staging_dir, "manifest.json"), "w") as f: + f.write(manifest.model_dump_json()) + + # Tar up staging dir + with tempfile.NamedTemporaryFile(suffix=".tar") as f: + tar_path = f.name + + with tarfile.TarFile(tar_path, "w") as tar: + tar.add(staging_dir, recursive=True) + return tar_path + + +# Inverse of `pack_staging_dir` +# Consumed by server-side tarball importer +def untar_staging_dir(tar_name) -> Tuple[str, ResultSchema]: + + # Working directory for unpacking the archive + with tempfile.TemporaryDirectory(prefix="postprocess-") as work_dir: + ... + + # Find and extract manifest + with tarfile.TarFile(tar_name) as tar: + manifest_path = list(filter(lambda x: x.endswith("manifest.json"), tar.getnames())) + if len(manifest_path) == 0: + raise RuntimeError("Archive contains no manifest") + + manifest = ResultSchema.model_validate_json(tar.extractfile(manifest_path[0]).read()) + + tar.extractall(path=work_dir, filter="tar") + + # prepend work_dir to each asset path scheme_prefix = "file://" - if local_uri.startswith(scheme_prefix): - local_path = local_uri[len(scheme_prefix) :] - else: - local_path = local_uri - - filesize = os.path.getsize(local_path) - with open(local_path, "rb") as reader: - app_log.debug(f"uploading to {remote_uri}") - f = furl(remote_uri) - scheme = f.scheme - host = f.host - port = f.port - dispatcher_addr = f"{scheme}://{host}:{port}" - endpoint = str(f.path) - api_client = APIClient(dispatcher_addr) - if f.query: - endpoint = f"{endpoint}?{f.query}" - - # Workaround for Requests bug when streaming from empty files - data = reader.read() if filesize < 50 else reader - - r = api_client.put(endpoint, headers={"Content-Length": str(filesize)}, data=data) - r.raise_for_status() + for _, asset in manifest.assets: + if asset.uri: + path = asset.uri[len(scheme_prefix) :] + asset.uri = f"{scheme_prefix}{work_dir}{path}" + print("Rewrote asset uri ", asset.uri) + for _, asset in manifest.lattice.assets: + if asset.uri: + path = asset.uri[len(scheme_prefix) :] + asset.uri = f"{scheme_prefix}{work_dir}{path}" + print("Rewrote asset uri ", asset.uri) + + for node in manifest.lattice.transport_graph.nodes: + for _, asset in node.assets: + if asset.uri: + path = asset.uri[len(scheme_prefix) :] + asset.uri = f"{scheme_prefix}{work_dir}{path}" + print("Rewrote asset uri ", asset.uri) + + return work_dir, manifest + + +# Consumed by server-side tarball importer (`import_b64_staging_tarball`) +# TODO: support streaming decode to avoid having to load the entire buffer in mem +def decode_b64_tar(b64_buffer: str) -> str: + with tempfile.NamedTemporaryFile(suffix=".tar") as tar_file: + tar_path = tar_file.name + + with open(tar_path, "wb") as tar_file: + tar_file.write(base64.b64decode(b64_buffer.encode("utf-8"))) + + return tar_path diff --git a/covalent/_file_transfer/enums.py b/covalent/_file_transfer/enums.py index 603cc38a4..4e5f63239 100644 --- a/covalent/_file_transfer/enums.py +++ b/covalent/_file_transfer/enums.py @@ -25,7 +25,7 @@ class Order(str, enum.Enum): class FileSchemes(str, enum.Enum): File = "file" S3 = "s3" - Blob = "https" + Blob = "blob" GCloud = "gs" Globus = "globus" HTTP = "http" diff --git a/covalent/_file_transfer/file.py b/covalent/_file_transfer/file.py index 9b7d2b709..631e6818d 100644 --- a/covalent/_file_transfer/file.py +++ b/covalent/_file_transfer/file.py @@ -20,7 +20,24 @@ from furl import furl -from .enums import FileSchemes, FileTransferStrategyTypes, SchemeToStrategyMap +from .enums import FileSchemes + +_is_remote_scheme = { + FileSchemes.S3.value: True, + FileSchemes.Blob.value: True, + FileSchemes.GCloud.value: True, + FileSchemes.Globus.value: True, + FileSchemes.HTTP.value: True, + FileSchemes.HTTPS.value: True, + FileSchemes.FTP.value: True, + FileSchemes.File: False, +} + + +# For registering additional file transfer strategies; this will be called by +# `register_uploader`` and `register_downloader`` +def register_remote_scheme(s: str): + _is_remote_scheme[s] = True class File: @@ -80,19 +97,7 @@ def get_temp_filepath(self): @property def is_remote(self): - return self._is_remote or self.scheme in [ - FileSchemes.S3, - FileSchemes.Blob, - FileSchemes.GCloud, - FileSchemes.Globus, - FileSchemes.HTTP, - FileSchemes.HTTPS, - FileSchemes.FTP, - ] - - @property - def mapped_strategy_type(self) -> FileTransferStrategyTypes: - return SchemeToStrategyMap[self.scheme.value] + return self._is_remote or _is_remote_scheme[self.scheme] @property def filepath(self) -> str: @@ -127,23 +132,13 @@ def get_uri(scheme: str, path: str) -> str: return path_components.url @staticmethod - def resolve_scheme(path: str) -> FileSchemes: + def resolve_scheme(path: str) -> str: scheme = furl(path).scheme host = furl(path).host - if scheme == FileSchemes.Globus: - return FileSchemes.Globus - if scheme == FileSchemes.S3: - return FileSchemes.S3 - if scheme == FileSchemes.Blob and "blob.core.windows.net" in host: - return FileSchemes.Blob - if scheme == FileSchemes.GCloud: - return FileSchemes.GCloud - if scheme == FileSchemes.FTP: - return FileSchemes.FTP - if scheme == FileSchemes.HTTP: - return FileSchemes.HTTP - if scheme == FileSchemes.HTTPS: - return FileSchemes.HTTPS - if scheme is None or scheme == FileSchemes.File: - return FileSchemes.File - raise ValueError(f"Provided File scheme ({scheme}) is not supported.") + # Canonicalize file system paths to file:// urls + if not scheme: + return FileSchemes.File.value + if scheme in _is_remote_scheme: + return scheme + else: + raise ValueError(f"Provided File scheme ({scheme}) is not supported.") diff --git a/covalent/_file_transfer/file_transfer.py b/covalent/_file_transfer/file_transfer.py index ee982a92f..baaa39c2a 100644 --- a/covalent/_file_transfer/file_transfer.py +++ b/covalent/_file_transfer/file_transfer.py @@ -16,12 +16,73 @@ from typing import Optional, Union -from .enums import FileTransferStrategyTypes, FtCallDepReturnValue, Order -from .file import File +from .enums import FileSchemes, FtCallDepReturnValue, Order +from .file import File, register_remote_scheme +from .strategies.blob_strategy import Blob +from .strategies.gcloud_strategy import GCloud from .strategies.http_strategy import HTTP +from .strategies.s3_strategy import S3 from .strategies.shutil_strategy import Shutil from .strategies.transfer_strategy_base import FileTransferStrategy +_downloaders = { + FileSchemes.File.value: Shutil, + FileSchemes.HTTP.value: HTTP, + FileSchemes.HTTPS.value: HTTP, + FileSchemes.GCloud.value: GCloud, + FileSchemes.Blob.value: Blob, + FileSchemes.S3.value: S3, +} + +_uploaders = { + FileSchemes.File.value: Shutil, + FileSchemes.HTTP.value: HTTP, + FileSchemes.HTTPS.value: HTTP, + FileSchemes.GCloud.value: GCloud, + FileSchemes.Blob.value: Blob, + FileSchemes.S3.value: S3, +} + + +# For registering additional file transfer strategies +def register_downloader(scheme: str, cls: FileTransferStrategy): + register_remote_scheme(scheme) + _downloaders[scheme] = cls + + +def register_uploader(scheme: str, cls: FileTransferStrategy): + register_remote_scheme(scheme) + _uploaders[scheme] = cls + + +def guess_transfer_strategy(from_file: File, to_file: File) -> FileTransferStrategy: + # Handle the following cases automatically + # Local-Remote (except HTTP destination) + # Remote-local + # Local-local + + # Local-Remote + if not from_file.is_remote and to_file.is_remote: + strategy = _uploaders.get(to_file.scheme) + if not strategy: + raise AttributeError(f"Cannot guess upload strategy for remote {to_file.uri}") + return strategy + + # Remote-Local + if from_file.is_remote and not to_file.is_remote: + strategy = _downloaders.get(from_file.scheme) + if not strategy: + raise AttributeError(f"Cannot guess download strategy for remote {from_file.uri}") + return strategy + + # Local-Local + if not from_file.is_remote and not to_file.is_remote: + strategy = Shutil + return strategy + + else: + raise AttributeError("FileTransfer requires a file transfer strategy to be specified") + class FileTransfer: """ @@ -58,15 +119,8 @@ def __init__( # assign explicit strategy or default to strategy based on from_file & to_file schemes if strategy: self.strategy = strategy - elif ( - from_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil - and to_file.mapped_strategy_type == FileTransferStrategyTypes.Shutil - ): - self.strategy = Shutil() - elif from_file.mapped_strategy_type == FileTransferStrategyTypes.HTTP: - self.strategy = HTTP() else: - raise AttributeError("FileTransfer requires a file transfer strategy to be specified") + self.strategy = guess_transfer_strategy(from_file, to_file)() self.to_file = to_file self.from_file = from_file diff --git a/covalent/_file_transfer/strategies/blob_strategy.py b/covalent/_file_transfer/strategies/blob_strategy.py index 0fb6a2afa..2b5925648 100644 --- a/covalent/_file_transfer/strategies/blob_strategy.py +++ b/covalent/_file_transfer/strategies/blob_strategy.py @@ -73,7 +73,7 @@ def _parse_blob_uri(self, blob_uri: str) -> Tuple[str, str, str]: """Parses a blob URI and returns the account name, container name, and blob name. Args: - blob_uri: A URI for an Azure Blob object in the form https://.blob.core.windows.net// + blob_uri: A URI for an Azure Blob object in the form blob://.blob.core.windows.net// Returns: parsed_uri: A tuple containing the storage account name, container name, and blob name diff --git a/covalent/_file_transfer/strategies/http_strategy.py b/covalent/_file_transfer/strategies/http_strategy.py index 3e86dd9e3..feeff99a2 100644 --- a/covalent/_file_transfer/strategies/http_strategy.py +++ b/covalent/_file_transfer/strategies/http_strategy.py @@ -14,15 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import urllib.request +import os +import requests + +from ..._shared_files import logger from .. import File from .transfer_strategy_base import FileTransferStrategy +app_log = logger.app_log + class HTTP(FileTransferStrategy): """ - Implements Base FileTransferStrategy class to use HTTP to download files from public URLs. + Implements Base FileTransferStrategy class to use download files from public http(s) URLs. """ # return callable to download here implies 'from' is a remote source @@ -31,15 +36,30 @@ def download(self, from_file: File, to_file: File = File()) -> File: to_filepath = to_file.filepath def callable(): - urllib.request.urlretrieve(from_filepath, to_filepath) - return to_filepath + resp = requests.get(from_filepath, stream=True) + resp.raise_for_status() + with open(to_filepath, "wb") as f: + for chunk in resp.iter_content(chunk_size=1024): + f.write(chunk) return callable - # HTTP Strategy is read only + # Upload a file to a (possibly presigned) HTTP(s) URL def upload(self, from_file: File, to_file: File = File()) -> File: - raise NotImplementedError + from_filepath = from_file.filepath + to_filepath = to_file.uri + filesize = os.path.getsize(from_filepath) + + def callable(): + with open(from_filepath, "rb") as reader: + # Workaround for Requests bug when streaming from empty files + app_log.debug(f"uploading to {to_filepath}") + data = reader.read() if filesize < 50 else reader + r = requests.put(to_filepath, headers={"Content-Length": str(filesize)}, data=data) + r.raise_for_status() + + return callable - # HTTP Strategy is read only + # HTTP Strategy does not support server-side copy between two remote URLs def cp(self, from_file: File, to_file: File = File()) -> File: raise NotImplementedError diff --git a/covalent/_file_transfer/strategies/s3_strategy.py b/covalent/_file_transfer/strategies/s3_strategy.py index 5011d2744..f49db85fe 100644 --- a/covalent/_file_transfer/strategies/s3_strategy.py +++ b/covalent/_file_transfer/strategies/s3_strategy.py @@ -15,6 +15,7 @@ # limitations under the License. import os +from typing import Callable from furl import furl @@ -47,7 +48,7 @@ def __init__(self, credentials: str = None, profile: str = None, region_name: st os.environ["AWS_SHARED_CREDENTIALS_FILE"] = self.credentials # return callable to download here implies 'from' is a remote source - def download(self, from_file: File, to_file: File = File()) -> File: + def download(self, from_file: File, to_file: File = File()) -> Callable: """Download files or the contents of folders from S3 bucket.""" app_log.debug(f"Is dir: {from_file._is_dir}") diff --git a/covalent/_file_transfer/strategies/shutil_strategy.py b/covalent/_file_transfer/strategies/shutil_strategy.py index 319d47d04..b2585a5ab 100644 --- a/covalent/_file_transfer/strategies/shutil_strategy.py +++ b/covalent/_file_transfer/strategies/shutil_strategy.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import shutil from .. import File @@ -46,6 +47,7 @@ def cp(self, from_file: File, to_file: File = File()) -> None: """ def callable(): + os.makedirs(os.path.dirname(to_file.filepath), exist_ok=True) shutil.copyfile(from_file.filepath, to_file.filepath) return callable diff --git a/covalent/_file_transfer/strategies/transfer_strategy_base.py b/covalent/_file_transfer/strategies/transfer_strategy_base.py index 8d95c6d59..e86359762 100644 --- a/covalent/_file_transfer/strategies/transfer_strategy_base.py +++ b/covalent/_file_transfer/strategies/transfer_strategy_base.py @@ -16,6 +16,7 @@ from abc import ABC, abstractmethod from pathlib import Path +from typing import Callable from ..enums import FtCallDepReturnValue from ..file import File @@ -34,7 +35,7 @@ def cp(self, from_file: File, to_file: File) -> None: # download here implies 'from' is a remote source @abstractmethod - def download(self, from_file: File, to_file: File) -> File: + def download(self, from_file: File, to_file: File) -> Callable: raise NotImplementedError # upload here implies 'to' is a remote source diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 00544785f..03de0727f 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -23,9 +23,8 @@ from pathlib import Path from typing import List, Optional -from furl import furl - from .._api.apiclient import CovalentAPIClient +from .._file_transfer import FileTransfer from .._serialize.common import load_asset from .._serialize.electron import ASSET_FILENAME_MAP as ELECTRON_ASSET_FILENAMES from .._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES @@ -40,12 +39,13 @@ from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema from .._shared_files.util_classes import RESULT_STATUS, Status -from .._shared_files.utils import copy_file_locally, format_server_url +from .._shared_files.utils import format_server_url from .result import Result app_log = logger.app_log log_stack_info = logger.log_stack_info +BASE_ENDPOINT = os.getenv("COVALENT_DISPATCH_BASE_ENDPOINT", "/api/v2/dispatches") SDK_NODE_META_KEYS = { "executor", @@ -124,7 +124,7 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = task_ids = [] api_client = CovalentAPIClient(dispatcher_addr) - endpoint = f"/api/v2/dispatches/{dispatch_id}/status" + endpoint = f"{BASE_ENDPOINT}/{dispatch_id}/status" if isinstance(task_ids, int): task_ids = [task_ids] @@ -138,7 +138,7 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = def _query_dispatch_status(dispatch_id: str, api_client: CovalentAPIClient): - endpoint = "/api/v2/dispatches" + endpoint = BASE_ENDPOINT resp = api_client.get(endpoint, params={"dispatch_id": dispatch_id, "status_only": True}) resp.raise_for_status() dispatches = resp.json()["dispatches"] @@ -167,7 +167,7 @@ def _get_result_export_from_dispatcher( MissingLatticeRecordError: If the result is not found. """ - endpoint = f"/api/v2/dispatches/{dispatch_id}" + endpoint = f"{BASE_ENDPOINT}/{dispatch_id}" response = api_client.get(endpoint) if response.status_code == 404: raise MissingLatticeRecordError @@ -220,21 +220,8 @@ def get_result_asset_path(results_dir: str, key: str): def download_asset(remote_uri: str, local_path: str, chunk_size: int = 1024 * 1024): - local_scheme = "file" - if remote_uri.startswith(local_scheme): - copy_file_locally(remote_uri, f"file://{local_path}") - else: - f = furl(remote_uri) - scheme = f.scheme - host = f.host - port = f.port - dispatcher_addr = f"{scheme}://{host}:{port}" - endpoint = str(f.path) - api_client = CovalentAPIClient(dispatcher_addr) - r = api_client.get(endpoint, stream=True) - with open(local_path, "wb") as f: - for chunk in r.iter_content(chunk_size=chunk_size): - f.write(chunk) + _, ft = FileTransfer(remote_uri, local_path).cp() + ft() def _download_result_asset(manifest: dict, results_dir: str, key: str): @@ -292,6 +279,13 @@ def __init__(self, manifest: ResultSchema, results_dir: str): self._manifest = manifest.model_dump() self._results_dir = results_dir + # Compute Result._error message from electron statuses + tg = manifest.lattice.transport_graph + failed_nodes = list(filter(lambda x: x.metadata.status == Result.FAILED, tg.nodes)) + if len(failed_nodes) > 0: + failed_nodes_msg = "".join(map(lambda x: f"{x.id}: {x.metadata.name}", failed_nodes)) + self.result_object._error = "The following tasks failed:\n" + failed_nodes_msg + def save(self, path: Optional[str] = None): if not path: path = os.path.join(self._results_dir, "manifest.json") @@ -316,7 +310,8 @@ def download_node_asset(self, node_id: int, key: str): def load_result_asset(self, key: str): data = _load_result_asset(self._manifest, key) - self.result_object.__dict__[f"_{key}"] = data + if data is not None: + self.result_object.__dict__[f"_{key}"] = data def load_lattice_asset(self, key: str): data = _load_lattice_asset(self._manifest, key) diff --git a/covalent/_shared_files/defaults.py b/covalent/_shared_files/defaults.py index 3956a69ae..71d9e9fc1 100644 --- a/covalent/_shared_files/defaults.py +++ b/covalent/_shared_files/defaults.py @@ -111,7 +111,6 @@ def get_default_dispatcher_config(): "heartbeat", ), "use_async_dispatcher": os.environ.get("COVALENT_USE_ASYNC_DISPATCHER", "true") or "false", - "data_uri_filter_policy": os.environ.get("COVALENT_DATA_URI_FILTER_POLICY", "http"), "asset_cache_size": int(os.environ.get("COVALENT_ASSET_CACHE_SIZE", 32)), } diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 915d22f66..37f36ebfa 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -16,10 +16,12 @@ """Class corresponding to computation nodes.""" +import base64 import inspect import json import operator import tempfile +import traceback from builtins import list from copy import deepcopy from dataclasses import asdict @@ -27,7 +29,12 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union -from covalent._dispatcher_plugins.local import LocalDispatcher +from covalent._dispatcher_plugins.local import ( + BASE_ENDPOINT, + APIClient, + LocalDispatcher, + pack_staging_dir, +) from .._file_transfer.enums import Order from .._file_transfer.file_transfer import FileTransfer @@ -858,29 +865,64 @@ def _build_sublattice_graph(sub: Lattice, json_parent_metadata: str, *args, **kw DISABLE_LEGACY_SUBLATTICES = os.environ.get("COVALENT_DISABLE_LEGACY_SUBLATTICES") == "1" try: - # Attempt multistage sublattice dispatch. For now we require - # the executor to reach the Covalent server + # Attempt to build and submit the sublattice directly to the control plane + parent_dispatch_id = os.environ["COVALENT_DISPATCH_ID"] dispatcher_url = os.environ["COVALENT_DISPATCHER_URL"] + tasks = json.loads(os.environ["COVALENT_TASKS"]) with tempfile.TemporaryDirectory(prefix="covalent-") as staging_path: manifest = LocalDispatcher.prepare_manifest(sub, staging_path) - - # Omit these two steps to return the manifest to Covalent and - # request the assets be pulled recv_manifest = LocalDispatcher.register_manifest( manifest, dispatcher_addr=dispatcher_url, - parent_dispatch_id=parent_dispatch_id, push_assets=True, ) + + # Read the parent electron from COVALENT_TASKS and connect it + # with the sublattice dispatch + if len(tasks) == 1: + node_id = tasks[0]["electron_id"] + body = {"sub_dispatch_id": recv_manifest.metadata.dispatch_id} + endpoint = f"{BASE_ENDPOINT}/{parent_dispatch_id}/electrons/{node_id}" + r = APIClient(dispatcher_url).patch(endpoint, json=body) + r.raise_for_status() + else: + raise RuntimeError( + "Error: could not deduce sublattice electron from COVALENT_TASKS" + ) + LocalDispatcher.upload_assets(recv_manifest) - return recv_manifest.model_dump_json() + # Indicate that the sublattice was successfully submitted + manifest.metadata.dispatch_id = recv_manifest.metadata.dispatch_id + manifest.metadata.root_dispatch_id = recv_manifest.metadata.root_dispatch_id + tar_file = pack_staging_dir(staging_path, manifest) + with open(tar_file, "rb") as tar: + tar_b64 = base64.b64encode(tar.read()).decode("utf-8") + + os.unlink(tar_file) + return tar_b64 except Exception as ex: - # Fall back to legacy sublattice handling + # If the executor can't reach the control plane, pack up the staging directory + # and let the control plane import the tarball + + tb = "".join(traceback.TracebackException.from_exception(ex).format()) if DISABLE_LEGACY_SUBLATTICES: + print(f"Unable to submit sublattice dispatch: {tb}") raise - print("Falling back to legacy sublattice handling") - return sub.serialize_to_json() + + print("Packing staging directory for server-side import") + with tempfile.TemporaryDirectory(prefix="covalent-") as staging_path: + manifest = LocalDispatcher.prepare_manifest(sub, staging_path) + + tar_file = pack_staging_dir(staging_path, manifest) + + # The base64-encoded tarball will be read server-side + # as `TransportableObject.object_string` + with open(tar_file, "rb") as tar: + tar_b64 = base64.b64encode(tar.read()).decode("utf-8") + + os.unlink(tar_file) + return tar_b64 diff --git a/covalent/executor/base.py b/covalent/executor/base.py index 1e7670ad8..218c3aea1 100644 --- a/covalent/executor/base.py +++ b/covalent/executor/base.py @@ -459,9 +459,9 @@ async def send( # # Task spec: # { - # "function_id": int, - # "args_ids": List[int], - # "kwargs_ids": Dict[str, int], + # "electron_id": int, + # "args": List[int], + # "kwargs": Dict[str, int], # "deps_id": str, # "call_before_id": str, # "call_after_id": str, @@ -815,9 +815,9 @@ async def send( # # Task spec: # { - # "function_id": int, - # "args_ids": List[int], - # "kwargs_ids": Dict[str, int], + # "electron_id": int, + # "args": List[int], + # "kwargs": Dict[str, int], # "deps_id": str, # "call_before_id": str, # "call_after_id": str, diff --git a/covalent/executor/executor_plugins/local.py b/covalent/executor/executor_plugins/local.py index 211ebc07e..edacc6905 100644 --- a/covalent/executor/executor_plugins/local.py +++ b/covalent/executor/executor_plugins/local.py @@ -22,6 +22,7 @@ import asyncio import os +import traceback from concurrent.futures import ProcessPoolExecutor from enum import Enum from typing import Any, Callable, Dict, List, Optional @@ -161,30 +162,39 @@ def _send( resources: ResourceMap, task_group_metadata: dict, ): + os.makedirs(self.workdir, exist_ok=True) dispatch_id = task_group_metadata["dispatch_id"] task_ids = task_group_metadata["node_ids"] gid = task_group_metadata["task_group_id"] output_uris = [] for node_id in task_ids: - result_uri = os.path.join(self.cache_dir, f"result_{dispatch_id}-{node_id}.pkl") - stdout_uri = os.path.join(self.cache_dir, f"stdout_{dispatch_id}-{node_id}.txt") - stderr_uri = os.path.join(self.cache_dir, f"stderr_{dispatch_id}-{node_id}.txt") + result_uri = os.path.join(self.workdir, f"result_{dispatch_id}-{node_id}.pkl") + stdout_uri = os.path.join(self.workdir, f"stdout_{dispatch_id}-{node_id}.txt") + stderr_uri = os.path.join(self.workdir, f"stderr_{dispatch_id}-{node_id}.txt") output_uris.append((result_uri, stdout_uri, stderr_uri)) server_url = format_server_url() app_log.debug(f"Running task group {dispatch_id}:{task_ids}") + app_log.debug(f"Generated artifacts will be saved at: {output_uris}") future = proc_pool.submit( run_task_group, list(map(lambda t: t.model_dump(), task_specs)), output_uris, - self.cache_dir, + self.workdir, task_group_metadata, server_url, ) def handle_cancelled(fut): app_log.debug(f"In done callback for {dispatch_id}:{gid}, future {fut}") + ex = fut.exception(timeout=0) + if ex is not None: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.debug(tb) + for task_id in task_ids: + url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/job" + requests.put(url, json={"status": "FAILED"}) if fut.cancelled(): for task_id in task_ids: url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/job" diff --git a/covalent/executor/schemas.py b/covalent/executor/schemas.py index f8a5f1bd1..c60396896 100644 --- a/covalent/executor/schemas.py +++ b/covalent/executor/schemas.py @@ -53,9 +53,9 @@ class TaskSpec(BaseModel): """An abstract description of a runnable task. Attributes: - function_id: The `node_id` of the function. - args_ids: The `node_id`s of the function's args - kwargs_ids: The `node_id`s of the function's kwargs {key: node_id} + electron_id: The `node_id` of the function. + args: The `node_id`s of the function's args + kwargs: The `node_id`s of the function's kwargs {key: node_id} hooks_id: An opaque string representing the task's hooks. The attribute values can be used in conjunction with a @@ -63,9 +63,9 @@ class TaskSpec(BaseModel): environment. """ - function_id: int - args_ids: List[int] - kwargs_ids: Dict[str, int] + electron_id: int + args: List[int] + kwargs: Dict[str, int] class ResourceMap(BaseModel): @@ -77,8 +77,8 @@ class ResourceMap(BaseModel): Resource identifiers are the attribute values of TaskSpecs. For instance, if ts is a `TaskSpec` and rm is the corresponding `ResourceMap`, then - - the serialized function has URI `rm.functions[ts.function_id]` - - the serialized args have URIs `rm.inputs[ts.args_ids[i]]` + - the serialized function has URI `rm.functions[ts.electron_id]` + - the serialized args have URIs `rm.inputs[ts.args[i]]` - the call_before has URI `rm.deps[ts.call_before_id]` Attributes: diff --git a/covalent/executor/utils/wrappers.py b/covalent/executor/utils/wrappers.py index e35d4f942..322ea1c43 100644 --- a/covalent/executor/utils/wrappers.py +++ b/covalent/executor/utils/wrappers.py @@ -29,6 +29,7 @@ import requests +from covalent._file_transfer import FileTransfer from covalent._workflow.depsbash import DepsBash from covalent._workflow.depscall import RESERVED_RETVAL_KEY__FILES, DepsCall from covalent._workflow.depspip import DepsPip @@ -169,17 +170,14 @@ def run_task_group( """ - prefix = "file://" - prefix_len = len(prefix) - outputs = {} results = [] dispatch_id = task_group_metadata["dispatch_id"] task_ids = task_group_metadata["node_ids"] - gid = task_group_metadata["task_group_id"] os.environ["COVALENT_DISPATCH_ID"] = dispatch_id os.environ["COVALENT_DISPATCHER_URL"] = server_url + os.environ["COVALENT_TASKS"] = json.dumps([task for task in task_specs]) for i, task in enumerate(task_specs): result_uri, stdout_uri, stderr_uri = output_uris[i] @@ -187,14 +185,19 @@ def run_task_group( with open(stdout_uri, "w") as stdout, open(stderr_uri, "w") as stderr: with redirect_stdout(stdout), redirect_stderr(stderr): try: - task_id = task["function_id"] - args_ids = task["args_ids"] - kwargs_ids = task["kwargs_ids"] + task_id = task["electron_id"] + args = task["args"] + kwargs = task["kwargs"] function_uri = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/assets/function" # Download function - resp = requests.get(function_uri, stream=True) + # Get remote uri + uri_resp = requests.get(function_uri) + uri_resp.raise_for_status() + remote_uri = uri_resp.json()["remote_uri"] + resp = requests.get(remote_uri, stream=True) + resp.raise_for_status() serialized_fn = deserialize_node_asset(resp.content, "function") @@ -202,21 +205,32 @@ def run_task_group( ser_kwargs = {} # Download args and kwargs - for node_id in args_ids: + for node_id in args: url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/output" - resp = requests.get(url, stream=True) + uri_resp = requests.get(url) + uri_resp.raise_for_status() + remote_url = uri_resp.json()["remote_uri"] + + resp = requests.get(remote_url, stream=True) resp.raise_for_status() ser_args.append(deserialize_node_asset(resp.content, "output")) - for k, node_id in kwargs_ids.items(): + for k, node_id in kwargs.items(): url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/output" - resp = requests.get(url, stream=True) + uri_resp = requests.get(url) + uri_resp.raise_for_status() + remote_url = uri_resp.json()["remote_uri"] + resp = requests.get(remote_url, stream=True) resp.raise_for_status() ser_kwargs[k] = deserialize_node_asset(resp.content, "output") # Download deps, call_before, and call_after hooks_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/assets/hooks" - resp = requests.get(hooks_url, stream=True) + uri_resp = requests.get(hooks_url) + uri_resp.raise_for_status() + remote_url = uri_resp.json()["remote_uri"] + + resp = requests.get(remote_url, stream=True) resp.raise_for_status() hooks_json = deserialize_node_asset(resp.content, "hooks") deps_json = hooks_json.get("deps", {}) @@ -267,22 +281,32 @@ def run_task_group( # POST task artifacts if result_uri: upload_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/assets/output" - with open(result_uri, "rb") as f: - requests.put(upload_url, data=f) + headers = {"Content-Length": str(os.path.getsize(result_uri))} + uri_resp = requests.post(upload_url, headers=headers) + uri_resp.raise_for_status() + remote_uri = uri_resp.json()["remote_uri"] + _, cp = FileTransfer(f"file://{result_uri}", remote_uri).cp() + cp() sys.stdout.flush() if stdout_uri: upload_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/assets/stdout" - with open(stdout_uri, "rb") as f: - headers = {"Content-Length": os.path.getsize(stdout_uri)} - requests.put(upload_url, data=f) + headers = {"Content-Length": str(os.path.getsize(stdout_uri))} + uri_resp = requests.post(upload_url, headers=headers) + uri_resp.raise_for_status() + remote_uri = uri_resp.json()["remote_uri"] + _, cp = FileTransfer(f"file://{stdout_uri}", remote_uri).cp() + cp() sys.stderr.flush() if stderr_uri: upload_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/assets/stderr" - with open(stderr_uri, "rb") as f: - headers = {"Content-Length": os.path.getsize(stderr_uri)} - requests.put(upload_url, data=f) + headers = {"Content-Length": str(os.path.getsize(stderr_uri))} + uri_resp = requests.post(upload_url, headers=headers) + uri_resp.raise_for_status() + remote_uri = uri_resp.json()["remote_uri"] + _, cp = FileTransfer(f"file://{stderr_uri}", remote_uri).cp() + cp() result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json") @@ -349,8 +373,9 @@ def run_task_group_alt( task_ids = task_group_metadata["node_ids"] gid = task_group_metadata["task_group_id"] - os.environ["COVALENT_DISPATCH_ID"] = dispatch_id - os.environ["COVALENT_DISPATCHER_URL"] = server_url + # os.environ["COVALENT_DISPATCH_ID"] = dispatch_id + # os.environ["COVALENT_DISPATCHER_URL"] = server_url + # os.environ["COVALENT_TASKS"] = json.dumps([task for task in task_specs]) for i, task in enumerate(task_specs): result_uri, stdout_uri, stderr_uri = output_uris[i] @@ -358,9 +383,9 @@ def run_task_group_alt( with open(stdout_uri, "w") as stdout, open(stderr_uri, "w") as stderr: with redirect_stdout(stdout), redirect_stderr(stderr): try: - task_id = task["function_id"] - args_ids = task["args_ids"] - kwargs_ids = task["kwargs_ids"] + task_id = task["electron_id"] + args = task["args"] + kwargs = task["kwargs"] # Load function function_uri = resources["functions"][task_id] @@ -374,14 +399,14 @@ def run_task_group_alt( ser_args = [] ser_kwargs = {} - args_uris = [resources["inputs"][index] for index in args_ids] + args_uris = [resources["inputs"][index] for index in args] for uri in args_uris: if uri.startswith(prefix): uri = uri[prefix_len:] with open(uri, "rb") as f: ser_args.append(deserialize_node_asset(f.read(), "output")) - kwargs_uris = {k: resources["inputs"][v] for k, v in kwargs_ids.items()} + kwargs_uris = {k: resources["inputs"][v] for k, v in kwargs.items()} for key, uri in kwargs_uris.items(): if uri.startswith(prefix): uri = uri[prefix_len:] diff --git a/covalent_dispatcher/__init__.py b/covalent_dispatcher/__init__.py index 0ad60669e..ca7c2f33b 100644 --- a/covalent_dispatcher/__init__.py +++ b/covalent_dispatcher/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .entry_point import cancel_running_dispatch, run_dispatcher +from .entry_point import cancel_running_dispatch # nopycln: import diff --git a/covalent_dispatcher/_core/__init__.py b/covalent_dispatcher/_core/__init__.py index 58c050f35..a3c58bd78 100644 --- a/covalent_dispatcher/_core/__init__.py +++ b/covalent_dispatcher/_core/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .data_manager import make_dispatch from .data_modules.importer import copy_futures -from .dispatcher import cancel_dispatch, run_dispatch +from .dispatcher import cancel_dispatch, run_dispatch # nopycln: import diff --git a/covalent_dispatcher/_core/data_manager.py b/covalent_dispatcher/_core/data_manager.py index 6b5e52b71..fff11e228 100644 --- a/covalent_dispatcher/_core/data_manager.py +++ b/covalent_dispatcher/_core/data_manager.py @@ -18,19 +18,13 @@ Defines the core functionality of the result service """ -import asyncio -import tempfile import traceback from typing import Dict -from pydantic import ValidationError - -from covalent._dispatcher_plugins.local import LocalDispatcher from covalent._results_manager import Result from covalent._shared_files import logger from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS -from covalent._workflow.lattice import Lattice from .._dal.result import Result as SRVResult from .._dal.result import get_result_object as get_result_object_from_db @@ -99,9 +93,10 @@ async def update_node_result(dispatch_id, node_result): # Handle returns from _build_sublattice_graph -- change # COMPLETED -> DISPATCHING - node_result = _filter_sublattice_status( + node_result = await _filter_sublattice_status( dispatch_id, node_id, node_status, node_type, sub_dispatch_id, node_result ) + app_log.debug(f"Filtered node result: {node_result}") valid_update = await electron.update(dispatch_id, node_result) if not valid_update: @@ -110,15 +105,26 @@ async def update_node_result(dispatch_id, node_result): ) return + # TODO: refactor _make_sublattice_dispatch if node_result["status"] == RESULT_STATUS.DISPATCHING: - app_log.debug("Received sublattice dispatch") - try: - sub_dispatch_id = await _make_sublattice_dispatch(dispatch_id, node_result) - except Exception as ex: - tb = "".join(traceback.TracebackException.from_exception(ex).format()) - node_result["status"] = RESULT_STATUS.FAILED - node_result["error"] = tb - await electron.update(dispatch_id, node_result) + if sub_dispatch_id: + # `_build_sublattice_graph` linked` the sublattice dispatch id with its parent electron + app_log.debug( + f"Electron {dispatch_id}:{node_id} is already linked to subdispatch {sub_dispatch_id}" + ) + else: + try: + # If `_build_sublattice_graph` was unable to reach the server, import the + # b64-encoded staging tarball from `output.object_string` + sub_dispatch_id = await _make_sublattice_dispatch(dispatch_id, node_result) + app_log.debug( + f"Created sublattice dispatch {sub_dispatch_id} for {dispatch_id}:{node_id}" + ) + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + node_result["status"] = RESULT_STATUS.FAILED + node_result["error"] = tb + await electron.update(dispatch_id, node_result) except KeyError as ex: valid_update = False @@ -142,59 +148,6 @@ async def update_node_result(dispatch_id, node_result): await dispatcher.notify_node_status(dispatch_id, node_id, node_status, detail) -# Domain: result -def _redirect_lattice( - json_lattice: str, - parent_dispatch_id: str, - parent_electron_id: int, - loop: asyncio.AbstractEventLoop, -) -> str: - """Redirect a JSON lattice through the new DAL. - - Args: - json_lattice: A JSON-serialized lattice. - parent_dispatch_id: The id of a sublattice's parent dispatch. - - This will only be triggered from either the monolithic /submit - endpoint or a monolithic sublattice dispatch. - - Returns: - The dispatch manifest - - """ - lattice = Lattice.deserialize_from_json(json_lattice) - with tempfile.TemporaryDirectory() as staging_dir: - manifest = LocalDispatcher.prepare_manifest(lattice, staging_dir) - - # Trigger an internal asset pull from /tmp to object store - coro = manifest_importer.import_manifest( - manifest, - parent_dispatch_id, - parent_electron_id, - ) - filtered_manifest = manifest_importer._import_manifest( - manifest, - parent_dispatch_id, - parent_electron_id, - ) - - manifest_importer._pull_assets(filtered_manifest) - - return filtered_manifest.metadata.dispatch_id - - -async def make_dispatch( - json_lattice: str, parent_dispatch_id: str = None, parent_electron_id: int = None -) -> str: - return await run_in_executor( - _redirect_lattice, - json_lattice, - parent_dispatch_id, - parent_electron_id, - asyncio.get_running_loop(), - ) - - def get_result_object(dispatch_id: str, bare: bool = True) -> SRVResult: app_log.debug(f"Getting result object from db, bare={bare}") return get_result_object_from_db(dispatch_id, bare) @@ -225,69 +178,56 @@ async def _update_parent_electron(dispatch_id: str): await update_node_result(parent_result_obj.dispatch_id, node_result) -def _filter_sublattice_status( +async def _filter_sublattice_status( dispatch_id, node_id, status, node_type, sub_dispatch_id, node_result ): - if status == Result.COMPLETED and node_type == "sublattice" and not sub_dispatch_id: + # COMPLETED -> DISPATCHING if either + # * Dispatch with id `sub_dispatch_id` has not started running (NEW_OBJ) + # * Electron is a sublattice but no sub_dispatch_id (legacy style sublattices) + if status != Result.COMPLETED: + return node_result + if sub_dispatch_id: + dispatch_info = await dispatch.get(sub_dispatch_id, ["status"]) + if dispatch_info["status"] == RESULT_STATUS.NEW_OBJECT: + node_result["status"] = RESULT_STATUS.DISPATCHING + if node_type == "sublattice" and not sub_dispatch_id: node_result["status"] = RESULT_STATUS.DISPATCHING return node_result -async def _make_sublattice_dispatch(dispatch_id: str, node_result: dict): +async def _make_sublattice_dispatch(dispatch_id: str, node_result: dict) -> str: try: - manifest, parent_electron_id = await run_in_executor( - _make_sublattice_dispatch_helper, + subl_manifest = await run_in_executor( + _import_sublattice_tarball, dispatch_id, node_result, ) + return subl_manifest.metadata.dispatch_id - imported_manifest = await manifest_importer.import_manifest( - manifest=manifest, - parent_dispatch_id=dispatch_id, - parent_electron_id=parent_electron_id, - ) - - return imported_manifest.metadata.dispatch_id - - except ValidationError as ex: - # Fall back to legacy sublattice handling - # NB: this loads the JSON sublattice in memory - json_lattice, parent_electron_id = await run_in_executor( - _legacy_sublattice_dispatch_helper, - dispatch_id, - node_result, - ) - return await make_dispatch( - json_lattice, - dispatch_id, - parent_electron_id, - ) + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + raise RuntimeError(f"Failed to import sublattice tarball:\n{ex}") -def _legacy_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): - app_log.debug("falling back to legacy sublattice dispatch") +def _import_sublattice_tarball(dispatch_id: str, node_result: Dict) -> ResultSchema: + app_log.debug("Importing sublattice tarball") result_object = get_result_object(dispatch_id, bare=True) node_id = node_result["node_id"] parent_node = result_object.lattice.transport_graph.get_node(node_id) - bg_output = parent_node.get_value("output") - parent_electron_id = parent_node._electron_id - json_lattice = bg_output.object_string - return json_lattice, parent_electron_id - -def _make_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): - """Helper function for performing DB queries related to sublattices.""" - result_object = get_result_object(dispatch_id, bare=True) - node_id = node_result["node_id"] - parent_node = result_object.lattice.transport_graph.get_node(node_id) + # Extract the base64-encoded tarball from the graph builder output's `object_string` attribute. + # TODO: stream the base64 tarball to the decoder to avoid loading the entire + # tarball in memory bg_output = parent_node.get_value("output") + b64_tarball = bg_output.object_string - manifest = ResultSchema.parse_raw(bg_output.object_string) - parent_electron_id = parent_node._electron_id + subl_manifest = manifest_importer.import_b64_staging_tarball( + b64_tarball, dispatch_id, parent_electron_id + ) - return manifest, parent_electron_id + return subl_manifest # Common Result object queries diff --git a/covalent_dispatcher/_core/data_modules/importer.py b/covalent_dispatcher/_core/data_modules/importer.py index 4630b9544..c85be951e 100644 --- a/covalent_dispatcher/_core/data_modules/importer.py +++ b/covalent_dispatcher/_core/data_modules/importer.py @@ -18,17 +18,18 @@ Functionality for importing dispatch submissions """ +import shutil import uuid -from typing import Optional +from typing import Optional, Tuple +from covalent._dispatcher_plugins.local import decode_b64_tar, untar_staging_dir from covalent._shared_files import logger from covalent._shared_files.config import get_config from covalent._shared_files.schemas.result import ResultSchema -from covalent_dispatcher._dal.asset import copy_asset from covalent_dispatcher._dal.importers.result import handle_redispatch, import_result from covalent_dispatcher._dal.result import Result as SRVResult -from .utils import dm_pool, run_in_executor +from .utils import run_in_executor BASE_PATH = get_config("dispatcher.results_dir") @@ -82,16 +83,18 @@ def _get_all_assets(dispatch_id: str): def _pull_assets(manifest: ResultSchema) -> None: dispatch_id = manifest.metadata.dispatch_id assets = _get_all_assets(dispatch_id) - futs = [] + download_count = 0 for asset in assets["lattice"]: if asset.remote_uri: + download_count += 1 asset.download(asset.remote_uri) for asset in assets["nodes"]: if asset.remote_uri: + download_count += 1 asset.download(asset.remote_uri) - app_log.debug(f"imported {len(futs)} assets for dispatch {dispatch_id}") + app_log.debug(f"imported {download_count} assets for dispatch {dispatch_id}") async def import_manifest( @@ -107,12 +110,6 @@ async def import_manifest( return filtered_manifest -def _copy_assets(assets_to_copy): - for item in assets_to_copy: - src, dest = item - copy_asset(src, dest) - - def _import_derived_manifest( manifest: ResultSchema, parent_dispatch_id: str, @@ -123,11 +120,6 @@ def _import_derived_manifest( filtered_manifest, parent_dispatch_id, reuse_previous_results ) - dispatch_id = filtered_manifest.metadata.dispatch_id - fut = dm_pool.submit(_copy_assets, assets_to_copy) - copy_futures[dispatch_id] = fut - fut.add_done_callback(lambda x: copy_futures.pop(dispatch_id)) - return filtered_manifest @@ -146,3 +138,24 @@ async def import_derived_manifest( await run_in_executor(_pull_assets, filtered_manifest) return filtered_manifest + + +# Import b64 tarball of a client-side staging directory +# For handling sublattice dispatches +def import_b64_staging_tarball( + b64_buffer: str, parent_dispatch_id: str, parent_electron_id: str +) -> Tuple[str, ResultSchema]: + tar_path = decode_b64_tar(b64_buffer) + work_dir, manifest = untar_staging_dir(tar_path) + + app_log.debug(f"Extracted tarball to working directory {work_dir}") + filtered_manifest = _import_manifest( + manifest, + parent_dispatch_id, + parent_electron_id, + ) + _pull_assets(filtered_manifest) + + shutil.rmtree(work_dir) + app_log.debug(f"Cleaned up working directory {work_dir}") + return filtered_manifest diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index 14ae116f9..8c653d7a1 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -40,7 +40,7 @@ app_log = logger.app_log log_stack_info = logger.log_stack_info _global_status_queue = None -_status_queues = {} +_background_tasks = set() _futures = {} _global_event_listener = None @@ -212,10 +212,10 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro selected_executor = executor_attrs["executor"] selected_executor_data = executor_attrs["executor_data"] task_spec = { - "function_id": node_id, + "electron_id": node_id, "name": node_name, - "args_ids": abs_task_input["args"], - "kwargs_ids": abs_task_input["kwargs"], + "args": abs_task_input["args"], + "kwargs": abs_task_input["kwargs"], } # Task inputs that don't belong to the task group have already beeen resolved external_task_args = filter( @@ -243,7 +243,9 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro selected_executor=[selected_executor, selected_executor_data], ) - asyncio.create_task(coro) + fut = asyncio.create_task(coro) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) else: ts = datetime.now(timezone.utc) for node_id in sorted_nodes: @@ -363,7 +365,10 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: def run_dispatch(dispatch_id: str) -> asyncio.Future: - return asyncio.create_task(run_workflow(dispatch_id)) + fut = asyncio.create_task(run_workflow(dispatch_id)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + return fut async def notify_node_status( @@ -391,16 +396,11 @@ async def _finalize_dispatch(dispatch_id: str): cancelled = incomplete_tasks["cancelled"] if failed or cancelled: app_log.debug(f"Workflow {dispatch_id} cancelled or failed") - failed_nodes = failed - failed_nodes = map(lambda x: f"{x[0]}: {x[1]}", failed_nodes) - failed_nodes_msg = "\n".join(failed_nodes) - error_msg = "The following tasks failed:\n" + failed_nodes_msg ts = datetime.now(timezone.utc) status = RESULT_STATUS.FAILED if failed else RESULT_STATUS.CANCELLED result_update = datasvc.generate_dispatch_result( dispatch_id, status=status, - error=error_msg, end_time=ts, ) await datasvc.dispatch.update(dispatch_id, result_update) @@ -509,7 +509,9 @@ async def _node_event_listener(): while True: msg = await _global_status_queue.get() - asyncio.create_task(_handle_event(msg)) + fut = asyncio.create_task(_handle_event(msg)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) async def _handle_event(msg: Dict): diff --git a/covalent_dispatcher/_core/runner_ng.py b/covalent_dispatcher/_core/runner_ng.py index 5c98a7721..8e63ea946 100644 --- a/covalent_dispatcher/_core/runner_ng.py +++ b/covalent_dispatcher/_core/runner_ng.py @@ -70,10 +70,10 @@ async def _submit_abstract_task_group( known_nodes: list, executor: AsyncBaseExecutor, ) -> None: - # Task sequence of the form {"function_id": task_id, "args_ids": - # [node_ids], "kwargs_ids": {key: node_id}} + # Task sequence of the form {"electron_id": task_id, "args": + # [node_ids], "kwargs": {key: node_id}} - task_ids = [task["function_id"] for task in task_seq] + task_ids = [task["electron_id"] for task in task_seq] task_specs = [] task_group_metadata = { "dispatch_id": dispatch_id, @@ -89,7 +89,7 @@ async def _submit_abstract_task_group( # Get upload URIs for task_spec in task_seq: - task_id = task_spec["function_id"] + task_id = task_spec["electron_id"] function_uri = executor.get_upload_uri(task_group_metadata, f"function-{task_id}") hooks_uri = executor.get_upload_uri(task_group_metadata, f"hooks-{task_id}") @@ -255,7 +255,7 @@ async def run_abstract_task_group( try: app_log.debug(f"Attempting to instantiate executor {selected_executor}") - task_ids = [task["function_id"] for task in task_seq] + task_ids = [task["electron_id"] for task in task_seq] app_log.debug(f"Running task group {dispatch_id}:{task_group_id}") executor = get_executor( node_id=task_group_id, @@ -281,11 +281,11 @@ async def run_abstract_task_group( raise RuntimeError("Task packing not supported by executor plugin") task_spec = task_seq[0] - node_id = task_spec["function_id"] + node_id = task_spec["electron_id"] name = task_spec["name"] abstract_inputs = { - "args": task_spec["args_ids"], - "kwargs": task_spec["kwargs_ids"], + "args": task_spec["args"], + "kwargs": task_spec["kwargs"], } app_log.debug(f"Reverting to legacy runner for task {task_group_id}") coro = runner_legacy.run_abstract_task( diff --git a/covalent_dispatcher/_dal/asset.py b/covalent_dispatcher/_dal/asset.py index 146b47624..a674a6973 100644 --- a/covalent_dispatcher/_dal/asset.py +++ b/covalent_dispatcher/_dal/asset.py @@ -33,13 +33,13 @@ app_log = logger.app_log -class StorageType(Enum): - LOCAL = "file" +class StorageType(str, Enum): + LOCAL = local_store.scheme # "file" S3 = "s3" _storage_provider_map = { - StorageType.LOCAL: local_store, + StorageType.LOCAL.value: local_store, } @@ -73,8 +73,8 @@ def primary_key(self): return self._id @property - def storage_type(self) -> StorageType: - return StorageType(self._attrs["storage_type"]) + def storage_type(self) -> str: + return self._attrs["storage_type"] @property def storage_path(self) -> str: @@ -98,8 +98,7 @@ def remote_uri(self) -> str: @property def internal_uri(self) -> str: - scheme = self.storage_type.value - return f"{scheme}://" + str(Path(self.storage_path) / self.object_key) + return f"{self.storage_type}://" + str(Path(self.storage_path) / self.object_key) @property def size(self) -> int: @@ -123,14 +122,14 @@ def load_data(self) -> Any: return self.object_store.load_file(self.storage_path, self.object_key) def download(self, src_uri: str): - scheme = self.storage_type.value + scheme = self.storage_type dest_uri = scheme + "://" + os.path.join(self.storage_path, self.object_key) app_log.debug(f"Downloading asset from {src_uri} to {dest_uri}") cp(src_uri, dest_uri) def upload(self, dest_uri: str): - scheme = self.storage_type.value + scheme = self.storage_type src_uri = scheme + "://" + os.path.join(self.storage_path, self.object_key) app_log.debug(f"Uploading asset from {src_uri} to {dest_uri}") cp(src_uri, dest_uri) @@ -154,7 +153,7 @@ def copy_asset(src: Asset, dest: Asset): """ if src.size > 0: - scheme = dest.storage_type.value + scheme = dest.storage_type dest_uri = scheme + "://" + os.path.join(dest.storage_path, dest.object_key) src.upload(dest_uri) else: @@ -174,5 +173,9 @@ def copy_asset_meta(session: Session, src: Asset, dest: Asset): "digest_alg": src.digest_alg, "digest": src.digest, "size": src.size, + "storage_type": src.storage_type, + "storage_path": src.storage_path, + "object_key": src.object_key, + "remote_uri": src.remote_uri, } dest.update(session, values=update) diff --git a/covalent_dispatcher/_dal/base.py b/covalent_dispatcher/_dal/base.py index 071d9a44b..6df880283 100644 --- a/covalent_dispatcher/_dal/base.py +++ b/covalent_dispatcher/_dal/base.py @@ -20,7 +20,7 @@ from typing import Any, Dict, Generator, Generic, List, Type, TypeVar, Union from sqlalchemy import select -from sqlalchemy.orm import Session, load_only +from sqlalchemy.orm import Session from .._db.datastore import workflow_db from . import controller @@ -220,15 +220,13 @@ def get_linked_assets( .join(link_model) .join(cls.meta_type.model) ) - if len(fields) == 0: - fields = FIELDS + fields = FIELDS for attr, val in equality_filters.items(): stmt = stmt.where(getattr(cls.meta_type.model, attr) == val) for attr, vals in membership_filters.items(): stmt = stmt.where(getattr(cls.meta_type.model, attr).in_(vals)) attrs = [getattr(Asset.model, f) for f in fields] - stmt = stmt.options(load_only(*attrs)) records = session.execute(stmt) diff --git a/covalent_dispatcher/_dal/exporters/electron.py b/covalent_dispatcher/_dal/exporters/electron.py index e84211291..62ab7d0ae 100644 --- a/covalent_dispatcher/_dal/exporters/electron.py +++ b/covalent_dispatcher/_dal/exporters/electron.py @@ -26,6 +26,7 @@ ElectronSchema, ) from covalent_dispatcher._dal.electron import ASSET_KEYS, Electron +from covalent_dispatcher._object_store.base import TransferDirection app_log = logger.app_log @@ -60,8 +61,12 @@ def _export_electron_assets(e: Electron) -> ElectronAssets: size = asset.size digest_alg = asset.digest_alg digest = asset.digest - scheme = asset.storage_type.value - remote_uri = f"{scheme}://{asset.storage_path}/{asset.object_key}" + object_store = asset.object_store + remote_uri = object_store.get_public_uri( + asset.storage_path, + asset.object_key, + transfer_direction=TransferDirection.download, + ) manifests[asset_key] = AssetSchema( remote_uri=remote_uri, size=size, digest_alg=digest_alg, digest=digest ) diff --git a/covalent_dispatcher/_dal/exporters/lattice.py b/covalent_dispatcher/_dal/exporters/lattice.py index 4b630fcd1..0f3535e64 100644 --- a/covalent_dispatcher/_dal/exporters/lattice.py +++ b/covalent_dispatcher/_dal/exporters/lattice.py @@ -21,6 +21,7 @@ from covalent._shared_files.schemas.asset import AssetSchema from covalent._shared_files.schemas.lattice import LatticeAssets, LatticeMetadata, LatticeSchema from covalent_dispatcher._dal.lattice import ASSET_KEYS, METADATA_KEYS, Lattice +from covalent_dispatcher._object_store.base import TransferDirection from .tg import export_transport_graph @@ -37,8 +38,12 @@ def _export_lattice_assets(lat: Lattice) -> LatticeAssets: size = asset.size digest_alg = asset.digest_alg digest = asset.digest - scheme = asset.storage_type.value - remote_uri = f"{scheme}://{asset.storage_path}/{asset.object_key}" + object_store = asset.object_store + remote_uri = object_store.get_public_uri( + asset.storage_path, + asset.object_key, + transfer_direction=TransferDirection.download, + ) manifests[asset_key] = AssetSchema( remote_uri=remote_uri, size=size, digest_alg=digest_alg, digest=digest ) diff --git a/covalent_dispatcher/_dal/exporters/result.py b/covalent_dispatcher/_dal/exporters/result.py index dfc21bca5..5958afa09 100644 --- a/covalent_dispatcher/_dal/exporters/result.py +++ b/covalent_dispatcher/_dal/exporters/result.py @@ -19,7 +19,6 @@ from covalent._shared_files import logger -from covalent._shared_files.config import get_config from covalent._shared_files.schemas.asset import AssetSchema from covalent._shared_files.schemas.result import ( ASSET_KEYS, @@ -28,16 +27,13 @@ ResultMetadata, ResultSchema, ) -from covalent._shared_files.utils import format_server_url from covalent_dispatcher._dal.electron import Electron from covalent_dispatcher._dal.result import Result, get_result_object +from covalent_dispatcher._object_store.base import TransferDirection -from ..utils.uri_filters import AssetScope, URIFilterPolicy, filter_asset_uri from .lattice import export_lattice METADATA_KEYS_TO_OMIT = {"num_nodes"} -SERVER_URL = format_server_url(get_config("dispatcher.address"), get_config("dispatcher.port")) -URI_FILTER_POLICY = URIFilterPolicy[get_config("dispatcher.data_uri_filter_policy")] app_log = logger.app_log @@ -98,8 +94,12 @@ def _export_result_assets(res: Result) -> ResultAssets: size = asset.size digest_alg = asset.digest_alg digest = asset.digest - scheme = asset.storage_type.value - remote_uri = f"{scheme}://{asset.storage_path}/{asset.object_key}" + object_store = asset.object_store + remote_uri = object_store.get_public_uri( + asset.storage_path, + asset.object_key, + transfer_direction=TransferDirection.download, + ) manifests[asset_key] = AssetSchema( remote_uri=remote_uri, size=size, digest_alg=digest_alg, digest=digest ) @@ -119,35 +119,7 @@ def export_result(res: Result) -> ResultSchema: # Filter asset URIs - return _filter_remote_uris(ResultSchema(metadata=metadata, assets=assets, lattice=lattice)) - - -def _filter_remote_uris(manifest: ResultSchema) -> ResultSchema: - dispatch_id = manifest.metadata.dispatch_id - - # Workflow-level - for key, asset in manifest.assets: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, asset.remote_uri, {}, AssetScope.DISPATCH, dispatch_id, None, key - ) - asset.remote_uri = filtered_uri - - for key, asset in manifest.lattice.assets: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, asset.remote_uri, {}, AssetScope.LATTICE, dispatch_id, None, key - ) - asset.remote_uri = filtered_uri - - # Now filter each node - tg = manifest.lattice.transport_graph - for node in tg.nodes: - for key, asset in node.assets: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, asset.remote_uri, {}, AssetScope.NODE, dispatch_id, node.id, key - ) - asset.remote_uri = filtered_uri - - return manifest + return ResultSchema(metadata=metadata, assets=assets, lattice=lattice) def export_result_manifest(dispatch_id: str) -> ResultSchema: diff --git a/covalent_dispatcher/_dal/importers/electron.py b/covalent_dispatcher/_dal/importers/electron.py index ac60577bf..4f8924e7b 100644 --- a/covalent_dispatcher/_dal/importers/electron.py +++ b/covalent_dispatcher/_dal/importers/electron.py @@ -18,14 +18,12 @@ """Functions to transform ResultSchema -> Result""" import json -import os from typing import Dict, Tuple from sqlalchemy.orm import Session from covalent._shared_files import logger from covalent._shared_files.schemas.electron import ( - ASSET_FILENAME_MAP, ELECTRON_ERROR_FILENAME, ELECTRON_FUNCTION_FILENAME, ELECTRON_FUNCTION_STRING_FILENAME, @@ -43,7 +41,7 @@ from covalent_dispatcher._dal.lattice import Lattice from covalent_dispatcher._db import models from covalent_dispatcher._db.write_result_to_db import get_electron_type -from covalent_dispatcher._object_store.base import BaseProvider +from covalent_dispatcher._object_store.base import BaseProvider, TransferDirection app_log = logger.app_log @@ -142,8 +140,6 @@ def import_electron_assets( asset_key, ) - object_key = ASSET_FILENAME_MAP[asset_key] - local_uri = os.path.join(node_storage_path, object_key) asset_kwargs = { "storage_type": object_store.scheme, "storage_path": node_storage_path, @@ -156,28 +152,29 @@ def import_electron_assets( asset_recs[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False) # Send this back to the client + remote_uri = object_store.get_public_uri( + node_storage_path, + object_key, + transfer_direction=TransferDirection.upload, + ) asset.digest = None - asset.remote_uri = f"file://{local_uri}" - - # Register custom assets - if e.assets._custom: - for asset_key, asset in e.assets._custom.items(): - object_key = f"{asset_key}.data" - local_uri = os.path.join(node_storage_path, object_key) - - asset_kwargs = { - "storage_type": object_store.scheme, - "storage_path": node_storage_path, - "object_key": object_key, - "digest_alg": asset.digest_alg, - "digest": asset.digest, - "remote_uri": asset.uri, - "size": asset.size, - } - asset_recs[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False) - - # Send this back to the client - asset.remote_uri = f"file://{local_uri}" if asset.digest else "" - asset.digest = None + asset.remote_uri = remote_uri + + # Declare an asset for sublattice manifests + electron_type = get_electron_type(e.metadata.name) + if electron_type == "sublattice": + object_key = "result.tobj" + asset_kwargs = { + "storage_type": object_store.scheme, + "storage_path": node_storage_path, + "object_key": object_key, + "digest_alg": None, + "digest": None, + "remote_uri": None, + "size": 0, + } + asset_recs["sublattice_manifest"] = Asset.create( + session, insert_kwargs=asset_kwargs, flush=False + ) return e.assets, asset_recs diff --git a/covalent_dispatcher/_dal/importers/lattice.py b/covalent_dispatcher/_dal/importers/lattice.py index 55fa50925..f100e997d 100644 --- a/covalent_dispatcher/_dal/importers/lattice.py +++ b/covalent_dispatcher/_dal/importers/lattice.py @@ -37,7 +37,7 @@ ) from covalent_dispatcher._dal.asset import Asset from covalent_dispatcher._dal.lattice import Lattice -from covalent_dispatcher._object_store.local import BaseProvider +from covalent_dispatcher._object_store.base import BaseProvider, TransferDirection def _get_lattice_meta(lat: LatticeSchema, storage_path) -> dict: @@ -97,7 +97,6 @@ def import_lattice_assets( ) local_uri = os.path.join(storage_path, object_key) - asset_kwargs = { "storage_type": object_store.scheme, "storage_path": storage_path, @@ -111,30 +110,12 @@ def import_lattice_assets( # Send this back to the client asset.digest = None - asset.remote_uri = f"file://{local_uri}" - - # Register custom assets - if lat.assets._custom: - for asset_key, asset in lat.assets._custom.items(): - object_key = f"{asset_key}.data" - local_uri = os.path.join(storage_path, object_key) - - asset_kwargs = { - "storage_type": object_store.scheme, - "storage_path": storage_path, - "object_key": object_key, - "digest_alg": asset.digest_alg, - "digest": asset.digest, - "remote_uri": asset.uri, - "size": asset.size, - } - asset_ids[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False) - - # Send this back to the client - asset.remote_uri = f"file://{local_uri}" if asset.digest else "" - asset.digest = None - - session.flush() + remote_uri = object_store.get_public_uri( + storage_path, + object_key, + transfer_direction=TransferDirection.upload, + ) + asset.remote_uri = remote_uri # Write asset records to DB session.flush() diff --git a/covalent_dispatcher/_dal/importers/result.py b/covalent_dispatcher/_dal/importers/result.py index 7e4bd36f9..dfcb13e35 100644 --- a/covalent_dispatcher/_dal/importers/result.py +++ b/covalent_dispatcher/_dal/importers/result.py @@ -24,27 +24,22 @@ from sqlalchemy.orm import Session from covalent._shared_files import logger -from covalent._shared_files.config import get_config from covalent._shared_files.schemas.lattice import LatticeSchema from covalent._shared_files.schemas.result import ResultAssets, ResultSchema -from covalent._shared_files.utils import format_server_url from covalent_dispatcher._dal.asset import Asset from covalent_dispatcher._dal.electron import ElectronMeta from covalent_dispatcher._dal.job import Job from covalent_dispatcher._dal.result import Result, ResultMeta -from covalent_dispatcher._object_store.local import BaseProvider, local_store +from covalent_dispatcher._object_store.base import BaseProvider, TransferDirection +from covalent_dispatcher._object_store.local import local_store from ..asset import copy_asset_meta from ..tg_ops import TransportGraphOps -from ..utils.uri_filters import AssetScope, URIFilterPolicy, filter_asset_uri from .lattice import _get_lattice_meta, import_lattice_assets from .tg import import_transport_graph -SERVER_URL = format_server_url(get_config("dispatcher.address"), get_config("dispatcher.port")) - -URI_FILTER_POLICY = URIFilterPolicy[get_config("dispatcher.data_uri_filter_policy")] - app_log = logger.app_log +DEFAULT_OBJECT_STORE = local_store def import_result( @@ -80,14 +75,14 @@ def import_result( st = datetime.now() lattice_row = ResultMeta.create(session, insert_kwargs=lattice_record_kwargs, flush=True) res_record = Result(session, lattice_row, True) - res_assets = import_result_assets(session, res, res_record, local_store) + res_assets = import_result_assets(session, res, res_record, DEFAULT_OBJECT_STORE) lat_assets = import_lattice_assets( session, dispatch_id, res.lattice, res_record.lattice, - local_store, + DEFAULT_OBJECT_STORE, ) et = datetime.now() delta = (et - st).total_seconds() @@ -99,7 +94,7 @@ def import_result( dispatch_id, res.lattice.transport_graph, res_record.lattice, - local_store, + DEFAULT_OBJECT_STORE, electron_id, ) et = datetime.now() @@ -108,13 +103,7 @@ def import_result( lat = LatticeSchema(metadata=res.lattice.metadata, assets=lat_assets, transport_graph=tg) - output = ResultSchema(metadata=res.metadata, assets=res_assets, lattice=lat) - st = datetime.now() - filtered_uris = _filter_remote_uris(output) - et = datetime.now() - delta = (et - st).total_seconds() - app_log.debug(f"{dispatch_id}: Filtering URIs took {delta} seconds") - return filtered_uris + return ResultSchema(metadata=res.metadata, assets=res_assets, lattice=lat) def _connect_result_to_electron( @@ -165,49 +154,6 @@ def _connect_result_to_electron( return res -def _filter_remote_uris(manifest: ResultSchema) -> ResultSchema: - dispatch_id = manifest.metadata.dispatch_id - - # Workflow-level - for key, asset in manifest.assets: - if asset.remote_uri: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, - asset.remote_uri, - {}, - AssetScope.DISPATCH, - dispatch_id, - None, - key, - ) - asset.remote_uri = filtered_uri - - for key, asset in manifest.lattice.assets: - if asset.remote_uri: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, asset.remote_uri, {}, AssetScope.LATTICE, dispatch_id, None, key - ) - asset.remote_uri = filtered_uri - - # Now filter each node - tg = manifest.lattice.transport_graph - for node in tg.nodes: - for key, asset in node.assets: - if asset.remote_uri: - filtered_uri = filter_asset_uri( - URI_FILTER_POLICY, - asset.remote_uri, - {}, - AssetScope.NODE, - dispatch_id, - node.id, - key, - ) - asset.remote_uri = filtered_uri - - return manifest - - def _get_result_meta(res: ResultSchema, storage_path: str, electron_id: Optional[int]) -> dict: kwargs = { "dispatch_id": res.metadata.dispatch_id, @@ -239,7 +185,6 @@ def import_result_assets( node_id=None, asset_key=asset_key, ) - local_uri = os.path.join(storage_path, object_key) asset_kwargs = { "storage_type": object_store.scheme, @@ -254,7 +199,12 @@ def import_result_assets( # Send this back to the client asset.digest = None - asset.remote_uri = f"file://{local_uri}" + remote_uri = object_store.get_public_uri( + storage_path, + object_key, + transfer_direction=TransferDirection.upload, + ) + asset.remote_uri = remote_uri # Write asset records to DB n_records = len(asset_ids) diff --git a/covalent_dispatcher/_dal/result.py b/covalent_dispatcher/_dal/result.py index 9780ab294..fc7528968 100644 --- a/covalent_dispatcher/_dal/result.py +++ b/covalent_dispatcher/_dal/result.py @@ -28,7 +28,7 @@ from covalent._shared_files.util_classes import RESULT_STATUS, Status from .._db import models -from .asset import Asset, copy_asset, copy_asset_meta +from .asset import Asset, copy_asset_meta from .base import DispatchedObject from .controller import Record from .db_interfaces.result_utils import ASSET_KEYS # nopycln: import @@ -186,14 +186,16 @@ def _update_dispatch( electron_output = parent_electron.get_asset("output", session) electron_err = parent_electron.get_asset("error", session) + # Archive result of _build_sublattice_graph since electron output will + # be overwritten by the sublattice execution result + subl_manifest = parent_electron.get_asset("sublattice_manifest", session) + app_log.debug("Copying sublattice output to parent electron") with self.session() as session: + copy_asset_meta(session, electron_output, subl_manifest) copy_asset_meta(session, subl_output, electron_output) copy_asset_meta(session, subl_err, electron_err) - copy_asset(subl_output, electron_output) - copy_asset(subl_err, electron_err) - def _update_node( self, node_id: int, @@ -275,7 +277,6 @@ def _update_node( workflow_result = self.get_asset("result", session) node_output = tg.get_node(node_id).get_asset("output", session) copy_asset_meta(session, node_output, workflow_result) - copy_asset(node_output, workflow_result) self._update_dispatch(status=status, end_time=end_time) diff --git a/covalent_dispatcher/_dal/tg_ops.py b/covalent_dispatcher/_dal/tg_ops.py index 79f87fa4c..1b1be3e83 100644 --- a/covalent_dispatcher/_dal/tg_ops.py +++ b/covalent_dispatcher/_dal/tg_ops.py @@ -24,7 +24,7 @@ from covalent._shared_files import logger from covalent._shared_files.util_classes import RESULT_STATUS -from .asset import copy_asset, copy_asset_meta +from .asset import copy_asset_meta from .electron import ASSET_KEYS, METADATA_KEYS from .tg import _TransportGraph @@ -108,12 +108,6 @@ def copy_nodes_from( copy_asset_meta(session, old, new) assets_to_copy.append((old, new)) - # Now perform all data copy operations (this could be slow) - if not defer_copy_objects: - for item in assets_to_copy: - src, dest = item - copy_asset(src, dest) - # Return the assets to copy at a later time return assets_to_copy diff --git a/covalent_dispatcher/_dal/utils/uri_filters.py b/covalent_dispatcher/_dal/utils/uri_filters.py deleted file mode 100644 index 4d9afdbb8..000000000 --- a/covalent_dispatcher/_dal/utils/uri_filters.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions to transform URIs""" - -import enum -from typing import Optional - -from covalent._shared_files import logger -from covalent._shared_files.config import get_config -from covalent._shared_files.utils import format_server_url - -SERVER_URL = format_server_url(get_config("dispatcher.address"), get_config("dispatcher.port")) - -app_log = logger.app_log - - -class AssetScope(enum.Enum): - DISPATCH = "dispatch" - LATTICE = "lattice" - NODE = "node" - - -class URIFilterPolicy(enum.Enum): - raw = "raw" # expose raw URIs - http = "http" # return data endpoints - - -def _srv_asset_uri( - uri: str, attrs: dict, scope: AssetScope, dispatch_id: str, node_id: Optional[int], key: str -) -> str: - base_uri = f"{SERVER_URL}/api/v2/dispatches/{dispatch_id}" - - if scope == AssetScope.DISPATCH: - return f"{base_uri}/assets/{key}" - elif scope == AssetScope.LATTICE: - return f"{base_uri}/lattice/assets/{key}" - else: - return f"{base_uri}/electrons/{node_id}/assets/{key}" - - -def _raw( - uri: str, attrs: dict, scope: AssetScope, dispatch_id: str, node_id: Optional[int], key: str -): - return uri - - -_filter_map = { - URIFilterPolicy.raw: _raw, - URIFilterPolicy.http: _srv_asset_uri, -} - - -def filter_asset_uri( - filter_policy: URIFilterPolicy, - uri: str, - attrs: dict, - scope: AssetScope, - dispatch_id: str, - node_id: Optional[int], - key: str, -) -> str: - """Transform an internal URI for an asset to an external URI. - - Parameters: - uri: internal URI - attrs: attributes for the external URI - scope: asset scope ("dispatch", "lattice", "node") - key: asset key - - Returns: - The external URI for the asset - - """ - - selected_filter = _filter_map[filter_policy] - return selected_filter( - uri=uri, - attrs=attrs, - scope=scope, - dispatch_id=dispatch_id, - node_id=node_id, - key=key, - ) diff --git a/covalent_dispatcher/_db/upsert.py b/covalent_dispatcher/_db/upsert.py index 2cd7fd3e5..b58b681a9 100644 --- a/covalent_dispatcher/_db/upsert.py +++ b/covalent_dispatcher/_db/upsert.py @@ -285,6 +285,21 @@ def _electron_data( assets[key] = Asset.create(session, insert_kwargs=asset_record_kwargs, flush=True) + # Register sublattice manifest for sublattice electrons + electron_type = get_electron_type(tg.get_node_value(node_id, "name")) + if electron_type == "sublattice": + asset_record_kwargs = { + "storage_type": ELECTRON_STORAGE_TYPE, + "storage_path": str(node_path), + "object_key": "result.tobj", + "digest_alg": None, + "digest": None, + "size": 0, + } + assets["sublattice_manifest"] = Asset.create( + session, insert_kwargs=asset_record_kwargs, flush=True + ) + # Register custom assets node_metadata = tg.get_node_value(node_id, "metadata") if CUSTOM_ASSETS_FIELD in node_metadata: diff --git a/covalent_dispatcher/_object_store/base.py b/covalent_dispatcher/_object_store/base.py index 7329c6671..e8fea2595 100644 --- a/covalent_dispatcher/_object_store/base.py +++ b/covalent_dispatcher/_object_store/base.py @@ -17,9 +17,15 @@ """Base storage backend provider""" from dataclasses import dataclass +from enum import Enum from typing import Optional, Tuple +class TransferDirection(str, Enum): + upload = "put" + download = "get" + + @dataclass class Digest: algorithm: str @@ -59,3 +65,6 @@ def get_uri_components( """ raise NotImplementedError + + def get_public_uri(self, storage_path: str, object_key: str, **options) -> str: + raise NotImplementedError diff --git a/covalent_dispatcher/_object_store/local.py b/covalent_dispatcher/_object_store/local.py index 6bd4e7ec9..9abda2297 100644 --- a/covalent_dispatcher/_object_store/local.py +++ b/covalent_dispatcher/_object_store/local.py @@ -25,6 +25,7 @@ from covalent._shared_files.config import get_config from covalent._shared_files.schemas import electron, lattice, result +from covalent._shared_files.utils import format_server_url from covalent._workflow.transport import TransportableObject from .base import BaseProvider, Digest @@ -36,6 +37,7 @@ WORKFLOW_ASSET_FILENAME_MAP.update(lattice.ASSET_FILENAME_MAP) ELECTRON_ASSET_FILENAME_MAP = electron.ASSET_FILENAME_MAP.copy() +SERVER_URL = format_server_url(get_config("dispatcher.address"), get_config("dispatcher.port")) # Moved from write_result_to_db.py @@ -92,17 +94,17 @@ def get_uri_components( the asset. """ - storage_path = os.path.join(self.base_path, dispatch_id) + rel_dir = dispatch_id if node_id is not None: - storage_path = os.path.join(storage_path, f"node_{node_id}") - object_key = ELECTRON_ASSET_FILENAME_MAP[asset_key] + rel_dir = f"{dispatch_id}/node_{node_id}" + basename = ELECTRON_ASSET_FILENAME_MAP[asset_key] else: - object_key = WORKFLOW_ASSET_FILENAME_MAP[asset_key] + basename = WORKFLOW_ASSET_FILENAME_MAP[asset_key] - os.makedirs(storage_path, exist_ok=True) + object_key = os.path.join(rel_dir, basename) - return storage_path, object_key + return self.base_path, object_key def store_file(self, storage_path: str, filename: str, data: Any = None) -> Tuple[Digest, int]: """This function writes data corresponding to the filepaths in the DB.""" @@ -167,5 +169,11 @@ def load_file(self, storage_path: str, filename: str) -> Any: return data + def get_public_uri(self, storage_path: str, object_key: str, **options) -> str: + if storage_path.startswith(self.base_path): + return f"{SERVER_URL}/api/v0/files/{object_key}" + else: + return "" + local_store = LocalProvider() diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index bf02b66ac..f815486b5 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -42,6 +42,7 @@ BulkGetMetadata, DispatchStatusSetSchema, DispatchSummary, + ElectronUpdateSchema, TargetDispatchStatus, ) @@ -171,28 +172,26 @@ async def register(manifest: ResultSchema) -> ResultSchema: ) from e -@router.post("/dispatches/{dispatch_id}/sublattices", status_code=201) -async def register_subdispatch( - manifest: ResultSchema, - dispatch_id: str, -) -> ResultSchema: - """Register a subdispatch in the database. +@router.patch("/dispatches/{dispatch_id}/electrons/{node_id}") +def associate_sublattice_with_electron(dispatch_id: str, node_id: str, body: ElectronUpdateSchema): + dispatch_controller = Result.meta_type - Args: - manifest: Declares all metadata and assets in the workflow - dispatch_id: The parent dispatch id + with workflow_db.session() as session: + res = get_result_object(dispatch_id, bare=True, session=session) + node = res.lattice.transport_graph.get_node(node_id, session=session) + electron_id = node._electron_id + + dispatch_controller.update_bulk( + session=session, + values={"electron_id": electron_id}, + equality_filters={"dispatch_id": body.sub_dispatch_id}, + membership_filters={}, + ) + session.commit() - Returns: - The manifest with `dispatch_id` and remote URIs for each asset populated. - """ - try: - return await dispatcher.register_dispatch(manifest, dispatch_id) - except Exception as e: - app_log.debug(f"Exception in register: {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to submit workflow: {e}", - ) from e + app_log.debug( + f"Associated sublattice {body.sub_dispatch_id} with parent electron {dispatch_id}:{node_id}" + ) @router.post("/dispatches/{dispatch_id}/redispatches", status_code=201) diff --git a/covalent_dispatcher/_service/assets.py b/covalent_dispatcher/_service/assets.py index 0664e5058..c2b10b413 100644 --- a/covalent_dispatcher/_service/assets.py +++ b/covalent_dispatcher/_service/assets.py @@ -16,36 +16,19 @@ """Endpoints for uploading and downloading workflow assets""" -import asyncio -import mmap -import os from functools import lru_cache -from typing import Tuple, Union - -import aiofiles -import aiofiles.os -from fastapi import APIRouter, Header, HTTPException, Request -from fastapi.responses import StreamingResponse -from furl import furl - -from covalent._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES -from covalent._serialize.lattice import ASSET_TYPES as LATTICE_ASSET_TYPES -from covalent._serialize.result import ASSET_TYPES as RESULT_ASSET_TYPES -from covalent._serialize.result import AssetType +from typing import Union + +from fastapi import APIRouter, Header, HTTPException + from covalent._shared_files import logger from covalent._shared_files.config import get_config -from covalent._workflow.transportable_object import TOArchiveUtils +from covalent._shared_files.schemas.asset import AssetSchema +from covalent_dispatcher._object_store.base import TransferDirection from .._dal.result import get_result_object from .._db.datastore import workflow_db -from .models import ( - AssetRepresentation, - DispatchAssetKey, - ElectronAssetKey, - LatticeAssetKey, - range_pattern, - range_regex, -) +from .models import ElectronAssetKey app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -67,35 +50,16 @@ def get_node_asset( dispatch_id: str, node_id: int, key: ElectronAssetKey, - representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), -): +) -> AssetSchema: """Returns an asset for an electron. Args: dispatch_id: The dispatch's unique id. node_id: The id of the electron. key: The name of the asset - representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` - range: (optional) range request header - - If `representation` is specified, it will override the range request. """ - start_byte = 0 - end_byte = -1 try: - if Range: - start_byte, end_byte = _extract_byte_range(Range) - - if end_byte >= 0 and end_byte < start_byte: - raise HTTPException( - status_code=400, - detail="Invalid byte range", - ) - app_log.debug( - f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for node {dispatch_id}:{node_id}" - ) result_object = get_cached_result_object(dispatch_id) @@ -104,148 +68,26 @@ def get_node_asset( node = result_object.lattice.transport_graph.get_node(node_id) with workflow_db.session() as session: asset = node.get_asset(key=key.value, session=session) - - # Explicit representation overrides the byte range - if representation is None or ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) - - app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") - generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) - return StreamingResponse(generator) - - except Exception as e: - app_log.debug(e) - raise - - -@router.get("/dispatches/{dispatch_id}/assets/{key}") -def get_dispatch_asset( - dispatch_id: str, - key: DispatchAssetKey, - representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), -): - """Returns a dynamic asset for a workflow - - Args: - dispatch_id: The dispatch's unique id. - key: The name of the asset - representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` - range: (optional) range request header - - If `representation` is specified, it will override the range request. - """ - start_byte = 0 - end_byte = -1 - - try: - if Range: - start_byte, end_byte = _extract_byte_range(Range) - - if end_byte >= 0 and end_byte < start_byte: - raise HTTPException( - status_code=400, - detail="Invalid byte range", + remote_uri = asset.object_store.get_public_uri( + asset.storage_path, asset.object_key, direction=TransferDirection.download ) - app_log.debug( - f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for dispatch {dispatch_id}" - ) - result_object = get_cached_result_object(dispatch_id) + return AssetSchema(size=asset.size, remote_uri=remote_uri) - app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") - with workflow_db.session() as session: - asset = result_object.get_asset(key=key.value, session=session) - - # Explicit representation overrides the byte range - if representation is None or RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) - - app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") - generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) - return StreamingResponse(generator) except Exception as e: app_log.debug(e) raise -@router.get("/dispatches/{dispatch_id}/lattice/assets/{key}") -def get_lattice_asset( - dispatch_id: str, - key: LatticeAssetKey, - representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), -): - """Returns a static asset for a workflow - - Args: - dispatch_id: The dispatch's unique id. - key: The name of the asset - representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` - range: (optional) range request header - - If `representation` is specified, it will override the range request. - """ - start_byte = 0 - end_byte = -1 - - try: - if Range: - start_byte, end_byte = _extract_byte_range(Range) - - if end_byte >= 0 and end_byte < start_byte: - raise HTTPException( - status_code=400, - detail="Invalid byte range", - ) - app_log.debug( - f"Requested lattice asset {key.value} ([{start_byte}:{end_byte}])for dispatch {dispatch_id}" - ) - - result_object = get_cached_result_object(dispatch_id) - app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") - - with workflow_db.session() as session: - asset = result_object.lattice.get_asset(key=key.value, session=session) - - # Explicit representation overrides the byte range - if representation is None or LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) - - app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") - generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) - return StreamingResponse(generator) - - except Exception as e: - app_log.debug(e) - raise e - - -@router.put("/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") -async def upload_node_asset( - req: Request, +@router.post("/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") +def upload_node_asset( dispatch_id: str, node_id: int, key: ElectronAssetKey, content_length: int = Header(default=0), digest_alg: Union[str, None] = Header(default=None), digest: Union[str, None] = Header(default=None), -): +) -> AssetSchema: """Upload an electron asset. Args: @@ -256,179 +98,23 @@ async def upload_node_asset( content_length: (header) digest: (header) """ - app_log.debug(f"Uploading node asset {dispatch_id}:{node_id}:{key} ({content_length} bytes) ") - + app_log.debug( + f"Initiating upload for {dispatch_id}:{node_id}:{key.value} ({content_length} bytes) " + ) try: metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} - internal_uri = await _run_in_executor( - _update_node_asset_metadata, + remote_uri = _update_node_asset_metadata( dispatch_id, node_id, key, metadata, ) - # Stream the request body to object store - await _transfer_data(req, internal_uri) - - return f"Uploaded file to {internal_uri}" - except Exception as e: - app_log.debug(e) - raise - - -@router.put("/dispatches/{dispatch_id}/assets/{key}") -async def upload_dispatch_asset( - req: Request, - dispatch_id: str, - key: DispatchAssetKey, - content_length: int = Header(default=0), - digest_alg: Union[str, None] = Header(default=None), - digest: Union[str, None] = Header(default=None), -): - """Upload a dispatch asset. - - Args: - dispatch_id: The dispatch's unique id. - key: The name of the asset - asset_file: (body) The file to be uploaded - content_length: (header) - digest: (header) - """ - app_log.debug(f"Uploading dispatch asset {dispatch_id}:{key} ({content_length} bytes) ") - try: - metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} - internal_uri = await _run_in_executor( - _update_dispatch_asset_metadata, - dispatch_id, - key, - metadata, - ) - # Stream the request body to object store - await _transfer_data(req, internal_uri) - return f"Uploaded file to {internal_uri}" + return AssetSchema(size=content_length, remote_uri=remote_uri) except Exception as e: app_log.debug(e) raise -@router.put("/dispatches/{dispatch_id}/lattice/assets/{key}") -async def upload_lattice_asset( - req: Request, - dispatch_id: str, - key: LatticeAssetKey, - content_length: int = Header(default=0), - digest_alg: Union[str, None] = Header(default=None), - digest: Union[str, None] = Header(default=None), -): - """Upload a lattice asset. - - Args: - dispatch_id: The dispatch's unique id. - key: The name of the asset - asset_file: (body) The file to be uploaded - content_length: (header) - digest: (header) - """ - try: - app_log.debug(f"Uploading lattice asset {dispatch_id}:{key} ({content_length} bytes) ") - metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} - internal_uri = await _run_in_executor( - _update_lattice_asset_metadata, - dispatch_id, - key, - metadata, - ) - # Stream the request body to object store - await _transfer_data(req, internal_uri) - return f"Uploaded file to {internal_uri}" - except Exception as e: - app_log.debug(e) - raise - - -def _generate_file_slice(file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - """Generator of a byte slice from a file. - - Args: - file_url: A file:/// type URL pointing to the file - start_byte: The beginning of the byte range - end_byte: The end of the byte range, or -1 to select [start_byte:] - chunk_size: The size of each chunk - - Returns: - Yields chunks of size <= chunk_size - """ - byte_pos = start_byte - file_path = str(furl(file_url).path) - with open(file_path, "rb") as f: - f.seek(start_byte) - if end_byte < 0: - for chunk in f: - yield chunk - else: - while byte_pos + chunk_size < end_byte: - byte_pos += chunk_size - yield f.read(chunk_size) - yield f.read(end_byte - byte_pos) - - -def _extract_byte_range(byte_range_header: str) -> Tuple[int, int]: - """Extract the byte range from a range request header.""" - start_byte = 0 - end_byte = -1 - match = range_pattern.match(byte_range_header) - start = match.group(1) - end = match.group(2) - start_byte = int(start) - if end: - end_byte = int(end) - - return start_byte, end_byte - - -# Helpers for TransportableObject - - -def _get_tobj_string_offsets(file_url: str) -> Tuple[int, int]: - """Get the byte range for the str rep of a stored TObj. - - For a first implementation we just query the filesystem directly. - - Args: - file_url: A file:/// URL pointing to the TransportableObject - - Returns: - (start_byte, end_byte) - """ - - file_path = str(furl(file_url).path) - filelen = os.path.getsize(file_path) - with open(file_path, "rb+") as f: - with mmap.mmap(f.fileno(), filelen) as mm: - # TOArchiveUtils operates on byte arrays - return TOArchiveUtils.string_byte_range(mm) - - -def _get_tobj_pickle_offsets(file_url: str) -> Tuple[int, int]: - """Get the byte range for the picklebytes of a stored TObj. - - For a first implementation we just query the filesystem directly. - - Args: - file_url: A file:/// URL pointing to the TransportableObject - - Returns: - (start_byte, -1) - """ - - file_path = str(furl(file_url).path) - filelen = os.path.getsize(file_path) - with open(file_path, "rb+") as f: - with mmap.mmap(f.fileno(), filelen) as mm: - # TOArchiveUtils operates on byte arrays - return TOArchiveUtils.data_byte_range(mm) - - # This must only be used for static data as we don't have yet any # intelligent invalidation logic. @lru_cache(maxsize=LRU_CACHE_SIZE) @@ -474,56 +160,10 @@ def _update_node_asset_metadata(dispatch_id, node_id, key, metadata) -> str: # Update asset metadata update = _filter_null_metadata(metadata) node.update_assets(updates={key: update}, session=session) - app_log.debug(f"Updated node asset {dispatch_id}:{node_id}:{key}") - - return asset.internal_uri - - -def _update_lattice_asset_metadata(dispatch_id, key, metadata) -> str: - result_object = get_cached_result_object(dispatch_id) - - app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") - with workflow_db.session() as session: - asset = result_object.lattice.get_asset(key=key.value, session=session) - - # Update asset metadata - update = _filter_null_metadata(metadata) - result_object.lattice.update_assets(updates={key: update}, session=session) - app_log.debug(f"Updated size for lattice asset {dispatch_id}:{key}") - - return asset.internal_uri - - -def _update_dispatch_asset_metadata(dispatch_id, key, metadata) -> str: - result_object = get_cached_result_object(dispatch_id) + app_log.debug(f"Updated node asset {dispatch_id}:{node_id}:{key.value}") - app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") - with workflow_db.session() as session: - asset = result_object.get_asset(key=key.value, session=session) - - # Update asset metadata - update = _filter_null_metadata(metadata) - result_object.update_assets(updates={key: update}, session=session) - app_log.debug(f"Updated size for dispatch asset {dispatch_id}:{key}") - return asset.internal_uri - - -async def _transfer_data(req: Request, destination_url: str): - dest_url = furl(destination_url) - dest_path = str(dest_url.path) - - # Stream data to a temporary file, then replace the destination - # file atomically - tmp_path = f"{dest_path}.tmp" - app_log.debug(f"Streaming file upload to {tmp_path}") - - async with aiofiles.open(tmp_path, "wb") as f: - async for chunk in req.stream(): - await f.write(chunk) - - await aiofiles.os.replace(tmp_path, dest_path) - - -def _run_in_executor(function, *args) -> asyncio.Future: - loop = asyncio.get_running_loop() - return loop.run_in_executor(None, function, *args) + object_store = asset.object_store + remote_uri = object_store.get_public_uri( + asset.storage_path, asset.object_key, direction=TransferDirection.upload + ) + return remote_uri diff --git a/covalent_dispatcher/_service/files.py b/covalent_dispatcher/_service/files.py new file mode 100644 index 000000000..32a37c20e --- /dev/null +++ b/covalent_dispatcher/_service/files.py @@ -0,0 +1,69 @@ +# Copyright 2024 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedded file server API""" + +import os + +import aiofiles +import aiofiles.os +from fastapi import APIRouter, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import FileResponse + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config + +router = APIRouter() + +app_log = logger.app_log +BASE_PATH = get_config("dispatcher.results_dir") + + +async def _transfer_data(req: Request, dest_path: str): + + # Stream data to a temporary file, then replace the destination + # file atomically + tmp_path = f"{dest_path}.tmp" + app_log.debug(f"Streaming file upload to {tmp_path}") + + async with aiofiles.open(tmp_path, "wb") as f: + async for chunk in req.stream(): + await f.write(chunk) + + await aiofiles.os.replace(tmp_path, dest_path) + + +# Resolve path to an absolute path and check that it +# doesn't escape the data directory root +def _sanitize_path(path: str) -> str: + abs_path = os.path.realpath(path) + if not abs_path.startswith(BASE_PATH) or len(abs_path) <= len(BASE_PATH): + raise RequestValidationError(f"Invalid object key {path}") + return abs_path + + +@router.get("/files/{object_key:path}") +async def download_file(object_key: str): + path = _sanitize_path(os.path.join(BASE_PATH, object_key)) + return FileResponse(path) + + +@router.put("/files/{object_key:path}") +async def upload_file(req: Request, object_key: str): + path = _sanitize_path(os.path.join(BASE_PATH, object_key)) + os.makedirs(os.path.dirname(path), exist_ok=True) + await _transfer_data(req, path) diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py index af81a30bc..1b47df765 100644 --- a/covalent_dispatcher/_service/models.py +++ b/covalent_dispatcher/_service/models.py @@ -95,3 +95,7 @@ class DispatchSummary(BaseModel): class BulkDispatchGetSchema(BaseModel): dispatches: List[DispatchSummary] metadata: BulkGetMetadata + + +class ElectronUpdateSchema(BaseModel): + sub_dispatch_id: str diff --git a/covalent_dispatcher/entry_point.py b/covalent_dispatcher/entry_point.py index a78242d06..abdffabd3 100644 --- a/covalent_dispatcher/entry_point.py +++ b/covalent_dispatcher/entry_point.py @@ -30,28 +30,6 @@ log_stack_info = logger.log_stack_info -async def make_dispatch(json_lattice: str): - """ - Run the dispatcher from the lattice asynchronously using Dask. - Assign a new dispatch id to the result object and return it. - Also save the result in this initial stage to the file mentioned in the result object. - - Args: - json_lattice: A JSON-serialized lattice - - Returns: - dispatch_id: A string containing the dispatch id of current dispatch. - """ - - from ._core import make_dispatch - - dispatch_id = await make_dispatch(json_lattice) - - app_log.debug(f"Created new dispatch {dispatch_id}") - - return dispatch_id - - async def start_dispatch(dispatch_id: str): """ Run the dispatcher from the lattice asynchronously using Dask. @@ -80,27 +58,6 @@ async def start_dispatch(dispatch_id: str): app_log.debug(f"Running dispatch {dispatch_id}") -async def run_dispatcher(json_lattice: str): - """ - Run the dispatcher from the lattice asynchronously using Dask. - Assign a new dispatch id to the result object and return it. - Also save the result in this initial stage to the file mentioned in the result object. - - Args: - json_lattice: A JSON-serialized lattice - - Returns: - dispatch_id: A string containing the dispatch id of current dispatch. - """ - - dispatch_id = await make_dispatch(json_lattice) - await start_dispatch(dispatch_id) - - app_log.debug("Submitted result object to run_workflow.") - - return dispatch_id - - async def cancel_running_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: """ Cancels a running dispatch job. diff --git a/covalent_ui/api/v1/routes/routes.py b/covalent_ui/api/v1/routes/routes.py index 9b6c50b45..58c5d501a 100644 --- a/covalent_ui/api/v1/routes/routes.py +++ b/covalent_ui/api/v1/routes/routes.py @@ -18,7 +18,7 @@ from fastapi import APIRouter -from covalent_dispatcher._service import app, assets, runnersvc +from covalent_dispatcher._service import app, assets, files, runnersvc from covalent_dispatcher._triggers_app.app import router as tr_router from covalent_ui.api.v1.routes.end_points import ( electron_routes, @@ -43,3 +43,4 @@ routes.include_router(app.router, prefix="/api/v2", tags=["Dispatcher"]) routes.include_router(assets.router, prefix="/api/v2", tags=["Assets"]) routes.include_router(runnersvc.router, prefix="/api/v2", tags=["Runner"]) +routes.include_router(files.router, prefix="/api/v0", tags=["Files"]) diff --git a/tests/covalent_dispatcher_tests/_core/data_manager_test.py b/tests/covalent_dispatcher_tests/_core/data_manager_test.py index 847c0b152..66b947234 100644 --- a/tests/covalent_dispatcher_tests/_core/data_manager_test.py +++ b/tests/covalent_dispatcher_tests/_core/data_manager_test.py @@ -19,24 +19,23 @@ """ +import base64 +import tempfile from unittest.mock import MagicMock import pytest import covalent as ct +from covalent._dispatcher_plugins.local import LocalDispatcher, pack_staging_dir from covalent._results_manager import Result from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.lattice import Lattice from covalent_dispatcher._core.data_manager import ( - ResultSchema, - _legacy_sublattice_dispatch_helper, _make_sublattice_dispatch, - _redirect_lattice, _update_parent_electron, ensure_dispatch, finalize_dispatch, get_result_object, - make_dispatch, persist_result, update_node_result, ) @@ -86,6 +85,7 @@ def pipeline(x): (Result.FAILED, "function", Result.FAILED, ""), (Result.CANCELLED, "function", Result.CANCELLED, ""), (Result.COMPLETED, "sublattice", RESULT_STATUS.DISPATCHING, ""), + (Result.COMPLETED, "sublattice", RESULT_STATUS.DISPATCHING, "asdf"), (Result.COMPLETED, "sublattice", RESULT_STATUS.COMPLETED, "asdf"), (Result.FAILED, "sublattice", Result.FAILED, ""), (Result.CANCELLED, "sublattice", Result.CANCELLED, ""), @@ -99,11 +99,16 @@ async def test_update_node_result(mocker, node_status, node_type, output_status, result_object.dispatch_id = "test_update_node_result" node_result = {"node_id": 0, "status": node_status} - mock_update_node = mocker.patch( - "covalent_dispatcher._dal.result.Result._update_node", return_value=True - ) node_info = {"type": node_type, "sub_dispatch_id": sub_id, "status": Result.NEW_OBJ} + subdispatch_info = ( + {"status": Result.NEW_OBJ} + if output_status == RESULT_STATUS.DISPATCHING + else {"status": Result.RUNNING} + ) mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", return_value=subdispatch_info + ) mock_notify = mocker.patch( "covalent_dispatcher._core.dispatcher.notify_node_status", @@ -247,17 +252,6 @@ async def test_update_node_result_handles_db_exceptions(mocker): mock_notify.assert_awaited_with(result_object.dispatch_id, 0, Result.FAILED, {}) -@pytest.mark.asyncio -async def test_make_dispatch(mocker): - res = MagicMock() - dispatch_id = "test_make_dispatch" - mock_resubmit_lattice = mocker.patch( - "covalent_dispatcher._core.data_manager._redirect_lattice", return_value=dispatch_id - ) - json_lattice = '{"workflow_function": "asdf"}' - assert dispatch_id == await make_dispatch(json_lattice) - - def test_get_result_object(mocker): result_object = MagicMock() result_object.dispatch_id = "dispatch_1" @@ -347,14 +341,32 @@ async def test_update_parent_electron(mocker, sub_status, mapped_status): @pytest.mark.asyncio async def test_make_sublattice_dispatch(mocker): + + @ct.lattice + @ct.electron + def sublattice(x): + return x**2 + + sublattice.build_graph(3) + + with tempfile.TemporaryDirectory(prefix="covalent-") as staging_path: + manifest = LocalDispatcher.prepare_manifest(sublattice, staging_path) + + # This tarball will be unpacked and resubmitted server-side + tar_file = pack_staging_dir(staging_path, manifest) + with open(tar_file, "rb") as tar: + tar_b64 = base64.b64encode(tar.read()).decode("utf-8") + + # This base64-encoded tarball will be read server-side + # as `TransportableObject.object_string` + node_result = {"node_id": 0, "status": Result.COMPLETED} - output_json = "lattice_json" mock_node = MagicMock() mock_node._electron_id = 5 mock_bg_output = MagicMock() - mock_bg_output.object_string = output_json + mock_bg_output.object_string = tar_b64 mock_node.get_value = MagicMock(return_value=mock_bg_output) @@ -368,100 +380,18 @@ async def test_make_sublattice_dispatch(mocker): "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object, ) - mocker.patch("covalent._shared_files.schemas.result.ResultSchema.parse_raw") mocker.patch( - "covalent_dispatcher._core.data_manager.manifest_importer.import_manifest", + "covalent_dispatcher._core.data_modules.importer._import_manifest", return_value=mock_manifest, ) - mock_make_dispatch = mocker.patch("covalent_dispatcher._core.data_manager.make_dispatch") + mock_pull = mocker.patch("covalent_dispatcher._core.data_modules.importer._pull_assets") sub_dispatch_id = await _make_sublattice_dispatch(result_object.dispatch_id, node_result) + mock_pull.assert_called() assert sub_dispatch_id == mock_manifest.metadata.dispatch_id -@pytest.mark.asyncio -async def test_make_monolithic_sublattice_dispatch(mocker): - """Check that JSON sublattices are handled correctly""" - - dispatch_id = "test_make_monolithic_sublattice_dispatch" - - def _mock_helper(dispatch_id, node_result): - return ResultSchema.parse_raw("invalid_input") - - mocker.patch( - "covalent_dispatcher._core.data_manager._make_sublattice_dispatch_helper", _mock_helper - ) - - json_lattice = "json_lattice" - parent_electron_id = 5 - mock_legacy_subl_helper = mocker.patch( - "covalent_dispatcher._core.data_manager._legacy_sublattice_dispatch_helper", - return_value=(json_lattice, parent_electron_id), - ) - sub_dispatch_id = "sub_dispatch" - mock_make_dispatch = mocker.patch( - "covalent_dispatcher._core.data_manager.make_dispatch", return_value=sub_dispatch_id - ) - - assert sub_dispatch_id == await _make_sublattice_dispatch(dispatch_id, {}) - - mock_make_dispatch.assert_awaited_with(json_lattice, dispatch_id, parent_electron_id) - - -def test_legacy_sublattice_dispatch_helper(mocker): - dispatch_id = "test_legacy_sublattice_dispatch_helper" - res_obj = MagicMock() - bg_output = MagicMock() - bg_output.object_string = "json_sublattice" - parent_node = MagicMock() - parent_node._electron_id = 2 - parent_node.get_value = MagicMock(return_value=bg_output) - res_obj.lattice.transport_graph.get_node = MagicMock(return_value=parent_node) - node_result = {"node_id": 0} - - mocker.patch("covalent_dispatcher._core.data_manager.get_result_object", return_value=res_obj) - - assert _legacy_sublattice_dispatch_helper(dispatch_id, node_result) == ("json_sublattice", 2) - - -def test_redirect_lattice(mocker): - """Test redirecting JSON lattices to new DAL.""" - - dispatch_id = "test_redirect_lattice" - mock_manifest = MagicMock() - mock_manifest.metadata.dispatch_id = dispatch_id - mock_prepare_manifest = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.prepare_manifest", - return_value=mock_manifest, - ) - mock_import_manifest = mocker.patch( - "covalent_dispatcher._core.data_manager.manifest_importer._import_manifest", - return_value=mock_manifest, - ) - - mock_pull = mocker.patch( - "covalent_dispatcher._core.data_manager.manifest_importer._pull_assets", - ) - - mock_lat_deserialize = mocker.patch( - "covalent_dispatcher._core.data_manager.Lattice.deserialize_from_json" - ) - - json_lattice = "json_lattice" - - parent_dispatch_id = "parent_dispatch" - parent_electron_id = 3 - - assert ( - _redirect_lattice(json_lattice, parent_dispatch_id, parent_electron_id, None) - == dispatch_id - ) - - mock_import_manifest.assert_called_with(mock_manifest, parent_dispatch_id, parent_electron_id) - mock_pull.assert_called_with(mock_manifest) - - @pytest.mark.asyncio async def test_ensure_dispatch(mocker): mock_ensure_run_once = mocker.patch( diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py index 21b410a53..4ac55892a 100644 --- a/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py +++ b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py @@ -21,7 +21,6 @@ import pytest from covalent_dispatcher._core.data_modules.importer import ( - _copy_assets, import_derived_manifest, import_manifest, ) @@ -91,10 +90,6 @@ async def test_import_derived_manifest(mocker): "covalent_dispatcher._core.data_modules.importer._import_manifest", ) - mock_copy = mocker.patch( - "covalent_dispatcher._core.data_modules.importer._copy_assets", - ) - mock_handle_redispatch = mocker.patch( "covalent_dispatcher._core.data_modules.importer.handle_redispatch", return_value=(mock_manifest, []), @@ -110,11 +105,3 @@ async def test_import_derived_manifest(mocker): mock_import_manifest.assert_called() mock_pull.assert_called() mock_handle_redispatch.assert_called() - mock_copy.assert_called_with([]) - - -def test_copy_assets(mocker): - mock_copy = mocker.patch("covalent_dispatcher._core.data_modules.importer.copy_asset") - - _copy_assets([("src", "dest")]) - mock_copy.assert_called_with("src", "dest") diff --git a/tests/covalent_dispatcher_tests/_core/runner_ng_test.py b/tests/covalent_dispatcher_tests/_core/runner_ng_test.py index 4e6f8a88d..25c7de507 100644 --- a/tests/covalent_dispatcher_tests/_core/runner_ng_test.py +++ b/tests/covalent_dispatcher_tests/_core/runner_ng_test.py @@ -166,11 +166,11 @@ async def test_submit_abstract_task_group(mocker, task_cancelled): mock_node_upload_uri_1 = me.get_upload_uri(task_group_metadata, "node_1") mock_node_upload_uri_2 = me.get_upload_uri(task_group_metadata, "node_2") - mock_function_id_0 = 0 - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id_0 = 0 + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] - mock_function_id_3 = 3 + mock_electron_id_3 = 3 resources = { "functions": { @@ -185,27 +185,27 @@ async def test_submit_abstract_task_group(mocker, task_cancelled): } mock_task_spec_0 = { - "function_id": mock_function_id_0, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id_0, + "args": mock_args, + "kwargs": mock_kwargs, } mock_task_spec_3 = { - "function_id": mock_function_id_3, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id_3, + "args": mock_args, + "kwargs": mock_kwargs, } mock_task_0 = { - "function_id": mock_function_id_0, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id_0, + "args": mock_args, + "kwargs": mock_kwargs, } mock_task_3 = { - "function_id": mock_function_id_3, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id_3, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] @@ -269,14 +269,14 @@ async def test_submit_requires_opt_in(mocker): "covalent_dispatcher._core.runner_ng.datamgr.generate_node_result", return_value=node_result, ) - mock_function_id = task_id - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id = task_id + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] mock_task = { - "function_id": mock_function_id, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] @@ -522,14 +522,14 @@ async def test_run_abstract_task_group(mocker): node_name = "task" abstract_inputs = {"args": [], "kwargs": {}} selected_executor = ["local", {}] - mock_function_id = node_id - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id = node_id + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] mock_task = { - "function_id": mock_function_id, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] task_group_metadata = { @@ -574,15 +574,15 @@ async def test_run_abstract_task_group_handles_old_execs(mocker): node_name = "task" abstract_inputs = {"args": [], "kwargs": {}} selected_executor = ["local", {}] - mock_function_id = node_id - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id = node_id + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] mock_task = { - "function_id": mock_function_id, + "electron_id": mock_electron_id, "name": node_name, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] @@ -614,14 +614,14 @@ async def test_run_abstract_task_group_handles_bad_executors(mocker): node_name = sublattice_prefix abstract_inputs = {"args": [], "kwargs": {}} selected_executor = ["local", {}] - mock_function_id = node_id - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id = node_id + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] mock_task = { - "function_id": mock_function_id, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] @@ -670,14 +670,14 @@ async def test_run_abstract_task_group_handles_cancelled_tasks(mocker): node_name = "task" abstract_inputs = {"args": [], "kwargs": {}} selected_executor = ["local", {}] - mock_function_id = node_id - mock_args_ids = abstract_inputs["args"] - mock_kwargs_ids = abstract_inputs["kwargs"] + mock_electron_id = node_id + mock_args = abstract_inputs["args"] + mock_kwargs = abstract_inputs["kwargs"] mock_task = { - "function_id": mock_function_id, - "args_ids": mock_args_ids, - "kwargs_ids": mock_kwargs_ids, + "electron_id": mock_electron_id, + "args": mock_args, + "kwargs": mock_kwargs, } known_nodes = [1, 2] diff --git a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py index 39c19bb3d..db3cbb99e 100644 --- a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py +++ b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py @@ -24,12 +24,13 @@ import covalent as ct from covalent._results_manager.result import Result as SDKResult from covalent._serialize.result import serialize_result -from covalent._shared_files.schemas.result import AssetSchema, ResultSchema +from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS -from covalent_dispatcher._dal.importers.result import SERVER_URL, handle_redispatch, import_result +from covalent_dispatcher._dal.importers.result import handle_redispatch, import_result from covalent_dispatcher._dal.job import Job from covalent_dispatcher._dal.result import get_result_object from covalent_dispatcher._db.datastore import DataStore +from covalent_dispatcher._object_store.local import SERVER_URL TEMP_RESULTS_DIR = "/tmp/covalent_result_import_test" @@ -128,10 +129,6 @@ def test_import_previously_imported_result(mocker, test_db): mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - mock_filter_uris = mocker.patch( - "covalent_dispatcher._dal.importers.result._filter_remote_uris" - ) - with ( tempfile.TemporaryDirectory(prefix="covalent-") as sdk_dir, tempfile.TemporaryDirectory(prefix="covalent-") as srv_dir, @@ -153,7 +150,6 @@ def test_import_previously_imported_result(mocker, test_db): import_result(sub_res, srv_dir, parent_node._electron_id) sub_srv_res = get_result_object(sub_dispatch_id, bare=True) - assert mock_filter_uris.call_count == 2 assert sub_srv_res._electron_id == parent_node._electron_id @@ -165,10 +161,6 @@ def test_import_subdispatch_cancel_req(mocker, test_db): mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - mock_filter_uris = mocker.patch( - "covalent_dispatcher._dal.importers.result._filter_remote_uris" - ) - with ( tempfile.TemporaryDirectory(prefix="covalent-") as sdk_dir, tempfile.TemporaryDirectory(prefix="covalent-") as srv_dir, @@ -217,7 +209,6 @@ def test_handle_redispatch_identical(mocker, test_db, parent_status, new_status) redispatch_id = "test_handle_redispatch_2" mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - mock_copy_node_asset = mocker.patch("covalent_dispatcher._dal.tg_ops.copy_asset") mock_copy_asset_meta = mocker.patch("covalent_dispatcher._dal.asset.copy_asset_meta") mock_copy_workflow_asset_meta = mocker.patch( "covalent_dispatcher._dal.importers.result.copy_asset_meta" @@ -270,31 +261,3 @@ def test_handle_redispatch_identical(mocker, test_db, parent_status, new_status) assert tg.get_node_value(n, "status") == new_status assert len(assets_to_copy) == n_workflow_assets + n_electron_assets - - -def test_import_result_with_custom_assets(mocker, test_db): - dispatch_id = "test_import_result" - - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - with ( - tempfile.TemporaryDirectory(prefix="covalent-") as sdk_dir, - tempfile.TemporaryDirectory(prefix="covalent-") as srv_dir, - ): - manifest = get_mock_result(dispatch_id, sdk_dir) - manifest.lattice.assets._custom = {"custom_lattice_asset": AssetSchema(size=0)} - manifest.lattice.transport_graph.nodes[0].assets._custom = { - "custom_electron_asset": AssetSchema(size=0) - } - filtered_res = import_result(manifest, srv_dir, None) - - with test_db.session() as session: - result_object = get_result_object(dispatch_id, bare=True, session=session) - node_0 = result_object.lattice.transport_graph.get_node(0, session) - node_1 = result_object.lattice.transport_graph.get_node(1, session) - lat_asset_ids = result_object.lattice.get_asset_ids(session, []) - node_0_asset_ids = node_0.get_asset_ids(session, []) - node_1_asset_ids = node_1.get_asset_ids(session, []) - assert "custom_lattice_asset" in lat_asset_ids - assert "custom_electron_asset" in node_0_asset_ids - assert "custom_electron_asset" not in node_1_asset_ids diff --git a/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py b/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py index 3740b3c78..85110b069 100644 --- a/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py +++ b/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py @@ -214,10 +214,10 @@ def replacement(x): mock_old_asset = MagicMock() mock_new_asset = MagicMock() - mock_old_asset.storage_type = StorageType.LOCAL + mock_old_asset.storage_type = StorageType.LOCAL.value mock_old_asset.storage_path = "/tmp" mock_old_asset.object_key = "result.pkl" - mock_new_asset.storage_type = StorageType.LOCAL + mock_new_asset.storage_type = StorageType.LOCAL.value mock_new_asset.storage_path = "/tmp" mock_new_asset.object_key = "result_new.pkl" @@ -231,7 +231,6 @@ def replacement(x): mocker.patch("covalent_dispatcher._dal.tg_ops.METADATA_KEYS", MOCK_META_KEYS) mocker.patch("covalent_dispatcher._dal.tg_ops.ASSET_KEYS", MOCK_ASSET_KEYS) - mock_copy_asset = mocker.patch("covalent_dispatcher._dal.tg_ops.copy_asset") mock_copy_asset_meta = mocker.patch("covalent_dispatcher._dal.tg_ops.copy_asset_meta") tg_new = _TransportGraph(lattice_id=2) @@ -276,7 +275,6 @@ def replacement(x): assert tg_ops.tg._graph.nodes[1]["name"] == "multiply" assert tg_ops.tg._graph.nodes(data=True)[2]["name"] == "replacement" - assert mock_copy_asset.call_count == 2 assert mock_copy_asset_meta.call_count == 2 @@ -364,11 +362,11 @@ def test_get_reusable_nodes(mocker, tg, tg_2): ) mock_old_asset = MagicMock() mock_new_asset = MagicMock() - mock_old_asset.storage_type = StorageType.LOCAL + mock_old_asset.storage_type = StorageType.LOCAL.value mock_old_asset.storage_path = "/tmp" mock_old_asset.object_key = "value.pkl" mock_old_asset.meta = {"digest": "24af"} - mock_new_asset.storage_type = StorageType.LOCAL + mock_new_asset.storage_type = StorageType.LOCAL.value mock_new_asset.storage_path = "/tmp" mock_new_asset.object_key = "value.pkl" mock_new_asset.meta = {"digest": "24af"} @@ -392,11 +390,11 @@ def test_get_diff_nodes_integration_test(tg, tg_2): mock_old_asset = MagicMock() mock_new_asset = MagicMock() - mock_old_asset.storage_type = StorageType.LOCAL + mock_old_asset.storage_type = StorageType.LOCAL.value mock_old_asset.storage_path = "/tmp" mock_old_asset.object_key = "value.pkl" mock_old_asset.__dict__.update({"digest": "24af"}) - mock_new_asset.storage_type = StorageType.LOCAL + mock_new_asset.storage_type = StorageType.LOCAL.value mock_new_asset.storage_path = "/tmp" mock_new_asset.object_key = "value.pkl" mock_new_asset.__dict__.update({"digest": "24af"}) diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index d02382b98..c7fcb29f2 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -176,20 +176,6 @@ def test_register_exception(mocker, app, client, mock_manifest): assert resp.status_code == 400 -def test_register_sublattice(mocker, app, client, mock_manifest): - mock_register_dispatch = mocker.patch( - "covalent_dispatcher._service.app.dispatcher.register_dispatch", return_value=mock_manifest - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - resp = client.post( - "/api/v2/dispatches/parent_dispatch/sublattices", - data=mock_manifest.json(), - ) - - assert resp.json() == json.loads(mock_manifest.json()) - mock_register_dispatch.assert_awaited_with(mock_manifest, "parent_dispatch") - - def test_register_redispatch(mocker, app, client, mock_manifest): dispatch_id = "test_register_redispatch" mock_register_redispatch = mocker.patch( diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py index 5f704ca43..8b2c1a5d7 100644 --- a/tests/covalent_dispatcher_tests/_service/assets_test.py +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -16,7 +16,6 @@ """Unit tests for the FastAPI asset endpoints""" -import tempfile from contextlib import contextmanager from typing import Generator from unittest.mock import MagicMock @@ -27,13 +26,7 @@ from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.orm import Session, declarative_base, sessionmaker -from covalent._workflow.transportable_object import TransportableObject -from covalent_dispatcher._service.assets import ( - _generate_file_slice, - _get_tobj_pickle_offsets, - _get_tobj_string_offsets, - get_cached_result_object, -) +from covalent_dispatcher._service.assets import get_cached_result_object from covalent_ui.app import fastapi_app as fast_app DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" @@ -87,7 +80,8 @@ def mock_result_object(): res_obj = MagicMock() mock_node = MagicMock() mock_asset = MagicMock() - mock_asset.internal_uri = INTERNAL_URI + mock_asset.object_store = MagicMock() + mock_asset.object_store.get_public_uri.return_value = "http://localhost:48008/files/output" res_obj.get_asset = MagicMock(return_value=mock_asset) res_obj.update_assets = MagicMock() @@ -107,128 +101,18 @@ def test_get_node_asset(mocker, client, test_db, mock_result_object): Test get node asset """ - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - yield "Hi" - key = "output" node_id = 0 - dispatch_id = "test_get_node_asset_no_dispatch_id" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") + dispatch_id = "test_get_node_asset_id" mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) mocker.patch( "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get(f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") - assert resp.text == "Hi" - assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] - - -def test_get_node_asset_byte_range(mocker, client, test_db, mock_result_object): - """ - Test get node asset - """ - - test_str = "test_get_node_asset_string_rep" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "output" - node_id = 0 - dispatch_id = "test_get_node_asset_no_dispatch_id" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - - headers = {"Range": "bytes=0-6"} - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get( - f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", headers=headers - ) - - assert resp.text == test_str[0:6] - assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] - - -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) -def test_get_node_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte -): - """ - Test get node asset - """ - - test_str = "test_get_node_asset_rep" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "output" - node_id = 0 - dispatch_id = "test_get_node_asset_no_dispatch_id" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) - ) - - params = {"representation": rep} - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get( - f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params - ) - - assert resp.text == test_str[start_byte:end_byte] - assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + assert resp.json()["remote_uri"] == "http://localhost:48008/files/output" def test_get_node_asset_bad_dispatch_id(mocker, client): @@ -247,292 +131,14 @@ def test_get_node_asset_bad_dispatch_id(mocker, client): assert resp.status_code == 400 -def test_get_lattice_asset(mocker, client, test_db, mock_result_object): +def test_post_node_asset(test_db, mocker, client, mock_result_object): """ - Test get lattice asset - """ - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - yield "Hi" - - key = "workflow_function" - dispatch_id = "test_get_lattice_asset_no_dispatch_id" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}") - - assert resp.text == "Hi" - assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] - - -def test_get_lattice_asset_byte_range(mocker, client, test_db, mock_result_object): - """ - Test get lattice asset - """ - - test_str = "test_lattice_asset_byte_range" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "workflow_function" - dispatch_id = "test_get_lattice_asset_no_dispatch_id" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - headers = {"Range": "bytes=0-6"} - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", headers=headers) - - assert resp.text == test_str[0:6] - assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] - - -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) -def test_get_lattice_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte -): - """ - Test get lattice asset - """ - - test_str = "test_get_lattice_asset_rep" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "workflow_function" - dispatch_id = "test_get_lattice_asset_rep" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - params = {"representation": rep} - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params) - - assert resp.text == test_str[start_byte:end_byte] - assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] - - -def test_get_lattice_asset_bad_dispatch_id(mocker, client): - """ - Test get lattice asset - """ - - key = "workflow_function" - dispatch_id = "test_get_lattice_asset_no_dispatch_id" - - mocker.patch( - "covalent_dispatcher._service.assets.get_cached_result_object", - side_effect=HTTPException(status_code=400), - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}") - assert resp.status_code == 400 - - -def test_get_dispatch_asset(mocker, client, test_db, mock_result_object): - """ - Test get dispatch asset - """ - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - yield "Hi" - - key = "result" - dispatch_id = "test_get_dispatch_asset" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}") - - assert resp.text == "Hi" - assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] - - -def test_get_dispatch_asset_byte_range(mocker, client, test_db, mock_result_object): - """ - Test get dispatch asset - """ - - test_str = "test_dispatch_asset_byte_range" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "result" - dispatch_id = "test_get_dispatch_asset_byte_range" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - - headers = {"Range": "bytes=0-6"} - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", headers=headers) - - assert resp.text == test_str[0:6] - assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] - - -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) -def test_get_dispatch_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte -): - """ - Test get dispatch asset - """ - - test_str = "test_get_dispatch_asset_rep" - - class MockGenerateFileSlice: - def __init__(self): - self.calls = [] - - def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): - self.calls.append((file_url, start_byte, end_byte, chunk_size)) - if end_byte >= 0: - yield test_str[start_byte:end_byte] - else: - yield test_str[start_byte:] - - key = "result" - dispatch_id = "test_get_dispatch_asset_rep" - mock_generator = MockGenerateFileSlice() - - mocker.patch("fastapi.responses.StreamingResponse") - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mock_generate_file_slice = mocker.patch( - "covalent_dispatcher._service.assets._generate_file_slice", mock_generator - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) - ) - mocker.patch( - "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - params = {"representation": rep} - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params) - - assert resp.text == test_str[start_byte:end_byte] - assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] - - -def test_get_dispatch_asset_bad_dispatch_id(mocker, client): - """ - Test get dispatch asset - """ - - key = "result" - dispatch_id = "test_get_dispatch_asset_no_dispatch_id" - - mocker.patch( - "covalent_dispatcher._service.assets.get_cached_result_object", - side_effect=HTTPException(status_code=400), - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}") - assert resp.status_code == 400 - - -def test_put_node_asset(test_db, mocker, client, mock_result_object): - """ - Test put node asset + Test post node asset """ key = "function" node_id = 0 - dispatch_id = "test_put_node_asset" + dispatch_id = "test_post_node_asset" mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) mocker.patch( @@ -540,26 +146,13 @@ def test_put_node_asset(test_db, mocker, client, mock_result_object): ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post(f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") + assert resp.json()["remote_uri"] == "http://localhost:48008/files/output" - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - headers = {"Digest-alg": "sha", "Digest": "0bf"} - with open(writer.name, "rb") as reader: - resp = client.put( - f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", - data=reader, - headers=headers, - ) - mock_node = mock_result_object.lattice.transport_graph.get_node(node_id) - mock_node.update_assets.assert_called() - assert resp.status_code == 200 - - -def test_put_node_asset_bad_dispatch_id(mocker, client): +def test_post_node_asset_bad_dispatch_id(mocker, client): """ - Test put node asset + Test post node asset """ key = "function" node_id = 0 @@ -570,168 +163,10 @@ def test_put_node_asset_bad_dispatch_id(mocker, client): side_effect=HTTPException(status_code=400), ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - - with open(writer.name, "rb") as reader: - resp = client.put( - f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", data=reader - ) - - assert resp.status_code == 400 - - -def test_put_lattice_asset(mocker, client, test_db, mock_result_object): - """ - Test put lattice asset - """ - key = "workflow_function" - dispatch_id = "test_put_lattice_asset" - - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - - with open(writer.name, "rb") as reader: - resp = client.put( - f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", data=reader - ) - mock_lattice = mock_result_object.lattice - mock_lattice.update_assets.assert_called() - assert resp.status_code == 200 - - -def test_put_lattice_asset_bad_dispatch_id(mocker, client): - """ - Test put lattice asset - """ - key = "workflow_function" - dispatch_id = "test_put_lattice_asset_no_dispatch_id" - - mocker.patch( - "covalent_dispatcher._service.assets.get_cached_result_object", - side_effect=HTTPException(status_code=404), - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - - with open(writer.name, "rb") as reader: - resp = client.put( - f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", data=reader - ) - - assert resp.status_code == 400 - - -def test_put_dispatch_asset(mocker, client, test_db, mock_result_object): - """ - Test put dispatch asset - """ - key = "result" - dispatch_id = "test_put_dispatch_asset" - - mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) - mocker.patch( - "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object - ) - - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - - with open(writer.name, "rb") as reader: - resp = client.put(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", data=reader) - mock_result_object.update_assets.assert_called() - assert resp.status_code == 200 - - -def test_put_dispatch_asset_bad_dispatch_id(mocker, client): - """ - Test put dispatch asset - """ - key = "result" - dispatch_id = "test_put_dispatch_asset_no_dispatch_id" - - mocker.patch( - "covalent_dispatcher._service.assets.get_cached_result_object", - side_effect=HTTPException(status_code=400), - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - - with tempfile.NamedTemporaryFile("w") as writer: - writer.write(f"{dispatch_id}") - writer.flush() - - with open(writer.name, "rb") as reader: - resp = client.put(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", data=reader) - + resp = client.post(f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") assert resp.status_code == 400 -def test_get_string_offsets(): - tobj = TransportableObject("test_get_string_offsets") - - data = tobj.serialize() - with tempfile.NamedTemporaryFile("wb") as write_file: - write_file.write(data) - write_file.flush() - - start, end = _get_tobj_string_offsets(f"file://{write_file.name}") - - assert data[start:end].decode("utf-8") == tobj.object_string - - -def test_get_pickle_offsets(): - tobj = TransportableObject("test_get_pickle_offsets") - - data = tobj.serialize() - with tempfile.NamedTemporaryFile("wb") as write_file: - write_file.write(data) - write_file.flush() - - start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}") - - assert data[start:].decode("utf-8") == tobj.get_serialized() - - -def test_generate_partial_file_slice(): - """Test generating slices of files.""" - - data = "test_generate_file_slice".encode("utf-8") - with tempfile.NamedTemporaryFile("wb") as write_file: - write_file.write(data) - write_file.flush() - gen = _generate_file_slice(f"file://{write_file.name}", 1, 5, 2) - assert next(gen) == data[1:3] - assert next(gen) == data[3:5] - with pytest.raises(StopIteration): - next(gen) - - -def test_generate_whole_file_slice(): - """Test generating slices of files.""" - - data = "test_generate_file_slice".encode("utf-8") - with tempfile.NamedTemporaryFile("wb") as write_file: - write_file.write(data) - write_file.flush() - gen = _generate_file_slice(f"file://{write_file.name}", 0, -1) - assert next(gen) == data - - def test_get_cached_result_obj(mocker, test_db): mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) mocker.patch("covalent_dispatcher._service.assets.get_result_object", side_effect=KeyError()) diff --git a/tests/covalent_dispatcher_tests/entry_point_test.py b/tests/covalent_dispatcher_tests/entry_point_test.py index 53f92fece..679b0f563 100644 --- a/tests/covalent_dispatcher_tests/entry_point_test.py +++ b/tests/covalent_dispatcher_tests/entry_point_test.py @@ -27,26 +27,12 @@ cancel_running_dispatch, register_dispatch, register_redispatch, - run_dispatcher, start_dispatch, ) DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" -@pytest.mark.asyncio -async def test_run_dispatcher(mocker): - mock_run_dispatch = mocker.patch("covalent_dispatcher._core.run_dispatch") - mock_make_dispatch = mocker.patch( - "covalent_dispatcher._core.make_dispatch", return_value=DISPATCH_ID - ) - json_lattice = '{"workflow_function": "asdf"}' - dispatch_id = await run_dispatcher(json_lattice) - assert dispatch_id == DISPATCH_ID - mock_make_dispatch.assert_awaited_with(json_lattice) - mock_run_dispatch.assert_called_with(dispatch_id) - - @pytest.mark.asyncio async def test_cancel_running_dispatch(mocker): mock_cancel_workflow = mocker.patch("covalent_dispatcher.entry_point.cancel_dispatch") diff --git a/tests/covalent_tests/dispatcher_plugins/local_test.py b/tests/covalent_tests/dispatcher_plugins/local_test.py index e10c83d31..edc04adc7 100644 --- a/tests/covalent_tests/dispatcher_plugins/local_test.py +++ b/tests/covalent_tests/dispatcher_plugins/local_test.py @@ -338,7 +338,7 @@ def workflow(a, b): endpoint = f"/api/v2/dispatches/{dispatch_id}/lattice/assets/dummy" r = Response() r.status_code = 200 - mock_put = mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + mock_put = mocker.patch("requests.put", return_value=r) LocalDispatcher.upload_assets(manifest) diff --git a/tests/covalent_tests/executor/executor_plugins/dask_test.py b/tests/covalent_tests/executor/executor_plugins/dask_test.py index 3dd56ef61..18423fe1d 100644 --- a/tests/covalent_tests/executor/executor_plugins/dask_test.py +++ b/tests/covalent_tests/executor/executor_plugins/dask_test.py @@ -320,9 +320,9 @@ def task(x, y): node_2_file.flush() task_spec = TaskSpec( - function_id=0, - args_ids=[1, 2], - kwargs_ids={}, + electron_id=0, + args=[1, 2], + kwargs={}, ) resources = ResourceMap( @@ -417,9 +417,9 @@ def task(x, y): node_2_file.flush() task_spec = TaskSpec( - function_id=0, - args_ids=[1, 2], - kwargs_ids={}, + electron_id=0, + args=[1, 2], + kwargs={}, ) resources = ResourceMap( @@ -515,9 +515,9 @@ def task(x, y): node_2_file.flush() task_spec = TaskSpec( - function_id=0, - args_ids=[1], - kwargs_ids={"y": 2}, + electron_id=0, + args=[1], + kwargs={"y": 2}, ) resources = ResourceMap( diff --git a/tests/covalent_tests/executor/executor_plugins/local_test.py b/tests/covalent_tests/executor/executor_plugins/local_test.py index 6c9d87ce3..9aa65899d 100644 --- a/tests/covalent_tests/executor/executor_plugins/local_test.py +++ b/tests/covalent_tests/executor/executor_plugins/local_test.py @@ -281,47 +281,80 @@ def task(x, y): node_0_function_url = ( f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/function" ) + node_0_function_file_url = f"{server_url}/files/node_0_function" + node_0_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/output" + node_0_output_file_url = f"{server_url}/files/node_0_output" + node_0_stdout_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stdout" + node_0_stdout_file_url = f"{server_url}/files/node_0_stdout" + node_0_stderr_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stderr" + node_0_stderr_file_url = f"{server_url}/files/node_0_stderr" hooks_file = tempfile.NamedTemporaryFile("wb") hooks_file.write(ser_hooks) hooks_file.flush() hooks_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/hooks" + hooks_file_url = f"{server_url}/files/hooks" node_1_file = tempfile.NamedTemporaryFile("wb") node_1_file.write(ser_x) node_1_file.flush() node_1_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/1/assets/output" + node_1_output_file_url = f"{server_url}/node_1_output" node_2_file = tempfile.NamedTemporaryFile("wb") node_2_file.write(ser_y) node_2_file.flush() node_2_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/2/assets/output" + node_2_output_file_url = f"{server_url}/node_2_output" task_spec = TaskSpec( - function_id=0, - args_ids=[1, 2], - kwargs_ids={}, + electron_id=0, + args=[1, 2], + kwargs={}, ) + # GET/POST URLs + url_map = { + node_0_function_url: node_0_function_file_url, + node_0_output_url: node_0_output_file_url, + node_0_stdout_url: node_0_stdout_file_url, + node_0_stderr_url: node_0_stderr_file_url, + hooks_url: hooks_file_url, + node_1_output_url: node_1_output_file_url, + node_2_output_url: node_2_output_file_url, + } + + # GET/PUT files resources = { - node_0_function_url: ser_task, - node_1_output_url: ser_x, - node_2_output_url: ser_y, - hooks_url: ser_hooks, + node_0_function_file_url: ser_task, + node_1_output_file_url: ser_x, + node_2_output_file_url: ser_y, + hooks_file_url: ser_hooks, } - def mock_req_get(url, stream): + def mock_req_get(url, **kwargs): mock_resp = MagicMock() mock_resp.status_code = 200 - mock_resp.content = resources[url] + if url in url_map: + mock_resp.json.return_value = {"remote_uri": url_map[url]} + else: + mock_resp.content = resources[url] return mock_resp - def mock_req_post(url, files): - resources[url] = files["asset_file"].read() + def mock_req_post(url, **kwargs): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"remote_uri": url_map[url]} + return mock_resp + + def mock_req_put(url, data=None, headers={}, json={}): + if data is not None: + resources[url] = data if isinstance(data, bytes) else data.read() + return MagicMock() mocker.patch("requests.get", mock_req_get) mocker.patch("requests.post", mock_req_post) - mock_put = mocker.patch("requests.put") + mocker.patch("requests.put", mock_req_put) task_group_metadata = { "dispatch_id": dispatch_id, "node_ids": [node_id], @@ -342,8 +375,7 @@ def mock_req_post(url, files): server_url=server_url, ) - with open(result_file.name, "rb") as f: - output = TransportableObject.deserialize(f.read()) + output = TransportableObject.deserialize(resources[node_0_output_file_url]) assert output.get_deserialized() == 3 with open(cb_tmpfile.name, "r") as f: @@ -352,8 +384,6 @@ def mock_req_post(url, files): with open(ca_tmpfile.name, "r") as f: assert f.read() == "Bye\n" - mock_put.assert_called() - def test_run_task_group_exception(mocker): """Test the wrapper submitted to local""" @@ -398,47 +428,80 @@ def task(x, y): node_0_function_url = ( f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/function" ) + node_0_function_file_url = f"{server_url}/files/node_0_function" + node_0_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/output" + node_0_output_file_url = f"{server_url}/files/node_0_output" + node_0_stdout_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stdout" + node_0_stdout_file_url = f"{server_url}/files/node_0_stdout" + node_0_stderr_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stderr" + node_0_stderr_file_url = f"{server_url}/files/node_0_stderr" hooks_file = tempfile.NamedTemporaryFile("wb") hooks_file.write(ser_hooks) hooks_file.flush() hooks_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/hooks" + hooks_file_url = f"{server_url}/files/hooks" node_1_file = tempfile.NamedTemporaryFile("wb") node_1_file.write(ser_x) node_1_file.flush() node_1_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/1/assets/output" + node_1_output_file_url = f"{server_url}/node_1_output" node_2_file = tempfile.NamedTemporaryFile("wb") node_2_file.write(ser_y) node_2_file.flush() node_2_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/2/assets/output" + node_2_output_file_url = f"{server_url}/node_2_output" task_spec = TaskSpec( - function_id=0, - args_ids=[1], - kwargs_ids={"y": 2}, + electron_id=0, + args=[1], + kwargs={"y": 2}, ) + # GET/POST URLs + url_map = { + node_0_function_url: node_0_function_file_url, + node_0_output_url: node_0_output_file_url, + node_0_stdout_url: node_0_stdout_file_url, + node_0_stderr_url: node_0_stderr_file_url, + hooks_url: hooks_file_url, + node_1_output_url: node_1_output_file_url, + node_2_output_url: node_2_output_file_url, + } + + # GET/PUT files resources = { - node_0_function_url: ser_task, - node_1_output_url: ser_x, - node_2_output_url: ser_y, - hooks_url: ser_hooks, + node_0_function_file_url: ser_task, + node_1_output_file_url: ser_x, + node_2_output_file_url: ser_y, + hooks_file_url: ser_hooks, } - def mock_req_get(url, stream): + def mock_req_get(url, **kwargs): mock_resp = MagicMock() mock_resp.status_code = 200 - mock_resp.content = resources[url] + if url in url_map: + mock_resp.json.return_value = {"remote_uri": url_map[url]} + else: + mock_resp.content = resources[url] return mock_resp - def mock_req_post(url, files): - resources[url] = files["asset_file"].read() + def mock_req_post(url, **kwargs): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"remote_uri": url_map[url]} + return mock_resp + + def mock_req_put(url, data=None, headers={}, json={}): + if data is not None: + resources[url] = data if isinstance(data, bytes) else data.read() + return MagicMock() mocker.patch("requests.get", mock_req_get) mocker.patch("requests.post", mock_req_post) - mocker.patch("requests.put") + mocker.patch("requests.put", mock_req_put) task_group_metadata = { "dispatch_id": dispatch_id, "node_ids": [node_id], @@ -459,6 +522,9 @@ def mock_req_post(url, files): server_url=server_url, ) + stderr = resources[node_0_stderr_file_url].decode("utf-8") + assert "AssertionError" in stderr + summary_file_path = f"{results_dir.name}/result-{dispatch_id}:{node_id}.json" with open(summary_file_path, "r") as f: @@ -504,9 +570,9 @@ def mock_proc_pool_submit(mock_future): "id": "happy_path", "task_specs": [ TaskSpec( - function_id=0, - args_ids=[1], - kwargs_ids={"y": 2}, + electron_id=0, + args=[1], + kwargs={"y": 2}, ) ], "resources": ResourceMap( @@ -523,9 +589,9 @@ def mock_proc_pool_submit(mock_future): "id": "future_cancelled", "task_specs": [ TaskSpec( - function_id=0, - args_ids=[1], - kwargs_ids={"y": 2}, + electron_id=0, + args=[1], + kwargs={"y": 2}, ) ], "resources": ResourceMap( @@ -571,7 +637,7 @@ def test_send_internal( run_task_group, list(map(lambda t: t.dict(), test_case["task_specs"])), test_case["expected_output_uris"], - "mock_cache_dir", + local_exec.workdir, test_case["task_group_metadata"], test_case["expected_server_url"], ) @@ -586,9 +652,9 @@ async def test_send(mocker): # Arrange task_group_metadata = {"dispatch_id": "1", "node_ids": ["1", "2"]} task_spec = TaskSpec( - function_id=0, - args_ids=[1], - kwargs_ids={"y": 2}, + electron_id=0, + args=[1], + kwargs={"y": 2}, ) resource = ResourceMap( functions={0: "mock_function_uri"}, diff --git a/tests/covalent_tests/file_transfer/file_test.py b/tests/covalent_tests/file_transfer/file_test.py index ffec8e7b5..60770589b 100644 --- a/tests/covalent_tests/file_transfer/file_test.py +++ b/tests/covalent_tests/file_transfer/file_test.py @@ -18,7 +18,7 @@ import pytest -from covalent._file_transfer.enums import FileSchemes, FileTransferStrategyTypes +from covalent._file_transfer.enums import FileSchemes from covalent._file_transfer.file import File @@ -46,19 +46,12 @@ def test_raise_exception_valid_args(self): ("file:///home/ubuntu/observations.csv", FileSchemes.File), ("s3://mybucket/observations.csv", FileSchemes.S3), ("globus://037f054a-15cf-11e8-b611-0ac6873fc731/observations.txt", FileSchemes.Globus), + ("blob://my-account.blob.core.windows.net/container/blob", FileSchemes.Blob), ], ) def test_scheme_resolution(self, filepath, expected_scheme): assert File(filepath).scheme == expected_scheme - def test_scheme_to_strategy_map(self): - assert File("s3://file").mapped_strategy_type == FileTransferStrategyTypes.S3 - assert File("ftp://file").mapped_strategy_type == FileTransferStrategyTypes.FTP - assert File("globus://file").mapped_strategy_type == FileTransferStrategyTypes.GLOBUS - assert File("file://file").mapped_strategy_type == FileTransferStrategyTypes.Shutil - assert File("https://example.com").mapped_strategy_type == FileTransferStrategyTypes.HTTP - assert File("http://example.com").mapped_strategy_type == FileTransferStrategyTypes.HTTP - def test_is_remote_flag(self): assert File("s3://file").is_remote assert File("ftp://file").is_remote @@ -67,6 +60,7 @@ def test_is_remote_flag(self): assert File("file://file", is_remote=True).is_remote assert File("https://example.com").is_remote assert File("http://example.com").is_remote + assert File("blob://msft-acct.blob.core.windows.net/container/blob").is_remote @pytest.mark.parametrize( "filepath, expected_filepath", diff --git a/tests/covalent_tests/file_transfer/file_transfer_test.py b/tests/covalent_tests/file_transfer/file_transfer_test.py index 746e490c9..03114d373 100644 --- a/tests/covalent_tests/file_transfer/file_transfer_test.py +++ b/tests/covalent_tests/file_transfer/file_transfer_test.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable from unittest.mock import Mock import pytest @@ -24,8 +25,22 @@ FileTransfer, TransferFromRemote, TransferToRemote, + guess_transfer_strategy, + register_downloader, + register_uploader, ) from covalent._file_transfer.strategies.rsync_strategy import Rsync +from covalent._file_transfer.strategies.s3_strategy import S3 +from covalent._file_transfer.strategies.shutil_strategy import Shutil + + +# Sample custom transfer strategy +class HiveTransferStrategy: + def download(self, from_file: File, to_file: File) -> Callable: + raise NotImplementedError + + def upload(self, from_file: File, to_file: File) -> Callable: + raise NotImplementedError class TestFileTransfer: @@ -109,3 +124,36 @@ def test_transfer_to_remote(self): with pytest.raises(ValueError): result = TransferToRemote("file:///home/one", "file:///home/one/", strategy=strategy) + + def test_auto_transfer_strategy(self): + from_file = File("s3://bucket/object.pkl") + to_file = File("file:///tmp/object.pkl") + ft = FileTransfer(from_file, to_file) + assert type(ft.strategy) is S3 + + ft = FileTransfer(to_file, from_file) + assert type(ft.strategy) is S3 + + ft = FileTransfer(to_file, to_file) + assert type(ft.strategy) is Shutil + + with pytest.raises(AttributeError): + _ = FileTransfer(from_file, from_file) + + def test_register_custom_schemes_and_transfers(self): + register_downloader("hive", HiveTransferStrategy) + register_uploader("hive", HiveTransferStrategy) + from_file = File("hive://gateway/assets/from_asset") + to_file = File("file:///tmp/stdout.txt") + + assert from_file.is_remote + assert not to_file.is_remote + strategy = guess_transfer_strategy(from_file, to_file) + assert strategy == HiveTransferStrategy + + strategy = guess_transfer_strategy(to_file, from_file) + assert strategy == HiveTransferStrategy + + # Copying not supported + with pytest.raises(AttributeError): + guess_transfer_strategy(from_file, from_file) diff --git a/tests/covalent_tests/file_transfer/strategies/http_strategy_test.py b/tests/covalent_tests/file_transfer/strategies/http_strategy_test.py index 615551a6b..7410f3103 100644 --- a/tests/covalent_tests/file_transfer/strategies/http_strategy_test.py +++ b/tests/covalent_tests/file_transfer/strategies/http_strategy_test.py @@ -21,16 +21,16 @@ class TestHTTPStrategy: - MOCK_LOCAL_FILEPATH = "/Users/user/data.csv" + MOCK_LOCAL_FILEPATH = "/tmp/data.csv" MOCK_REMOTE_FILEPATH = "http://example.com/data.csv" def test_download(self, mocker): # validate urlretrieve called with appropriate arguments - urlretrieve_mock = mocker.patch("urllib.request.urlretrieve") + mock_get = mocker.patch("requests.get") from_file = File(self.MOCK_REMOTE_FILEPATH) to_file = File(self.MOCK_LOCAL_FILEPATH) HTTP().download(from_file, to_file)() - urlretrieve_mock.assert_called_with(from_file.uri, to_file.filepath) + mock_get.assert_called_with(from_file.uri, stream=True) @pytest.mark.parametrize( "operation", @@ -39,9 +39,6 @@ def test_download(self, mocker): ("upload"), ], ) - def test_upload_cp_failure(self, operation, mocker): + def test_upload_failure(self, operation, mocker): with pytest.raises(NotImplementedError): - if operation == "upload": - HTTP().upload(File(self.MOCK_REMOTE_FILEPATH), File(self.MOCK_LOCAL_FILEPATH))() - elif operation == "cp": - HTTP().cp(File(self.MOCK_REMOTE_FILEPATH), File(self.MOCK_LOCAL_FILEPATH))() + HTTP().cp(File(self.MOCK_REMOTE_FILEPATH), File(self.MOCK_LOCAL_FILEPATH))() diff --git a/tests/covalent_tests/file_transfer/strategies/shutil_strategy_test.py b/tests/covalent_tests/file_transfer/strategies/shutil_strategy_test.py index 654a2bc11..5dab32488 100644 --- a/tests/covalent_tests/file_transfer/strategies/shutil_strategy_test.py +++ b/tests/covalent_tests/file_transfer/strategies/shutil_strategy_test.py @@ -26,6 +26,7 @@ class TestShutilStrategy: MOCK_TO_FILEPATH = "/home/user/data.csv.bak" def test_cp(self, mocker): + mocker.patch("os.makedirs") mock_copyfile = mocker.patch("shutil.copyfile") from_file = File(TestShutilStrategy.MOCK_FROM_FILEPATH) to_file = File(TestShutilStrategy.MOCK_TO_FILEPATH) diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 72c0260d9..acc15a31c 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -277,20 +277,8 @@ def test_get_status_only(mocker): def test_download_asset(mocker): dispatch_id = "test_download_asset" remote_uri = f"http://localhost:48008/api/v2/dispatches/{dispatch_id}/assets/result" - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.status_code = 200 - - mock_client.get = MagicMock(return_value=mock_response) - mocker.patch( - "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client - ) - - def mock_generator(): - yield "Hello".encode("utf-8") - - mock_response.iter_content = MagicMock(return_value=mock_generator()) + mock_get = mocker.patch("requests.get") with tempfile.NamedTemporaryFile() as local_file: download_asset(remote_uri, local_file.name) - assert local_file.read().decode("utf-8") == "Hello" + mock_get.assert_called_with(remote_uri, stream=True) diff --git a/tests/covalent_tests/workflow/electron_test.py b/tests/covalent_tests/workflow/electron_test.py index 72e629b40..260df7ced 100644 --- a/tests/covalent_tests/workflow/electron_test.py +++ b/tests/covalent_tests/workflow/electron_test.py @@ -16,17 +16,19 @@ """Unit tests for electron""" +import copy import json -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY import flake8 import isort import pytest +import requests import covalent as ct +from covalent._dispatcher_plugins.local import decode_b64_tar, untar_staging_dir from covalent._shared_files.context_managers import active_lattice_manager from covalent._shared_files.defaults import WAIT_EDGE_NAME, sublattice_prefix -from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.electron import ( Electron, @@ -34,7 +36,6 @@ filter_null_metadata, get_serialized_function_str, ) -from covalent._workflow.lattice import Lattice from covalent._workflow.transport import TransportableObject, encode_metadata from covalent.executor.executor_plugins.local import LocalExecutor @@ -106,13 +107,14 @@ def workflow(x): mock_environ = { "COVALENT_DISPATCH_ID": dispatch_id, "COVALENT_DISPATCHER_URL": "http://localhost:48008", + "COVALENT_TASKS": json.dumps([{"electron_id": 0, "args": [], "kwargs": []}]), } - mock_manifest = MagicMock() - mock_manifest.json = MagicMock(return_value=dispatch_id) - def mock_register(manifest, *args, **kwargs): - return manifest + returned_manifest = copy.deepcopy(manifest) + returned_manifest.metadata.dispatch_id = "mock-sublattice-dispatch" + returned_manifest.metadata.root_dispatch_id = "mock-sublattice-dispatch" + return returned_manifest mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest", @@ -125,25 +127,20 @@ def mock_register(manifest, *args, **kwargs): mocker.patch("os.environ", mock_environ) - json_manifest = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1) - - manifest = ResultSchema.parse_raw(json_manifest) + # Mock out the call to associate sublattice with parent electron + r = requests.Response() + r.status_code = 200 + mocker.patch("covalent._api.apiclient.requests.Session.patch", return_value=r) + tar_b64 = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1) mock_upload_assets.assert_called() - - assert len(manifest.lattice.transport_graph.nodes) == 3 - - lat = manifest.lattice - assert lat.metadata.executor == parent_metadata["executor"] - assert lat.metadata.executor_data == parent_metadata["executor_data"] - - assert lat.metadata.workflow_executor == parent_metadata["workflow_executor"] - assert lat.metadata.workflow_executor_data == parent_metadata["workflow_executor_data"] + work_dir, manifest = untar_staging_dir(decode_b64_tar(tar_b64)) + assert manifest.metadata.dispatch_id == "mock-sublattice-dispatch" def test_build_sublattice_graph_fallback(mocker): """ - Test falling back to monolithic sublattice dispatch + Test _build_sublattice_graph when electron is unable to reach the control plane """ dispatch_id = "test_build_sublattice_graph" @@ -175,27 +172,13 @@ def workflow(x): mock_reg = mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest", ) - - mock_upload_assets = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets", - ) - mocker.patch("os.environ", mock_environ) - json_lattice = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1) - - lattice = Lattice.deserialize_from_json(json_lattice) - + tar_b64 = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1) mock_reg.assert_not_called() - mock_upload_assets.assert_not_called() - - assert list(lattice.transport_graph._graph.nodes) == list(range(3)) - for k in lattice.metadata.keys(): - # results_dir will be deprecated soon - if k == "triggers": - assert lattice.metadata[k] is None - elif k != "results_dir": - assert parent_metadata[k] == lattice.metadata[k] + work_dir, manifest = untar_staging_dir(decode_b64_tar(tar_b64)) + assert manifest.metadata.dispatch_id == "" + assert manifest.metadata.root_dispatch_id == "" def test_wait_for_building(): diff --git a/tests/covalent_tests/workflow/lepton_test.py b/tests/covalent_tests/workflow/lepton_test.py index 5f33b0156..306547c53 100644 --- a/tests/covalent_tests/workflow/lepton_test.py +++ b/tests/covalent_tests/workflow/lepton_test.py @@ -26,7 +26,7 @@ import pytest from covalent import DepsBash, TransportableObject -from covalent._file_transfer.file_transfer import HTTP, File, FileTransfer, Order +from covalent._file_transfer.file_transfer import File, FileTransfer, Order from covalent._workflow.lepton import Lepton from covalent._workflow.transport import encode_metadata from covalent.executor import LocalExecutor @@ -228,14 +228,6 @@ def test_http_file_transfer(order): mock_file_download = FileTransfer(from_file=mock_from_file, to_file=mock_to_file, order=order) mock_lepton_with_files = Lepton("python", command="mockcmd", files=[mock_file_download]) - # Test that HTTP upload strategy is not currently implemented - with pytest.raises(NotImplementedError): - Lepton( - "python", - command="mockcmd", - files=[FileTransfer(from_file=mock_to_file, to_file=mock_from_file, strategy=HTTP())], - ) - deps = mock_lepton_with_files.get_metadata("hooks")["deps"] assert deps.get("bash") is None assert deps.get("pip") is None diff --git a/tests/functional_tests/dispatcher_stack_test.py b/tests/functional_tests/dispatcher_stack_test.py deleted file mode 100644 index 1f439abcf..000000000 --- a/tests/functional_tests/dispatcher_stack_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Integration test for the dispatcher. -""" - -import pytest - -import covalent_dispatcher as dispatcher -from covalent._results_manager import results_manager as rm -from covalent._shared_files.defaults import parameter_prefix - -from .data import get_mock_result, get_mock_result_2, get_mock_result_3 - - -@pytest.mark.parametrize( - "mock_result,expected_res, expected_node_outputs", - [ - (get_mock_result, 1, {"identity": 1, f"{parameter_prefix}1": 1}), - ( - get_mock_result_2, - 1, - { - "product": 1, - f"{parameter_prefix}1": 1, - f"{parameter_prefix}1": 1, - "identity": 1, - }, - ), - ( - get_mock_result_3, - 1, - {"pipeline": 1, f"{parameter_prefix}1": 1, f"{parameter_prefix}1": 1}, - ), - ], -) -def test_dispatcher_flow(mock_result, expected_res, expected_node_outputs): - """Integration test that given a results object, plans and executes the workflow on the - default executor. - """ - - import asyncio - - mock_result_object = mock_result() - serialized_lattice = mock_result_object.lattice.serialize_to_json() - - awaitable = dispatcher.run_dispatcher(json_lattice=serialized_lattice) - dispatch_id = asyncio.run(awaitable) - rm._delete_result(dispatch_id=dispatch_id, remove_parent_directory=True) diff --git a/tests/functional_tests/init_test.py b/tests/functional_tests/init_test.py deleted file mode 100644 index 63b81eedf..000000000 --- a/tests/functional_tests/init_test.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Tests for self-contained entry point for the dispatcher -""" - - -import covalent_dispatcher as dispatcher - -from .data import get_mock_result - - -def test_run_dispatcher(): - """ - Test run_dispatcher by passing a result object for a lattice and check if no exception is raised. - """ - - import asyncio - - try: - awaitable = dispatcher.run_dispatcher( - json_lattice=get_mock_result().lattice.serialize_to_json() - ) - dispatch_id = asyncio.run(awaitable) - except Exception as e: - assert False, f"Exception raised: {e}"