forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
149 lines (149 loc) · 6.62 KB
/
bazel_cuda_h100_b200.yml
File metadata and controls
149 lines (149 loc) · 6.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
name: CI - Bazel H100 and B200 CUDA tests
# This runs if any of the following conditions are met
# H100 and B200 on Workflow dispatch
# H100 and B200 on scheduled every two hours
# B200 on PR to main that modifies mosaic files or this file, see below for list
# H100 and B200 on PR to main that has the 'CI Optional GPU Presubmit' label
on:
# Runs on PR if label "CI Optional GPU Presubmit" is present.
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
types: [ labeled, synchronize, opened, reopened ]
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
changed_files:
permissions: {} # No permissions given
runs-on: ubuntu-latest # Do not run tj-actions on self-hosted runners
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Get and filter changed files # We only run this if it is a pull request, do not run tj-actions on non PR event
if: ${{ github.event_name == 'pull_request' }}
id: changed-files
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46
with:
files: |
jax/_src/pallas/mosaic_gpu/**
jax/experimental/mosaic/gpu/**
jaxlib/mosaic/dialect/gpu/**
jaxlib/mosaic/gpu/**
.github/workflows/bazel_cuda_h100_b200.yml
- name: List all changed files
env:
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
run: |
for file in ${ALL_CHANGED_FILES}; do
echo "$file was changed"
done
outputs:
any_changed: ${{ steps.changed-files.outputs.any_changed || 'false' }}
run_tests:
needs: changed_files
if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || needs.changed_files.outputs.any_changed == 'true' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }}
runs-on: linux-x86-a4-224-b200-1gpu
container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest'
name: "Bazel single B200 CUDA tests"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel single B200 CUDA Tests
run: |
nvidia-smi
bazel test \
--config=ci_linux_x86_64_cuda \
--config=ci_rbe_cache \
--config=hermetic_cuda_umd \
--repo_env=HERMETIC_PYTHON_VERSION="3.14" \
--repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \
--repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
--test_output=errors \
--test_tag_filters=-multiaccelerator \
--test_env=JAX_ACCELERATOR_COUNT=1 \
--test_env=JAX_TESTS_PER_ACCELERATOR=8 \
--strategy=TestRunner=local \
--local_test_jobs=8 \
--test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="1" \
--action_env=NCCL_DEBUG=WARN \
--flaky_test_attempts=1 \
--test_timeout=420 \
--color=yes \
//tests:cudnn_fusion_test_gpu \
//tests:scaled_matmul_stablehlo_test_gpu \
//tests:fused_attention_stablehlo_test_gpu \
//tests:nn_test_gpu \
//tests/pallas:gpu_tests \
//tests/mosaic:gpu_tests
run_multiaccelerator_tests:
if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }}
runs-on: linux-x86-a3-8g-h100-8gpu
container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest'
name: "Bazel multiple H100 CUDA tests"
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel multiple H100 CUDA Tests
run: |
nvidia-smi
bazel test \
--config=ci_linux_x86_64_cuda \
--config=ci_rbe_cache \
--config=hermetic_cuda_umd \
--repo_env=HERMETIC_PYTHON_VERSION="3.14" \
--repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \
--repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--test_output=errors \
--strategy=TestRunner=local \
--local_test_jobs=8 \
--test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \
--test_tag_filters=multiaccelerator \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="1" \
--action_env=NCCL_DEBUG=WARN \
--flaky_test_attempts=1 \
--color=yes \
//tests/mosaic:gpu_tests \
//tests/pallas:gpu_tests \
//tests:array_interoperability_test_gpu \
//tests:cudnn_fusion_test_gpu \
//tests:fused_attention_stablehlo_test_gpu \
//tests:gpu_tests \
//tests:python_callback_test_gpu \
//tests:ragged_collective_test_gpu \
//tests/multiprocess:gpu_tests \
//jax/experimental/jax2tf/tests/multiprocess:gpu_tests