Skip to content

Commit e7e5e06

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 9267b5d + 5ade371 commit e7e5e06

File tree

90 files changed

+3098
-928
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+3098
-928
lines changed

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ jobs:
5454
with:
5555
repository: openxla/xla
5656
path: xla
57+
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
58+
- name: Mark GitHub workspace as safe
59+
run: |
60+
git config --global --add safe.directory "$GITHUB_WORKSPACE"
5761
- name: Install JAX test requirements
5862
run: |
5963
$PYTHON -m pip install -U -r build/test-requirements.txt
@@ -63,9 +67,11 @@ jobs:
6367
$PYTHON -m pip uninstall -y jax jaxlib libtpu
6468
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
6569
# Build and install jaxlib at head
66-
$PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67-
--bazel_options="--override_repository=xla=$(pwd)/xla" \
68-
--bazel_options=--color=yes
70+
$PYTHON build/build.py build --wheels=jaxlib \
71+
--bazel_options=--config=rbe_linux_x86_64 \
72+
--local_xla_path="$(pwd)/xla" \
73+
--verbose
74+
6975
$PYTHON -m pip install dist/*.whl
7076
7177
# Install "jax" at head

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
31+
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

build/rocm/Dockerfile.ms

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
################################################################################
2-
FROM ubuntu:20.04 AS rocm_base
2+
ARG BASE_DOCKER=ubuntu:22.04
3+
FROM $BASE_DOCKER AS rocm_base
34
################################################################################
45

56
RUN --mount=type=cache,target=/var/cache/apt \

build/rocm/ci_build

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def dist_wheels(
8989

9090
mounts = [
9191
"-v",
92-
"./:/jax",
92+
os.path.abspath("./") + ":/jax",
9393
"-v",
94-
"./wheelhouse:/wheelhouse",
94+
os.path.abspath("./wheelhouse") + ":/wheelhouse",
9595
]
9696

9797
if xla_path:
@@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path):
143143

144144
def dist_docker(
145145
rocm_version,
146+
base_docker,
146147
python_versions,
147148
xla_path,
148149
rocm_build_job="",
@@ -168,6 +169,7 @@ def dist_docker(
168169
"--build-arg=ROCM_VERSION=%s" % rocm_version,
169170
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
170171
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
172+
"--build-arg=BASE_DOCKER=%s" % base_docker,
171173
"--build-arg=PYTHON_VERSION=%s" % python_version,
172174
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
173175
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
@@ -210,7 +212,7 @@ def test(image_name):
210212
# JAX and jaxlib are already installed from wheels
211213
mounts = [
212214
"-v",
213-
"./:/jax",
215+
os.path.abspath("./") + ":/jax",
214216
]
215217

216218
cmd.extend(mounts)
@@ -231,6 +233,12 @@ def test(image_name):
231233

232234
def parse_args():
233235
p = argparse.ArgumentParser()
236+
p.add_argument(
237+
"--base-docker",
238+
default="",
239+
help="Argument to override base docker in dockerfile",
240+
)
241+
234242
p.add_argument(
235243
"--python-versions",
236244
type=lambda x: x.split(","),
@@ -308,6 +316,7 @@ def main():
308316
)
309317
dist_docker(
310318
args.rocm_version,
319+
args.base_docker,
311320
args.python_versions,
312321
args.xla_source_dir,
313322
rocm_build_job=args.rocm_build_job,

build/rocm/ci_build.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ PYTHON_VERSION="3.10"
4848
ROCM_VERSION="6.1.3"
4949
ROCM_BUILD_JOB=""
5050
ROCM_BUILD_NUM=""
51-
BASE_DOCKER="ubuntu:20.04"
51+
BASE_DOCKER="ubuntu:22.04"
5252
CUSTOM_INSTALL=""
5353
JAX_USE_CLANG=""
5454
POSITIONAL_ARGS=()
@@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do
9090
ROCM_BUILD_NUM="$2"
9191
shift 2
9292
;;
93+
--base_docker)
94+
BASE_DOCKER="$2"
95+
shift 2
96+
;;
9397
--use_clang)
9498
JAX_USE_CLANG="$2"
9599
shift 2
@@ -154,6 +158,7 @@ fi
154158
# which is the ROCm image that is shipped for users to use (i.e. distributable).
155159
./build/rocm/ci_build \
156160
--rocm-version $ROCM_VERSION \
161+
--base-docker $BASE_DOCKER \
157162
--python-versions $PYTHON_VERSION \
158163
--xla-source-dir=$XLA_CLONE_DIR \
159164
--rocm-build-job=$ROCM_BUILD_JOB \

build/rocm/run_multi_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
set -eu
16+
set -xu
1717

1818
# Function to run tests with specified GPUs
1919
run_tests() {

build/rocm/tools/get_rocm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _build_installer_url(rocm_version, metadata):
229229

230230
rv = parse_version(rocm_version)
231231

232-
base_url = "http://artifactory-cdn.amd.com/artifactory/list"
232+
base_url = "https://artifactory-cdn.amd.com/artifactory/list"
233233

234234
if md["ID"] == "ubuntu":
235235
fmt = "amdgpu-install-internal_%(rocm_major)s.%(rocm_minor)s-%(os_version)s-1_all.deb"

ci/envs/default.env

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12}
5858
# this value in the Github action workflow files.
5959
export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0}
6060

61-
# #############################################################################
62-
# Docker specific environment variables.
63-
# #############################################################################
61+
# Pytest specific environment variables below. Used in run_pytest_*.sh scripts.
62+
# Sets the number of TPU cores for the TPU machine type. These values are
63+
# defined in the TPU GitHub Actions workflow.
64+
export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
6465

65-
# Docker specifc environment variables. Used by `run_docker_container.sh`
66-
export JAXCI_DOCKER_WORK_DIR="/jax"
67-
export JAXCI_DOCKER_IMAGE=""
68-
export JAXCI_DOCKER_ARGS=""
66+
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
67+
# on the system. By default, it is set to match the version of the hermetic
68+
# Python used by Bazel for building the wheels.
69+
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}

ci/run_pytest_cpu.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616
# Runs Pyest CPU tests. Requires a jaxlib wheel to be present
17+
<<<<<<< HEAD
1718
# inside $JAXCI_OUTPUT_DIR (../dist)
19+
=======
20+
# inside the $JAXCI_OUTPUT_DIR (../dist)
21+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
1822
#
1923
# -e: abort script if one command fails
2024
# -u: error if undefined variable used
@@ -23,20 +27,38 @@
2327
# -o allexport: export all functions and variables to be available to subscripts
2428
set -exu -o history -o allexport
2529

30+
<<<<<<< HEAD
2631
# Inherit default JAXCI environment variables.
2732
source ci/envs/default.env
2833

34+
=======
35+
# Source default JAXCI environment variables.
36+
source ci/envs/default.env
37+
38+
# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system.
39+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
2940
echo "Installing wheels locally..."
3041
source ./ci/utilities/install_wheels_locally.sh
3142

3243
# Set up the build environment.
3344
source "ci/utilities/setup_build_environment.sh"
3445

46+
<<<<<<< HEAD
3547
export PY_COLORS=1
3648
export JAX_SKIP_SLOW_TESTS=true
3749

3850
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
3951

4052
export TF_CPP_MIN_LOG_LEVEL=0
53+
=======
54+
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
55+
56+
# Set up all test environment variables
57+
export PY_COLORS=1
58+
export JAX_SKIP_SLOW_TESTS=true
59+
export TF_CPP_MIN_LOG_LEVEL=0
60+
# End of test environment variable setup
61+
62+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
4163
echo "Running CPU tests..."
4264
"$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples

ci/run_pytest_gpu.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16+
<<<<<<< HEAD
1617
# Runs Pyest CPU tests. Requires all jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
18+
=======
19+
# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
20+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
1721
# wheels to be present inside $JAXCI_OUTPUT_DIR (../dist)
1822
#
1923
# -e: abort script if one command fails
@@ -23,16 +27,25 @@
2327
# -o allexport: export all functions and variables to be available to subscripts
2428
set -exu -o history -o allexport
2529

30+
<<<<<<< HEAD
2631
# Inherit default JAXCI environment variables.
2732
source ci/envs/default.env
2833

2934
# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels on the system.
35+
=======
36+
# Source default JAXCI environment variables.
37+
source ci/envs/default.env
38+
39+
# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the
40+
# $JAXCI_OUTPUT_DIR directory on the system.
41+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
3042
echo "Installing wheels locally..."
3143
source ./ci/utilities/install_wheels_locally.sh
3244

3345
# Set up the build environment.
3446
source "ci/utilities/setup_build_environment.sh"
3547

48+
<<<<<<< HEAD
3649
export PY_COLORS=1
3750
export JAX_SKIP_SLOW_TESTS=true
3851

@@ -46,6 +59,28 @@ echo "Running GPU tests..."
4659
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
4760
export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
4861
"$JAXCI_PYTHON" -m pytest -n 8 --tb=short --maxfail=20 \
62+
=======
63+
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
64+
65+
nvidia-smi
66+
67+
# Set up all test environment variables
68+
export PY_COLORS=1
69+
export JAX_SKIP_SLOW_TESTS=true
70+
export NCCL_DEBUG=WARN
71+
export TF_CPP_MIN_LOG_LEVEL=0
72+
73+
# Set the number of processes to run to be 4x the number of GPUs.
74+
export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
75+
export num_processes=`expr 4 \* $gpu_count`
76+
77+
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
78+
export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
79+
# End of test environment variable setup
80+
81+
echo "Running GPU tests..."
82+
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
83+
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
4984
tests examples \
5085
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
5186
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \

0 commit comments

Comments
 (0)