@@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None:
2626 logging .debug (f"Downloading { url } to { destination } " )
2727
2828 # get the size of the file
29- response = requests .get (url , stream = True )
29+ response = requests .get (url = url , stream = True )
3030 response .raise_for_status ()
3131 total = int (response .headers .get ("content-length" , 0 ))
32+ chunk_size = 1024 * 1000 # 1 MiB
3233
3334 # create a progress bar
3435 bar = tqdm (
@@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None:
4546 with destination .open ("wb" ) as f :
4647 with requests .get (url , stream = True ) as r :
4748 r .raise_for_status ()
48- for chunk in r .iter_content (chunk_size = 1024 * 1000 ):
49+ for chunk in r .iter_content (chunk_size = chunk_size ):
4950 size = f .write (chunk )
5051 bar .update (size )
5152 bar .close ()
@@ -63,8 +64,8 @@ def __init__(
6364 self ,
6465 repo_id : str ,
6566 filename : str ,
66- expected_sha256 : str ,
6767 revision : str = "main" ,
68+ expected_sha256 : str | None = None ,
6869 download_url : str | None = None ,
6970 ) -> None :
7071 """Initialize the HubPath.
@@ -73,14 +74,14 @@ def __init__(
7374 repo_id: The repository identifier on the hub.
7475 filename: The filename of the file in the repository.
7576 revision: The revision of the file on the hf hub.
76- expected_sha256: The sha256 hash of the file.
77+ expected_sha256: The sha256 hash of the file, to optionally (but strongly recommended) check against the local or remote hash .
7778 download_url: The url to download the file from, if not from the huggingface hub.
7879 """
7980 self .repo_id = repo_id
8081 self .filename = filename
8182 self .revision = revision
82- self .expected_sha256 = expected_sha256 .lower ()
83- self .override_download_url = download_url
83+ self .expected_sha256 = expected_sha256 .lower () if expected_sha256 is not None else None
84+ self .download_url = download_url
8485
8586 @staticmethod
8687 def hub_location ():
@@ -90,16 +91,22 @@ def hub_location():
9091 @property
9192 def hf_url (self ) -> str :
9293 """Return the url to the file on the hf hub."""
93- assert self .override_download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
94+ assert self .download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
9495 return hf_hub_url (
9596 repo_id = self .repo_id ,
9697 filename = self .filename ,
9798 revision = self .revision ,
9899 )
99100
101+ @property
102+ def hf_metadata (self ) -> HfFileMetadata :
103+ """Return the metadata of the file on the hf hub."""
104+ return get_hf_file_metadata (self .hf_url )
105+
100106 @property
101107 def hf_cache_path (self ) -> Path :
102108 """Download the file from the hf hub and return its path in the local hf cache."""
109+ assert self .download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
103110 return Path (
104111 hf_hub_download (
105112 repo_id = self .repo_id ,
@@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path:
108115 ),
109116 )
110117
111- @property
112- def hf_metadata (self ) -> HfFileMetadata :
113- """Return the metadata of the file on the hf hub."""
114- return get_hf_file_metadata (self .hf_url )
115-
116118 @property
117119 def hf_sha256_hash (self ) -> str :
118120 """Return the sha256 hash of the file on the hf hub."""
@@ -127,24 +129,32 @@ def local_path(self) -> Path:
127129 return self .hub_location () / self .repo_id / self .filename
128130
129131 @property
130- def local_hash (self ) -> str :
132+ def local_sha256_hash (self ) -> str :
131133 """Return the sha256 hash of the file in the local hub."""
132134 assert self .local_path .is_file (), f"{ self .local_path } does not exist"
133135 # TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11
134136 return sha256 (self .local_path .read_bytes ()).hexdigest ().lower ()
135137
136138 def check_local_hash (self ) -> bool :
137139 """Check if the sha256 hash of the file in the local hub is correct."""
138- if self .expected_sha256 != self .local_hash :
139- logging .warning (f"{ self .local_path } local sha256 mismatch, { self .local_hash } != { self .expected_sha256 } " )
140+ if self .expected_sha256 is None :
141+ logging .warning (f"{ self .repo_id } /{ self .filename } has no expected sha256 hash, skipping check" )
142+ return True
143+ elif self .expected_sha256 != self .local_sha256_hash :
144+ logging .warning (
145+ f"{ self .local_path } local sha256 mismatch, { self .local_sha256_hash } != { self .expected_sha256 } "
146+ )
140147 return False
141148 else :
142- logging .debug (f"{ self .local_path } local sha256 is correct ({ self .local_hash } )" )
149+ logging .debug (f"{ self .local_path } local sha256 is correct ({ self .local_sha256_hash } )" )
143150 return True
144151
145152 def check_remote_hash (self ) -> bool :
146153 """Check if the sha256 hash of the file on the hf hub is correct."""
147- if self .expected_sha256 != self .hf_sha256_hash :
154+ if self .expected_sha256 is None :
155+ logging .warning (f"{ self .repo_id } /{ self .filename } has no expected sha256 hash, skipping check" )
156+ return True
157+ elif self .expected_sha256 != self .hf_sha256_hash :
148158 logging .warning (
149159 f"{ self .local_path } remote sha256 mismatch, { self .hf_sha256_hash } != { self .expected_sha256 } "
150160 )
@@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool:
154164 return True
155165
156166 def download (self ) -> None :
157- """Download the file from the hf hub or from the override download url."""
158- self .local_path .parent .mkdir (parents = True , exist_ok = True )
167+ """Download the file from the hf hub or from the override download url, and save it to the local hub."""
159168 if self .local_path .is_file ():
160169 logging .warning (f"{ self .local_path } already exists" )
161- elif self .override_download_url is not None :
162- download_file_url (url = self .override_download_url , destination = self .local_path )
170+ elif self .download_url is not None :
171+ self .local_path .parent .mkdir (parents = True , exist_ok = True )
172+ download_file_url (url = self .download_url , destination = self .local_path )
163173 else :
164- # TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache
174+ self . local_path . parent . mkdir ( parents = True , exist_ok = True )
165175 self .local_path .symlink_to (self .hf_cache_path )
166176 assert self .check_local_hash ()
167177
0 commit comments