1414# limitations under the License.
1515# ==============================================================================
1616# Build JAX artifacts.
17- # Usage: ./ci/build_artifacts.sh "<comma-separated artifact values >"
17+ # Usage: ./ci/build_artifacts.sh "<artifact>"
1818# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
1919# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
20- # Multiple artifacts builds are permitted. E.g: ./ci/build_artifacts.sh "jax,jaxlib"
2120#
2221# -e: abort script if one command fails
2322# -u: error if undefined variable used
@@ -46,53 +45,51 @@ allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt")
4645os=$( uname -s | awk ' {print tolower($0)}' )
4746arch=$( uname -m)
4847
49- # Adjust the values when running on Windows x86 to match the case in
48+ # Adjust the values when running on Windows x86 to match the config in
5049# .bazelrc
5150if [[ $os =~ " msys_nt" ]] && [[ $arch == " x86_64" ]]; then
5251 os=" windows"
5352 arch=" amd64"
5453fi
5554
56- # Use "rbe_" config for Linux x86/Windows and "ci_" for other platforms
57- bazelrc_config=" ${os} _${arch} "
58- if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
59- bazelrc_config=" rbe_$bazelrc_config "
60- else
61- bazelrc_config=" ci_$bazelrc_config "
55+ # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
56+ if [[ " $artifact " != " jax" ]]; then
57+ bazelrc_config=" ${os} _${arch} "
58+ if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
59+ bazelrc_config=" rbe_$bazelrc_config "
60+ else
61+ bazelrc_config=" ci_$bazelrc_config "
62+ fi
6263fi
6364
64- for artifact in " ${artifacts[@]} " ; do
65-
66- if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
67- # Build the jax artifact
68- if [[ " $artifact " == " jax" ]]; then
69- python -m build --outdir $JAXCI_OUTPUT_DIR
70- fi
71-
72- # Build the jaxlib CPU artifact
73- if [[ " $artifact " == " jaxlib" ]]; then
74- python build/build.py build --wheels=" jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
75- fi
65+ if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
66+ # Build the jax artifact
67+ if [[ " $artifact " == " jax" ]]; then
68+ python -m build --outdir $JAXCI_OUTPUT_DIR
69+ fi
7670
77- # Build the jax-cuda-plugin artifact
78- if [[ " $artifact " == " jax-cuda-plugin " ]]; then
79- python build/build.py build --wheels=" jax-cuda-plugin " --bazel_build_options=--config=" ${ bazelrc_config} _cuda " --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
80- fi
71+ # Build the jaxlib CPU artifact
72+ if [[ " $artifact " == " jaxlib " ]]; then
73+ python build/build.py build --wheels=" jaxlib " --bazel_build_options=--config=$ bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
74+ fi
8175
82- # Build the jax-cuda-pjrt artifact
83- if [[ " $artifact " == " jax-cuda-pjrt " ]]; then
84- python build/build.py build --wheels=" jax-cuda-pjrt " --bazel_build_options=--config=" ${bazelrc_config} _cuda" --verbose
85- fi
76+ # Build the jax-cuda-plugin artifact
77+ if [[ " $artifact " == " jax-cuda-plugin " ]]; then
78+ python build/build.py build --wheels=" jax-cuda-plugin " --bazel_build_options=--config=" ${bazelrc_config} _cuda" --python_version= $JAXCI_HERMETIC_PYTHON_VERSION --verbose
79+ fi
8680
87- # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
88- # run `auditwheel show` to verify manylinux compliance.
89- if [[ " $os " == " linux" ]] && [[ " $artifact " != " jax" ]]; then
90- ./ci/utilities/run_auditwheel.sh
91- fi
81+ # Build the jax-cuda-pjrt artifact
82+ if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
83+ python build/build.py build --wheels=" jax-cuda-pjrt" --bazel_build_options=--config=" ${bazelrc_config} _cuda" --verbose
84+ fi
9285
93- else
94- echo " Error: Invalid artifact '$artifact '. Allowed values are: ${allowed_artifacts[@]} "
95- exit 1
86+ # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
87+ # run `auditwheel show` to verify manylinux compliance.
88+ if [[ " $os " == " linux" ]] && [[ " $artifact " != " jax" ]]; then
89+ ./ci/utilities/run_auditwheel.sh
9690 fi
9791
98- done
92+ else
93+ echo " Error: Invalid artifact '$artifact '. Allowed values are: ${allowed_artifacts[@]} "
94+ exit 1
95+ fi
0 commit comments