Skip to content

Commit 4002fd4

Browse files
authored
Merge pull request #107 from y0z/feature/fix-force-reload
Change the internal behavior of force-reloading
2 parents 7bf1813 + f3fb84a commit 4002fd4

File tree

1 file changed

+43
-29
lines changed

1 file changed

+43
-29
lines changed

optunahub/hub.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
import shutil
99
import sys
10+
import tempfile
1011
import types
1112
from urllib.parse import urlparse
1213
from urllib.request import Request
@@ -151,7 +152,6 @@ def load_module(
151152
repo_name=repo_name,
152153
dir_path=dir_path,
153154
ref=ref,
154-
package_cache_dir=package_cache_dir,
155155
cache_dir_prefix=cache_dir_prefix,
156156
)
157157

@@ -188,16 +188,26 @@ def _download_via_git(
188188
) -> None:
189189
repo_url_separator = "/" if "://" in base_url else ":"
190190
repo_url = f"{base_url.rstrip('/')}{repo_url_separator}{repo_owner}/{repo_name}"
191-
repo = Repo.init(cache_dir_prefix)
192-
origin = (
193-
repo.remotes.origin if "origin" in repo.remotes else repo.create_remote("origin", repo_url)
194-
)
195-
if repo.remotes.origin.url != repo_url:
196-
repo.remotes.origin.set_url(repo_url)
197-
repo.git.sparse_checkout("init", "--cone")
198-
repo.git.sparse_checkout("set", dir_path)
199-
origin.fetch(refspec=ref, depth=1)
200-
repo.git.checkout("FETCH_HEAD")
191+
with tempfile.TemporaryDirectory() as tmpdir:
192+
# Initialize a temporary Git repository to perform sparse checkout.
193+
repo = Repo.init(tmpdir)
194+
origin = (
195+
repo.remotes.origin
196+
if "origin" in repo.remotes
197+
else repo.create_remote("origin", repo_url)
198+
)
199+
if repo.remotes.origin.url != repo_url:
200+
repo.remotes.origin.set_url(repo_url)
201+
repo.git.sparse_checkout("init", "--cone")
202+
repo.git.sparse_checkout("set", dir_path)
203+
origin.fetch(refspec=ref, depth=1)
204+
repo.git.checkout("FETCH_HEAD")
205+
206+
# Move the downloaded package to the cache directory.
207+
package_cache_dir = os.path.join(cache_dir_prefix, dir_path)
208+
shutil.rmtree(package_cache_dir, ignore_errors=True)
209+
os.makedirs(os.path.dirname(package_cache_dir), exist_ok=True)
210+
shutil.move(os.path.join(tmpdir, dir_path), package_cache_dir)
201211

202212

203213
def _download_via_github_api(
@@ -207,7 +217,6 @@ def _download_via_github_api(
207217
repo_name: str,
208218
dir_path: str,
209219
ref: str,
210-
package_cache_dir: str,
211220
cache_dir_prefix: str,
212221
) -> None:
213222
g = Github(auth=auth, base_url=base_url)
@@ -218,23 +227,28 @@ def _download_via_github_api(
218227
if isinstance(package_contents, ContentFile):
219228
package_contents = [package_contents]
220229

221-
shutil.rmtree(package_cache_dir, ignore_errors=True)
222-
os.makedirs(cache_dir_prefix, exist_ok=True)
223-
for m in package_contents:
224-
file_path = os.path.join(cache_dir_prefix, m.path)
225-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
226-
if m.type == "dir":
227-
dir_contents = repo.get_contents(m.path, ref)
228-
if isinstance(dir_contents, ContentFile):
229-
dir_contents = [dir_contents]
230-
package_contents.extend(dir_contents)
231-
else:
232-
with open(file_path, "wb") as f:
233-
try:
234-
decoded_content = m.decoded_content
235-
except AssertionError:
236-
continue
237-
f.write(decoded_content)
230+
with tempfile.TemporaryDirectory() as tmpdir:
231+
for m in package_contents:
232+
file_path = os.path.join(tmpdir, m.path)
233+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
234+
if m.type == "dir":
235+
dir_contents = repo.get_contents(m.path, ref)
236+
if isinstance(dir_contents, ContentFile):
237+
dir_contents = [dir_contents]
238+
package_contents.extend(dir_contents)
239+
else:
240+
with open(file_path, "wb") as f:
241+
try:
242+
decoded_content = m.decoded_content
243+
except AssertionError:
244+
continue
245+
f.write(decoded_content)
246+
247+
# Move the downloaded package to the cache directory.
248+
package_cache_dir = os.path.join(cache_dir_prefix, dir_path)
249+
shutil.rmtree(package_cache_dir, ignore_errors=True)
250+
os.makedirs(os.path.dirname(package_cache_dir), exist_ok=True)
251+
shutil.move(os.path.join(tmpdir, dir_path), package_cache_dir)
238252

239253

240254
def load_local_module(

0 commit comments

Comments
 (0)