diff --git a/Dockerfile.dev b/Dockerfile.dev index 1a839104e4..ea51253710 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -40,6 +40,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ -e /flytekit \ -e /flytekit/plugins/flytekit-deck-standard \ -e /flytekit/plugins/flytekit-flyteinteractive \ + obstore==0.6.0 \ markdown \ pandas \ pillow \ diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 25d1ddc688..59301fa169 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -24,6 +24,7 @@ import pathlib import tempfile import typing +from datetime import timedelta from time import sleep from typing import Any, Dict, Optional, Union, cast from uuid import UUID @@ -32,23 +33,28 @@ from decorator import decorator from fsspec.asyn import AsyncFileSystem from fsspec.utils import get_protocol +from obstore.exceptions import GenericError +from obstore.fsspec import register from typing_extensions import Unpack from flytekit import configuration from flytekit.configuration import DataConfig from flytekit.core.local_fsspec import FlyteLocalFileSystem from flytekit.core.utils import timeit -from flytekit.exceptions.system import FlyteDownloadDataException, FlyteUploadDataException +from flytekit.exceptions.system import ( + FlyteDownloadDataException, + FlyteUploadDataException, +) from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException from flytekit.interfaces.random import random from flytekit.loggers import logger from flytekit.utils.asyn import loop_manager -# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11 # for key and secret -_FSSPEC_S3_KEY_ID = "key" -_FSSPEC_S3_SECRET = "secret" -_ANON = "anon" +_FSSPEC_S3_KEY_ID = "access_key_id" +_FSSPEC_S3_SECRET = "secret_access_key" +_SKIP_SIGNATURE = "skip_signature" Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO] @@ -57,58 +63,102 @@ _WRITE_SIZE_CHUNK_BYTES = int(os.environ.get("_F_P_WRITE_CHUNK_SIZE", "26214400")) # 25 * 2**20 -def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]: - kwargs: Dict[str, Any] = { - "cache_regions": True, - } - if s3_cfg.access_key_id: - kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False, **kwargs) -> Dict[str, Any]: + """ + Setup s3 storage, bucket is needed to create obstore store object + """ + + config: Dict[str, Any] = {} - if s3_cfg.secret_access_key: - kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key + if _FSSPEC_S3_KEY_ID in kwargs or s3_cfg.access_key_id: + config[_FSSPEC_S3_KEY_ID] = kwargs.pop(_FSSPEC_S3_KEY_ID, s3_cfg.access_key_id) + if _FSSPEC_S3_SECRET in kwargs or s3_cfg.secret_access_key: + config[_FSSPEC_S3_SECRET] = kwargs.pop(_FSSPEC_S3_SECRET, s3_cfg.secret_access_key) + if "endpoint_url" in kwargs or s3_cfg.endpoint: + config["endpoint_url"] = kwargs.pop("endpoint_url", s3_cfg.endpoint) - # S3fs takes this as a special arg - if s3_cfg.endpoint is not None: - kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} + retries = kwargs.pop("retries", s3_cfg.retries) + backoff = kwargs.pop("backoff", s3_cfg.backoff) if anonymous: - kwargs[_ANON] = True + config[_SKIP_SIGNATURE] = True + + retry_config = { + "max_retries": retries, + "backoff": { + "base": 2, + "init_backoff": backoff, + "max_backoff": timedelta(seconds=16), + }, + "retry_timeout": timedelta(minutes=3), + } + + client_options = {"timeout": "99999s", "allow_http": True} + + if config: + kwargs["config"] = config + kwargs["client_options"] = client_options or None + kwargs["retry_config"] = retry_config or None return kwargs -def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: - kwargs: Dict[str, Any] = {} - - if azure_cfg.account_name: - kwargs["account_name"] = azure_cfg.account_name - if azure_cfg.account_key: - kwargs["account_key"] = azure_cfg.account_key - if azure_cfg.client_id: - kwargs["client_id"] = azure_cfg.client_id - if azure_cfg.client_secret: - kwargs["client_secret"] = azure_cfg.client_secret - if azure_cfg.tenant_id: - kwargs["tenant_id"] = azure_cfg.tenant_id - kwargs[_ANON] = anonymous +def azure_setup_args( + azure_cfg: configuration.AzureBlobStorageConfig, + anonymous: bool = False, + **kwargs, +) -> Dict[str, Any]: + """ + Setup azure blob storage, bucket is needed to create obstore store object + """ + + config: Dict[str, Any] = {} + + if "account_name" in kwargs or azure_cfg.account_name: + config["account_name"] = kwargs.get("account_name", azure_cfg.account_name) + if "account_key" in kwargs or azure_cfg.account_key: + config["account_key"] = kwargs.get("account_key", azure_cfg.account_key) + if "client_id" in kwargs or azure_cfg.client_id: + config["client_id"] = kwargs.get("client_id", azure_cfg.client_id) + if "client_secret" in kwargs or azure_cfg.client_secret: + config["client_secret"] = kwargs.get("client_secret", azure_cfg.client_secret) + if "tenant_id" in kwargs or azure_cfg.tenant_id: + config["tenant_id"] = kwargs.get("tenant_id", azure_cfg.tenant_id) + + if anonymous: + config[_SKIP_SIGNATURE] = True + + client_options = {"timeout": "99999s", "allow_http": "true"} + + if config: + kwargs["config"] = config + kwargs["client_options"] = client_options + return kwargs def get_fsspec_storage_options( - protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs + protocol: str, + data_config: typing.Optional[DataConfig] = None, + anonymous: bool = False, + **kwargs, ) -> Dict[str, Any]: data_config = data_config or DataConfig.auto() if protocol == "file": return {"auto_mkdir": True, **kwargs} if protocol == "s3": - return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs} + return { + **s3_setup_args(data_config.s3, anonymous=anonymous, **kwargs), + **kwargs, + } if protocol == "gs": - if anonymous: - kwargs["token"] = _ANON return kwargs if protocol in ("abfs", "abfss"): - return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} + return { + **azure_setup_args(data_config.azure, anonymous=anonymous, **kwargs), + **kwargs, + } return {} @@ -222,19 +272,24 @@ def get_filesystem( kwargs["auto_mkdir"] = True return FlyteLocalFileSystem(**kwargs) elif protocol == "s3": - s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous, **kwargs) s3kwargs.update(kwargs) return fsspec.filesystem(protocol, **s3kwargs) # type: ignore elif protocol == "gs": - if anonymous: - kwargs["token"] = _ANON return fsspec.filesystem(protocol, **kwargs) # type: ignore + elif protocol in ("abfs", "abfss"): + azkwargs = azure_setup_args(self._data_config.azure, anonymous=anonymous, **kwargs) + azkwargs.update(kwargs) + return fsspec.filesystem(protocol, **azkwargs) # type: ignore elif protocol == "ftp": kwargs.update(fsspec.implementations.ftp.FTPFileSystem._get_kwargs_from_urls(path)) return fsspec.filesystem(protocol, **kwargs) storage_options = get_fsspec_storage_options( - protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs + protocol=protocol, + anonymous=anonymous, + data_config=self._data_config, + **kwargs, ) kwargs.update(storage_options) @@ -246,7 +301,14 @@ async def get_async_filesystem_for_path( protocol = get_protocol(path) loop = asyncio.get_running_loop() - return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs) + return self.get_filesystem( + protocol, + anonymous=anonymous, + path=path, + asynchronous=True, + loop=loop, + **kwargs, + ) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) @@ -328,7 +390,9 @@ async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwa import shutil return shutil.copytree( - self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + self.strip_file_header(from_path), + self.strip_file_header(to_path), + dirs_exist_ok=True, ) logger.info(f"Getting {from_path} to {to_path}") if isinstance(file_system, AsyncFileSystem): @@ -338,10 +402,15 @@ async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwa if isinstance(dst, (str, pathlib.Path)): return dst return to_path - except OSError as oe: + except (OSError, GenericError) as oe: logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") if isinstance(file_system, AsyncFileSystem): - exists = await file_system._exists(from_path) # pylint: disable=W0212 + try: + exists = await file_system._exists(from_path) # pylint: disable=W0212 + except GenericError: + # for obstore, as it does not raise FileNotFoundError in fsspec but GenericError + # force it to try get_filesystem(anonymous=True) + exists = True else: exists = file_system.exists(from_path) if not exists: @@ -371,7 +440,9 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw import shutil return shutil.copytree( - self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + self.strip_file_header(from_path), + self.strip_file_header(to_path), + dirs_exist_ok=True, ) from_path, to_path = self.recursive_paths(from_path, to_path) if self._execution_metadata: @@ -633,7 +704,11 @@ async def async_get_data(self, remote_path: str, local_path: str, is_multipart: get_data = loop_manager.synced(async_get_data) async def async_put_data( - self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs + self, + local_path: Union[str, os.PathLike], + remote_path: str, + is_multipart: bool = False, + **kwargs, ) -> str: """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure @@ -664,6 +739,9 @@ async def async_put_data( put_data = loop_manager.synced(async_put_data) +register(["s3", "gs", "abfs", "abfss"], asynchronous=True) + + flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( local_sandbox_dir=os.path.join(flyte_tmp_dir, "sandbox"), diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index f9f5d536a6..88a7896d24 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -37,7 +37,11 @@ def get_pandas_storage_options( from pandas.io.common import is_fsspec_url if is_fsspec_url(uri): - return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) + return get_fsspec_storage_options( + protocol=get_protocol(uri), + data_config=data_config, + anonymous=anonymous, + ) # Pandas does not allow storage_options for non-fsspec paths e.g. local. return None diff --git a/plugins/flytekit-async-fsspec/setup.py b/plugins/flytekit-async-fsspec/setup.py index 414658365a..13aca40f51 100644 --- a/plugins/flytekit-async-fsspec/setup.py +++ b/plugins/flytekit-async-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = "flytekitplugins-async-fsspec" -plugin_requires = ["flytekit"] +plugin_requires = ["flytekit", "s3fs>=2023.3.0,!=2024.3.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index e6359641ca..1924f57d84 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -91,7 +91,6 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: uri = flyte_value.uri - kwargs = get_fsspec_storage_options( protocol=fsspec_utils.get_protocol(uri), data_config=ctx.file_access.data_config, @@ -153,7 +152,6 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pl.LazyFrame: uri = flyte_value.uri - kwargs = get_fsspec_storage_options( protocol=fsspec_utils.get_protocol(uri), data_config=ctx.file_access.data_config, diff --git a/pyproject.toml b/pyproject.toml index 1ca3ad783e..37f5aaa6d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.9,<3.13" dependencies = [ # Please maintain an alphabetical order in the following list - "adlfs>=2023.3.0", + "aiohttp>=3.11.13", + "botocore>=1.37.15", "click>=6.6", "cloudpickle>=2.0.0", "croniter>=0.3.20", @@ -22,7 +23,6 @@ dependencies = [ "docstring-parser>=0.9.0", "flyteidl>=1.15.1", "fsspec>=2023.3.0", - "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", # Skipping those versions to account for the unwanted output coming from grpcio and grpcio-status. # Issue being tracked in https://github.com/flyteorg/flyte/issues/6082. @@ -38,6 +38,7 @@ dependencies = [ "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.15", "msgpack>=1.1.0", + "obstore==0.6.0", "protobuf!=4.25.0", "pygments", "python-json-logger>=2.0.0", @@ -46,7 +47,6 @@ dependencies = [ "requests>=2.18.4", "rich", "rich_click", - "s3fs>=2023.3.0,!=2024.3.1", "statsd>=3.0.0", "typing_extensions", "urllib3>=1.22", diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 8388bb77c6..ae97212e82 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -428,7 +428,7 @@ def test_setup_for_fast_register(): @mock.patch("google.auth.compute_engine._metadata") def test_setup_cloud_prefix(mock_gcs): with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert ctx.file_access._default_remote.protocol[0] == "s3" + assert ctx.file_access._default_remote.protocol == "s3" with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: assert "gs" in ctx.file_access._default_remote.protocol diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 42e74f453c..f69041236f 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -1,3 +1,4 @@ +from datetime import timedelta import os import random import shutil @@ -5,14 +6,21 @@ from uuid import UUID import typing import asyncio +from botocore.parsers import base64 import fsspec import mock import pytest -from s3fs import S3FileSystem +from obstore.fsspec import FsspecStore from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager, FlyteContext -from flytekit.core.data_persistence import FileAccessProvider, get_fsspec_storage_options, s3_setup_args +from flytekit.core.data_persistence import ( + FileAccessProvider, + get_fsspec_storage_options, + s3_setup_args, + _FSSPEC_S3_KEY_ID, + _FSSPEC_S3_SECRET, +) from flytekit.core.type_engine import TypeEngine from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile @@ -32,15 +40,21 @@ def test_path_getting(mock_uuid_class, mock_gcs): # Testing with raw output prefix pointing to a local path loc_sandbox = os.path.join(root, "tmp", "unittest") loc_data = os.path.join(root, "tmp", "unittestdata") - local_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data) + local_raw_fp = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data + ) r = local_raw_fp.get_random_string() rr = local_raw_fp.join(local_raw_fp.raw_output_prefix, r) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123") - rr = local_raw_fp.join(local_raw_fp.raw_output_prefix, r, local_raw_fp.get_file_tail("/fsa/blah.csv")) + rr = local_raw_fp.join( + local_raw_fp.raw_output_prefix, r, local_raw_fp.get_file_tail("/fsa/blah.csv") + ) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123", "blah.csv") # Test local path and directory - assert local_raw_fp.get_random_local_path() == os.path.join(root, "tmp", "unittest", "local_flytekit", "abcdef123") + assert local_raw_fp.get_random_local_path() == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123" + ) assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == os.path.join( root, "tmp", "unittest", "local_flytekit", "abcdef123", "blah.txt" ) @@ -49,20 +63,28 @@ def test_path_getting(mock_uuid_class, mock_gcs): ) # Recursive paths - assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + assert ( + "file:///abc/happy/" + ), "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" ) - assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + assert ( + "file:///abc/happy/" + ), "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( "file:///abc/happy", "s3://my-s3-bucket/bucket1" ) # Test with remote pointed to s3. - s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket") + s3_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket" + ) r = s3_fa.get_random_string() rr = s3_fa.join(s3_fa.raw_output_prefix, r) assert rr == "s3://my-s3-bucket/abcdef123" # trailing slash should make no difference - s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/") + s3_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/" + ) r = s3_fa.get_random_string() rr = s3_fa.join(s3_fa.raw_output_prefix, r) assert rr == "s3://my-s3-bucket/abcdef123" @@ -70,17 +92,23 @@ def test_path_getting(mock_uuid_class, mock_gcs): # Testing with raw output prefix pointing to file:// # Skip tests for windows if os.name != "nt": - file_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata") + file_raw_fp = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata" + ) r = file_raw_fp.get_random_string() rr = file_raw_fp.join(file_raw_fp.raw_output_prefix, r) rr = file_raw_fp.strip_file_header(rr) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123") r = file_raw_fp.get_random_string() - rr = file_raw_fp.join(file_raw_fp.raw_output_prefix, r, file_raw_fp.get_file_tail("/fsa/blah.csv")) + rr = file_raw_fp.join( + file_raw_fp.raw_output_prefix, r, file_raw_fp.get_file_tail("/fsa/blah.csv") + ) rr = file_raw_fp.strip_file_header(rr) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123", "blah.csv") - g_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/") + g_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/" + ) r = g_fa.get_random_string() rr = g_fa.join(g_fa.raw_output_prefix, r) assert rr == "gs://my-s3-bucket/abcdef123" @@ -119,7 +147,11 @@ async def test_local_provider(source_folder): # dest folder exists. dc = Config.for_sandbox().data_config with tempfile.TemporaryDirectory() as dest_tmpdir: - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=dest_tmpdir, + data_config=dc, + ) r = provider.get_random_string() doesnotexist = provider.join(provider.raw_output_prefix, r) await provider.async_put_data(source_folder, doesnotexist, is_multipart=True) @@ -139,16 +171,22 @@ async def test_async_file_system(): remote_path = "test:///tmp/test.py" local_path = "test.py" - class MockAsyncFileSystem(S3FileSystem): + class MockAsyncFileSystem(FsspecStore): + protocol = "test" + asynchronous = True def __init__(self, *args, **kwargs): - super().__init__(args, kwargs) + super().__init__(*args, **kwargs) async def _put_file(self, *args, **kwargs): - # s3fs._put_file returns None as well + # FsspecStore._put_file returns None as well return None + async def _isdir(self, *args, **kwargs): + # Return False indicating not directory here + return False + async def _get_file(self, *args, **kwargs): - # s3fs._get_file returns None as well + # FsspecStore._get_file returns None as well return None async def _lsdir( @@ -176,9 +214,13 @@ def test_s3_provider(source_folder): # Running mkdir on s3 filesystem doesn't do anything so leaving out for now dc = Config.for_sandbox().data_config provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix="s3://my-s3-bucket/testdata/", + data_config=dc, + ) + doesnotexist = provider.join( + provider.raw_output_prefix, provider.get_random_string() ) - doesnotexist = provider.join(provider.raw_output_prefix, provider.get_random_string()) provider.put_data(source_folder, doesnotexist, is_multipart=True) fs = provider.get_filesystem_for_path(doesnotexist) files = fs.find(doesnotexist) @@ -190,7 +232,9 @@ def test_local_provider_get_empty(): with tempfile.TemporaryDirectory() as empty_source: with tempfile.TemporaryDirectory() as dest_folder: provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=empty_source, + data_config=dc, ) provider.get_data(empty_source, dest_folder, is_multipart=True) loc = provider.get_filesystem_for_path(dest_folder) @@ -207,13 +251,14 @@ def test_s3_setup_args_env_empty(mock_os, mock_get_config_file): mock_os.get.return_value = None s3c = S3Config.auto() kwargs = s3_setup_args(s3c) - assert kwargs == {"cache_regions": True} + assert all(key in kwargs for key in ("client_options", "retry_config")) @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") def test_s3_setup_args_env_both(mock_os, mock_get_config_file): mock_get_config_file.return_value = None + ee = { "AWS_ACCESS_KEY_ID": "ignore-user", "AWS_SECRET_ACCESS_KEY": "ignore-secret", @@ -222,7 +267,11 @@ def test_s3_setup_args_env_both(mock_os, mock_get_config_file): } mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + assert kwargs["config"] == { + _FSSPEC_S3_KEY_ID: "flyte", + _FSSPEC_S3_SECRET: "flyte-secret", + } @mock.patch("flytekit.configuration.get_config_file") @@ -235,7 +284,11 @@ def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): } mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + assert kwargs["config"] == { + _FSSPEC_S3_KEY_ID: "flyte", + _FSSPEC_S3_SECRET: "flyte-secret", + } @mock.patch("flytekit.configuration.get_config_file") @@ -249,7 +302,7 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default - assert kwargs == {"cache_regions": True} + assert "config" not in kwargs @mock.patch("flytekit.configuration.get_config_file") @@ -272,51 +325,61 @@ def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_ "FLYTE_GCP_GSUTIL_PARALLELISM": "False", } mock_os.get.side_effect = lambda x, y: ee.get(x) - storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") - assert storage_options == {"token": "anon", "other_argument": "value"} + storage_options = get_fsspec_storage_options( + "gs", DataConfig.auto(), anonymous=True, other_argument="value" + ) + assert "other_argument" in storage_options @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file): mock_get_config_file.return_value = None + account_key = "accountkey" + + account_key_base64 = base64.b64encode(account_key.encode()).decode() + ee = { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", - "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": account_key_base64, "FLYTE_AZURE_TENANT_ID": "tenantid", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", } mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) - assert storage_options == { - "account_name": "accountname", - "account_key": "accountkey", - "client_id": "clientid", - "client_secret": "clientsecret", - "tenant_id": "tenantid", - "anon": False, - } + + assert storage_options["config"]["account_name"] == "accountname" + assert storage_options["config"]["account_key"] == account_key_base64 + assert storage_options["config"]["client_id"] == "clientid" + assert storage_options["config"]["client_secret"] == "clientsecret" + assert storage_options["config"]["tenant_id"] == "tenantid" @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file): mock_get_config_file.return_value = None + + account_key = "accountkey" + account_key_base64 = base64.b64encode(account_key.encode()).decode() + ee = { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", - "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": account_key_base64, } mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options( - "abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value" + "abfs", + DataConfig.auto(), + anonymous=True, + account_name="other_accountname", + other_argument="value", ) - assert storage_options == { - "account_name": "other_accountname", - "account_key": "accountkey", - "anon": True, - "other_argument": "value", - } + + assert storage_options["config"]["account_name"] == "other_accountname" + assert storage_options["config"]["account_key"] == account_key_base64 + assert storage_options["config"]["skip_signature"] == True def test_crawl_local_nt(source_folder): @@ -352,8 +415,14 @@ def test_crawl_local_non_nt(source_folder): res = fd.crawl() split = [(x, y) for x, y in res] files = [os.path.join(x, y) for x, y in split] - assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} - expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(split) == { + (source_folder, "original.txt"), + (source_folder, os.path.join("nested", "more.txt")), + } + expected = { + os.path.join(source_folder, "original.txt"), + os.path.join(source_folder, "nested", "more.txt"), + } assert set(files) == expected # Test crawling a directory without trailing / or \ @@ -379,12 +448,19 @@ def test_crawl_s3(source_folder): # Running mkdir on s3 filesystem doesn't do anything so leaving out for now dc = Config.for_sandbox().data_config provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix="s3://my-s3-bucket/testdata/", + data_config=dc, + ) + s3_random_target = provider.join( + provider.raw_output_prefix, provider.get_random_string() ) - s3_random_target = provider.join(provider.raw_output_prefix, provider.get_random_string()) provider.put_data(source_folder, s3_random_target, is_multipart=True) ctx = FlyteContextManager.current_context() - expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + expected = { + f"{s3_random_target}/original.txt", + f"{s3_random_target}/nested/more.txt", + } with FlyteContextManager.with_context(ctx.with_file_access(provider)): fd = FlyteDirectory(path=s3_random_target) @@ -392,7 +468,10 @@ def test_crawl_s3(source_folder): res = [(x, y) for x, y in res] files = [os.path.join(x, y) for x, y in res] assert set(files) == expected - assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + assert set(res) == { + (s3_random_target, "original.txt"), + (s3_random_target, os.path.join("nested", "more.txt")), + } fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") res = fd_file.crawl() @@ -405,7 +484,11 @@ def test_walk_local_copy_to_s3(source_folder): dc = Config.for_sandbox().data_config explicit_empty_folder = UUID(int=random.getrandbits(128)).hex raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=raw_output_path, + data_config=dc, + ) ctx = FlyteContextManager.current_context() local_fd = FlyteDirectory(path=source_folder) @@ -433,7 +516,9 @@ def test_s3_metadata(): dc = Config.for_sandbox().data_config random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/metadata_test/{random_folder}" - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) _, local_zip = tempfile.mkstemp(suffix=".gz") with open(local_zip, "w") as f: f.write("hello world") @@ -454,7 +539,9 @@ def test_s3_metadata(): assert len(files) == 2 -async def dummy_output_to_literal_map(ctx: FlyteContext, ff: typing.List[FlyteFile]) -> Literal: +async def dummy_output_to_literal_map( + ctx: FlyteContext, ff: typing.List[FlyteFile] +) -> Literal: lt = TypeEngine.to_literal_type(typing.List[FlyteFile]) lit = await TypeEngine.async_to_literal(ctx, ff, typing.List[FlyteFile], lt) return lit @@ -479,7 +566,9 @@ def test_async_local_copy_to_s3(): random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" print(f"Uploading to {raw_output}") - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) start_time = datetime.datetime.now(datetime.timezone.utc) start_wall_time = time.perf_counter() @@ -522,10 +611,17 @@ def test_async_download_from_s3(): random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" print(f"Uploading to {raw_output}") - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: - lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + lit = TypeEngine.to_literal( + ctx, + ff, + typing.List[FlyteFile], + TypeEngine.to_literal_type(typing.List[FlyteFile]), + ) print(f"Literal is {lit}") python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) @@ -545,10 +641,17 @@ def test_async_download_from_s3(): print(f"Time taken (serial download): {end_time - start_time}") print(f"Wall time taken (serial download): {end_wall_time - start_wall_time}") - print(f"Process time taken (serial download): {end_process_time - start_process_time}") + print( + f"Process time taken (serial download): {end_process_time - start_process_time}" + ) with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: - lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + lit = TypeEngine.to_literal( + ctx, + ff, + typing.List[FlyteFile], + TypeEngine.to_literal_type(typing.List[FlyteFile]), + ) print(f"Literal is {lit}") python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index ab56f5d07d..f83bfd6f26 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -9,7 +9,7 @@ import fsspec import mock import pytest -from azure.identity import ClientSecretCredential, DefaultAzureCredential +from botocore.parsers import base64 from mock import AsyncMock from flytekit.configuration import Config @@ -155,14 +155,17 @@ def test_generate_new_custom_path(): def test_initialise_azure_file_provider_with_account_key(): + account_key = "accountkey" + account_key_base64 = base64.b64encode(account_key.encode()).decode() + with mock.patch.dict( os.environ, - {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, + {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": account_key_base64}, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert fp.get_filesystem().account_key == "accountkey" - assert fp.get_filesystem().sync_credential is None + + assert fp.get_filesystem().config["account_name"] == "accountname" + assert fp.get_filesystem().config["account_key"] == account_key_base64 def test_initialise_azure_file_provider_with_service_principal(): @@ -176,11 +179,11 @@ def test_initialise_azure_file_provider_with_service_principal(): }, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) - assert fp.get_filesystem().client_secret == "clientsecret" - assert fp.get_filesystem().client_id == "clientid" - assert fp.get_filesystem().tenant_id == "tenantid" + + assert fp.get_filesystem().config["account_name"] == "accountname" + assert fp.get_filesystem().config["client_secret"] == "clientsecret" + assert fp.get_filesystem().config["client_id"] == "clientid" + assert fp.get_filesystem().config["tenant_id"] == "tenantid" def test_initialise_azure_file_provider_with_default_credential(): @@ -192,8 +195,8 @@ def test_initialise_azure_file_provider_with_default_credential(): }, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) + + assert fp.get_filesystem().config["account_name"] == "accountname" def test_get_file_system(): diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index fdb12e1dae..0ba355b8e4 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -319,11 +319,11 @@ def test_directory_guess(): assert fft.extension() == "" -@mock.patch("s3fs.core.S3FileSystem._lsdir") +@mock.patch("obstore.fsspec.FsspecStore.listdir") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -def test_list_dir(mock_get_data, mock_lsdir): +def test_list_dir(mock_get_data, mock_ls): remote_dir = "s3://test-flytedir" - mock_lsdir.return_value = [ + mock_ls.return_value = [ {"name": os.path.join(remote_dir, "file1.txt"), "type": "file"}, {"name": os.path.join(remote_dir, "file2.txt"), "type": "file"}, {"name": os.path.join(remote_dir, "subdir"), "type": "directory"}, diff --git a/tests/flytekit/unit/remote/test_fs_remote.py b/tests/flytekit/unit/remote/test_fs_remote.py index 5c635376b4..274ad7bc45 100644 --- a/tests/flytekit/unit/remote/test_fs_remote.py +++ b/tests/flytekit/unit/remote/test_fs_remote.py @@ -7,6 +7,7 @@ import fsspec import pytest from fsspec.implementations.http import HTTPFileSystem +from obstore.fsspec import register from flytekit.configuration import Config from flytekit.core.data_persistence import FileAccessProvider @@ -118,6 +119,8 @@ def test_remote_upload_with_data_persistence(sandbox_remote): @pytest.mark.parametrize("url_prefix", ["s3://my-s3-bucket", "abfs://my-azure-container", "abfss://my-azure-container", "gcs://my-gcs-bucket"]) def test_common_matching(url_prefix): + # ensure all url_prefix are registered + register(["s3", "abfs", "abfss", "gcs"]) urls = [ url_prefix + url_suffix for url_suffix in [