Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions ads/opctl/backend/ads_ml_job.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Copyright (c) 2022, 2026 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import copy
import json
import os
import re
import shlex
import shutil
import tempfile
import time
import re
from distutils import dir_util
from typing import Dict, Tuple, Union

from ads.common.auth import AuthContext, AuthType, create_signer
from ads.common.oci_client import OCIClientFactory
from ads.config import (
CONDA_BUCKET_NAME,
CONDA_BUCKET_NS,
)
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
from ads.jobs import (
ContainerRuntime,
DataScienceJob,
Expand All @@ -41,6 +37,7 @@
)
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader
from ads.opctl.utils import secure_copytree

REQUIRED_FIELDS = [
"project_id",
Expand Down Expand Up @@ -72,17 +69,16 @@ def __init__(self, config: Dict) -> None:
self.client = OCIClientFactory(**self.oci_auth).data_science
self.object_storage = OCIClientFactory(**self.oci_auth).object_storage

def _get_latest_conda_pack(self,
prefix,
python_version,
base_conda) -> str:
def _get_latest_conda_pack(self, prefix, python_version, base_conda) -> str:
"""
get the latest conda pack.
"""
try:
objects = self.object_storage.list_objects(namespace_name=CONDA_BUCKET_NS,
bucket_name=CONDA_BUCKET_NAME,
prefix=prefix).data.objects
objects = self.object_storage.list_objects(
namespace_name=CONDA_BUCKET_NS,
bucket_name=CONDA_BUCKET_NAME,
prefix=prefix,
).data.objects
py_str = python_version.replace(".", "")
py_filter = [obj for obj in objects if f"p{py_str}" in obj.name]

Expand All @@ -96,7 +92,6 @@ def extract_version(obj_name):
logger.warning(f"Error while fetching latest conda pack: {e}")
return base_conda


def init(
self,
uri: Union[str, None] = None,
Expand Down Expand Up @@ -135,9 +130,7 @@ def init(
if ":" in conda_slug:
base_conda = conda_slug.split(":")[0]
conda_slug = self._get_latest_conda_pack(
self.config["prefix"],
self.config["python_version"],
base_conda
self.config["prefix"], self.config["python_version"], base_conda
)
logger.info(f"Proceeding with the {conda_slug} conda pack.")

Expand Down Expand Up @@ -487,9 +480,9 @@ def run_diagnostics(self, cluster_info, dry_run=False, **kwargs):
print(f"Creating Job with payload: \n{self.job}")
print("+" * 200)

print(f"Creating Main Job Run with following details:")
print("Creating Main Job Run with following details:")
print(f"Name: {main_jobrun_conf['name']}")
print(f"Additional Environment Variables: ")
print("Additional Environment Variables: ")
main_env_Vars = main_jobrun_conf.get("envVars", {})
for k in main_env_Vars:
print(f"\t{k}:{main_env_Vars[k]}")
Expand Down Expand Up @@ -532,17 +525,17 @@ def run(self, cluster_info, dry_run=False) -> None:
"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
)

print(f"Creating Main Job Run with following details:")
print("Creating Main Job Run with following details:")
print(f"Name: {main_jobrun_conf['name']}")
print(f"Additional Environment Variables: ")
print("Additional Environment Variables: ")
main_env_Vars = main_jobrun_conf.get("envVars", {})
for k in main_env_Vars:
print(f"\t{k}:{main_env_Vars[k]}")
print(
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
)
if cluster_info.cluster.worker:
print(f"Creating Job Runs with following details:")
print("Creating Job Runs with following details:")
for i in range(len(worker_jobrun_conf_list)):
worker_jobrun_conf = worker_jobrun_conf_list[i]
print("Name: " + worker_jobrun_conf.get("name"))
Expand Down Expand Up @@ -686,7 +679,7 @@ def _adjust_python_runtime(self):
fp.write(f"python3 -m {self.operator_info.type}")

# copy the operator's source code to the temporary folder
shutil.copytree(
secure_copytree(
self.operator_info.path.rstrip("/"),
os.path.join(temp_dir, self.operator_info.type),
dirs_exist_ok=True,
Expand Down
28 changes: 13 additions & 15 deletions ads/opctl/operator/cmd.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
#!/usr/bin/env python

# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
# Copyright (c) 2023, 2026 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
import re
import runpy
import shutil
import tempfile
from typing import Any, Dict, Union

import fsspec
import yaml
from ads.opctl.operator.common.utils import print_traceback
from tabulate import tabulate

from ads.common import utils as ads_common_utils
Expand All @@ -33,22 +30,23 @@
from ads.opctl.constants import DEFAULT_ADS_CONFIG_FOLDER
from ads.opctl.decorator.common import validate_environment
from ads.opctl.operator.common.const import (
OPERATOR_BACKEND_SECTION_NAME,
OPERATOR_BASE_DOCKER_FILE,
OPERATOR_BASE_DOCKER_GPU_FILE,
OPERATOR_BASE_GPU_IMAGE,
OPERATOR_BASE_IMAGE,
OPERATOR_BACKEND_SECTION_NAME,
)
from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader
from ads.opctl.operator.common.utils import print_traceback
from ads.opctl.utils import publish_image as publish_image_cmd
from ads.opctl.utils import secure_copytree

from .__init__ import __operators__
from .common import utils as operator_utils
from .common.backend_factory import BackendFactory
from .common.errors import (
InvalidParameterError,
OperatorCondaNotFoundError,
OperatorImageNotFoundError,
InvalidParameterError,
)
from .common.operator_loader import _operator_info_list

Expand Down Expand Up @@ -105,7 +103,7 @@ def info(
readme_file_path = os.path.join(operator_info.path, "README.md")

if os.path.exists(readme_file_path):
with open(readme_file_path, "r") as readme_file:
with open(readme_file_path) as readme_file:
operator_readme = readme_file.read()

console.print(
Expand Down Expand Up @@ -153,7 +151,7 @@ def init(
"""
# validation
if not type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

# load operator info
operator_info: OperatorInfo = OperatorLoader.from_uri(uri=type).load()
Expand Down Expand Up @@ -255,7 +253,7 @@ def build_image(

# validation
if not type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

# load operator info
operator_info: OperatorInfo = OperatorLoader.from_uri(uri=type).load()
Expand Down Expand Up @@ -294,7 +292,7 @@ def build_image(
)

with tempfile.TemporaryDirectory() as td:
shutil.copytree(operator_info.path, os.path.join(td, "operator"))
secure_copytree(operator_info.path, os.path.join(td, "operator"))

run_command = [
f"FROM {base_image_name}",
Expand Down Expand Up @@ -359,7 +357,7 @@ def publish_image(

# validation
if not type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

client = docker.from_env()

Expand Down Expand Up @@ -410,7 +408,7 @@ def verify(

# validation
if not operator_type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

# load operator info
operator_info: OperatorInfo = OperatorLoader.from_uri(uri=operator_type).load()
Expand Down Expand Up @@ -475,7 +473,7 @@ def build_conda(

# validation
if not type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

# load operator info
operator_info: OperatorInfo = OperatorLoader.from_uri(uri=type).load()
Expand Down Expand Up @@ -529,7 +527,7 @@ def publish_conda(

# validation
if not type:
raise ValueError(f"The `type` attribute must be specified.")
raise ValueError("The `type` attribute must be specified.")

# load operator info
operator_info: OperatorInfo = OperatorLoader.from_uri(uri=type).load()
Expand Down
65 changes: 48 additions & 17 deletions ads/opctl/utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Copyright (c) 2022, 2026 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import functools
import logging
import os
import re
import shlex
import shutil
import subprocess
import sys
import shlex
import urllib.parse
from subprocess import Popen, PIPE, STDOUT
from typing import Union, List, Tuple, Dict
from pathlib import Path
from subprocess import PIPE, STDOUT, Popen
from typing import Dict, List, Tuple, Union

import yaml
import re

import ads
from ads.common.oci_client import OCIClientFactory
from ads.opctl import logger
from ads.opctl.constants import (
ML_JOB_IMAGE,
ML_JOB_GPU_IMAGE,
)
from ads.common.decorator.runtime_dependency import (
runtime_dependency,
OptionalDependency,
runtime_dependency,
)

from ads.common.oci_client import OCIClientFactory
from ads.opctl import logger
from ads.opctl.constants import ML_JOB_GPU_IMAGE, ML_JOB_IMAGE

CONTAINER_NETWORK = "CONTAINER_NETWORK"

Expand Down Expand Up @@ -67,7 +65,7 @@ def parse_conda_uri(uri: str) -> Tuple[str, str, str, str]:

def list_ads_operators() -> dict:
curr_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(curr_dir, "index.yaml"), "r") as f:
with open(os.path.join(curr_dir, "index.yaml")) as f:
ads_operators = yaml.safe_load(f.read())
return ads_operators or []

Expand Down Expand Up @@ -154,7 +152,7 @@ def build_image(image_type: str, gpu: bool = False) -> None:
# Just get the manufacturer of the processors
manufacturer = cpuinfo.get_cpu_info().get("brand_raw")
arch = (
"arm" if re.search("apple m\d", manufacturer, re.IGNORECASE) else "other"
"arm" if re.search(r"apple m\d", manufacturer, re.IGNORECASE) else "other"
)
print(f"The local machine's platform is {arch}.")
image, dockerfile, target = _get_image_name_dockerfile_target(
Expand Down Expand Up @@ -207,6 +205,39 @@ def run_command(
return proc


def _ensure_no_symlinks(root: Path) -> None:
"""Ensures the given directory tree does not contain symbolic links."""
for entry in root.rglob("*"):
if entry.is_symlink():
raise RuntimeError(f"Symbolic links are not allowed inside {root}: {entry}")


def secure_copytree(src: Union[str, Path], dst: Union[str, Path], **kwargs) -> None:
"""Safely copies a directory tree without following symbolic links.

Parameters
----------
src: Union[str, Path]
Source directory to copy from.
dst: Union[str, Path]
Destination directory to copy to.
kwargs: dict
Additional arguments forwarded to ``shutil.copytree``.
"""

src_path = Path(src).resolve(strict=True)
if not src_path.is_dir():
raise ValueError(f"Source path must be a directory: {src}")

_ensure_no_symlinks(src_path)

dst_path = Path(dst)
kwargs.setdefault("symlinks", True)
kwargs.setdefault("ignore_dangling_symlinks", True)
kwargs.setdefault("dirs_exist_ok", False)
shutil.copytree(src_path, dst_path, **kwargs)


class _DebugTraceback:
def __init__(self, debug):
self.cur_logging_level = logger.getEffectiveLevel()
Expand Down Expand Up @@ -327,7 +358,7 @@ def run_container(
detach=True,
entrypoint=entrypoint,
user=0,
**kwargs
**kwargs,
# auto_remove=True,
)
logger.info("Container ID: %s", container.id)
Expand Down
Loading
Loading