@@ -51,6 +51,8 @@ usage() {
5151 echo " --clean-only Do not build, just cleanup"
5252 echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
5353 echo " --debug Build in debug mode"
54+ echo " --extra-targets T1[,T2[...] Extra bazel targets that will be built and copied to --extra-target-dest."
55+ echo " --extra-target-dest PATH Where extra target output files will be copied to."
5456 echo " -h, --help Print usage."
5557 echo " --install Install the JAX wheels when build succeeds"
5658 echo " --no-install Do not install the JAX wheels when build succeeds"
@@ -75,11 +77,13 @@ CLEANONLY=0
7577CPU_ARCH=" $( dpkg --print-architecture) "
7678CUDA_COMPUTE_CAPABILITIES=" local"
7779DEBUG=0
80+ EXTRA_TARGETS=()
81+ EXTRA_TARGET_DEST=" "
7882INSTALL=1
7983SRC_PATH_JAX=" /opt/jax"
8084SRC_PATH_XLA=" /opt/xla"
8185
82- args=$( getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- " $@ " )
86+ args=$( getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,extra-targets:,extra-target-dest:, no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- " $@ " )
8387if [[ $? -ne 0 ]]; then
8488 exit 1
8589fi
@@ -123,6 +127,14 @@ while [ : ]; do
123127 DEBUG=1
124128 shift 1
125129 ;;
130+ --extra-targets)
131+ IFS=' ,' read -r -a EXTRA_TARGETS <<< " $2"
132+ shift 2
133+ ;;
134+ --extra-target-dest)
135+ EXTRA_TARGET_DEST=" $2 "
136+ shift 2
137+ ;;
126138 -h | --help)
127139 usage 1
128140 ;;
@@ -202,13 +214,20 @@ print_var CPU_ARCH
202214print_var CUDA_COMPUTE_CAPABILITIES
203215print_var CUDA_MAJOR_VERSION
204216print_var DEBUG
217+ print_var EXTRA_TARGETS
218+ print_var EXTRA_TARGET_DEST
205219print_var INSTALL
206220print_var PYTHON_VERSION
207221print_var SRC_PATH_JAX
208222print_var SRC_PATH_XLA
209223
210224echo " =================================================="
211225
226+ if [[ -n " ${EXTRA_TARGET_DEST} " && ! -d " ${EXTRA_TARGET_DEST} " ]]; then
227+ echo " You must pass a directory to --extra-target-dest"
228+ exit 1
229+ fi
230+
212231set -x
213232if [[ ${CLEANONLY} == 1 ]]; then
214233 clean
@@ -245,7 +264,21 @@ for component in jaxlib "jax-cuda${CUDA_MAJOR_VERSION}-pjrt" "jax-cuda${CUDA_MAJ
245264 # version, so nvidia-*-cu12 wheels disappear from the lock file
246265 sed -i " s|^${component} .*$|${component} @ file://${BUILD_PATH_JAXLIB} /${component// -/ _} |" build/requirements.in
247266done
248- bazel run --config=cuda_libraries_from_stubs --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=" ${PYTHON_VERSION} "
267+ # Bazel args to avoid cache invalidation
268+ BAZEL_ARGS=(
269+ --config=cuda_libraries_from_stubs
270+ --repo_env=HERMETIC_PYTHON_VERSION=" ${PYTHON_VERSION} "
271+ )
272+ bazel run " ${BAZEL_ARGS[@]} " --verbose_failures=true //build:requirements.update
273+ if (( "${# EXTRA_TARGETS[@]} " > 0 )) ; then
274+ bazel build " ${BAZEL_ARGS[@]} " --verbose_failures=true " ${EXTRA_TARGETS[@]} "
275+ if [[ -n " ${EXTRA_TARGET_DEST} " ]]; then
276+ for target in " ${EXTRA_TARGETS[@]} " ; do
277+ output_files=$( bazel cquery " ${BAZEL_ARGS[@]} " " ${target} " --output files)
278+ cp -v ${output_files} " ${EXTRA_TARGET_DEST} "
279+ done
280+ fi
281+ fi
249282popd
250283
251284# # Install the built packages
0 commit comments