Skip to content

Commit 9017a46

Browse files
committed
Update utils.py to match upstream
1 parent f9ce549 commit 9017a46

File tree

1 file changed

+209
-84
lines changed

1 file changed

+209
-84
lines changed

build/tools/utils.py

Lines changed: 209 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,100 +12,225 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
# Helper script for the JAX build CLI for running subprocess commands.
16-
import asyncio
17-
import dataclasses
18-
import datetime
19-
import os
15+
# Helper script for tools/utilities used by the JAX build CLI.
16+
import collections
17+
import hashlib
2018
import logging
21-
from typing import Dict, Optional
19+
import os
20+
import pathlib
21+
import platform
22+
import re
23+
import shutil
24+
import stat
25+
import subprocess
26+
import sys
27+
import urllib.request
2228

2329
logger = logging.getLogger(__name__)
2430

25-
class CommandBuilder:
26-
def __init__(self, base_command: str):
27-
self.command = [base_command]
28-
29-
def append(self, parameter: str):
30-
self.command.append(parameter)
31-
return self
31+
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
32+
BazelPackage = collections.namedtuple(
33+
"BazelPackage", ["base_uri", "file", "sha256"]
34+
)
35+
bazel_packages = {
36+
("Linux", "x86_64"): BazelPackage(
37+
base_uri=None,
38+
file="bazel-6.5.0-linux-x86_64",
39+
sha256=(
40+
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
41+
),
42+
),
43+
("Linux", "aarch64"): BazelPackage(
44+
base_uri=None,
45+
file="bazel-6.5.0-linux-arm64",
46+
sha256=(
47+
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
48+
),
49+
),
50+
("Darwin", "x86_64"): BazelPackage(
51+
base_uri=None,
52+
file="bazel-6.5.0-darwin-x86_64",
53+
sha256=(
54+
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
55+
),
56+
),
57+
("Darwin", "arm64"): BazelPackage(
58+
base_uri=None,
59+
file="bazel-6.5.0-darwin-arm64",
60+
sha256=(
61+
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
62+
),
63+
),
64+
("Windows", "AMD64"): BazelPackage(
65+
base_uri=None,
66+
file="bazel-6.5.0-windows-x86_64.exe",
67+
sha256=(
68+
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
69+
),
70+
),
71+
}
72+
73+
def download_and_verify_bazel():
74+
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
75+
package = bazel_packages.get((platform.system(), platform.machine()))
76+
if package is None:
77+
return None
78+
79+
if not os.access(package.file, os.X_OK):
80+
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
81+
sys.stdout.write(f"Downloading bazel from: {uri}\n")
82+
83+
def progress(block_count, block_size, total_size):
84+
if total_size <= 0:
85+
total_size = 170**6
86+
progress = (block_count * block_size) / total_size
87+
num_chars = 40
88+
progress_chars = int(num_chars * progress)
89+
sys.stdout.write(
90+
"{} [{}{}] {}%\r".format(
91+
package.file,
92+
"#" * progress_chars,
93+
"." * (num_chars - progress_chars),
94+
int(progress * 100.0),
95+
)
96+
)
97+
98+
tmp_path, _ = urllib.request.urlretrieve(
99+
uri, None, progress if sys.stdout.isatty() else None
100+
)
101+
sys.stdout.write("\n")
102+
103+
# Verify that the downloaded Bazel binary has the expected SHA256.
104+
with open(tmp_path, "rb") as downloaded_file:
105+
contents = downloaded_file.read()
106+
107+
digest = hashlib.sha256(contents).hexdigest()
108+
if digest != package.sha256:
109+
print(
110+
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
111+
.format(package.sha256, digest)
112+
)
113+
sys.exit(-1)
114+
115+
# Write the file as the bazel file name.
116+
with open(package.file, "wb") as out_file:
117+
out_file.write(contents)
118+
119+
# Mark the file as executable.
120+
st = os.stat(package.file)
121+
os.chmod(
122+
package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
123+
)
32124

33-
def get_command_as_string(self) -> str:
34-
return " ".join(self.command)
125+
return os.path.join(".", package.file)
35126

