Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/huggingface_hub/_commit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def token_refresher() -> tuple[str, int]:

if len(all_paths_ops) > 0:
all_paths = [str(op.path_or_fileobj) for op in all_paths_ops]
all_sha256s = [op.upload_info.sha256.hex() for op in all_paths_ops]
upload_files(
all_paths,
xet_endpoint,
Expand All @@ -649,10 +650,12 @@ def token_refresher() -> tuple[str, int]:
progress_callback,
repo_type,
request_headers=xet_headers,
sha256s=all_sha256s,
)

if len(all_bytes_ops) > 0:
all_bytes = [op.path_or_fileobj for op in all_bytes_ops]
all_sha256s = [op.upload_info.sha256.hex() for op in all_bytes_ops]
upload_bytes(
all_bytes,
xet_endpoint,
Expand All @@ -661,6 +664,7 @@ def token_refresher() -> tuple[str, int]:
progress_callback,
repo_type,
request_headers=xet_headers,
sha256s=all_sha256s,
)

finally:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_xet_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_request_headers_passed_to_upload_files(self, tmp_path):
assert request_headers.get("x-custom-header") == "custom_value"
assert request_headers.get("user-agent") == "test-agent"
assert "authorization" not in request_headers
assert mock_upload_files.call_args.kwargs["sha256s"] == [addition.upload_info.sha256.hex()]

def test_request_headers_passed_to_upload_bytes(self):
"""Test that headers (minus authorization) are passed as request_headers to hf_xet.upload_bytes."""
Expand Down Expand Up @@ -289,6 +290,7 @@ def test_request_headers_passed_to_upload_bytes(self):
assert request_headers.get("x-custom-header") == "custom_value"
assert request_headers.get("user-agent") == "test-agent"
assert "authorization" not in request_headers
assert mock_upload_bytes.call_args.kwargs["sha256s"] == [addition.upload_info.sha256.hex()]

def test_upload_folder(self, api, repo_url):
repo_id = repo_url.repo_id
Expand Down
Loading