Skip to content

Commit 478696c

Browse files
committed
add the modules used by build.py
1 parent 1bf2ca0 commit 478696c

File tree

3 files changed

+350
-1
lines changed

3 files changed

+350
-1
lines changed

build/build.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def parse_and_append_bazel_options(bazel_command: command.CommandBuilder, bazel_
268268
bazel_command.append(option)
269269

270270
async def main():
271-
cwd = os.getcwd()
272271
parser = argparse.ArgumentParser(
273272
description=(
274273
"CLI for building one of the following packages from source: jaxlib, "

build/tools/command.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Helper script for running subprocess commands.
16+
import asyncio
17+
import dataclasses
18+
import datetime
19+
import os
20+
import logging
21+
from typing import Dict, Optional
22+
23+
logger = logging.getLogger(__name__)
24+
25+
class CommandBuilder:
26+
def __init__(self, base_command: str):
27+
self.command = base_command
28+
29+
def append(self, parameter: str):
30+
self.command += " {}".format(parameter)
31+
return self
32+
33+
@dataclasses.dataclass
34+
class CommandResult:
35+
"""
36+
Represents the result of executing a subprocess command.
37+
"""
38+
39+
command: str
40+
return_code: int = 2 # Defaults to not successful
41+
logs: str = ""
42+
start_time: datetime.datetime = dataclasses.field(
43+
default_factory=datetime.datetime.now
44+
)
45+
end_time: Optional[datetime.datetime] = None
46+
47+
class SubprocessExecutor:
48+
"""
49+
Manages execution of subprocess commands with reusable environment and logging.
50+
"""
51+
52+
def __init__(self, environment: Dict[str, str] = None):
53+
"""
54+
55+
Args:
56+
environment:
57+
"""
58+
self.environment = environment or dict(os.environ)
59+
60+
async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
61+
"""
62+
Executes a subprocess command.
63+
64+
Args:
65+
cmd: The command to execute.
66+
dry_run: If True, prints the command instead of executing it.
67+
68+
Returns:
69+
A CommandResult instance.
70+
"""
71+
result = CommandResult(command=cmd)
72+
if dry_run:
73+
logger.info("[DRY RUN] %s", cmd)
74+
result.return_code = 0 # Dry run is a success
75+
return result
76+
77+
logger.info("[EXECUTING] %s", cmd)
78+
79+
process = await asyncio.create_subprocess_shell(
80+
cmd,
81+
stdout=asyncio.subprocess.PIPE,
82+
stderr=asyncio.subprocess.PIPE,
83+
env=self.environment,
84+
)
85+
86+
async def log_stream(stream, result: CommandResult):
87+
while True:
88+
line_bytes = await stream.readline()
89+
if not line_bytes:
90+
break
91+
line = line_bytes.decode().rstrip()
92+
result.logs += line
93+
logger.info("%s", line)
94+
95+
await asyncio.gather(
96+
log_stream(process.stdout, result), log_stream(process.stderr, result)
97+
)
98+
99+
result.return_code = await process.wait()
100+
result.end_time = datetime.datetime.now()
101+
logger.debug("Command finished with return code %s", result.return_code)
102+
return result

build/tools/utils.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Helper script for tools/utilities used by the JAX build CLI.
16+
import collections
17+
import hashlib
18+
import logging
19+
import os
20+
import pathlib
21+
import platform
22+
import re
23+
import shutil
24+
import stat
25+
import subprocess
26+
import sys
27+
import urllib.request
28+
29+
logger = logging.getLogger(__name__)
30+
31+
def is_windows():
32+
return sys.platform.startswith("win32")
33+
34+
def shell(cmd):
35+
try:
36+
logger.info("shell(): %s", cmd)
37+
output = subprocess.check_output(cmd)
38+
except subprocess.CalledProcessError as e:
39+
logger.info("subprocess raised: %s", e)
40+
if e.output:
41+
print(e.output)
42+
raise
43+
except Exception as e:
44+
logger.info("subprocess raised: %s", e)
45+
raise
46+
return output.decode("UTF-8").strip()
47+
48+
49+
# Bazel
50+
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
51+
BazelPackage = collections.namedtuple(
52+
"BazelPackage", ["base_uri", "file", "sha256"]
53+
)
54+
bazel_packages = {
55+
("Linux", "x86_64"): BazelPackage(
56+
base_uri=None,
57+
file="bazel-6.5.0-linux-x86_64",
58+
sha256=(
59+
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
60+
),
61+
),
62+
("Linux", "aarch64"): BazelPackage(
63+
base_uri=None,
64+
file="bazel-6.5.0-linux-arm64",
65+
sha256=(
66+
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
67+
),
68+
),
69+
("Darwin", "x86_64"): BazelPackage(
70+
base_uri=None,
71+
file="bazel-6.5.0-darwin-x86_64",
72+
sha256=(
73+
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
74+
),
75+
),
76+
("Darwin", "arm64"): BazelPackage(
77+
base_uri=None,
78+
file="bazel-6.5.0-darwin-arm64",
79+
sha256=(
80+
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
81+
),
82+
),
83+
("Windows", "AMD64"): BazelPackage(
84+
base_uri=None,
85+
file="bazel-6.5.0-windows-x86_64.exe",
86+
sha256=(
87+
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
88+
),
89+
),
90+
}
91+
92+
93+
def download_and_verify_bazel():
94+
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
95+
package = bazel_packages.get((platform.system(), platform.machine()))
96+
if package is None:
97+
return None
98+
99+
if not os.access(package.file, os.X_OK):
100+
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
101+
sys.stdout.write(f"Downloading bazel from: {uri}\n")
102+
103+
def progress(block_count, block_size, total_size):
104+
if total_size <= 0:
105+
total_size = 170**6
106+
progress = (block_count * block_size) / total_size
107+
num_chars = 40
108+
progress_chars = int(num_chars * progress)
109+
sys.stdout.write(
110+
"{} [{}{}] {}%\r".format(
111+
package.file,
112+
"#" * progress_chars,
113+
"." * (num_chars - progress_chars),
114+
int(progress * 100.0),
115+
)
116+
)
117+
118+
tmp_path, _ = urllib.request.urlretrieve(
119+
uri, None, progress if sys.stdout.isatty() else None
120+
)
121+
sys.stdout.write("\n")
122+
123+
# Verify that the downloaded Bazel binary has the expected SHA256.
124+
with open(tmp_path, "rb") as downloaded_file:
125+
contents = downloaded_file.read()
126+
127+
digest = hashlib.sha256(contents).hexdigest()
128+
if digest != package.sha256:
129+
print(
130+
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
131+
.format(package.sha256, digest)
132+
)
133+
sys.exit(-1)
134+
135+
# Write the file as the bazel file name.
136+
with open(package.file, "wb") as out_file:
137+
out_file.write(contents)
138+
139+
# Mark the file as executable.
140+
st = os.stat(package.file)
141+
os.chmod(
142+
package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
143+
)
144+
145+
return os.path.join(".", package.file)
146+
147+
148+
def get_bazel_paths(bazel_path_flag):
149+
"""Yields a sequence of guesses about bazel path.
150+
151+
Some of sequence elements can be None. The resulting iterator is lazy and
152+
potentially has a side effects.
153+
"""
154+
yield bazel_path_flag
155+
yield shutil.which("bazel")
156+
yield download_and_verify_bazel()
157+
158+
159+
def get_bazel_path(bazel_path_flag):
160+
"""Returns the path to a Bazel binary, downloading Bazel if not found.
161+
162+
Also, checks Bazel's version is at least newer than 6.5.0
163+
164+
A manual version check is needed only for really old bazel versions.
165+
Newer bazel releases perform their own version check against .bazelversion
166+
(see for details
167+
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
168+
"""
169+
for path in filter(None, get_bazel_paths(bazel_path_flag)):
170+
version = get_bazel_version(path)
171+
if version is not None and version >= (6, 5, 0):
172+
return path, ".".join(map(str, version))
173+
174+
print(
175+
"Cannot find or download a suitable version of bazel."
176+
"Please install bazel >= 6.5.0."
177+
)
178+
sys.exit(-1)
179+
180+
181+
def get_bazel_version(bazel_path):
182+
try:
183+
version_output = shell([bazel_path, "--version"])
184+
except (subprocess.CalledProcessError, OSError):
185+
return None
186+
match = re.search(r"bazel *([0-9\\.]+)", version_output)
187+
if match is None:
188+
return None
189+
return tuple(int(x) for x in match.group(1).split("."))
190+
191+
192+
def get_clang_path_or_exit():
193+
which_clang_output = shutil.which("clang")
194+
if which_clang_output:
195+
# If we've found a clang on the path, need to get the fully resolved path
196+
# to ensure that system headers are found.
197+
return str(pathlib.Path(which_clang_output).resolve())
198+
else:
199+
print(
200+
"--clang_path is unset and clang cannot be found"
201+
" on the PATH. Please pass --clang_path directly."
202+
)
203+
sys.exit(-1)
204+
205+
def get_githash():
206+
try:
207+
return subprocess.run(
208+
["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True
209+
).stdout.strip()
210+
except OSError:
211+
return ""
212+
213+
def get_bazelrc_config(os_name: str, arch: str, artifact: str, use_rbe: bool):
214+
"""Returns the bazelrc config for the given architecture and OS.
215+
Used in CI builds to retrive either the "ci_"/"rbe_" configs from the .bazelrc
216+
"""
217+
218+
bazelrc_config = f"{os_name}_{arch}"
219+
220+
# If a build is requesting RBE, the CLI will use RBE if the host system supports
221+
# it, otherwise it will use the "ci_" (non RBE) config.
222+
if use_rbe:
223+
if (os_name == "linux" and arch == "x86_64") \
224+
or (os_name == "windows" and arch == "amd64"):
225+
bazelrc_config = "rbe_" + bazelrc_config
226+
else:
227+
logger.warning("RBE is not supported on %s_%s. Using the non RBE, ci_%s_%s, config instead.", os_name, arch)
228+
bazelrc_config = "ci_" + bazelrc_config
229+
else:
230+
bazelrc_config = "ci_" + bazelrc_config
231+
232+
# When building jax-cuda-plugin or jax-cuda-pjrt, append "_cuda" to the
233+
# bazelrc config to use the CUDA specific configs.
234+
if "cuda" in artifact:
235+
bazelrc_config = bazelrc_config + "_cuda"
236+
237+
return bazelrc_config
238+
239+
def adjust_paths_for_windows(output_dir: str, arch: str) -> tuple[str, str]:
240+
"""Adjusts the paths to be compatible with Windows."""
241+
logger.debug("Adjusting paths for Windows...")
242+
output_dir = output_dir.replace("/", "\\")
243+
244+
# Change to upper case to match the case in
245+
# "jax/tools/build_utils.py" for Windows.
246+
arch = arch.upper()
247+
248+
return (output_dir, arch)

0 commit comments

Comments
 (0)