36-
def get_command_as_list(self) -> list[str]:
37-
return self.command
127+
def get_bazel_paths(bazel_path_flag):
128+
"""Yields a sequence of guesses about bazel path.
38129
39-
@dataclasses.dataclass
40-
class CommandResult:
41-
"""
42-
Represents the result of executing a subprocess command.
130+
Some of sequence elements can be None. The resulting iterator is lazy and
131+
potentially has a side effects.
43132
"""
133+
yield bazel_path_flag
134+
yield shutil.which("bazel")
135+
yield download_and_verify_bazel()
44136

45-
command: str
46-
return_code: int = 2 # Defaults to not successful
47-
logs: str = ""
48-
start_time: datetime.datetime = dataclasses.field(
49-
default_factory=datetime.datetime.now
50-
)
51-
end_time: Optional[datetime.datetime] = None
52-
137+
def get_bazel_path(bazel_path_flag):
138+
"""Returns the path to a Bazel binary, downloading Bazel if not found.
53139
54-
async def _process_log_stream(stream, result: CommandResult):
55-
"""Logs the output of a subprocess stream."""
56-
while True:
57-
line_bytes = await stream.readline()
58-
if not line_bytes:
59-
break
60-
line = line_bytes.decode().rstrip()
61-
result.logs += line
62-
logger.info("%s", line)
140+
Also, checks Bazel's version is at least newer than 6.5.0
63141
64-
65-
class SubprocessExecutor:
66-
"""
67-
Manages execution of subprocess commands with reusable environment and logging.
142+
A manual version check is needed only for really old bazel versions.
143+
Newer bazel releases perform their own version check against .bazelversion
144+
(see for details
145+
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
68146
"""
69-
70-
def __init__(self, environment: Dict[str, str] = None):
71-
"""
72-
73-
Args:
74-
environment:
75-
"""
76-
self.environment = environment or dict(os.environ)
77-
78-
async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
79-
"""
80-
Executes a subprocess command.
81-
82-
Args:
83-
cmd: The command to execute.
84-
dry_run: If True, prints the command instead of executing it.
85-
86-
Returns:
87-
A CommandResult instance.
88-
"""
89-
result = CommandResult(command=cmd)
90-
if dry_run:
91-
logger.info("[DRY RUN] %s", cmd)
92-
result.return_code = 0 # Dry run is a success
93-
return result
94-
95-
logger.info("[EXECUTING] %s", cmd)
96-
97-
process = await asyncio.create_subprocess_shell(
98-
cmd,
99-
stdout=asyncio.subprocess.PIPE,
100-
stderr=asyncio.subprocess.PIPE,
101-
env=self.environment,
102-
)
103-
104-
await asyncio.gather(
105-
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
147+
for path in filter(None, get_bazel_paths(bazel_path_flag)):
148+
version = get_bazel_version(path)
149+
if version is not None and version >= (6, 5, 0):
150+
return path, ".".join(map(str, version))
151+
152+
print(
153+
"Cannot find or download a suitable version of bazel."
154+
"Please install bazel >= 6.5.0."
155+
)
156+
sys.exit(-1)
157+
158+
def get_bazel_version(bazel_path):
159+
try:
160+
version_output = subprocess.run(
161+
[bazel_path, "--version"],
162+
encoding="utf-8",
163+
capture_output=True,
164+
check=True,
165+
).stdout.strip()
166+
except (subprocess.CalledProcessError, OSError):
167+
return None
168+
match = re.search(r"bazel *([0-9\\.]+)", version_output)
169+
if match is None:
170+
return None
171+
return tuple(int(x) for x in match.group(1).split("."))
172+
173+
def get_clang_path_or_exit():
174+
which_clang_output = shutil.which("clang")
175+
if which_clang_output:
176+
# If we've found a clang on the path, need to get the fully resolved path
177+
# to ensure that system headers are found.
178+
return str(pathlib.Path(which_clang_output).resolve())
179+
else:
180+
print(
181+
"--clang_path is unset and clang cannot be found"
182+
" on the PATH. Please pass --clang_path directly."
106183
)
107-
108-
result.return_code = await process.wait()
109-
result.end_time = datetime.datetime.now()
110-
logger.debug("Command finished with return code %s", result.return_code)
111-
return result
184+
sys.exit(-1)
185+
186+
def get_clang_major_version(clang_path):
187+
clang_version_proc = subprocess.run(
188+
[clang_path, "-E", "-P", "-"],
189+
input="__clang_major__",
190+
check=True,
191+
capture_output=True,
192+
text=True,
193+
)
194+
major_version = int(clang_version_proc.stdout)
195+
196+
return major_version
197+
198+
def get_jax_configure_bazel_options(bazel_command: list[str]):
199+
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
200+
# Get the index of the "run" parameter. Build options will come after "run" so
201+
# we find the index of "run" and filter everything after it.
202+
start = bazel_command.index("run")
203+
jax_configure_bazel_options = ""
204+
try:
205+
for i in range(start + 1, len(bazel_command)):
206+
bazel_flag = bazel_command[i]
207+
# On Windows, replace all backslashes with double backslashes to avoid
208+
# unintended escape sequences.
209+
if platform.system() == "Windows":
210+
bazel_flag = bazel_flag.replace("\\", "\\\\")
211+
jax_configure_bazel_options += f"build {bazel_flag}\n"
212+
return jax_configure_bazel_options
213+
except ValueError:
214+
logging.error("Unable to find index for 'run' in the Bazel command")
215+
return ""
216+
217+
def get_githash():
218+
try:
219+
return subprocess.run(
220+
["git", "rev-parse", "HEAD"],
221+
encoding="utf-8",
222+
capture_output=True,
223+
check=True,
224+
).stdout.strip()
225+
except OSError:
226+
return ""
227+
228+
def _parse_string_as_bool(s):
229+
"""Parses a string as a boolean value."""
230+
lower = s.lower()
231+
if lower == "true":
232+
return True
233+
elif lower == "false":
234+
return False
235+
else:
236+
raise ValueError(f"Expected either 'true' or 'false'; got {s}")

0 commit comments

Comments
 (0)