1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515# ==============================================================================
16- # Build JAX artifacts. Requires an env file from the ci/envs/build_artifacts to
17- # be passed as an argument
16+ # Build JAX artifacts.
17+ # Usage: ./ci/build_artifacts.sh "<comma-separated artifact values>"
18+ # Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
19+ # E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
20+ # Multiple artifacts builds are permitted. E.g: ./ci/build_artifacts.sh "jax,jaxlib"
1821#
1922# -e: abort script if one command fails
2023# -u: error if undefined variable used
2326# -o allexport: export all functions and variables to be available to subscripts
2427set -exu -o history -o allexport
2528
26- # If a JAX CI env file has not been passed, exit.
27- if [[ -z " $1 " ]]; then
28- echo " ERROR: No JAX CI env file passed."
29- echo " build_artifacts.sh requires that a path to a JAX CI env file to be"
30- echo " passed as an argument when invoking the build scripts."
31- echo " Pass in a corresponding env file from the ci/envs/build_artifacts"
32- echo " directory to continue."
33- exit 1
34- fi
29+ # Store the comma-separated string in a variable
30+ artifacts=" $1 "
31+
32+ # Replace commas with spaces
33+ artifacts=$( echo " $artifacts " | sed ' s/,/ /g' )
34+
35+ # Create an array from the space-separated string
36+ artifacts=($artifacts )
37+
38+ # Source default JAXCI environment variables.
39+ source ci/envs/default.env
3540
36- # Source JAXCI environment variables.
37- source " $1 "
3841# Set up the build environment.
3942source " ci/utilities/setup_build_environment.sh"
4043
41- # Build the jax artifact
42- if [[ " $JAXCI_BUILD_JAX " == 1 ]] ; then
43- python -m build --outdir $JAXCI_OUTPUT_DIR
44- fi
44+ os= $( uname -s | awk ' {print tolower($0)} ' )
45+ allowed_artifacts=( " jax " " jaxlib " " jax-cuda-plugin " " jax-cuda-pjrt " )
46+
47+ for artifact in " ${artifacts[@]} " ; do
4548
46- # Build the jaxlib CPU artifact
47- if [[ " $JAXCI_BUILD_JAXLIB " == 1 ]]; then
48- python build/build.py build_artifacts --wheel_list=" jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
49- fi
49+ if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
50+ # Build the jax artifact
51+ if [[ " $artifact " == " jax" ]]; then
52+ python -m build --outdir $JAXCI_OUTPUT_DIR
53+ fi
5054
51- # Build the jax-cuda-plugin artifact
52- if [[ " $JAXCI_BUILD_PLUGIN " == 1 ]]; then
53- python build/build.py build_artifacts --wheel_list=" jax-cuda-plugin " --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
54- fi
55+ # Build the jaxlib CPU artifact
56+ if [[ " $artifact " == " jaxlib " ]]; then
57+ python build/build.py build_artifacts --wheel_list=" jaxlib " --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
58+ fi
5559
56- # Build the jax-cuda-pjrt artifact
57- if [[ " $JAXCI_BUILD_PJRT " == 1 ]]; then
58- python build/build.py build_artifacts --wheel_list=" jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
59- fi
60+ # Build the jax-cuda-plugin artifact
61+ if [[ " $artifact " == " jax-cuda-plugin" ]]; then
62+ python build/build.py build_artifacts --wheel_list=" jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
63+ fi
64+
65+ # Build the jax-cuda-pjrt artifact
66+ if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
67+ python build/build.py build_artifacts --wheel_list=" jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
68+ fi
69+
70+ # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
71+ # run `auditwheel show` to verify manylinux compliance.
72+ if [[ " $os " == " linux" ]] && [[ " $artifact " != " jax" ]]; then
73+ ./ci/utilities/run_auditwheel.sh
74+ fi
75+
76+ else
77+ echo " Error: Invalid artifact '$artifact '. Allowed values are: ${allowed_artifacts[@]} "
78+ exit 1
79+ fi
6080
61- # After building `jaxlib`, `jaxcuda-plugin`, and `jax-cuda-pjrt`, we run
62- # `auditwheel show` to ensure manylinux compliance.
63- if [[ " $JAXCI_RUN_AUDITWHEEL " == 1 ]]; then
64- ./ci/utilities/run_auditwheel.sh
65- fi
81+ done
0 commit comments