|
33 | 33 | from fsspec.asyn import AsyncFileSystem
|
34 | 34 | from fsspec.utils import get_protocol
|
35 | 35 | from obstore.fsspec import AsyncFsspecStore
|
36 |
| -from obstore.store import GCSStore, S3Store |
| 36 | +from obstore.store import GCSStore, S3Store, AzureStore |
37 | 37 | from typing_extensions import Unpack
|
38 | 38 |
|
39 | 39 | from flytekit import configuration
|
@@ -132,28 +132,41 @@ def split_path(path: str) -> Tuple[str, str]:
|
132 | 132 | bucket = path_li[0]
|
133 | 133 | # use obstore for s3 and gcs only now, no need to split
|
134 | 134 | # bucket out of path for other storage
|
135 |
| - support_types = ["s3", "gs"] |
| 135 | + support_types = ["s3", "gs", "abfs"] |
136 | 136 | if protocol in support_types:
|
137 | 137 | file_path = "/".join(path_li[1:])
|
138 | 138 | return (bucket, file_path)
|
139 | 139 | else:
|
140 | 140 | return bucket, path
|
141 | 141 |
|
142 | 142 |
|
143 |
| -def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: |
| 143 | +def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, container: str = "", anonymous: bool = False) -> Dict[str, Any]: |
144 | 144 | kwargs: Dict[str, Any] = {}
|
| 145 | + store_kwargs: Dict[str, Any] = {} |
145 | 146 |
|
146 | 147 | if azure_cfg.account_name:
|
147 |
| - kwargs["account_name"] = azure_cfg.account_name |
| 148 | + store_kwargs["account_name"] = azure_cfg.account_name |
148 | 149 | if azure_cfg.account_key:
|
149 |
| - kwargs["account_key"] = azure_cfg.account_key |
| 150 | + store_kwargs["account_key"] = azure_cfg.account_key |
150 | 151 | if azure_cfg.client_id:
|
151 |
| - kwargs["client_id"] = azure_cfg.client_id |
| 152 | + store_kwargs["client_id"] = azure_cfg.client_id |
152 | 153 | if azure_cfg.client_secret:
|
153 |
| - kwargs["client_secret"] = azure_cfg.client_secret |
| 154 | + store_kwargs["client_secret"] = azure_cfg.client_secret |
154 | 155 | if azure_cfg.tenant_id:
|
155 |
| - kwargs["tenant_id"] = azure_cfg.tenant_id |
156 |
| - kwargs[_ANON] = anonymous |
| 156 | + store_kwargs["tenant_id"] = azure_cfg.tenant_id |
| 157 | + |
| 158 | + store = AzureStore.from_env( |
| 159 | + container, |
| 160 | + config={ |
| 161 | + **store_kwargs, |
| 162 | + }, |
| 163 | + ) |
| 164 | + |
| 165 | + kwargs["store"] = store |
| 166 | + |
| 167 | + if anonymous: |
| 168 | + kwargs[_ANON] = True |
| 169 | + |
157 | 170 | return kwargs
|
158 | 171 |
|
159 | 172 |
|
@@ -273,7 +286,10 @@ def get_filesystem(
|
273 | 286 | gskwargs = gs_setup_args(self._data_config.gcs, bucket, anonymous=anonymous)
|
274 | 287 | gskwargs.update(kwargs)
|
275 | 288 | return fsspec.filesystem(protocol, **gskwargs) # type: ignore
|
276 |
| - # TODO: add azure |
| 289 | + elif protocol == "abfs": |
| 290 | + azkwargs = azure_setup_args(self._data_config.azure, bucket, anonymous=anonymous) |
| 291 | + azkwargs.update(kwargs) |
| 292 | + return fsspec.filesystem(protocol, **azkwargs) # type: ignore |
277 | 293 | elif protocol == "ftp":
|
278 | 294 | kwargs.update(fsspec.implementations.ftp.FTPFileSystem._get_kwargs_from_urls(path))
|
279 | 295 | return fsspec.filesystem(protocol, **kwargs)
|
@@ -713,6 +729,7 @@ async def async_put_data(
|
713 | 729 |
|
714 | 730 | fsspec.register_implementation("s3", AsyncFsspecStore)
|
715 | 731 | fsspec.register_implementation("gs", AsyncFsspecStore)
|
| 732 | +fsspec.register_implementation("abfs", AsyncFsspecStore) |
716 | 733 |
|
717 | 734 | flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-")
|
718 | 735 | default_local_file_access_provider = FileAccessProvider(
|
|
0 commit comments