forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
130 lines (127 loc) · 5.81 KB
/
bazel_cuda.yml
File metadata and controls
130 lines (127 loc) · 5.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# CI - Bazel CUDA tests
#
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Bazel CUDA tests (RBE)`,`CI - Wheel Tests (Continuous)`
# and `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket if build_jaxlib is `false`.
# Otherwise, the artifacts are built from source.
# - Downloads the jax artifact from a GCS bucket if build_jax is `false`.
# Otherwise, the artifact is built from source.
# - If `run_multiaccelerator_tests` is `false`, executes the `run_bazel_test_cuda_rbe.sh` script,
# which performs the following actions:
# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies.
# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies.
# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies.
# - If `run_multiaccelerator_tests` is `true`, executes the `run_bazel_test_cuda_non_rbe.sh`
# script, which performs the following actions:
# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies.
# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies.
name: CI - Bazel CUDA tests
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-n4-16"
python:
description: "Which python version to test?"
type: string
default: "3.12"
cuda-version:
description: "Which CUDA version to test?"
type: string
default: "12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
default: "0"
jaxlib-version:
description: "Which jaxlib version to test? (head/pypi_latest)"
type: string
default: "head"
download-jax-from-gcs:
description: "Whether to download the jax wheel from GCS"
default: '1'
type: string
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
build_jaxlib:
description: 'Should jaxlib be built from source?'
required: true
type: string
build_jax:
description: 'Should jax be built from source?'
required: true
type: string
write_to_bazel_remote_cache:
description: 'Whether to enable writing to the Bazel remote cache bucket'
required: false
default: '0'
type: string
run_multiaccelerator_tests:
description: 'Whether to run multi-accelerator tests'
required: false
default: 'false'
type: string
clone_main_xla:
description: "Should latest XLA be used?"
type: string
required: true
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
permissions: {}
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest"
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
JAXCI_CUDA_VERSION: ${{ inputs.cuda-version }}
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }}
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
JAXCI_CLONE_MAIN_XLA: ${{ inputs.clone_main_xla }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') ||
(contains(inputs.runner, 'windows-x86') && 'windows x86') }}, jaxlib=${{ inputs.jaxlib-version }}, CUDA=${{ inputs.cuda-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX CUDA wheels
if: inputs.build_jaxlib == 'false'
uses: ./.github/actions/download-jax-cuda-wheels
with:
python: ${{ inputs.python }}
cuda-version: ${{ inputs.cuda-version }}
download-jax-from-gcs: ${{ inputs.download-jax-from-gcs }}
skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
jaxlib-version: ${{ inputs.jaxlib-version }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: "Bazel CUDA tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
timeout-minutes: 60
run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_cuda_rbe.sh') || './ci/run_bazel_test_cuda_non_rbe.sh' }}