Skip to content

File transfer and sublattice enhancements #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 22, 2025
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ node_modules/

# Ignore mock database
**/*.sqlite

# Ignore virtual envs
*.venv
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Added

- Allow registering custom file transfer strategies

### Changed

- Improved automatic file transfer strategy selection
- HTTP strategy can now upload files too
- Adjusted sublattice logic. The sublattice builder now attempts to
link the sublattice with its parent electron.
- Replaced json sublattice flow with new tarball importer to allow future memory
footprint enhancements

## [0.239.0-rc.0] - 2025-04-16

### Authors
Expand Down
24 changes: 24 additions & 0 deletions covalent/_api/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get(self, endpoint: str, **kwargs):
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)
session.mount("https://", self.adapter)

r = session.get(url, headers=headers, **kwargs)

Expand All @@ -61,13 +62,34 @@ def get(self, endpoint: str, **kwargs):

return r

def patch(self, endpoint: str, **kwargs):
headers = self.prepare_headers(kwargs)
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)
session.mount("https://", self.adapter)

r = session.patch(url, headers=headers, **kwargs)

if self.auto_raise:
r.raise_for_status()
except requests.exceptions.ConnectionError:
message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage."
print(message)
raise

return r

def put(self, endpoint: str, **kwargs):
headers = self.prepare_headers(kwargs)
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)
session.mount("https://", self.adapter)

r = session.put(url, headers=headers, **kwargs)

Expand All @@ -87,6 +109,7 @@ def post(self, endpoint: str, **kwargs):
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)
session.mount("https://", self.adapter)

r = session.post(url, headers=headers, **kwargs)

Expand All @@ -106,6 +129,7 @@ def delete(self, endpoint: str, **kwargs):
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)
session.mount("https://", self.adapter)

r = session.delete(url, headers=headers, **kwargs)

Expand Down
123 changes: 82 additions & 41 deletions covalent/_dispatcher_plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import os
import tarfile
import tempfile
from copy import deepcopy
from functools import wraps
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union

from furl import furl
from typing import Callable, Dict, List, Optional, Tuple, Union

from .._api.apiclient import CovalentAPIClient as APIClient
from .._file_transfer import FileTransfer
from .._results_manager.result import Result
from .._results_manager.results_manager import get_result, get_result_manager
from .._serialize.result import (
Expand All @@ -36,7 +37,7 @@
from .._shared_files.config import get_config
from .._shared_files.schemas.asset import AssetSchema
from .._shared_files.schemas.result import ResultSchema
from .._shared_files.utils import copy_file_locally, format_server_url
from .._shared_files.utils import format_server_url
from .._workflow.lattice import Lattice
from ..triggers import BaseTrigger
from .base import BaseDispatcher
Expand Down Expand Up @@ -251,7 +252,7 @@ def start(
if dispatcher_addr is None:
dispatcher_addr = format_server_url()

endpoint = f"/api/v2/dispatches/{dispatch_id}/status"
endpoint = f"{BASE_ENDPOINT}/{dispatch_id}/status"
body = {"status": "RUNNING"}
r = APIClient(dispatcher_addr).put(endpoint, json=body)
r.raise_for_status()
Expand Down Expand Up @@ -463,7 +464,6 @@ def prepare_manifest(lattice, storage_path) -> ResultSchema:
def register_manifest(
manifest: ResultSchema,
dispatcher_addr: Optional[str] = None,
parent_dispatch_id: Optional[str] = None,
push_assets: bool = True,
) -> ResultSchema:
"""Submits a manifest for registration.
Expand All @@ -482,9 +482,6 @@ def register_manifest(
stripped = strip_local_uris(manifest) if push_assets else manifest
endpoint = BASE_ENDPOINT

if parent_dispatch_id:
endpoint = f"{BASE_ENDPOINT}/{parent_dispatch_id}/sublattices"

r = APIClient(dispatcher_addr).post(endpoint, data=stripped.model_dump_json())
r.raise_for_status()

Expand Down Expand Up @@ -512,7 +509,7 @@ def register_derived_manifest(
# We don't yet support pulling assets for redispatch
stripped = strip_local_uris(manifest)

endpoint = f"/api/v2/dispatches/{dispatch_id}/redispatches"
endpoint = f"{BASE_ENDPOINT}/{dispatch_id}/redispatches"

params = {"reuse_previous_results": reuse_previous_results}
r = APIClient(dispatcher_addr).post(
Expand All @@ -531,45 +528,89 @@ def upload_assets(manifest: ResultSchema):

@staticmethod
def _upload(assets: List[AssetSchema]):
local_scheme_prefix = "file://"
total = len(assets)
number_uploaded = 0
for i, asset in enumerate(assets):
if not asset.remote_uri or not asset.uri:
app_log.debug(f"Skipping asset {i + 1} out of {total}")
continue
if asset.remote_uri.startswith(local_scheme_prefix):
copy_file_locally(asset.uri, asset.remote_uri)
number_uploaded += 1
else:
_upload_asset(asset.uri, asset.remote_uri)
number_uploaded += 1
app_log.debug(f"Uploaded asset {i + 1} out of {total}.")

_upload_asset(asset.uri, asset.remote_uri)
number_uploaded += 1
app_log.debug(f"Uploaded asset {i + 1} out of {total}.")
app_log.debug(f"uploaded {number_uploaded} assets.")


def _upload_asset(local_uri, remote_uri):
_, ft = FileTransfer(local_uri, remote_uri).cp()
ft()


# Archive staging directory and manifest
# Used for sublattice dispatch when the executor cannot directly
# submit the sublattice to the control plane
def pack_staging_dir(staging_dir, manifest: ResultSchema) -> str:
# save manifest json to staging root
with open(os.path.join(staging_dir, "manifest.json"), "w") as f:
f.write(manifest.model_dump_json())

# Tar up staging dir
with tempfile.NamedTemporaryFile(suffix=".tar") as f:
tar_path = f.name

with tarfile.TarFile(tar_path, "w") as tar:
tar.add(staging_dir, recursive=True)
return tar_path


# Inverse of `pack_staging_dir`
# Consumed by server-side tarball importer
def untar_staging_dir(tar_name) -> Tuple[str, ResultSchema]:

# Working directory for unpacking the archive
with tempfile.TemporaryDirectory(prefix="postprocess-") as work_dir:
...

# Find and extract manifest
with tarfile.TarFile(tar_name) as tar:
manifest_path = list(filter(lambda x: x.endswith("manifest.json"), tar.getnames()))
if len(manifest_path) == 0:
raise RuntimeError("Archive contains no manifest")

manifest = ResultSchema.model_validate_json(tar.extractfile(manifest_path[0]).read())

tar.extractall(path=work_dir, filter="tar")

# prepend work_dir to each asset path
scheme_prefix = "file://"
if local_uri.startswith(scheme_prefix):
local_path = local_uri[len(scheme_prefix) :]
else:
local_path = local_uri

filesize = os.path.getsize(local_path)
with open(local_path, "rb") as reader:
app_log.debug(f"uploading to {remote_uri}")
f = furl(remote_uri)
scheme = f.scheme
host = f.host
port = f.port
dispatcher_addr = f"{scheme}://{host}:{port}"
endpoint = str(f.path)
api_client = APIClient(dispatcher_addr)
if f.query:
endpoint = f"{endpoint}?{f.query}"

# Workaround for Requests bug when streaming from empty files
data = reader.read() if filesize < 50 else reader

r = api_client.put(endpoint, headers={"Content-Length": str(filesize)}, data=data)
r.raise_for_status()
for _, asset in manifest.assets:
if asset.uri:
path = asset.uri[len(scheme_prefix) :]
asset.uri = f"{scheme_prefix}{work_dir}{path}"
print("Rewrote asset uri ", asset.uri)
for _, asset in manifest.lattice.assets:
if asset.uri:
path = asset.uri[len(scheme_prefix) :]
asset.uri = f"{scheme_prefix}{work_dir}{path}"
print("Rewrote asset uri ", asset.uri)

for node in manifest.lattice.transport_graph.nodes:
for _, asset in node.assets:
if asset.uri:
path = asset.uri[len(scheme_prefix) :]
asset.uri = f"{scheme_prefix}{work_dir}{path}"
print("Rewrote asset uri ", asset.uri)

return work_dir, manifest


# Consumed by server-side tarball importer (`import_b64_staging_tarball`)
# TODO: support streaming decode to avoid having to load the entire buffer in mem
def decode_b64_tar(b64_buffer: str) -> str:
with tempfile.NamedTemporaryFile(suffix=".tar") as tar_file:
tar_path = tar_file.name

with open(tar_path, "wb") as tar_file:
tar_file.write(base64.b64decode(b64_buffer.encode("utf-8")))

return tar_path
2 changes: 1 addition & 1 deletion covalent/_file_transfer/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Order(str, enum.Enum):
class FileSchemes(str, enum.Enum):
File = "file"
S3 = "s3"
Blob = "https"
Blob = "blob"
GCloud = "gs"
Globus = "globus"
HTTP = "http"
Expand Down
59 changes: 27 additions & 32 deletions covalent/_file_transfer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,24 @@

from furl import furl

from .enums import FileSchemes, FileTransferStrategyTypes, SchemeToStrategyMap
from .enums import FileSchemes

_is_remote_scheme = {
FileSchemes.S3.value: True,
FileSchemes.Blob.value: True,
FileSchemes.GCloud.value: True,
FileSchemes.Globus.value: True,
FileSchemes.HTTP.value: True,
FileSchemes.HTTPS.value: True,
FileSchemes.FTP.value: True,
FileSchemes.File: False,
}


# For registering additional file transfer strategies; this will be called by
# `register_uploader`` and `register_downloader``
def register_remote_scheme(s: str):
_is_remote_scheme[s] = True


class File:
Expand Down Expand Up @@ -80,19 +97,7 @@ def get_temp_filepath(self):

@property
def is_remote(self):
return self._is_remote or self.scheme in [
FileSchemes.S3,
FileSchemes.Blob,
FileSchemes.GCloud,
FileSchemes.Globus,
FileSchemes.HTTP,
FileSchemes.HTTPS,
FileSchemes.FTP,
]

@property
def mapped_strategy_type(self) -> FileTransferStrategyTypes:
return SchemeToStrategyMap[self.scheme.value]
return self._is_remote or _is_remote_scheme[self.scheme]

@property
def filepath(self) -> str:
Expand Down Expand Up @@ -127,23 +132,13 @@ def get_uri(scheme: str, path: str) -> str:
return path_components.url

@staticmethod
def resolve_scheme(path: str) -> FileSchemes:
def resolve_scheme(path: str) -> str:
scheme = furl(path).scheme
host = furl(path).host
if scheme == FileSchemes.Globus:
return FileSchemes.Globus
if scheme == FileSchemes.S3:
return FileSchemes.S3
if scheme == FileSchemes.Blob and "blob.core.windows.net" in host:
return FileSchemes.Blob
if scheme == FileSchemes.GCloud:
return FileSchemes.GCloud
if scheme == FileSchemes.FTP:
return FileSchemes.FTP
if scheme == FileSchemes.HTTP:
return FileSchemes.HTTP
if scheme == FileSchemes.HTTPS:
return FileSchemes.HTTPS
if scheme is None or scheme == FileSchemes.File:
return FileSchemes.File
raise ValueError(f"Provided File scheme ({scheme}) is not supported.")
# Canonicalize file system paths to file:// urls
if not scheme:
return FileSchemes.File.value
if scheme in _is_remote_scheme:
return scheme
else:
raise ValueError(f"Provided File scheme ({scheme}) is not supported.")
Loading
Loading