1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515# ==============================================================================
16- # Build JAX artifacts.
16+ # Build JAX artifacts.
1717# Usage: ./ci/build_artifacts.sh "<artifact>"
1818# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
1919# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
@@ -46,10 +46,12 @@ if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then
4646fi
4747
4848if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
49+
4950 # Build the jax artifact
5051 if [[ " $artifact " == " jax" ]]; then
5152 python -m build --outdir $JAXCI_OUTPUT_DIR
5253 else
54+
5355 # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
5456 bazelrc_config=" ${os} _${arch} "
5557 if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
@@ -60,27 +62,28 @@ if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
6062
6163 # Build the jaxlib CPU artifact
6264 if [[ " $artifact " == " jaxlib" ]]; then
63- python build/build.py build --wheels=" jaxlib" --bazel_build_options =--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
65+ python build/build.py build --wheels=" jaxlib" --bazel_options =--config=" $bazelrc_config " --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
6466 fi
6567
6668 # Build the jax-cuda-plugin artifact
6769 if [[ " $artifact " == " jax-cuda-plugin" ]]; then
68- python build/build.py build --wheels=" jax-cuda-plugin" --bazel_build_options =--config=" ${bazelrc_config} _cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
70+ python build/build.py build --wheels=" jax-cuda-plugin" --bazel_options =--config=" ${bazelrc_config} _cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
6971 fi
7072
7173 # Build the jax-cuda-pjrt artifact
7274 if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
73- python build/build.py build --wheels=" jax-cuda-pjrt" --bazel_build_options =--config=" ${bazelrc_config} _cuda" --verbose
75+ python build/build.py build --wheels=" jax-cuda-pjrt" --bazel_options =--config=" ${bazelrc_config} _cuda" --verbose
7476 fi
7577
7678 # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
7779 # run `auditwheel show` to verify manylinux compliance.
7880 if [[ " $os " == " linux" ]]; then
7981 ./ci/utilities/run_auditwheel.sh
8082 fi
81- fi
83+
84+ fi
8285
8386else
84- echo " Error: Invalid artifact ' $artifact ' . Allowed values are: ${allowed_artifacts[@]} "
87+ echo " Error: Invalid artifact: $artifact . Allowed values are: ${allowed_artifacts[@]} "
8588 exit 1
8689fi
0 commit comments