Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions ci/build_oneapi_artifacts.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/bin/bash
# Copyright 2026 The JAX Authors.
##
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Build JAX artifacts.
# Usage: ./ci/build_oneapi_artifacts.sh "<artifact>"
# Supported artifact values are: jax-oneapi-plugin, jax-oneapi-pjrt
# E.g: ./ci/build_oneapi_artifacts.sh "jax-oneapi-plugin" or ./ci/build_oneapi_artifacts.sh "jax-oneapi-pjrt"
#
# -e: abort script if one command fails
# -u: error if undefined variable used
# -x: log all commands
# -o history: record shell history
# -o allexport: export all functions and variables to be available to subscripts
set -exu -o history -o allexport

artifact="$1"

# Source default JAXCI environment variables.
source ci/envs/default.env

# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

allowed_artifacts=("jax-oneapi-plugin" "jax-oneapi-pjrt")
Comment thread
ashahba marked this conversation as resolved.

if [[ ! " ${allowed_artifacts[*]} " =~ " ${artifact} " ]]; then
echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[*]}"
exit 1
fi

os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

bazel_startup_options=""
if [[ -n "${JAXCI_BAZEL_OUTPUT_BASE}" ]]; then
bazel_startup_options="--output_base=${JAXCI_BAZEL_OUTPUT_BASE}"
fi

# Determine the artifact tag flags based on the artifact type. A release
# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is
# tagged with the release version and a nightly suffix that contains the
# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with
# the git commit hash of the HEAD of the current branch and the date of the
# commit (e.g. 0.5.1.dev20250128+3e75e20c7).
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
current_date=$(date +%Y%m%d)
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
else
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
exit 1
fi

if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then
JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION%-nogil}-ft
fi

# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"

# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""

if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
bazel_cpu_pool_config="--config=rbe_cpu_pool"
else
bazelrc_config="ci_${bazelrc_config}"
bazel_cpu_pool_config=""

# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
fi

oneapi_version_flag="--oneapi_version=$JAXCI_ONEAPI_VERSION"

# Build the artifact.
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
$bazel_cpu_pool_config \
--bazel_startup_options="$bazel_startup_options" \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
$oneapi_version_flag \
--verbose --detailed_timestamped_log \
--output_path="$JAXCI_OUTPUT_DIR" \
$artifact_tag_flags

# If building release artifacts, we also build a release candidate ("rc")
# tagged wheel.
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
$bazel_cpu_pool_config \
--bazel_startup_options="$bazel_startup_options" \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
$oneapi_version_flag \
--verbose --detailed_timestamped_log \
--output_path="$JAXCI_OUTPUT_DIR" \
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
fi

./ci/utilities/run_auditwheel.sh
4 changes: 4 additions & 0 deletions ci/envs/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ export JAXCI_CUDA_VERSION=${JAXCI_CUDA_VERSION:-12}
# running the tests.
export JAXCI_ROCM_VERSION=${JAXCI_ROCM_VERSION:-7}

# Controls the OneAPI version to use when building the JAX artifacts or
# running the tests.
export JAXCI_ONEAPI_VERSION=${JAXCI_ONEAPI_VERSION:-2025.1}

# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local
# copy of XLA instead of the pinned version in the WORKSPACE.
export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
Expand Down
4 changes: 2 additions & 2 deletions ci/utilities/run_auditwheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# Get a list of all the wheels in the output directory. Only look for wheels
# that need to be verified for manylinux compliance.
WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" -o -name "*jax*rocm*pjrt*whl" -o -name "*jax*rocm*plugin*whl" \))
WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" -o -name "*jax*rocm*pjrt*whl" -o -name "*jax*rocm*plugin*whl" -o -name "*jax*oneapi*pjrt*whl" -o -name "*jax*oneapi*plugin*whl" \))

if [[ -z "$WHEELS" ]]; then
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
Expand All @@ -27,7 +27,7 @@ fi

for wheel in $WHEELS; do
# Skip checking manylinux compliance for jax wheel.
if [[ "$wheel" =~ 'jax-' ]]; then
if [[ "${wheel##*/}" =~ ^jax- ]]; then
continue
fi
printf "\nRunning auditwheel on the following wheel:"
Expand Down
Loading