-
Notifications
You must be signed in to change notification settings - Fork 963
Expand file tree
/
Copy pathtest_xet_upload.py
More file actions
509 lines (426 loc) · 21.3 KB
/
test_xet_upload.py
File metadata and controls
509 lines (426 loc) · 21.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from huggingface_hub import HfApi, RepoUrl
from huggingface_hub._commit_api import CommitOperationAdd, _upload_files, _upload_lfs_files, _upload_xet_files
from huggingface_hub.file_download import (
_get_metadata_or_catch_error,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
)
from huggingface_hub.utils import XetConnectionInfo, build_hf_headers, refresh_xet_connection_info
from .testing_constants import ENDPOINT_STAGING, TOKEN
from .testing_utils import repo_name, requires
@contextmanager
def assert_upload_mode(mode: str):
if mode not in ("xet", "lfs"):
raise ValueError("Mode must be either 'xet' or 'lfs'")
with patch("huggingface_hub._commit_api._upload_xet_files", wraps=_upload_xet_files) as mock_xet:
with patch("huggingface_hub._commit_api._upload_lfs_files", wraps=_upload_lfs_files) as mock_lfs:
yield
assert mock_xet.called == (mode == "xet"), (
f"Expected {'XET' if mode == 'xet' else 'LFS'} upload to be used"
)
assert mock_lfs.called == (mode == "lfs"), (
f"Expected {'LFS' if mode == 'lfs' else 'XET'} upload to be used"
)
@pytest.fixture(scope="module")
def api():
return HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
@pytest.fixture
def repo_url(api, repo_type: str = "model"):
repo_url = api.create_repo(repo_id=repo_name(prefix=repo_type), repo_type=repo_type)
yield repo_url
api.delete_repo(repo_id=repo_url.repo_id, repo_type=repo_type)
@pytest.fixture
def xet_setup(request, tmp_path):
instance = getattr(request, "instance", None)
if instance is None:
yield
return
instance.folder_path = tmp_path
# Create a regular text file
text_file = instance.folder_path / "text_file.txt"
instance.text_content = "This is a regular text file"
text_file.write_text(instance.text_content)
# Create a binary file
instance.bin_file = instance.folder_path / "binary_file.bin"
instance.bin_content = b"0" * (1 * 1024 * 1024)
instance.bin_file.write_bytes(instance.bin_content)
# Create nested directory structure
nested_dir = instance.folder_path / "nested"
nested_dir.mkdir()
# Create a nested text file
nested_text_file = nested_dir / "nested_text.txt"
instance.nested_text_content = "This is a nested text file"
nested_text_file.write_text(instance.nested_text_content)
# Create a nested binary file
nested_bin_file = nested_dir / "nested_binary.safetensors"
instance.nested_bin_content = b"1" * (1 * 1024 * 1024)
nested_bin_file.write_bytes(instance.nested_bin_content)
yield
@requires("hf_xet")
@pytest.mark.usefixtures("xet_setup")
class TestXetUpload:
def test_upload_file(self, api, tmp_path, repo_url):
filename_in_repo = "binary_file.bin"
repo_id = repo_url.repo_id
with assert_upload_mode("xet"):
return_val = api.upload_file(
path_or_fileobj=self.bin_file,
path_in_repo=filename_in_repo,
repo_id=repo_id,
)
assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit")
# Download and verify content
downloaded_file = hf_hub_download(repo_id=repo_id, filename=filename_in_repo, cache_dir=tmp_path)
with open(downloaded_file, "rb") as f:
downloaded_content = f.read()
assert downloaded_content == self.bin_content
# Check xet metadata
url = hf_hub_url(
repo_id=repo_id,
filename=filename_in_repo,
)
metadata = get_hf_file_metadata(url)
assert metadata.xet_file_data is not None
xet_connection = refresh_xet_connection_info(file_data=metadata.xet_file_data, headers={})
assert xet_connection is not None
def test_upload_file_with_bytesio(self, api, tmp_path, repo_url):
repo_id = repo_url.repo_id
content = BytesIO(self.bin_content)
with assert_upload_mode("lfs"):
api.upload_file(
path_or_fileobj=content,
path_in_repo="bytesio_file.bin",
repo_id=repo_id,
)
# Download and verify content
downloaded_file = hf_hub_download(repo_id=repo_id, filename="bytesio_file.bin", cache_dir=tmp_path)
with open(downloaded_file, "rb") as f:
downloaded_content = f.read()
assert downloaded_content == self.bin_content
def test_upload_file_with_byte_array(self, api, tmp_path, repo_url):
repo_id = repo_url.repo_id
content = self.bin_content
with assert_upload_mode("xet"):
api.upload_file(
path_or_fileobj=content,
path_in_repo="bytearray_file.bin",
repo_id=repo_id,
)
# Download and verify content
downloaded_file = hf_hub_download(repo_id=repo_id, filename="bytearray_file.bin", cache_dir=tmp_path)
with open(downloaded_file, "rb") as f:
downloaded_content = f.read()
assert downloaded_content == self.bin_content
def test_fallback_to_lfs_when_xet_not_available(self, api, repo_url):
repo_id = repo_url.repo_id
with patch("huggingface_hub._commit_api.is_xet_available", return_value=False):
with assert_upload_mode("lfs"):
api.upload_file(
path_or_fileobj=self.bin_file,
path_in_repo="fallback_file.bin",
repo_id=repo_id,
)
def test_transfers_to_xet_when_server_returns_xet(self):
addition = CommitOperationAdd(path_in_repo="xet.bin", path_or_fileobj=self.bin_file)
def fake_batch(
upload_infos, token, repo_type, repo_id, revision=None, endpoint=None, headers=None, transfers=None
):
action = {
"oid": upload_infos[0].sha256.hex(),
"size": upload_infos[0].size,
"actions": {"upload": {"href": "https://example.invalid", "header": {}}},
}
return ([action], [], "xet")
with patch("huggingface_hub._commit_api.post_lfs_batch_info", side_effect=fake_batch) as mock_batch:
with patch("huggingface_hub._commit_api._upload_lfs_files") as mock_lfs:
with patch("huggingface_hub._commit_api._upload_xet_files") as mock_xet:
_upload_files(
additions=[addition],
repo_type="model",
repo_id="dummy/user-repo",
headers={},
endpoint="https://hub-ci.huggingface.co",
revision="main",
create_pr=False,
)
assert mock_batch.call_count == 1
mock_xet.assert_called_once()
mock_lfs.assert_not_called()
def test_transfers_bytesio_renegotiates_to_lfs_when_server_returns_xet(self):
addition = CommitOperationAdd(path_in_repo="bytesio.bin", path_or_fileobj=BytesIO(self.bin_content))
def fake_batch(
upload_infos, token, repo_type, repo_id, revision=None, endpoint=None, headers=None, transfers=None
):
action = {
"oid": upload_infos[0].sha256.hex(),
"size": upload_infos[0].size,
"actions": {"upload": {"href": "https://example.invalid", "header": {}}},
}
return ([action], [], "xet")
with patch("huggingface_hub._commit_api.post_lfs_batch_info", side_effect=fake_batch) as mock_batch:
with patch("huggingface_hub._commit_api._upload_lfs_files") as mock_lfs:
with patch("huggingface_hub._commit_api._upload_xet_files") as mock_xet:
_upload_files(
additions=[addition],
repo_type="model",
repo_id="dummy/user-repo",
headers={},
endpoint="https://hub-ci.huggingface.co",
revision="main",
create_pr=False,
)
# Ensure we retried negotiation and routed to LFS, not XET
assert mock_batch.call_count == 1
mock_xet.assert_not_called()
mock_lfs.assert_called_once()
def test_request_headers_passed_to_upload_files(self, tmp_path):
"""Test that headers (minus authorization) are passed as request_headers to hf_xet.upload_files."""
headers = {
"authorization": "Bearer my_token",
"x-custom-header": "custom_value",
"user-agent": "test-agent",
}
test_file = tmp_path / "test_file.bin"
test_file.write_bytes(b"test content")
addition = CommitOperationAdd(path_in_repo="test_file.bin", path_or_fileobj=test_file)
with patch("huggingface_hub._commit_api.fetch_xet_connection_info_from_repo_info") as mock_fetch:
mock_fetch.return_value = XetConnectionInfo(
endpoint="mock_endpoint",
access_token="mock_token",
expiration_unix_epoch=9999999999,
)
with patch("hf_xet.upload_files") as mock_upload_files:
with patch("huggingface_hub._commit_api.are_progress_bars_disabled", return_value=True):
_upload_xet_files(
additions=[addition],
repo_type="model",
repo_id="test/repo",
headers=headers,
)
mock_upload_files.assert_called_once()
request_headers = mock_upload_files.call_args.kwargs["request_headers"]
assert request_headers.get("x-custom-header") == "custom_value"
assert request_headers.get("user-agent") == "test-agent"
assert "authorization" not in request_headers
assert mock_upload_files.call_args.kwargs["sha256s"] == [addition.upload_info.sha256.hex()]
def test_request_headers_passed_to_upload_bytes(self):
"""Test that headers (minus authorization) are passed as request_headers to hf_xet.upload_bytes."""
headers = {
"authorization": "Bearer my_token",
"x-custom-header": "custom_value",
"user-agent": "test-agent",
}
addition = CommitOperationAdd(path_in_repo="test_file.bin", path_or_fileobj=b"test content")
with patch("huggingface_hub._commit_api.fetch_xet_connection_info_from_repo_info") as mock_fetch:
mock_fetch.return_value = XetConnectionInfo(
endpoint="mock_endpoint",
access_token="mock_token",
expiration_unix_epoch=9999999999,
)
with patch("hf_xet.upload_bytes") as mock_upload_bytes:
with patch("huggingface_hub._commit_api.are_progress_bars_disabled", return_value=True):
_upload_xet_files(
additions=[addition],
repo_type="model",
repo_id="test/repo",
headers=headers,
)
mock_upload_bytes.assert_called_once()
request_headers = mock_upload_bytes.call_args.kwargs["request_headers"]
assert request_headers.get("x-custom-header") == "custom_value"
assert request_headers.get("user-agent") == "test-agent"
assert "authorization" not in request_headers
assert mock_upload_bytes.call_args.kwargs["sha256s"] == [addition.upload_info.sha256.hex()]
def test_upload_folder(self, api, repo_url):
repo_id = repo_url.repo_id
folder_in_repo = "temp"
with assert_upload_mode("xet"):
return_val = api.upload_folder(
folder_path=self.folder_path,
path_in_repo=folder_in_repo,
repo_id=repo_id,
)
assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit")
files_in_repo = set(api.list_repo_files(repo_id=repo_id))
files = {
f"{folder_in_repo}/text_file.txt",
f"{folder_in_repo}/binary_file.bin",
f"{folder_in_repo}/nested/nested_text.txt",
f"{folder_in_repo}/nested/nested_binary.safetensors",
}
assert all(file in files_in_repo for file in files)
for rpath in files:
local_file = Path(rpath).relative_to(folder_in_repo)
local_path = self.folder_path / local_file
filepath = hf_hub_download(repo_id=repo_id, filename=rpath)
assert Path(local_path).read_bytes() == Path(filepath).read_bytes()
def test_upload_folder_create_pr(self, api, repo_url) -> None:
repo_id = repo_url.repo_id
folder_in_repo = "temp_create_pr"
with assert_upload_mode("xet"):
return_val = api.upload_folder(
folder_path=self.folder_path,
path_in_repo=folder_in_repo,
repo_id=repo_id,
create_pr=True,
)
assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit")
for rpath in ["text_file.txt", "nested/nested_binary.safetensors"]:
local_path = self.folder_path / rpath
filepath = hf_hub_download(
repo_id=repo_id, filename=f"{folder_in_repo}/{rpath}", revision=return_val.pr_revision
)
assert Path(local_path).read_bytes() == Path(filepath).read_bytes()
@requires("hf_xet")
class TestXetLargeUpload:
def test_upload_large_folder(self, api, tmp_path, repo_url: RepoUrl) -> None:
N_FILES_PER_FOLDER = 4
repo_id = repo_url.repo_id
folder = Path(tmp_path) / "large_folder"
for i in range(N_FILES_PER_FOLDER):
subfolder = folder / f"subfolder_{i}"
subfolder.mkdir(parents=True, exist_ok=True)
for j in range(N_FILES_PER_FOLDER):
(subfolder / f"file_xet_{i}_{j}.bin").write_bytes(f"content_lfs_{i}_{j}".encode())
(subfolder / f"file_regular_{i}_{j}.txt").write_bytes(f"content_regular_{i}_{j}".encode())
with assert_upload_mode("xet"):
api.upload_large_folder(repo_id=repo_id, repo_type="model", folder_path=folder, num_workers=4)
# Check all files have been uploaded
uploaded_files = api.list_repo_files(repo_id=repo_id)
# Download and verify content
local_dir = Path(tmp_path) / "snapshot"
local_dir.mkdir()
api.snapshot_download(repo_id=repo_id, local_dir=local_dir, cache_dir=None)
for i in range(N_FILES_PER_FOLDER):
for j in range(N_FILES_PER_FOLDER):
assert f"subfolder_{i}/file_xet_{i}_{j}.bin" in uploaded_files
assert f"subfolder_{i}/file_regular_{i}_{j}.txt" in uploaded_files
# Check xet metadata
url = hf_hub_url(
repo_id=repo_id,
filename=f"subfolder_{i}/file_xet_{i}_{j}.bin",
)
metadata = get_hf_file_metadata(url)
xet_filedata = metadata.xet_file_data
assert xet_filedata is not None
# Verify xet files
xet_file = local_dir / f"subfolder_{i}/file_xet_{i}_{j}.bin"
assert xet_file.read_bytes() == f"content_lfs_{i}_{j}".encode()
# Verify regular files
regular_file = local_dir / f"subfolder_{i}/file_regular_{i}_{j}.txt"
assert regular_file.read_bytes() == f"content_regular_{i}_{j}".encode()
def test_upload_large_folder_batch_size_greater_than_one(self, api, tmp_path, repo_url: RepoUrl) -> None:
from hf_xet import upload_files as real_upload_files
N_FILES = 500
repo_id = repo_url.repo_id
folder = Path(tmp_path) / "large_folder"
folder.mkdir()
for i in range(N_FILES):
(folder / f"file_xet_{i}.bin").write_bytes(f"content_lfs_{i}".encode())
# capture the number of files passed in per call to hf_xet.upload_files
# to ensure that the batch size is respected.
num_files_per_call = []
def spy_upload_files(*args, **kwargs):
num_files = len(args[0])
num_files_per_call.append(num_files)
return real_upload_files(*args, **kwargs)
with assert_upload_mode("xet"):
with patch("hf_xet.upload_files", side_effect=spy_upload_files):
api.upload_large_folder(repo_id=repo_id, repo_type="model", folder_path=folder, num_workers=4)
# the batch size is set to 256 however due to speed of hashing and get_upload_mode calls it's not always guaranteed
# that the files will be uploaded in batches of 256. They may be uploaded in smaller batches if no other jobs
# are available to run; even as small as 1 file per call.
#
# However, it would be unlikely that all files are uploaded in batches of 1 if batching was correctly implemented.
# So we assert that not all files were uploaded in batches of 1, although it is possible even with batching.
assert any(n > 1 for n in num_files_per_call)
@requires("hf_xet")
@pytest.mark.usefixtures("xet_setup")
class TestXetE2E:
def test_hf_xet_with_token_refresher(self, api, tmp_path, repo_url):
"""
Test the hf_xet.download_files function with a token refresher.
This test manually calls the hf_xet.download_files function with a token refresher
function to verify that the token refresh mechanism works as expected. It aims to
identify regressions in the hf_xet.download_files function.
* Define a token refresher function that issues a token refresh by returning a new
access token and expiration time.
* Mock the token refresher function.
* Construct the necessary headers and metadata for the file to be downloaded.
* Call the download_files function with the token refresher, forcing a token refresh.
* Assert that the token refresher function was called as expected.
This test ensures that the downloaded file is the same as the uploaded file.
"""
from hf_xet import PyXetDownloadInfo, download_files
filename_in_repo = "binary_file.bin"
repo_id = repo_url.repo_id
# Upload a file
api.upload_file(
path_or_fileobj=self.bin_file,
path_in_repo=filename_in_repo,
repo_id=repo_id,
)
# headers
headers = build_hf_headers(token=TOKEN)
# metadata for url
(url_to_download, etag, commit_hash, expected_size, xet_filedata, head_call_error) = (
_get_metadata_or_catch_error(
repo_id=repo_id,
filename=filename_in_repo,
revision="main",
repo_type="model",
headers=headers,
endpoint=api.endpoint,
token=TOKEN,
etag_timeout=None,
local_files_only=False,
)
)
assert head_call_error is None # ensure we got metadata successfully
xet_connection_info = refresh_xet_connection_info(file_data=xet_filedata, headers=headers)
# manually construct parameters to hf_xet.download_files and use a locally defined token_refresher function
# to verify that token refresh works as expected.
def token_refresher() -> tuple[str, int]:
# Issue a token refresh by returning a new access token and expiration time
new_connection = refresh_xet_connection_info(file_data=xet_filedata, headers=headers)
return new_connection.access_token, new_connection.expiration_unix_epoch
mock_token_refresher = MagicMock(side_effect=token_refresher)
incomplete_path = Path(tmp_path) / "file.bin.incomplete"
file_info = [
PyXetDownloadInfo(
destination_path=str(incomplete_path.absolute()), hash=xet_filedata.file_hash, file_size=expected_size
)
]
# Call the download_files function with the token refresher, set expiration to 0 forcing a refresh
download_files(
file_info,
endpoint=xet_connection_info.endpoint,
token_info=(xet_connection_info.access_token, 0),
token_refresher=mock_token_refresher,
progress_updater=None,
)
# assert that our local token_refresher function was called by hfxet as expected.
mock_token_refresher.assert_called_once()
# Check that the downloaded file is the same as the uploaded file
with open(incomplete_path, "rb") as f:
downloaded_content = f.read()
assert downloaded_content == self.bin_content