Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/bazel_test_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ on:
description: "Which TPU type is used for testing?"
type: string
default: "v5e-8"
tpu-parallelism-mode:
description: "How to split single-accelerator TPU test workers: chip or core."
type: string
default: "chip"
python:
description: "Which Python version should be used for testing?"
type: string
Expand Down Expand Up @@ -105,6 +109,7 @@ jobs:
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
JAXCI_TPU_PARALLELISM_MODE: "${{ inputs.tpu-parallelism-mode }}"
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/pytest_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ on:
description: "Which TPU type is used for testing?"
type: string
default: "v5e-8"
tpu-parallelism-mode:
description: "How to split single-accelerator TPU test workers: chip or core."
type: string
default: "chip"
python:
description: "Which Python version should be used for testing?"
type: string
Expand Down Expand Up @@ -79,6 +83,7 @@ jobs:
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
JAXCI_TPU_PARALLELISM_MODE: "${{ inputs.tpu-parallelism-mode }}"

steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
Expand Down
33 changes: 24 additions & 9 deletions .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
# Temporarily narrowed for TPU-only validation of TPU7x core splitting.
runner: ["linux-x86-n4-16"]
artifact: ["jaxlib"]
python: ["3.11"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
Expand All @@ -76,6 +77,8 @@ jobs:
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

build-cuda-artifacts:
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
Expand All @@ -99,7 +102,8 @@ jobs:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
Expand All @@ -121,7 +125,8 @@ jobs:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
Expand Down Expand Up @@ -162,6 +167,8 @@ jobs:
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-cpu-py-import:
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
uses: ./.github/workflows/bazel_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
Expand All @@ -182,7 +189,8 @@ jobs:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda.yml
strategy:
Expand Down Expand Up @@ -214,7 +222,8 @@ jobs:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
Expand Down Expand Up @@ -259,8 +268,9 @@ jobs:
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
cores: ${{ matrix.tpu-specs.type == 'v7x-8' && '8' || matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
tpu-parallelism-mode: ${{ matrix.tpu-specs.type == 'v7x-8' && 'core' || 'chip' }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
Expand All @@ -286,8 +296,9 @@ jobs:
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
cores: ${{ matrix.tpu-specs.type == 'v7x-8' && '8' || matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
tpu-parallelism-mode: ${{ matrix.tpu-specs.type == 'v7x-8' && 'core' || 'chip' }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
Expand All @@ -297,6 +308,8 @@ jobs:
clone_main_xla: 1

build-rocm-artifacts:
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
uses: ./.github/workflows/build_rocm_artifacts.yml
permissions:
id-token: write
Expand Down Expand Up @@ -327,7 +340,8 @@ jobs:
s3_upload_uri: 's3://jax-ci-amd/rocm-wheels/wheel-tests-continuous/${{ github.run_number }}/${{ github.run_attempt }}'

run-pytest-rocm:
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts]
uses: ./.github/workflows/pytest_rocm.yml
permissions:
Expand All @@ -353,7 +367,8 @@ jobs:
s3_download_uri: 's3://jax-ci-amd/rocm-wheels/wheel-tests-continuous/${{ github.run_number }}/${{ github.run_attempt }}'

run-bazel-test-rocm:
if: ${{ !cancelled() }}
# Temporarily disabled for TPU-only validation of TPU7x core splitting.
if: ${{ false }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts]
uses: ./.github/workflows/bazel_rocm.yml
permissions:
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/wheel_tests_nightly_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,9 @@ jobs:
name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
cores: ${{ matrix.tpu-specs.type == 'v7x-8' && matrix.libtpu-version-type == 'nightly' && '8' || matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
tpu-parallelism-mode: ${{ matrix.tpu-specs.type == 'v7x-8' && matrix.libtpu-version-type == 'nightly' && 'core' || 'chip' }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
Expand Down Expand Up @@ -351,8 +352,9 @@ jobs:
name: "Bazel tests TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
cores: ${{ matrix.tpu-specs.type == 'v7x-8' && matrix.libtpu-version-type == 'nightly' && '8' || matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
tpu-parallelism-mode: ${{ matrix.tpu-specs.type == 'v7x-8' && matrix.libtpu-version-type == 'nightly' && 'core' || 'chip' }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
Expand Down
21 changes: 19 additions & 2 deletions build/parallel_accelerator_execute.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
# Environment variables:
# JAX_ACCELERATOR_COUNT = Number of accelerators (GPUs/TPUs) available.
# JAX_TESTS_PER_ACCELERATOR = Number of accelerators (GPUs/TPUs) available.
# JAX_TPU_XDIST_VISIBILITY_MODE = "chips" or "devices" for TPU assignment.

JAX_ACCELERATOR_COUNT=${JAX_ACCELERATOR_COUNT:-4}
JAX_TESTS_PER_ACCELERATOR=${JAX_TESTS_PER_ACCELERATOR:-8}
JAX_TPU_XDIST_VISIBILITY_MODE=${JAX_TPU_XDIST_VISIBILITY_MODE:-chips}

export TF_PER_DEVICE_MEMORY_LIMIT_MB=${TF_PER_DEVICE_MEMORY_LIMIT_MB:-2048}

Expand Down Expand Up @@ -73,7 +75,22 @@ for j in `seq 0 $((JAX_TESTS_PER_ACCELERATOR-1))`; do
(
# This export only works within the brackets, so it is isolated to one
# single command.
export TPU_VISIBLE_CHIPS=$i
case "$JAX_TPU_XDIST_VISIBILITY_MODE" in
devices)
unset TPU_VISIBLE_CHIPS
export TPU_VISIBLE_DEVICES=$i
export TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1,1
export TPU_PROCESS_BOUNDS=1,1,1,1
;;
chips)
unset TPU_VISIBLE_DEVICES
export TPU_VISIBLE_CHIPS=$i
;;
*)
echo "Unknown JAX_TPU_XDIST_VISIBILITY_MODE: $JAX_TPU_XDIST_VISIBILITY_MODE"
exit 1
;;
esac
export CUDA_VISIBLE_DEVICES=$i
export ROCR_VISIBLE_DEVICES=$i
echo "Running test $TEST_BINARY $* on accelerator $i"
Expand All @@ -87,4 +104,4 @@ for j in `seq 0 $((JAX_TESTS_PER_ACCELERATOR-1))`; do
done

echo "Cannot find a free accelerator to run the test $* on, exiting with failure..."
exit 1
exit 1
15 changes: 15 additions & 0 deletions ci/run_bazel_test_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ NB_TPUS=$JAXCI_TPU_CORES
JOBS_PER_ACC=1
J=$((NB_TPUS * JOBS_PER_ACC))

case "${JAXCI_TPU_PARALLELISM_MODE:-chip}" in
core)
TPU_XDIST_VISIBILITY_MODE="devices"
;;
chip)
TPU_XDIST_VISIBILITY_MODE="chips"
;;
*)
echo "Unknown JAXCI_TPU_PARALLELISM_MODE: ${JAXCI_TPU_PARALLELISM_MODE}"
exit 1
;;
esac

# TODO(ybaturina): Bazel cache shouldn't be invalidated when
# `VBAR_CONTROL_SERVICE_URL` changes.
COMMON_TPU_TEST_ENV_VARS="--test_env=TPU_SKIP_MDS_QUERY=true \
Expand Down Expand Up @@ -107,6 +120,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
--local_test_jobs=$J \
--test_env=JAX_TEST_NUM_THREADS=$J \
--test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \
--test_env=JAX_TPU_XDIST_VISIBILITY_MODE=${TPU_XDIST_VISIBILITY_MODE} \
--test_env=JAX_SKIP_SLOW_TESTS=1 \
--test_env=JAX_ENABLE_TPU_XDIST=1 \
--test_env=JAX_PLATFORMS=tpu,cpu \
Expand Down Expand Up @@ -176,6 +190,7 @@ else
--local_test_jobs=$J \
--test_env=JAX_TEST_NUM_THREADS=$J \
--test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \
--test_env=JAX_TPU_XDIST_VISIBILITY_MODE=${TPU_XDIST_VISIBILITY_MODE} \
--test_env=JAX_SKIP_SLOW_TESTS=1 \
--test_env=JAX_ENABLE_TPU_XDIST=1 \
--test_env=JAX_PLATFORMS=tpu,cpu \
Expand Down
21 changes: 18 additions & 3 deletions ci/run_pytest_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ export JAX_PLATFORMS=tpu,cpu
export JAX_SKIP_SLOW_TESTS=true
# End of common test environment variable setup

case "${JAXCI_TPU_PARALLELISM_MODE:-chip}" in
core)
TPU_XDIST_VISIBILITY_MODE="devices"
;;
chip)
TPU_XDIST_VISIBILITY_MODE="chips"
;;
*)
echo "Unknown JAXCI_TPU_PARALLELISM_MODE: ${JAXCI_TPU_PARALLELISM_MODE}"
exit 1
;;
esac

echo "Running TPU tests..."
mkdir -p test-artifacts

Expand All @@ -70,7 +83,8 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
fi

# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
JAX_ENABLE_TPU_XDIST=true JAX_TPU_XDIST_VISIBILITY_MODE="$TPU_XDIST_VISIBILITY_MODE" \
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--junitxml=test-artifacts/junit-single.xml \
--deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \
--deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \
Expand All @@ -89,7 +103,8 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
second_cmd_retval=$?
else
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
JAX_ENABLE_TPU_XDIST=true JAX_TPU_XDIST_VISIBILITY_MODE="$TPU_XDIST_VISIBILITY_MODE" \
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--junitxml=test-artifacts/junit-single.xml \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
Expand Down Expand Up @@ -133,4 +148,4 @@ elif [[ $third_cmd_retval -ne 0 ]]; then
exit $third_cmd_retval
else
exit 0
fi
fi
19 changes: 17 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def add_imports(doctest_namespace):
# For TPU, the env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an
# effect. We do this to minimize any effect on non-TPU tests, and as a pointer
# in test code to this "magic" hook. TPU tests should not specify more xdist
# workers than the number of TPU chips.
# workers than the number of TPU chips, unless
# JAX_TPU_XDIST_VISIBILITY_MODE=devices is set for hosts where logical devices
# are a finer-grained unit than chips.
#
# For GPU, the env var JAX_ENABLE_CUDA_XDIST must be set equal to the number of
# CUDA devices. Test processes will be assigned in round robin fashion across
Expand All @@ -59,7 +61,20 @@ def pytest_collection() -> None:
if not xdist_worker_name.startswith("gw"):
return
xdist_worker_number = int(xdist_worker_name[len("gw") :])
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
tpu_visibility_mode = os.environ.get(
"JAX_TPU_XDIST_VISIBILITY_MODE", "chips"
)
if tpu_visibility_mode == "devices":
os.environ.setdefault("TPU_VISIBLE_DEVICES", str(xdist_worker_number))
os.environ.setdefault("TPU_CHIPS_PER_PROCESS_BOUNDS", "1,1,1,1")
os.environ.setdefault("TPU_PROCESS_BOUNDS", "1,1,1,1")
elif tpu_visibility_mode == "chips":
os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number))
else:
raise ValueError(
"JAX_TPU_XDIST_VISIBILITY_MODE must be 'chips' or 'devices'; "
f"got {tpu_visibility_mode!r}"
)
os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true")

elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None):
Expand Down
Loading