Skip to content

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing #124419

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing #124419

Workflow file for this run

name: CI
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main branches.
cancel-in-progress: ${{ github.ref != 'main' }}
env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
UV_MANAGED_PYTHON: true # Make sure `uv` uses its own Python installations
PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
# Print cache hits/misses for setup-uv-python Python and packages,
# as well as timing, in summaries. For packages, time taken has to be echo'ed
# separately in the workflow step.
CI_UV_DEBUG: true
defaults:
run:
shell: bash
jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
env:
PYTHON_VERSION: '3.11'
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Python ${{ env.PYTHON_VERSION }}
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lint dependencies
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" pre-commit
echo "### lint_and_typecheck dependency install" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- id: get-date
run: echo "date=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
- uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ steps.get-date.outputs.date }}-py${{ env.PYTHON_VERSION }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
- run: $PYTHON_BIN -m pre_commit run --show-diff-on-failure --color=always --all-files
zizmor:
name: GitHub Actions Security Analysis
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
- name: Run zizmor
# TODO(belitskiy): Update workflows and unpin/update zizmor version.
run: uvx zizmor@1.24.1 --format=github .
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
build:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-22.04, x64=${{ matrix.enable-x64}})"
runs-on: linux-x86-n4-32
container:
image: index.docker.io/library/ubuntu@sha256:4e0171b9275e12d375863f2b3ae9ce00a4c53ddda176bd55868df97ac6f21a6e # ratchet:ubuntu:22.04
timeout-minutes: 60
strategy:
matrix:
# Test the oldest and newest supported Python versions here.
include:
- name-prefix: "with 3.11"
python-version: "3.11"
enable-x64: 1
prng-upgrade: 1
num_generated_cases: 1
- name-prefix: "with 3.14"
python-version: "3.14"
enable-x64: 0
prng-upgrade: 0
num_generated_cases: 1
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" '.[minimum-jaxlib]' -r 'build/test-requirements.txt'
echo "### build dependency install (${{ matrix.python-version }})" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Run tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
$PYTHON_BIN -m pytest -n auto --tb=short --maxfail=20 tests examples
documentation:
name: Documentation - test code snippets
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
python-version: ['3.12']
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" -r docs/requirements.txt
echo "### documentation dependency install (${{ matrix.python-version }})" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Test documentation
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_TRACEBACK_FILTERING: "off"
JAX_ARRAY: 1
PY_COLORS: 1
run: |
$PYTHON_BIN -m pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
$PYTHON_BIN -m pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib
documentation_render:
if: github.repository_owner == 'jax-ml'
name: Documentation - render documentation
runs-on: linux-x86-n4-16
container:
image: index.docker.io/library/ubuntu@sha256:4e0171b9275e12d375863f2b3ae9ce00a4c53ddda176bd55868df97ac6f21a6e # ratchet:ubuntu:22.04
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.12']
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Image Setup
run: |
apt update
apt install -y build-essential
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" -r docs/requirements.txt
echo "### documentation_render dependency install (${{ matrix.python-version }})" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Render documentation
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
$PYTHON_BIN build/portserver.py &
SERVER_PID=$!
# Stop the background portserver even if Sphinx exits with an error.
cleanup_portserver() {
kill "$SERVER_PID" 2>/dev/null || true
wait "$SERVER_PID" 2>/dev/null || true
}
trap cleanup_portserver EXIT
PORTSERVER_ADDRESS=@unittest-portserver $PYTHON_BIN -m sphinx -j auto --color -W --keep-going -b html docs docs/build/html
ffi:
if: github.repository_owner == 'jax-ml'
name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Python 3.12
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: '3.12'
- name: Install JAX and FFI example
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
# We test building using GCC instead of clang. All other JAX builds use
# clang, but it is useful to make sure that FFI users can compile using
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" '.[cuda12]'
uv pip install --python "$PYTHON_BIN" './examples/ffi[test]'
echo "### ffi dependency install" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Run CPU tests
run: $PYTHON_BIN -m pytest examples/ffi/tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
JAX_PLATFORM_NAME: cpu
- name: Run GPU tests
run: $PYTHON_BIN -m pytest examples/ffi/tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}