2525# -o allexport: export all functions and variables to be available to subscripts
2626set -exu -o history -o allexport
2727
28- # Store the comma-separated string in a variable
29- artifacts=" $1 "
30-
31- # Replace commas with spaces
32- artifacts=$( echo " $artifacts " | sed ' s/,/ /g' )
33-
34- # Create an array from the space-separated string
35- artifacts=($artifacts )
28+ artifact=" $1 "
3629
3730# Source default JAXCI environment variables.
3831source ci/envs/default.env
@@ -52,42 +45,40 @@ if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then
5245 arch=" amd64"
5346fi
5447
55- # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
56- if [[ " $artifact " != " jax" ]]; then
57- bazelrc_config=" ${os} _${arch} "
58- if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
59- bazelrc_config=" rbe_$bazelrc_config "
60- else
61- bazelrc_config=" ci_$bazelrc_config "
62- fi
63- fi
64-
6548if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
6649 # Build the jax artifact
6750 if [[ " $artifact " == " jax" ]]; then
6851 python -m build --outdir $JAXCI_OUTPUT_DIR
69- fi
52+ else
53+ # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
54+ bazelrc_config=" ${os} _${arch} "
55+ if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
56+ bazelrc_config=" rbe_$bazelrc_config "
57+ else
58+ bazelrc_config=" ci_$bazelrc_config "
59+ fi
7060
71- # Build the jaxlib CPU artifact
72- if [[ " $artifact " == " jaxlib" ]]; then
73- python build/build.py build --wheels=" jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
74- fi
61+ # Build the jaxlib CPU artifact
62+ if [[ " $artifact " == " jaxlib" ]]; then
63+ python build/build.py build --wheels=" jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
64+ fi
7565
76- # Build the jax-cuda-plugin artifact
77- if [[ " $artifact " == " jax-cuda-plugin" ]]; then
78- python build/build.py build --wheels=" jax-cuda-plugin" --bazel_build_options=--config=" ${bazelrc_config} _cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
79- fi
66+ # Build the jax-cuda-plugin artifact
67+ if [[ " $artifact " == " jax-cuda-plugin" ]]; then
68+ python build/build.py build --wheels=" jax-cuda-plugin" --bazel_build_options=--config=" ${bazelrc_config} _cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
69+ fi
8070
81- # Build the jax-cuda-pjrt artifact
82- if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
83- python build/build.py build --wheels=" jax-cuda-pjrt" --bazel_build_options=--config=" ${bazelrc_config} _cuda" --verbose
84- fi
71+ # Build the jax-cuda-pjrt artifact
72+ if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
73+ python build/build.py build --wheels=" jax-cuda-pjrt" --bazel_build_options=--config=" ${bazelrc_config} _cuda" --verbose
74+ fi
8575
86- # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
87- # run `auditwheel show` to verify manylinux compliance.
88- if [[ " $os " == " linux" ]] && [[ " $artifact " != " jax" ]]; then
89- ./ci/utilities/run_auditwheel.sh
90- fi
76+ # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
77+ # run `auditwheel show` to verify manylinux compliance.
78+ if [[ " $os " == " linux" ]]; then
79+ ./ci/utilities/run_auditwheel.sh
80+ fi
81+ fi
9182
9283else
9384 echo " Error: Invalid artifact '$artifact '. Allowed values are: ${allowed_artifacts[@]} "
0 commit comments