Skip to content

ITEP-64780 weights uploader #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def create_flyte_container_task( # noqa: PLR0913
V1EnvVar(name="TASK_ID", value=container_name),
V1EnvVar(name="SESSION_ORGANIZATION_ID", value=str(session.organization_id)),
V1EnvVar(name="SESSION_WORKSPACE_ID", value=str(session.workspace_id)),
V1EnvVar(name="WEIGHTS_URL", value="https://storage.geti.intel.com/weights"),
V1EnvVar(
name="KAFKA_TOPIC_PREFIX",
value_from=V1EnvVarSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,31 @@
logger = logging.getLogger("mlflow_job")


def download_file_from_s3(bucket_name, object_name, file_path, endpoint): # noqa: ANN001, ANN201, D103
def download_file_from_url(client: Minio, bucket_name: str, object_name: str, file_path: str) -> None:
"""
Download file from weights url and save it in target S3 bucket
"""
try:
# Try to download the file from the Internet
url = f"{os.environ.get('WEIGHTS_URL')}/{object_name}"
resp = requests.get(url, timeout=600)
if resp.status_code == 200:
with open(file_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=512):
if chunk:
f.write(chunk)
logger.info(f"File '{object_name}' downloaded successfully from {url} to '{file_path}'")
# Upload file to the S3 bucket, it will overwrite existing one, if it was put in the meantime
client.fput_object(bucket_name, object_name, file_path)
logger.info(f"File '{object_name}' uploaded successfully to S3")
else:
raise RuntimeError(f"Failed to download '{object_name}' from {url}. Status code: {resp.status_code}")
except Exception:
logger.error(f"Failed to download '{object_name}' from the Internet.")
raise


def download_file(bucket_name, object_name, file_path, endpoint): # noqa: ANN001, ANN201, D103
# Initialize the Minio client
s3_credentials_provider = os.environ.get("S3_CREDENTIALS_PROVIDER")
if s3_credentials_provider == "local":
Expand Down Expand Up @@ -51,21 +75,24 @@ def download_file_from_s3(bucket_name, object_name, file_path, endpoint): # noq
# Download the file from the S3 bucket
client.fget_object(bucket_name, object_name, file_path)
print(f"File '{object_name}' downloaded successfully to '{file_path}'")
except S3Error:
logger.warning(f"{traceback.print_exc()}")
logger.warning("Trying to get object using presigned URL")
url = presigned_urls_client.presigned_get_object(bucket_name, object_name)
try:
resp = requests.get(url, timeout=600)
if resp.status_code == 200:
with open(file_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=512):
if chunk:
f.write(chunk)
logger.info(f"File '{object_name}' downloaded successfully to '{file_path}' using presigned URL")
except Exception:
print(f"{traceback.print_exc()}")
raise
except S3Error as e:
if e.code == "NoSuchKey":
download_file_from_url(client, bucket_name, object_name, file_path)
else:
logger.warning(f"{traceback.print_exc()}")
logger.warning("Trying to get object using presigned URL")
url = presigned_urls_client.presigned_get_object(bucket_name, object_name)
try:
resp = requests.get(url, timeout=600)
if resp.status_code == 200:
with open(file_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=512):
if chunk:
f.write(chunk)
logger.info(f"File '{object_name}' downloaded successfully to '{file_path}' using presigned URL")
except Exception:
print(f"{traceback.print_exc()}")
raise


# Example usage
Expand All @@ -75,4 +102,4 @@ def download_file_from_s3(bucket_name, object_name, file_path, endpoint): # noq
file_path = "./temp_downloaded.obj"
endpoint = os.environ.get("S3_HOST", "impt-seaweed-fs:8333")

download_file_from_s3(bucket_name, object_name, file_path, endpoint)
download_file(bucket_name, object_name, file_path, endpoint)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import TYPE_CHECKING

import mlflow
from minio_util import download_file_from_s3
from minio_util import download_file
from mlflow_io import AsyncCaller, download_config_file, download_shard_files, log_error, log_full
from optimize import optimize
from train import train
Expand All @@ -34,7 +34,7 @@ def download_pretrained_weights(template_id: str) -> None:
host_name = os.environ.get("S3_HOST", "impt-seaweed-fs:8333")
bucket_name = os.environ.get("BUCKET_NAME_PRETRAINEDWEIGHTS", "pretrainedweights")
metadata_path = os.path.join(str(work_dir), "metadata.json")
download_file_from_s3(bucket_name, "pretrained_models.json", metadata_path, host_name)
download_file(bucket_name, "pretrained_models.json", metadata_path, host_name)

if not os.path.exists(metadata_path):
raise RuntimeError(f"Metadata file {metadata_path} does not exist")
Expand Down Expand Up @@ -62,7 +62,7 @@ def download_pretrained_weights(template_id: str) -> None:
host_name = os.environ.get("S3_HOST", "impt-seaweed-fs:8333")
for obj_name in obj_names:
file_path = os.path.join(model_cache_dir, obj_name)
download_file_from_s3(bucket_name, obj_name, file_path, host_name)
download_file(bucket_name, obj_name, file_path, host_name)
if file_path.endswith(".zip"):
with zipfile.ZipFile(file_path) as zip_ref:
zip_ref.extractall(os.path.dirname(file_path))
Expand Down
4 changes: 4 additions & 0 deletions platform/services/weights_uploader/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*
!app
!uv.lock
!pyproject.toml
47 changes: 47 additions & 0 deletions platform/services/weights_uploader/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
FROM python:3.10-slim-bookworm AS base

FROM base AS build

ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy

# Disable Python downloads, because we want to use the system interpreter
# across both images.
ENV UV_PYTHON_DOWNLOADS=0

# Copy the service dependencies
WORKDIR /builder
COPY --link --from=libs . ../libs

WORKDIR /builder/weights_uploader/app

COPY --link --from=ghcr.io/astral-sh/uv:0.6.12 /uv /bin/uv

COPY --link app .

RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv venv --relocatable && \
uv sync --frozen --no-dev --no-editable

FROM base AS runtime

RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

RUN useradd -l -u 10001 non-root && \
pip3 uninstall -y setuptools pip wheel && \
rm -rf /root/.cache/pip

USER non-root

# Copy the application from the builder
COPY --link --from=build --chown=10001 /builder/weights_uploader/app /app

# Place executables in the environment at the front of the path
ENV PATH="/app/.venv/bin:$PATH"

WORKDIR /app
6 changes: 6 additions & 0 deletions platform/services/weights_uploader/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (C) 2022-2025 Intel Corporation
# LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

include ../../../Makefile.shared-python

DOCKER_BUILD_CONTEXT := --build-context libs=../../../libs
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (C) 2022-2025 Intel Corporation
# LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import hashlib
import logging
import os
import shutil
import urllib.error
import urllib.request
import zipfile
from collections.abc import Callable

logging.basicConfig(level=logging.INFO)

RETRIES = 5

logger = logging.getLogger(__name__)


def sha256sum(filepath: str): # noqa: ANN201, D103
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
while True:
data = f.read(65536)
if not data:
break
sha256.update(data)
return sha256.hexdigest()


def download_file(url: str, target_path: str, auto_unzip=True): # noqa: ANN001, ANN201, D103
logger.info(f"Downloading file: {url}")
url_original_filename = os.path.basename(url)
if "?" in url_original_filename:
url_original_filename = url_original_filename.split("?")[0]

target_dir_path = os.path.dirname(target_path)
download_temp_target_path = os.path.join(target_dir_path, url_original_filename)

with (
urllib.request.urlopen(url) as response, # noqa: S310
open(download_temp_target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)

# do not use 'zipfile.is_zipfile'!
# some '.pth' files are actually zip files and they should not be unzipped here
if auto_unzip and download_temp_target_path.endswith(".zip"):
with zipfile.ZipFile(download_temp_target_path) as zip_ref:
files_in_zip = zip_ref.namelist()
number_of_files_in_zip = len(files_in_zip)
if number_of_files_in_zip != 1:
raise RuntimeError(
f"Unexpected number of files: {number_of_files_in_zip}, expected: 1 in: {download_temp_target_path}"
)
zip_ref.extractall(target_dir_path)
os.remove(download_temp_target_path)
shutil.move(os.path.join(target_dir_path, files_in_zip[0]), target_path)
elif os.path.dirname(download_temp_target_path) != os.path.dirname(target_path) or (
os.path.basename(download_temp_target_path) != os.path.basename(target_path)
):
shutil.move(download_temp_target_path, target_path)


class MaxTriesExhausted(Exception):
pass


# no retry lib has been used here on purpose - to avoid installing additional libs
def retry_call(call: Callable, retries: int = RETRIES, **kwargs): # noqa: ANN201, D103
for i in range(retries):
logger.info(f"Try {i + 1}/{retries}")
try:
call(**kwargs)
break
except Exception:
logger.exception(f"Failed try {i + 1}/{retries}")
else:
raise MaxTriesExhausted


def download_pretrained_model(model_spec: dict, target_dir: str, weights_url: str | None = None): # noqa: ANN201, D103
model_external_url = model_spec["url"]
target_path = model_spec["target"]
auto_unzip = model_spec.get("unzip", True)
sha_sum = model_spec.get("sha_sum")

target_download_path = os.path.join(target_dir, os.path.basename(target_path))
if weights_url is not None:
model_external_url = os.path.join(weights_url, os.path.basename(model_external_url))

if os.path.exists(target_download_path):
if sha_sum is None:
logger.warning(f"Model already existed: {target_download_path} but sha_sum is not specified")
logger.warning(f"consider to add sha_sum to the model spec: {sha256sum(target_download_path)}")
elif sha256sum(target_download_path) == sha_sum:
logger.info(f"Model already downloaded: {target_download_path}")
return
else:
logger.warning(f"Model already downloaded but SHA mismatch: {target_download_path}")
logger.warning("Redownloading...")
os.remove(target_download_path)

try:
retry_call(
download_file,
url=model_external_url,
target_path=target_download_path,
auto_unzip=auto_unzip,
)
except MaxTriesExhausted:
raise

# verify SHA
if sha_sum is not None:
received_sha = sha256sum(target_download_path)
if sha_sum != received_sha:
raise RuntimeError(f"Wrong SHA sum for: {target_download_path}. Expected: {sha_sum}, got: {received_sha}")
logger.info("SHA match")
Loading
Loading