Skip to content

Commit c9c0a3e

Browse files
authored
Aistudio 使用 SDK 进行下载 (#10678)
* aistudio SDK * test: transformer download util * style: pre-commit check * test * fix: remove redundant parameter * Update aistudio-sdk requirement
1 parent 7f5d719 commit c9c0a3e

4 files changed

Lines changed: 108 additions & 69 deletions

File tree

paddlenlp/transformers/aistudio_utils.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from typing import Optional
1616

17-
from aistudio_sdk.hub import download
17+
from aistudio_sdk.file_download import model_file_download as download
18+
from requests import HTTPError
1819

1920

2021
class UnauthorizedError(Exception):
@@ -46,22 +47,24 @@ def aistudio_download(
4647
if revision is not None:
4748
download_kwargs["revision"] = revision
4849
if cache_dir is not None:
49-
download_kwargs["cache_dir"] = cache_dir
50-
res = download(
51-
repo_id=repo_id,
52-
filename=filename,
53-
**download_kwargs,
54-
)
55-
if "path" in res:
56-
return res["path"]
57-
else:
58-
if res["error_code"] == 10001:
59-
raise ValueError("Illegal argument error")
60-
elif res["error_code"] == 10002:
61-
raise UnauthorizedError(
62-
"Unauthorized Access. Please ensure that you have provided the AIStudio Access Token and you have access to the requested asset"
63-
)
64-
elif res["error_code"] == 12001:
65-
raise EntryNotFoundError(f"Cannot find the requested file '{filename}' in repo '{repo_id}'")
66-
else:
67-
raise Exception(f"Unknown error: {res}")
50+
download_kwargs["local_dir"] = cache_dir
51+
52+
try:
53+
return download(
54+
repo_id=repo_id,
55+
file_path=filename,
56+
**download_kwargs,
57+
)
58+
except ValueError:
59+
raise EnvironmentError(
60+
f"Cannot find {filename} in the cached files and it looks like {repo_id} is not the path to a directory containing the {filename} or"
61+
" \nCheckout your internet connection or see how to run the library in offline mode."
62+
)
63+
except EntryNotFoundError:
64+
raise EnvironmentError(
65+
f"Cannot find the requested file {filename} in {repo_id}, please make sure the {filename} under the repo {repo_id}"
66+
)
67+
except HTTPError as err:
68+
raise EnvironmentError(f"There was a specific connection error when trying to load {repo_id}:\n{err}")
69+
except Exception:
70+
raise EnvironmentError(f"Please make sure the {filename} under the repo {repo_id}")

paddlenlp/utils/download/__init__.py

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@
3030
from paddle import __version__
3131
from requests import HTTPError
3232

33-
from .aistudio_hub_download import (
34-
aistudio_hub_download,
35-
aistudio_hub_file_exists,
36-
aistudio_hub_try_to_load_from_cache,
37-
)
3833
from .bos_download import bos_download, bos_file_exists, bos_try_to_load_from_cache
3934

4035

@@ -146,8 +141,8 @@ def resolve_file_path(
146141

147142
# check cache
148143
for filename in filenames:
149-
cache_file_name = bos_aistudio_hf_try_to_load_from_cache(
150-
repo_id, filename, cache_dir, subfolder, revision, repo_type, from_bos, from_aistudio, from_hf_hub
144+
cache_file_name = bos_hf_try_to_load_from_cache(
145+
repo_id, filename, cache_dir, subfolder, revision, repo_type, from_bos, from_hf_hub
151146
)
152147
if from_hf_hub and cache_file_name is _CACHED_NO_EXIST:
153148
cache_file_name = None
@@ -187,32 +182,25 @@ def resolve_file_path(
187182
return None
188183

189184
elif from_aistudio:
190-
log_endpoint = "Aistudio Hub"
191-
for filename in filenames:
192-
download_kwargs["filename"] = filename
193-
is_available = bos_aistudio_hf_file_exist(
194-
repo_id,
195-
filename,
196-
subfolder=subfolder,
197-
repo_type=repo_type,
198-
revision=revision,
199-
token=token,
200-
endpoint=endpoint,
201-
from_bos=from_bos,
202-
from_aistudio=from_aistudio,
203-
from_hf_hub=from_hf_hub,
204-
)
205-
if is_available:
206-
cached_file = aistudio_hub_download(
207-
**download_kwargs,
185+
for index, filename in enumerate(filenames):
186+
try:
187+
from aistudio_sdk.file_download import (
188+
model_file_download as aistudio_download,
208189
)
209-
if cached_file is not None:
210-
return cached_file
190+
191+
return aistudio_download(repo_id, filename, revision, local_files_only, local_dir)
192+
except Exception:
193+
if index < len(filenames) - 1:
194+
continue
195+
else:
196+
print(f"please make sure one of the {filenames} under the repo {repo_id}")
197+
return None
198+
211199
elif from_hf_hub:
212200
log_endpoint = "Huggingface Hub"
213201
for filename in filenames:
214202
download_kwargs["filename"] = filename
215-
is_available = bos_aistudio_hf_file_exist(
203+
is_available = bos_hf_file_exist(
216204
repo_id,
217205
filename,
218206
subfolder=subfolder,
@@ -221,7 +209,6 @@ def resolve_file_path(
221209
token=token,
222210
endpoint=endpoint,
223211
from_bos=from_bos,
224-
from_aistudio=from_aistudio,
225212
from_hf_hub=from_hf_hub,
226213
)
227214
if is_available:
@@ -235,7 +222,7 @@ def resolve_file_path(
235222
download_kwargs["url"] = url
236223
for filename in filenames:
237224
download_kwargs["filename"] = filename
238-
is_available = bos_aistudio_hf_file_exist(
225+
is_available = bos_hf_file_exist(
239226
repo_id,
240227
filename,
241228
subfolder=subfolder,
@@ -244,7 +231,6 @@ def resolve_file_path(
244231
token=token,
245232
endpoint=endpoint,
246233
from_bos=from_bos,
247-
from_aistudio=from_aistudio,
248234
from_hf_hub=from_hf_hub,
249235
)
250236
if is_available:
@@ -291,7 +277,7 @@ def resolve_file_path(
291277
)
292278

293279

294-
def bos_aistudio_hf_file_exist(
280+
def bos_hf_file_exist(
295281
repo_id: str,
296282
filename: str,
297283
*,
@@ -301,7 +287,6 @@ def bos_aistudio_hf_file_exist(
301287
token: Optional[str] = None,
302288
endpoint: Optional[str] = None,
303289
from_bos: bool = True,
304-
from_aistudio: bool = False,
305290
from_hf_hub: bool = False,
306291
):
307292
assert repo_id is not None, "repo_id cannot be None"
@@ -310,16 +295,7 @@ def bos_aistudio_hf_file_exist(
310295
if subfolder is None:
311296
subfolder = ""
312297
filename = os.path.join(subfolder, filename)
313-
if from_aistudio:
314-
out = aistudio_hub_file_exists(
315-
repo_id=repo_id,
316-
filename=filename,
317-
repo_type=repo_type,
318-
revision=revision,
319-
token=token,
320-
endpoint=endpoint,
321-
)
322-
elif from_hf_hub:
298+
if from_hf_hub:
323299
out = hf_hub_file_exists(
324300
repo_id=repo_id,
325301
filename=filename,
@@ -339,15 +315,14 @@ def bos_aistudio_hf_file_exist(
339315
return out
340316

341317

342-
def bos_aistudio_hf_try_to_load_from_cache(
318+
def bos_hf_try_to_load_from_cache(
343319
repo_id: str,
344320
filename: str,
345321
cache_dir: Union[str, Path, None] = None,
346322
subfolder: str = None,
347323
revision: Optional[str] = None,
348324
repo_type: Optional[str] = None,
349325
from_bos: bool = True,
350-
from_aistudio: bool = False,
351326
from_hf_hub: bool = False,
352327
):
353328
if subfolder is None:
@@ -359,9 +334,7 @@ def bos_aistudio_hf_try_to_load_from_cache(
359334
revision=revision,
360335
repo_type=repo_type,
361336
)
362-
if from_aistudio:
363-
return aistudio_hub_try_to_load_from_cache(**load_kwargs)
364-
elif from_hf_hub:
337+
if from_hf_hub:
365338
return hf_hub_try_to_load_from_cache(**load_kwargs)
366339
else:
367340
return bos_try_to_load_from_cache(**load_kwargs)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ typer
2121
rich
2222
safetensors
2323
fast_dataindex>=0.1.1 ; platform_system == "Linux"
24-
aistudio-sdk==0.2.6
24+
aistudio-sdk>=0.3.0
2525
jinja2
2626
regex
2727
numpy<=1.26.4
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from paddlenlp.transformers.aistudio_utils import aistudio_download
18+
from paddlenlp.utils.download import resolve_file_path
19+
20+
21+
class TestAistudioDownload(unittest.TestCase):
22+
def test_aistudio_download(self):
23+
# 设置测试数据
24+
repo_id = "PaddleNLP/DeepSeek-R1-Distill-Qwen-1.5B"
25+
filename = "model.safetensors"
26+
revision = "master"
27+
local_dir = "./local/model"
28+
29+
# 调用待测试的函数
30+
result = resolve_file_path(
31+
repo_id=repo_id,
32+
filenames=filename,
33+
revision=revision,
34+
from_aistudio=True,
35+
local_dir=local_dir,
36+
)
37+
38+
# 验证结果
39+
print(result)
40+
self.assertEqual(result, f"{local_dir}/{filename}")
41+
42+
def test_aistudio_download_transformer(self):
43+
# 设置测试数据
44+
repo_id = "PaddleNLP/DeepSeek-R1-Distill-Qwen-1.5B"
45+
filename = "model.safetensors"
46+
revision = "master"
47+
cache_dir = "./local/model"
48+
49+
# 调用待测试的函数
50+
result = aistudio_download(
51+
repo_id=repo_id,
52+
filename=filename,
53+
revision=revision,
54+
cache_dir=cache_dir,
55+
)
56+
57+
# 验证结果
58+
print(result)
59+
self.assertEqual(result, f"{cache_dir}/{filename}")
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()

0 commit comments

Comments
 (0)