|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | | -# Helper script for the JAX build CLI for running subprocess commands. |
16 | | -import asyncio |
17 | | -import dataclasses |
18 | | -import datetime |
19 | | -import os |
| 15 | +# Helper script for tools/utilities used by the JAX build CLI. |
| 16 | +import collections |
| 17 | +import hashlib |
20 | 18 | import logging |
21 | | -from typing import Dict, Optional |
| 19 | +import os |
| 20 | +import pathlib |
| 21 | +import platform |
| 22 | +import re |
| 23 | +import shutil |
| 24 | +import stat |
| 25 | +import subprocess |
| 26 | +import sys |
| 27 | +import urllib.request |
22 | 28 |
|
23 | 29 | logger = logging.getLogger(__name__) |
24 | 30 |
|
25 | | -class CommandBuilder: |
26 | | - def __init__(self, base_command: str): |
27 | | - self.command = [base_command] |
28 | | - |
29 | | - def append(self, parameter: str): |
30 | | - self.command.append(parameter) |
31 | | - return self |
| 31 | +BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" |
| 32 | +BazelPackage = collections.namedtuple( |
| 33 | + "BazelPackage", ["base_uri", "file", "sha256"] |
| 34 | +) |
| 35 | +bazel_packages = { |
| 36 | + ("Linux", "x86_64"): BazelPackage( |
| 37 | + base_uri=None, |
| 38 | + file="bazel-6.5.0-linux-x86_64", |
| 39 | + sha256=( |
| 40 | + "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307" |
| 41 | + ), |
| 42 | + ), |
| 43 | + ("Linux", "aarch64"): BazelPackage( |
| 44 | + base_uri=None, |
| 45 | + file="bazel-6.5.0-linux-arm64", |
| 46 | + sha256=( |
| 47 | + "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f" |
| 48 | + ), |
| 49 | + ), |
| 50 | + ("Darwin", "x86_64"): BazelPackage( |
| 51 | + base_uri=None, |
| 52 | + file="bazel-6.5.0-darwin-x86_64", |
| 53 | + sha256=( |
| 54 | + "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29" |
| 55 | + ), |
| 56 | + ), |
| 57 | + ("Darwin", "arm64"): BazelPackage( |
| 58 | + base_uri=None, |
| 59 | + file="bazel-6.5.0-darwin-arm64", |
| 60 | + sha256=( |
| 61 | + "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb" |
| 62 | + ), |
| 63 | + ), |
| 64 | + ("Windows", "AMD64"): BazelPackage( |
| 65 | + base_uri=None, |
| 66 | + file="bazel-6.5.0-windows-x86_64.exe", |
| 67 | + sha256=( |
| 68 | + "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6" |
| 69 | + ), |
| 70 | + ), |
| 71 | +} |
| 72 | + |
| 73 | +def download_and_verify_bazel(): |
| 74 | + """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" |
| 75 | + package = bazel_packages.get((platform.system(), platform.machine())) |
| 76 | + if package is None: |
| 77 | + return None |
| 78 | + |
| 79 | + if not os.access(package.file, os.X_OK): |
| 80 | + uri = (package.base_uri or BAZEL_BASE_URI) + package.file |
| 81 | + sys.stdout.write(f"Downloading bazel from: {uri}\n") |
| 82 | + |
| 83 | + def progress(block_count, block_size, total_size): |
| 84 | + if total_size <= 0: |
| 85 | + total_size = 170**6 |
| 86 | + progress = (block_count * block_size) / total_size |
| 87 | + num_chars = 40 |
| 88 | + progress_chars = int(num_chars * progress) |
| 89 | + sys.stdout.write( |
| 90 | + "{} [{}{}] {}%\r".format( |
| 91 | + package.file, |
| 92 | + "#" * progress_chars, |
| 93 | + "." * (num_chars - progress_chars), |
| 94 | + int(progress * 100.0), |
| 95 | + ) |
| 96 | + ) |
| 97 | + |
| 98 | + tmp_path, _ = urllib.request.urlretrieve( |
| 99 | + uri, None, progress if sys.stdout.isatty() else None |
| 100 | + ) |
| 101 | + sys.stdout.write("\n") |
| 102 | + |
| 103 | + # Verify that the downloaded Bazel binary has the expected SHA256. |
| 104 | + with open(tmp_path, "rb") as downloaded_file: |
| 105 | + contents = downloaded_file.read() |
| 106 | + |
| 107 | + digest = hashlib.sha256(contents).hexdigest() |
| 108 | + if digest != package.sha256: |
| 109 | + print( |
| 110 | + "Checksum mismatch for downloaded bazel binary (expected {}; got {})." |
| 111 | + .format(package.sha256, digest) |
| 112 | + ) |
| 113 | + sys.exit(-1) |
| 114 | + |
| 115 | + # Write the file as the bazel file name. |
| 116 | + with open(package.file, "wb") as out_file: |
| 117 | + out_file.write(contents) |
| 118 | + |
| 119 | + # Mark the file as executable. |
| 120 | + st = os.stat(package.file) |
| 121 | + os.chmod( |
| 122 | + package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH |
| 123 | + ) |
32 | 124 |
|
33 | | - def get_command_as_string(self) -> str: |
34 | | - return " ".join(self.command) |
| 125 | + return os.path.join(".", package.file) |
35 | 126 |
|
36 | | - def get_command_as_list(self) -> list[str]: |
37 | | - return self.command |
| 127 | +def get_bazel_paths(bazel_path_flag): |
| 128 | + """Yields a sequence of guesses about bazel path. |
38 | 129 |
|
39 | | -@dataclasses.dataclass |
40 | | -class CommandResult: |
41 | | - """ |
42 | | - Represents the result of executing a subprocess command. |
| 130 | + Some of sequence elements can be None. The resulting iterator is lazy and |
| 131 | + potentially has a side effects. |
43 | 132 | """ |
| 133 | + yield bazel_path_flag |
| 134 | + yield shutil.which("bazel") |
| 135 | + yield download_and_verify_bazel() |
44 | 136 |
|
45 | | - command: str |
46 | | - return_code: int = 2 # Defaults to not successful |
47 | | - logs: str = "" |
48 | | - start_time: datetime.datetime = dataclasses.field( |
49 | | - default_factory=datetime.datetime.now |
50 | | - ) |
51 | | - end_time: Optional[datetime.datetime] = None |
52 | | - |
| 137 | +def get_bazel_path(bazel_path_flag): |
| 138 | + """Returns the path to a Bazel binary, downloading Bazel if not found. |
53 | 139 |
|
54 | | -async def _process_log_stream(stream, result: CommandResult): |
55 | | - """Logs the output of a subprocess stream.""" |
56 | | - while True: |
57 | | - line_bytes = await stream.readline() |
58 | | - if not line_bytes: |
59 | | - break |
60 | | - line = line_bytes.decode().rstrip() |
61 | | - result.logs += line |
62 | | - logger.info("%s", line) |
| 140 | + Also, checks Bazel's version is at least newer than 6.5.0 |
63 | 141 |
|
64 | | - |
65 | | -class SubprocessExecutor: |
66 | | - """ |
67 | | - Manages execution of subprocess commands with reusable environment and logging. |
| 142 | + A manual version check is needed only for really old bazel versions. |
| 143 | + Newer bazel releases perform their own version check against .bazelversion |
| 144 | + (see for details |
| 145 | + https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). |
68 | 146 | """ |
69 | | - |
70 | | - def __init__(self, environment: Dict[str, str] = None): |
71 | | - """ |
72 | | -
|
73 | | - Args: |
74 | | - environment: |
75 | | - """ |
76 | | - self.environment = environment or dict(os.environ) |
77 | | - |
78 | | - async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: |
79 | | - """ |
80 | | - Executes a subprocess command. |
81 | | -
|
82 | | - Args: |
83 | | - cmd: The command to execute. |
84 | | - dry_run: If True, prints the command instead of executing it. |
85 | | -
|
86 | | - Returns: |
87 | | - A CommandResult instance. |
88 | | - """ |
89 | | - result = CommandResult(command=cmd) |
90 | | - if dry_run: |
91 | | - logger.info("[DRY RUN] %s", cmd) |
92 | | - result.return_code = 0 # Dry run is a success |
93 | | - return result |
94 | | - |
95 | | - logger.info("[EXECUTING] %s", cmd) |
96 | | - |
97 | | - process = await asyncio.create_subprocess_shell( |
98 | | - cmd, |
99 | | - stdout=asyncio.subprocess.PIPE, |
100 | | - stderr=asyncio.subprocess.PIPE, |
101 | | - env=self.environment, |
102 | | - ) |
103 | | - |
104 | | - await asyncio.gather( |
105 | | - _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) |
| 147 | + for path in filter(None, get_bazel_paths(bazel_path_flag)): |
| 148 | + version = get_bazel_version(path) |
| 149 | + if version is not None and version >= (6, 5, 0): |
| 150 | + return path, ".".join(map(str, version)) |
| 151 | + |
| 152 | + print( |
| 153 | + "Cannot find or download a suitable version of bazel." |
| 154 | + "Please install bazel >= 6.5.0." |
| 155 | + ) |
| 156 | + sys.exit(-1) |
| 157 | + |
| 158 | +def get_bazel_version(bazel_path): |
| 159 | + try: |
| 160 | + version_output = subprocess.run( |
| 161 | + [bazel_path, "--version"], |
| 162 | + encoding="utf-8", |
| 163 | + capture_output=True, |
| 164 | + check=True, |
| 165 | + ).stdout.strip() |
| 166 | + except (subprocess.CalledProcessError, OSError): |
| 167 | + return None |
| 168 | + match = re.search(r"bazel *([0-9\\.]+)", version_output) |
| 169 | + if match is None: |
| 170 | + return None |
| 171 | + return tuple(int(x) for x in match.group(1).split(".")) |
| 172 | + |
| 173 | +def get_clang_path_or_exit(): |
| 174 | + which_clang_output = shutil.which("clang") |
| 175 | + if which_clang_output: |
| 176 | + # If we've found a clang on the path, need to get the fully resolved path |
| 177 | + # to ensure that system headers are found. |
| 178 | + return str(pathlib.Path(which_clang_output).resolve()) |
| 179 | + else: |
| 180 | + print( |
| 181 | + "--clang_path is unset and clang cannot be found" |
| 182 | + " on the PATH. Please pass --clang_path directly." |
106 | 183 | ) |
107 | | - |
108 | | - result.return_code = await process.wait() |
109 | | - result.end_time = datetime.datetime.now() |
110 | | - logger.debug("Command finished with return code %s", result.return_code) |
111 | | - return result |
| 184 | + sys.exit(-1) |
| 185 | + |
| 186 | +def get_clang_major_version(clang_path): |
| 187 | + clang_version_proc = subprocess.run( |
| 188 | + [clang_path, "-E", "-P", "-"], |
| 189 | + input="__clang_major__", |
| 190 | + check=True, |
| 191 | + capture_output=True, |
| 192 | + text=True, |
| 193 | + ) |
| 194 | + major_version = int(clang_version_proc.stdout) |
| 195 | + |
| 196 | + return major_version |
| 197 | + |
| 198 | +def get_jax_configure_bazel_options(bazel_command: list[str]): |
| 199 | + """Returns the bazel options to be written to .jax_configure.bazelrc.""" |
| 200 | + # Get the index of the "run" parameter. Build options will come after "run" so |
| 201 | + # we find the index of "run" and filter everything after it. |
| 202 | + start = bazel_command.index("run") |
| 203 | + jax_configure_bazel_options = "" |
| 204 | + try: |
| 205 | + for i in range(start + 1, len(bazel_command)): |
| 206 | + bazel_flag = bazel_command[i] |
| 207 | + # On Windows, replace all backslashes with double backslashes to avoid |
| 208 | + # unintended escape sequences. |
| 209 | + if platform.system() == "Windows": |
| 210 | + bazel_flag = bazel_flag.replace("\\", "\\\\") |
| 211 | + jax_configure_bazel_options += f"build {bazel_flag}\n" |
| 212 | + return jax_configure_bazel_options |
| 213 | + except ValueError: |
| 214 | + logging.error("Unable to find index for 'run' in the Bazel command") |
| 215 | + return "" |
| 216 | + |
| 217 | +def get_githash(): |
| 218 | + try: |
| 219 | + return subprocess.run( |
| 220 | + ["git", "rev-parse", "HEAD"], |
| 221 | + encoding="utf-8", |
| 222 | + capture_output=True, |
| 223 | + check=True, |
| 224 | + ).stdout.strip() |
| 225 | + except OSError: |
| 226 | + return "" |
| 227 | + |
| 228 | +def _parse_string_as_bool(s): |
| 229 | + """Parses a string as a boolean value.""" |
| 230 | + lower = s.lower() |
| 231 | + if lower == "true": |
| 232 | + return True |
| 233 | + elif lower == "false": |
| 234 | + return False |
| 235 | + else: |
| 236 | + raise ValueError(f"Expected either 'true' or 'false'; got {s}") |
0 commit comments