Skip to content

Commit 9d21e2c

Browse files
authored
Merge pull request #16522 from github/redsun82/lfs
Bazel: allow LFS rules to use cached downloads without internet
2 parents 13a7d9a + d01d657 commit 9d21e2c

File tree

2 files changed

+57
-30
lines changed

2 files changed

+57
-30
lines changed

misc/bazel/internal/git_lfs_probe.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
"""
44
Probe lfs files.
5-
For each source file provided as output, this will print:
5+
For each source file provided as input, this will print:
66
* "local", if the source file is not an LFS pointer
77
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
8+
If --hash-only is provided, the transient URL will not be fetched and printed
89
"""
910

1011
import sys
@@ -19,6 +20,14 @@
1920
import base64
2021
from dataclasses import dataclass
2122
from typing import Dict
23+
import argparse
24+
25+
26+
def options():
27+
p = argparse.ArgumentParser(description=__doc__)
28+
p.add_argument("--hash-only", action="store_true")
29+
p.add_argument("sources", type=pathlib.Path, nargs="+")
30+
return p.parse_args()
2231

2332

2433
@dataclass
@@ -30,7 +39,8 @@ def update_headers(self, d: Dict[str, str]):
3039
self.headers.update((k.capitalize(), v) for k, v in d.items())
3140

3241

33-
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
42+
opts = options()
43+
sources = [p.resolve() for p in opts.sources]
3444
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
3545
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
3646

@@ -60,7 +70,12 @@ def get_endpoint():
6070
server, _, path = ssh_endpoint.partition(":")
6171
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
6272
assert ssh_command, "no ssh command found"
63-
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"]))
73+
resp = json.loads(subprocess.check_output([ssh_command,
74+
"-oStrictHostKeyChecking=accept-new",
75+
server,
76+
"git-lfs-authenticate",
77+
path,
78+
"download"]))
6479
endpoint.href = resp.get("href", endpoint)
6580
endpoint.update_headers(resp.get("header", {}))
6681
url = urlparse(endpoint.href)
@@ -84,11 +99,15 @@ def get_endpoint():
8499
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
85100
def get_locations(objects):
86101
ret = ["local" for _ in objects]
87-
endpoint = get_endpoint()
88102
indexes = [i for i, o in enumerate(objects) if o]
89103
if not indexes:
90104
# all objects are local, do not send an empty request as that would be an error
91105
return ret
106+
if opts.hash_only:
107+
for i in indexes:
108+
ret[i] = objects[i]["oid"]
109+
return ret
110+
endpoint = get_endpoint()
92111
data = {
93112
"operation": "download",
94113
"transfers": ["basic"],

misc/bazel/lfs.bzl

+34-26
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,44 @@
11
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None):
2-
for src in srcs:
3-
repository_ctx.watch(src)
4-
script = Label("//misc/bazel/internal:git_lfs_probe.py")
52
python = repository_ctx.which("python3") or repository_ctx.which("python")
63
if not python:
74
fail("Neither python3 nor python executables found")
8-
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
9-
res = repository_ctx.execute([python, script] + srcs, quiet = True)
10-
if res.return_code != 0:
11-
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
12-
promises = []
13-
for src, loc in zip(srcs, res.stdout.splitlines()):
14-
if loc == "local":
15-
if extract:
16-
repository_ctx.report_progress("extracting local %s" % src.basename)
17-
repository_ctx.extract(src, stripPrefix = stripPrefix)
18-
else:
19-
repository_ctx.report_progress("symlinking local %s" % src.basename)
20-
repository_ctx.symlink(src, src.basename)
5+
script = Label("//misc/bazel/internal:git_lfs_probe.py")
6+
7+
def probe(srcs, hash_only = False):
8+
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
9+
cmd = [python, script]
10+
if hash_only:
11+
cmd.append("--hash-only")
12+
cmd.extend(srcs)
13+
res = repository_ctx.execute(cmd, quiet = True)
14+
if res.return_code != 0:
15+
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
16+
return res.stdout.splitlines()
17+
18+
for src in srcs:
19+
repository_ctx.watch(src)
20+
infos = probe(srcs, hash_only = True)
21+
remote = []
22+
for src, info in zip(srcs, infos):
23+
if info == "local":
24+
repository_ctx.report_progress("symlinking local %s" % src.basename)
25+
repository_ctx.symlink(src, src.basename)
2126
else:
22-
sha256, _, url = loc.partition(" ")
23-
if extract:
24-
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz`
25-
# or similar wouldn't work
26-
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip",
27-
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works
28-
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:])
29-
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename)
30-
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension)
31-
else:
27+
repository_ctx.report_progress("trying cache for remote %s" % src.basename)
28+
res = repository_ctx.download([], src.basename, sha256 = info, allow_fail = True)
29+
if not res.success:
30+
remote.append(src)
31+
if remote:
32+
infos = probe(remote)
33+
for src, info in zip(remote, infos):
34+
sha256, _, url = info.partition(" ")
3235
repository_ctx.report_progress("downloading remote %s" % src.basename)
3336
repository_ctx.download(url, src.basename, sha256 = sha256)
37+
if extract:
38+
for src in srcs:
39+
repository_ctx.report_progress("extracting %s" % src.basename)
40+
repository_ctx.extract(src.basename, stripPrefix = stripPrefix)
41+
repository_ctx.delete(src.basename)
3442

3543
def _download_and_extract_lfs(repository_ctx):
3644
attr = repository_ctx.attr

0 commit comments

Comments
 (0)