1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15- # Helper script for the JAX build CLI for running subprocess commands.
16- import asyncio
17- import dataclasses
18- import datetime
19- import os
15+ # Helper script for tools/utilities used by the JAX build CLI.
16+ import collections
17+ import hashlib
2018import logging
21- from typing import Dict , Optional
19+ import os
20+ import pathlib
21+ import platform
22+ import re
23+ import shutil
24+ import stat
25+ import subprocess
26+ import sys
27+ import urllib .request
2228
2329logger = logging .getLogger (__name__ )
2430
25- class CommandBuilder :
26- def __init__ (self , base_command : str ):
27- self .command = [base_command ]
28-
29- def append (self , parameter : str ):
30- self .command .append (parameter )
31- return self
31+ BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
32+ BazelPackage = collections .namedtuple (
33+ "BazelPackage" , ["base_uri" , "file" , "sha256" ]
34+ )
35+ bazel_packages = {
36+ ("Linux" , "x86_64" ): BazelPackage (
37+ base_uri = None ,
38+ file = "bazel-6.5.0-linux-x86_64" ,
39+ sha256 = (
40+ "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
41+ ),
42+ ),
43+ ("Linux" , "aarch64" ): BazelPackage (
44+ base_uri = None ,
45+ file = "bazel-6.5.0-linux-arm64" ,
46+ sha256 = (
47+ "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
48+ ),
49+ ),
50+ ("Darwin" , "x86_64" ): BazelPackage (
51+ base_uri = None ,
52+ file = "bazel-6.5.0-darwin-x86_64" ,
53+ sha256 = (
54+ "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
55+ ),
56+ ),
57+ ("Darwin" , "arm64" ): BazelPackage (
58+ base_uri = None ,
59+ file = "bazel-6.5.0-darwin-arm64" ,
60+ sha256 = (
61+ "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
62+ ),
63+ ),
64+ ("Windows" , "AMD64" ): BazelPackage (
65+ base_uri = None ,
66+ file = "bazel-6.5.0-windows-x86_64.exe" ,
67+ sha256 = (
68+ "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
69+ ),
70+ ),
71+ }
72+
73+ def download_and_verify_bazel ():
74+ """Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
75+ package = bazel_packages .get ((platform .system (), platform .machine ()))
76+ if package is None :
77+ return None
78+
79+ if not os .access (package .file , os .X_OK ):
80+ uri = (package .base_uri or BAZEL_BASE_URI ) + package .file
81+ sys .stdout .write (f"Downloading bazel from: { uri } \n " )
82+
83+ def progress (block_count , block_size , total_size ):
84+ if total_size <= 0 :
85+ total_size = 170 ** 6
86+ progress = (block_count * block_size ) / total_size
87+ num_chars = 40
88+ progress_chars = int (num_chars * progress )
89+ sys .stdout .write (
90+ "{} [{}{}] {}%\r " .format (
91+ package .file ,
92+ "#" * progress_chars ,
93+ "." * (num_chars - progress_chars ),
94+ int (progress * 100.0 ),
95+ )
96+ )
97+
98+ tmp_path , _ = urllib .request .urlretrieve (
99+ uri , None , progress if sys .stdout .isatty () else None
100+ )
101+ sys .stdout .write ("\n " )
102+
103+ # Verify that the downloaded Bazel binary has the expected SHA256.
104+ with open (tmp_path , "rb" ) as downloaded_file :
105+ contents = downloaded_file .read ()
106+
107+ digest = hashlib .sha256 (contents ).hexdigest ()
108+ if digest != package .sha256 :
109+ print (
110+ "Checksum mismatch for downloaded bazel binary (expected {}; got {})."
111+ .format (package .sha256 , digest )
112+ )
113+ sys .exit (- 1 )
114+
115+ # Write the file as the bazel file name.
116+ with open (package .file , "wb" ) as out_file :
117+ out_file .write (contents )
118+
119+ # Mark the file as executable.
120+ st = os .stat (package .file )
121+ os .chmod (
122+ package .file , st .st_mode | stat .S_IXUSR | stat .S_IXGRP | stat .S_IXOTH
123+ )
32124
33- def get_command_as_string (self ) -> str :
34- return " " .join (self .command )
125+ return os .path .join ("." , package .file )
35126
36- def get_command_as_list ( self ) -> list [ str ] :
37- return self . command
127+ def get_bazel_paths ( bazel_path_flag ) :
128+ """Yields a sequence of guesses about bazel path.
38129
39- @ dataclasses . dataclass
40- class CommandResult :
130+ Some of sequence elements can be None. The resulting iterator is lazy and
131+ potentially has a side effects.
41132 """
42- Represents the result of executing a subprocess command.
43- """
44-
45- command : str
46- return_code : int = 2 # Defaults to not successful
47- logs : str = ""
48- start_time : datetime .datetime = dataclasses .field (
49- default_factory = datetime .datetime .now
50- )
51- end_time : Optional [datetime .datetime ] = None
52-
133+ yield bazel_path_flag
134+ yield shutil .which ("bazel" )
135+ yield download_and_verify_bazel ()
53136
54- async def _process_log_stream (stream , result : CommandResult ):
55- """Logs the output of a subprocess stream."""
56- while True :
57- line_bytes = await stream .readline ()
58- if not line_bytes :
59- break
60- line = line_bytes .decode ().rstrip ()
61- result .logs += line
62- logger .info ("%s" , line )
137+ def get_bazel_path (bazel_path_flag ):
138+ """Returns the path to a Bazel binary, downloading Bazel if not found.
63139
140+ Also, checks Bazel's version is at least newer than 6.5.0
64141
65- class SubprocessExecutor :
142+ A manual version check is needed only for really old bazel versions.
143+ Newer bazel releases perform their own version check against .bazelversion
144+ (see for details
145+ https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
66146 """
67- Manages execution of subprocess commands with reusable environment and logging.
68- """
69-
70- def __init__ (self , environment : Dict [str , str ] = None ):
71- """
72-
73- Args:
74- environment:
75- """
76- self .environment = environment or dict (os .environ )
77-
78- async def run (self , cmd : str , dry_run : bool = False ) -> CommandResult :
79- """
80- Executes a subprocess command.
81-
82- Args:
83- cmd: The command to execute.
84- dry_run: If True, prints the command instead of executing it.
85-
86- Returns:
87- A CommandResult instance.
88- """
89- result = CommandResult (command = cmd )
90- if dry_run :
91- logger .info ("[DRY RUN] %s" , cmd )
92- result .return_code = 0 # Dry run is a success
93- return result
94-
95- logger .info ("[EXECUTING] %s" , cmd )
96-
97- process = await asyncio .create_subprocess_shell (
98- cmd ,
99- stdout = asyncio .subprocess .PIPE ,
100- stderr = asyncio .subprocess .PIPE ,
101- env = self .environment ,
102- )
103-
104- await asyncio .gather (
105- _process_log_stream (process .stdout , result ), _process_log_stream (process .stderr , result )
147+ for path in filter (None , get_bazel_paths (bazel_path_flag )):
148+ version = get_bazel_version (path )
149+ if version is not None and version >= (6 , 5 , 0 ):
150+ return path , "." .join (map (str , version ))
151+
152+ print (
153+ "Cannot find or download a suitable version of bazel."
154+ "Please install bazel >= 6.5.0."
155+ )
156+ sys .exit (- 1 )
157+
158+ def get_bazel_version (bazel_path ):
159+ try :
160+ version_output = subprocess .run (
161+ [bazel_path , "--version" ],
162+ encoding = "utf-8" ,
163+ capture_output = True ,
164+ check = True ,
165+ ).stdout .strip ()
166+ except (subprocess .CalledProcessError , OSError ):
167+ return None
168+ match = re .search (r"bazel *([0-9\\.]+)" , version_output )
169+ if match is None :
170+ return None
171+ return tuple (int (x ) for x in match .group (1 ).split ("." ))
172+
173+ def get_clang_path_or_exit ():
174+ which_clang_output = shutil .which ("clang" )
175+ if which_clang_output :
176+ # If we've found a clang on the path, need to get the fully resolved path
177+ # to ensure that system headers are found.
178+ return str (pathlib .Path (which_clang_output ).resolve ())
179+ else :
180+ print (
181+ "--clang_path is unset and clang cannot be found"
182+ " on the PATH. Please pass --clang_path directly."
106183 )
107-
108- result .return_code = await process .wait ()
109- result .end_time = datetime .datetime .now ()
110- logger .debug ("Command finished with return code %s" , result .return_code )
111- return result
184+ sys .exit (- 1 )
185+
186+ def get_clang_major_version (clang_path ):
187+ clang_version_proc = subprocess .run (
188+ [clang_path , "-E" , "-P" , "-" ],
189+ input = "__clang_major__" ,
190+ check = True ,
191+ capture_output = True ,
192+ text = True ,
193+ )
194+ major_version = int (clang_version_proc .stdout )
195+
196+ return major_version
197+
198+ def get_jax_configure_bazel_options (bazel_command : list [str ]):
199+ """Returns the bazel options to be written to .jax_configure.bazelrc."""
200+ # Get the index of the "run" parameter. Build options will come after "run" so
201+ # we find the index of "run" and filter everything after it.
202+ start = bazel_command .index ("run" )
203+ jax_configure_bazel_options = ""
204+ try :
205+ for i in range (start + 1 , len (bazel_command )):
206+ bazel_flag = bazel_command [i ]
207+ # On Windows, replace all backslashes with double backslashes to avoid
208+ # unintended escape sequences.
209+ if platform .system () == "Windows" :
210+ bazel_flag = bazel_flag .replace ("\\ " , "\\ \\ " )
211+ jax_configure_bazel_options += f"build { bazel_flag } \n "
212+ return jax_configure_bazel_options
213+ except ValueError :
214+ logging .error ("Unable to find index for 'run' in the Bazel command" )
215+ return ""
216+
217+ def get_githash ():
218+ try :
219+ return subprocess .run (
220+ ["git" , "rev-parse" , "HEAD" ],
221+ encoding = "utf-8" ,
222+ capture_output = True ,
223+ check = True ,
224+ ).stdout .strip ()
225+ except OSError :
226+ return ""
112227
113228def _parse_string_as_bool (s ):
114229 """Parses a string as a boolean value."""
@@ -118,4 +233,4 @@ def _parse_string_as_bool(s):
118233 elif lower == "false" :
119234 return False
120235 else :
121- raise ValueError (f"Expected either 'true' or 'false'; got { s } " )
236+ raise ValueError (f"Expected either 'true' or 'false'; got { s } " )
0 commit comments