Skip to content

Commit 4c93d9f

Browse files
committed
Move around logic for when we are building jaxlib/cuda artifacts
1 parent 0fceda9 commit 4c93d9f

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

ci/build_artifacts.sh

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,7 @@
2525
# -o allexport: export all functions and variables to be available to subscripts
2626
set -exu -o history -o allexport
2727

28-
# Store the comma-separated string in a variable
29-
artifacts="$1"
30-
31-
# Replace commas with spaces
32-
artifacts=$(echo "$artifacts" | sed 's/,/ /g')
33-
34-
# Create an array from the space-separated string
35-
artifacts=($artifacts)
28+
artifact="$1"
3629

3730
# Source default JAXCI environment variables.
3831
source ci/envs/default.env
@@ -52,42 +45,40 @@ if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then
5245
arch="amd64"
5346
fi
5447

55-
# For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
56-
if [[ "$artifact" != "jax" ]]; then
57-
bazelrc_config="${os}_${arch}"
58-
if ( [[ "$os" == "linux" ]] && [[ "$arch" == "x86_64" ]] ) || [[ "$os" == "windows" ]]; then
59-
bazelrc_config="rbe_$bazelrc_config"
60-
else
61-
bazelrc_config="ci_$bazelrc_config"
62-
fi
63-
fi
64-
6548
if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
6649
# Build the jax artifact
6750
if [[ "$artifact" == "jax" ]]; then
6851
python -m build --outdir $JAXCI_OUTPUT_DIR
69-
fi
52+
else
53+
# For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms
54+
bazelrc_config="${os}_${arch}"
55+
if ( [[ "$os" == "linux" ]] && [[ "$arch" == "x86_64" ]] ) || [[ "$os" == "windows" ]]; then
56+
bazelrc_config="rbe_$bazelrc_config"
57+
else
58+
bazelrc_config="ci_$bazelrc_config"
59+
fi
7060

71-
# Build the jaxlib CPU artifact
72-
if [[ "$artifact" == "jaxlib" ]]; then
73-
python build/build.py build --wheels="jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
74-
fi
61+
# Build the jaxlib CPU artifact
62+
if [[ "$artifact" == "jaxlib" ]]; then
63+
python build/build.py build --wheels="jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
64+
fi
7565

76-
# Build the jax-cuda-plugin artifact
77-
if [[ "$artifact" == "jax-cuda-plugin" ]]; then
78-
python build/build.py build --wheels="jax-cuda-plugin" --bazel_build_options=--config="${bazelrc_config}_cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
79-
fi
66+
# Build the jax-cuda-plugin artifact
67+
if [[ "$artifact" == "jax-cuda-plugin" ]]; then
68+
python build/build.py build --wheels="jax-cuda-plugin" --bazel_build_options=--config="${bazelrc_config}_cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
69+
fi
8070

81-
# Build the jax-cuda-pjrt artifact
82-
if [[ "$artifact" == "jax-cuda-pjrt" ]]; then
83-
python build/build.py build --wheels="jax-cuda-pjrt" --bazel_build_options=--config="${bazelrc_config}_cuda" --verbose
84-
fi
71+
# Build the jax-cuda-pjrt artifact
72+
if [[ "$artifact" == "jax-cuda-pjrt" ]]; then
73+
python build/build.py build --wheels="jax-cuda-pjrt" --bazel_build_options=--config="${bazelrc_config}_cuda" --verbose
74+
fi
8575

86-
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
87-
# run `auditwheel show` to verify manylinux compliance.
88-
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
89-
./ci/utilities/run_auditwheel.sh
90-
fi
76+
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
77+
# run `auditwheel show` to verify manylinux compliance.
78+
if [[ "$os" == "linux" ]]; then
79+
./ci/utilities/run_auditwheel.sh
80+
fi
81+
fi
9182

9283
else
9384
echo "Error: Invalid artifact '$artifact'. Allowed values are: ${allowed_artifacts[@]}"

0 commit comments

Comments
 (0)