Skip to content

Commit de73307

Browse files
committed
remove build artifact envs
1 parent 5306637 commit de73307

File tree

12 files changed

+63
-147
lines changed

12 files changed

+63
-147
lines changed

.github/workflows/bazel_gpu_non_rbe.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ jobs:
3838
- name: Build jaxlib
3939
env:
4040
JAXCI_CLONE_MAIN_XLA: 1
41-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
41+
run: ./ci/build_artifacts.sh "jaxlib"
4242
- name: Build jax-cuda-plugin
4343
env:
4444
JAXCI_CLONE_MAIN_XLA: 1
45-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env"
45+
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
4646
- name: Build jax-cuda-pjrt
4747
env:
4848
JAXCI_CLONE_MAIN_XLA: 1
49-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env"
49+
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
5050
- name: Run Bazel GPU tests locally
5151
run: ./ci/run_bazel_test_gpu_non_rbe.sh

.github/workflows/build_artifacts.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Build JAX Artifacts
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -71,4 +71,4 @@ jobs:
7171
- name: Build ${{ matrix.artifact }}
7272
env:
7373
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
74-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/${{ matrix.artifact }}.env"
74+
run: ./ci/build_artifacts.sh "${{ matrix.artifact }}"

.github/workflows/pytest_cpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
with:
4545
halt-dispatch-input: ${{ inputs.halt-for-connection }}
4646
- name: Build jaxlib
47-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
47+
run: ./ci/build_artifacts.sh "jaxlib"
4848
- name: Install pytest
4949
env:
5050
JAXCI_PYTHON: python${{ matrix.python }}

.github/workflows/pytest_gpu.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ jobs:
3737
with:
3838
halt-dispatch-input: ${{ inputs.halt-for-connection }}
3939
- name: Build jaxlib
40-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
40+
run: ./ci/build_artifacts.sh "jaxlib"
4141
- name: Build jax-cuda-plugin
42-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env"
42+
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
4343
- name: Build jax-cuda-pjrt
44-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env"
44+
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
4545

4646
run_tests:
4747
needs: build_artifacts

.github/workflows/pytest_tpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
with:
4040
halt-dispatch-input: ${{ inputs.halt-for-connection }}
4141
- name: Build jaxlib
42-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
42+
run: ./ci/build_artifacts.sh "jaxlib"
4343
- name: Install pytest
4444
env:
4545
JAXCI_PYTHON: python${{ matrix.python }}

ci/build_artifacts.sh

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
# Build JAX artifacts. Requires an env file from the ci/envs/build_artifacts to
17-
# be passed as an argument
16+
# Build JAX artifacts.
17+
# Usage: ./ci/build_artifacts.sh "<comma-separated artifact values>"
18+
# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
19+
# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
20+
# Multiple artifacts builds are permitted. E.g: ./ci/build_artifacts.sh "jax,jaxlib"
1821
#
1922
# -e: abort script if one command fails
2023
# -u: error if undefined variable used
@@ -23,43 +26,56 @@
2326
# -o allexport: export all functions and variables to be available to subscripts
2427
set -exu -o history -o allexport
2528

26-
# If a JAX CI env file has not been passed, exit.
27-
if [[ -z "$1" ]]; then
28-
echo "ERROR: No JAX CI env file passed."
29-
echo "build_artifacts.sh requires that a path to a JAX CI env file to be"
30-
echo "passed as an argument when invoking the build scripts."
31-
echo "Pass in a corresponding env file from the ci/envs/build_artifacts"
32-
echo "directory to continue."
33-
exit 1
34-
fi
29+
# Store the comma-separated string in a variable
30+
artifacts="$1"
31+
32+
# Replace commas with spaces
33+
artifacts=$(echo "$artifacts" | sed 's/,/ /g')
34+
35+
# Create an array from the space-separated string
36+
artifacts=($artifacts)
37+
38+
# Source default JAXCI environment variables.
39+
source ci/envs/default.env
3540

36-
# Source JAXCI environment variables.
37-
source "$1"
3841
# Set up the build environment.
3942
source "ci/utilities/setup_build_environment.sh"
4043

41-
# Build the jax artifact
42-
if [[ "$JAXCI_BUILD_JAX" == 1 ]]; then
43-
python -m build --outdir $JAXCI_OUTPUT_DIR
44-
fi
44+
os=$(uname -s | awk '{print tolower($0)}')
45+
allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt")
46+
47+
for artifact in "${artifacts[@]}"; do
4548

46-
# Build the jaxlib CPU artifact
47-
if [[ "$JAXCI_BUILD_JAXLIB" == 1 ]]; then
48-
python build/build.py build_artifacts --wheel_list="jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
49-
fi
49+
if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
50+
# Build the jax artifact
51+
if [[ "$artifact" == "jax" ]]; then
52+
python -m build --outdir $JAXCI_OUTPUT_DIR
53+
fi
5054

51-
# Build the jax-cuda-plugin artifact
52-
if [[ "$JAXCI_BUILD_PLUGIN" == 1 ]]; then
53-
python build/build.py build_artifacts --wheel_list="jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
54-
fi
55+
# Build the jaxlib CPU artifact
56+
if [[ "$artifact" == "jaxlib" ]]; then
57+
python build/build.py build_artifacts --wheel_list="jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
58+
fi
5559

56-
# Build the jax-cuda-pjrt artifact
57-
if [[ "$JAXCI_BUILD_PJRT" == 1 ]]; then
58-
python build/build.py build_artifacts --wheel_list="jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
59-
fi
60+
# Build the jax-cuda-plugin artifact
61+
if [[ "$artifact" == "jax-cuda-plugin" ]]; then
62+
python build/build.py build_artifacts --wheel_list="jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
63+
fi
64+
65+
# Build the jax-cuda-pjrt artifact
66+
if [[ "$artifact" == "jax-cuda-pjrt" ]]; then
67+
python build/build.py build_artifacts --wheel_list="jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
68+
fi
69+
70+
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
71+
# run `auditwheel show` to verify manylinux compliance.
72+
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
73+
./ci/utilities/run_auditwheel.sh
74+
fi
75+
76+
else
77+
echo "Error: Invalid artifact '$artifact'. Allowed values are: ${allowed_artifacts[@]}"
78+
exit 1
79+
fi
6080

61-
# After building `jaxlib`, `jaxcuda-plugin`, and `jax-cuda-pjrt`, we run
62-
# `auditwheel show` to ensure manylinux compliance.
63-
if [[ "$JAXCI_RUN_AUDITWHEEL" == 1 ]]; then
64-
./ci/utilities/run_auditwheel.sh
65-
fi
81+
done

ci/envs/build_artifacts/jax-cuda-pjrt.env

Lines changed: 0 additions & 22 deletions
This file was deleted.

ci/envs/build_artifacts/jax-cuda-plugin.env

Lines changed: 0 additions & 22 deletions
This file was deleted.

ci/envs/build_artifacts/jax.env

Lines changed: 0 additions & 19 deletions
This file was deleted.

ci/envs/build_artifacts/jaxlib.env

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)