1+ # Copyright 2024 The JAX Authors.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ # Helper script for tools/utilities used by the JAX build CLI.
16+ import collections
17+ import hashlib
18+ import logging
19+ import os
20+ import pathlib
21+ import platform
22+ import re
23+ import shutil
24+ import stat
25+ import subprocess
26+ import sys
27+ import urllib .request
28+
29+ logger = logging .getLogger (__name__ )
30+
31+ def is_windows ():
32+ return sys .platform .startswith ("win32" )
33+
34+ def shell (cmd ):
35+ try :
36+ logger .info ("shell(): %s" , cmd )
37+ output = subprocess .check_output (cmd )
38+ except subprocess .CalledProcessError as e :
39+ logger .info ("subprocess raised: %s" , e )
40+ if e .output :
41+ print (e .output )
42+ raise
43+ except Exception as e :
44+ logger .info ("subprocess raised: %s" , e )
45+ raise
46+ return output .decode ("UTF-8" ).strip ()
47+
48+
49+ # Bazel
50+ BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
51+ BazelPackage = collections .namedtuple (
52+ "BazelPackage" , ["base_uri" , "file" , "sha256" ]
53+ )
54+ bazel_packages = {
55+ ("Linux" , "x86_64" ): BazelPackage (
56+ base_uri = None ,
57+ file = "bazel-6.5.0-linux-x86_64" ,
58+ sha256 = (
59+ "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
60+ ),
61+ ),
62+ ("Linux" , "aarch64" ): BazelPackage (
63+ base_uri = None ,
64+ file = "bazel-6.5.0-linux-arm64" ,
65+ sha256 = (
66+ "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
67+ ),
68+ ),
69+ ("Darwin" , "x86_64" ): BazelPackage (
70+ base_uri = None ,
71+ file = "bazel-6.5.0-darwin-x86_64" ,
72+ sha256 = (
73+ "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
74+ ),
75+ ),
76+ ("Darwin" , "arm64" ): BazelPackage (
77+ base_uri = None ,
78+ file = "bazel-6.5.0-darwin-arm64" ,
79+ sha256 = (
80+ "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
81+ ),
82+ ),
83+ ("Windows" , "AMD64" ): BazelPackage (
84+ base_uri = None ,
85+ file = "bazel-6.5.0-windows-x86_64.exe" ,
86+ sha256 = (
87+ "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
88+ ),
89+ ),
90+ }
91+
92+
93+ def download_and_verify_bazel ():
94+ """Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
95+ package = bazel_packages .get ((platform .system (), platform .machine ()))
96+ if package is None :
97+ return None
98+
99+ if not os .access (package .file , os .X_OK ):
100+ uri = (package .base_uri or BAZEL_BASE_URI ) + package .file
101+ sys .stdout .write (f"Downloading bazel from: { uri } \n " )
102+
103+ def progress (block_count , block_size , total_size ):
104+ if total_size <= 0 :
105+ total_size = 170 ** 6
106+ progress = (block_count * block_size ) / total_size
107+ num_chars = 40
108+ progress_chars = int (num_chars * progress )
109+ sys .stdout .write (
110+ "{} [{}{}] {}%\r " .format (
111+ package .file ,
112+ "#" * progress_chars ,
113+ "." * (num_chars - progress_chars ),
114+ int (progress * 100.0 ),
115+ )
116+ )
117+
118+ tmp_path , _ = urllib .request .urlretrieve (
119+ uri , None , progress if sys .stdout .isatty () else None
120+ )
121+ sys .stdout .write ("\n " )
122+
123+ # Verify that the downloaded Bazel binary has the expected SHA256.
124+ with open (tmp_path , "rb" ) as downloaded_file :
125+ contents = downloaded_file .read ()
126+
127+ digest = hashlib .sha256 (contents ).hexdigest ()
128+ if digest != package .sha256 :
129+ print (
130+ "Checksum mismatch for downloaded bazel binary (expected {}; got {})."
131+ .format (package .sha256 , digest )
132+ )
133+ sys .exit (- 1 )
134+
135+ # Write the file as the bazel file name.
136+ with open (package .file , "wb" ) as out_file :
137+ out_file .write (contents )
138+
139+ # Mark the file as executable.
140+ st = os .stat (package .file )
141+ os .chmod (
142+ package .file , st .st_mode | stat .S_IXUSR | stat .S_IXGRP | stat .S_IXOTH
143+ )
144+
145+ return os .path .join ("." , package .file )
146+
147+
148+ def get_bazel_paths (bazel_path_flag ):
149+ """Yields a sequence of guesses about bazel path.
150+
151+ Some of sequence elements can be None. The resulting iterator is lazy and
152+ potentially has a side effects.
153+ """
154+ yield bazel_path_flag
155+ yield shutil .which ("bazel" )
156+ yield download_and_verify_bazel ()
157+
158+
159+ def get_bazel_path (bazel_path_flag ):
160+ """Returns the path to a Bazel binary, downloading Bazel if not found.
161+
162+ Also, checks Bazel's version is at least newer than 6.5.0
163+
164+ A manual version check is needed only for really old bazel versions.
165+ Newer bazel releases perform their own version check against .bazelversion
166+ (see for details
167+ https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
168+ """
169+ for path in filter (None , get_bazel_paths (bazel_path_flag )):
170+ version = get_bazel_version (path )
171+ if version is not None and version >= (6 , 5 , 0 ):
172+ return path , "." .join (map (str , version ))
173+
174+ print (
175+ "Cannot find or download a suitable version of bazel."
176+ "Please install bazel >= 6.5.0."
177+ )
178+ sys .exit (- 1 )
179+
180+
181+ def get_bazel_version (bazel_path ):
182+ try :
183+ version_output = shell ([bazel_path , "--version" ])
184+ except (subprocess .CalledProcessError , OSError ):
185+ return None
186+ match = re .search (r"bazel *([0-9\\.]+)" , version_output )
187+ if match is None :
188+ return None
189+ return tuple (int (x ) for x in match .group (1 ).split ("." ))
190+
191+
192+ def get_clang_path_or_exit ():
193+ which_clang_output = shutil .which ("clang" )
194+ if which_clang_output :
195+ # If we've found a clang on the path, need to get the fully resolved path
196+ # to ensure that system headers are found.
197+ return str (pathlib .Path (which_clang_output ).resolve ())
198+ else :
199+ print (
200+ "--clang_path is unset and clang cannot be found"
201+ " on the PATH. Please pass --clang_path directly."
202+ )
203+ sys .exit (- 1 )
204+
205+ def get_githash ():
206+ try :
207+ return subprocess .run (
208+ ["git" , "rev-parse" , "HEAD" ], encoding = "utf-8" , capture_output = True
209+ ).stdout .strip ()
210+ except OSError :
211+ return ""
212+
213+ def get_bazelrc_config (os_name : str , arch : str , artifact : str , use_rbe : bool ):
214+ """Returns the bazelrc config for the given architecture and OS.
215+ Used in CI builds to retrive either the "ci_"/"rbe_" configs from the .bazelrc
216+ """
217+
218+ bazelrc_config = f"{ os_name } _{ arch } "
219+
220+ # If a build is requesting RBE, the CLI will use RBE if the host system supports
221+ # it, otherwise it will use the "ci_" (non RBE) config.
222+ if use_rbe :
223+ if (os_name == "linux" and arch == "x86_64" ) \
224+ or (os_name == "windows" and arch == "amd64" ):
225+ bazelrc_config = "rbe_" + bazelrc_config
226+ else :
227+ logger .warning ("RBE is not supported on %s_%s. Using the non RBE, ci_%s_%s, config instead." , os_name , arch )
228+ bazelrc_config = "ci_" + bazelrc_config
229+ else :
230+ bazelrc_config = "ci_" + bazelrc_config
231+
232+ # When building jax-cuda-plugin or jax-cuda-pjrt, append "_cuda" to the
233+ # bazelrc config to use the CUDA specific configs.
234+ if "cuda" in artifact :
235+ bazelrc_config = bazelrc_config + "_cuda"
236+
237+ return bazelrc_config
238+
239+ def adjust_paths_for_windows (output_dir : str , arch : str ) -> tuple [str , str ]:
240+ """Adjusts the paths to be compatible with Windows."""
241+ logger .debug ("Adjusting paths for Windows..." )
242+ output_dir = output_dir .replace ("/" , "\\ " )
243+
244+ # Change to upper case to match the case in
245+ # "jax/tools/build_utils.py" for Windows.
246+ arch = arch .upper ()
247+
248+ return (output_dir , arch )
0 commit comments