Skip to content

Commit e49e9b2

Browse files
authored
build-jax.sh: plumb through building + saving extra targets (#1515)
Example usage: `--extra-targets=@xla//xla/tools/multihost_hlo_runner:hlo_runner_main` to build the multi-host HLO runner. Integrating `--extra-target-dest` into the script makes it useful in conjunction with `--clean`.
1 parent a3c31cf commit e49e9b2

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

.github/container/Dockerfile.jax

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# syntax=docker/dockerfile:1-labs
22
ARG BASE_IMAGE=ghcr.io/nvidia/jax:base
33
ARG BUILD_PATH_JAXLIB=/opt/jaxlibs
4+
# Extra targets to build and copy outputs of, can be used for HLO tools. For example
5+
# @xla//xla/tools/multihost_hlo_runner:hlo_runner_main,@xla//xla/hlo/tools:hlo-opt
6+
ARG EXTRA_BAZEL_TARGETS=""
47
ARG URLREF_JAX=https://github.com/google/jax.git#main
58
ARG URLREF_XLA=https://github.com/openxla/xla.git#main
69
ARG URLREF_FLAX=https://github.com/google/flax.git#main
@@ -28,6 +31,7 @@ ARG SRC_PATH_TRANSFORMER_ENGINE
2831
ARG SRC_PATH_XLA
2932
ARG BAZEL_CACHE
3033
ARG BUILD_PATH_JAXLIB
34+
ARG EXTRA_BAZEL_TARGETS
3135
ARG GIT_USER_NAME
3236
ARG GIT_USER_EMAIL
3337

@@ -45,9 +49,11 @@ RUN ARCH="$(dpkg --print-architecture)" && \
4549
chmod +x /usr/local/bin/bazel
4650
# Populate ${BUILD_PATH_JAXLIB} with editable wheels; --no-install because
4751
# (a) this is the builder stage, and (b) pip-finalize.sh does the install
48-
RUN build-jax.sh \
52+
RUN mkdir -p /builder/extra-targets && build-jax.sh \
4953
--bazel-cache ${BAZEL_CACHE} \
5054
--build-path-jaxlib ${BUILD_PATH_JAXLIB} \
55+
--extra-targets "${EXTRA_BAZEL_TARGETS}" \
56+
--extra-target-dest /builder/extra-targets \
5157
--no-install \
5258
--src-path-jax ${SRC_PATH_JAX} \
5359
--src-path-xla ${SRC_PATH_XLA} \
@@ -91,6 +97,7 @@ COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
9197
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
9298
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
9399
COPY --from=builder /usr/local/bin/bazel /usr/local/bin/bazel
100+
COPY --from=builder /builder/extra-targets/* /usr/local/bin/
94101
# Preserve the versions of jax and xla
95102
COPY --from=builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml
96103
ADD build-jax.sh build-te.sh local_cuda_arch pytest-xdist.sh test-jax.sh /usr/local/bin/

.github/container/build-jax.sh

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ usage() {
5151
echo " --clean-only Do not build, just cleanup"
5252
echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
5353
echo " --debug Build in debug mode"
54+
echo " --extra-targets T1[,T2[...] Extra bazel targets that will be built and copied to --extra-target-dest."
55+
echo " --extra-target-dest PATH Where extra target output files will be copied to."
5456
echo " -h, --help Print usage."
5557
echo " --install Install the JAX wheels when build succeeds"
5658
echo " --no-install Do not install the JAX wheels when build succeeds"
@@ -75,11 +77,13 @@ CLEANONLY=0
7577
CPU_ARCH="$(dpkg --print-architecture)"
7678
CUDA_COMPUTE_CAPABILITIES="local"
7779
DEBUG=0
80+
EXTRA_TARGETS=()
81+
EXTRA_TARGET_DEST=""
7882
INSTALL=1
7983
SRC_PATH_JAX="/opt/jax"
8084
SRC_PATH_XLA="/opt/xla"
8185

82-
args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
86+
args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,extra-targets:,extra-target-dest:,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
8387
if [[ $? -ne 0 ]]; then
8488
exit 1
8589
fi
@@ -123,6 +127,14 @@ while [ : ]; do
123127
DEBUG=1
124128
shift 1
125129
;;
130+
--extra-targets)
131+
IFS=',' read -r -a EXTRA_TARGETS <<< "$2"
132+
shift 2
133+
;;
134+
--extra-target-dest)
135+
EXTRA_TARGET_DEST="$2"
136+
shift 2
137+
;;
126138
-h | --help)
127139
usage 1
128140
;;
@@ -202,13 +214,20 @@ print_var CPU_ARCH
202214
print_var CUDA_COMPUTE_CAPABILITIES
203215
print_var CUDA_MAJOR_VERSION
204216
print_var DEBUG
217+
print_var EXTRA_TARGETS
218+
print_var EXTRA_TARGET_DEST
205219
print_var INSTALL
206220
print_var PYTHON_VERSION
207221
print_var SRC_PATH_JAX
208222
print_var SRC_PATH_XLA
209223

210224
echo "=================================================="
211225

226+
if [[ -n "${EXTRA_TARGET_DEST}" && ! -d "${EXTRA_TARGET_DEST}" ]]; then
227+
echo "You must pass a directory to --extra-target-dest"
228+
exit 1
229+
fi
230+
212231
set -x
213232
if [[ ${CLEANONLY} == 1 ]]; then
214233
clean
@@ -245,7 +264,21 @@ for component in jaxlib "jax-cuda${CUDA_MAJOR_VERSION}-pjrt" "jax-cuda${CUDA_MAJ
245264
# version, so nvidia-*-cu12 wheels disappear from the lock file
246265
sed -i "s|^${component}.*$|${component} @ file://${BUILD_PATH_JAXLIB}/${component//-/_}|" build/requirements.in
247266
done
248-
bazel run --config=cuda_libraries_from_stubs --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
267+
# Bazel args to avoid cache invalidation
268+
BAZEL_ARGS=(
269+
--config=cuda_libraries_from_stubs
270+
--repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
271+
)
272+
bazel run "${BAZEL_ARGS[@]}" --verbose_failures=true //build:requirements.update
273+
if (( "${#EXTRA_TARGETS[@]}" > 0 )); then
274+
bazel build "${BAZEL_ARGS[@]}" --verbose_failures=true "${EXTRA_TARGETS[@]}"
275+
if [[ -n "${EXTRA_TARGET_DEST}" ]]; then
276+
for target in "${EXTRA_TARGETS[@]}"; do
277+
output_files=$(bazel cquery "${BAZEL_ARGS[@]}" "${target}" --output files)
278+
cp -v ${output_files} "${EXTRA_TARGET_DEST}"
279+
done
280+
fi
281+
fi
249282
popd
250283

251284
## Install the built packages

0 commit comments

Comments
 (0)