Skip to content

Commit 7ba66e2

Browse files
committed
feat: use obstore for azure blob storage (abfs)
Signed-off-by: machichima <[email protected]>
1 parent caaa657 commit 7ba66e2

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

flytekit/core/data_persistence.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from fsspec.asyn import AsyncFileSystem
3434
from fsspec.utils import get_protocol
3535
from obstore.fsspec import AsyncFsspecStore
36-
from obstore.store import GCSStore, S3Store
36+
from obstore.store import GCSStore, S3Store, AzureStore
3737
from typing_extensions import Unpack
3838

3939
from flytekit import configuration
@@ -132,28 +132,41 @@ def split_path(path: str) -> Tuple[str, str]:
132132
bucket = path_li[0]
133133
# use obstore for s3 and gcs only now, no need to split
134134
# bucket out of path for other storage
135-
support_types = ["s3", "gs"]
135+
support_types = ["s3", "gs", "abfs"]
136136
if protocol in support_types:
137137
file_path = "/".join(path_li[1:])
138138
return (bucket, file_path)
139139
else:
140140
return bucket, path
141141

142142

143-
def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]:
143+
def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, container: str = "", anonymous: bool = False) -> Dict[str, Any]:
144144
kwargs: Dict[str, Any] = {}
145+
store_kwargs: Dict[str, Any] = {}
145146

146147
if azure_cfg.account_name:
147-
kwargs["account_name"] = azure_cfg.account_name
148+
store_kwargs["account_name"] = azure_cfg.account_name
148149
if azure_cfg.account_key:
149-
kwargs["account_key"] = azure_cfg.account_key
150+
store_kwargs["account_key"] = azure_cfg.account_key
150151
if azure_cfg.client_id:
151-
kwargs["client_id"] = azure_cfg.client_id
152+
store_kwargs["client_id"] = azure_cfg.client_id
152153
if azure_cfg.client_secret:
153-
kwargs["client_secret"] = azure_cfg.client_secret
154+
store_kwargs["client_secret"] = azure_cfg.client_secret
154155
if azure_cfg.tenant_id:
155-
kwargs["tenant_id"] = azure_cfg.tenant_id
156-
kwargs[_ANON] = anonymous
156+
store_kwargs["tenant_id"] = azure_cfg.tenant_id
157+
158+
store = AzureStore.from_env(
159+
container,
160+
config={
161+
**store_kwargs,
162+
},
163+
)
164+
165+
kwargs["store"] = store
166+
167+
if anonymous:
168+
kwargs[_ANON] = True
169+
157170
return kwargs
158171

159172

@@ -273,7 +286,10 @@ def get_filesystem(
273286
gskwargs = gs_setup_args(self._data_config.gcs, bucket, anonymous=anonymous)
274287
gskwargs.update(kwargs)
275288
return fsspec.filesystem(protocol, **gskwargs) # type: ignore
276-
# TODO: add azure
289+
elif protocol == "abfs":
290+
azkwargs = azure_setup_args(self._data_config.azure, bucket, anonymous=anonymous)
291+
azkwargs.update(kwargs)
292+
return fsspec.filesystem(protocol, **azkwargs) # type: ignore
277293
elif protocol == "ftp":
278294
kwargs.update(fsspec.implementations.ftp.FTPFileSystem._get_kwargs_from_urls(path))
279295
return fsspec.filesystem(protocol, **kwargs)
@@ -713,6 +729,7 @@ async def async_put_data(
713729

714730
fsspec.register_implementation("s3", AsyncFsspecStore)
715731
fsspec.register_implementation("gs", AsyncFsspecStore)
732+
fsspec.register_implementation("abfs", AsyncFsspecStore)
716733

717734
flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-")
718735
default_local_file_access_provider = FileAccessProvider(

0 commit comments

Comments
 (0)