-
Notifications
You must be signed in to change notification settings - Fork 235
Expand file tree
/
Copy pathrelease_portable_linux_jax_wheels.yml
More file actions
130 lines (124 loc) · 4.51 KB
/
release_portable_linux_jax_wheels.yml
File metadata and controls
130 lines (124 loc) · 4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
name: Release portable Linux JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: Staging subdirectory to push the wheels for test
type: string
default: "v2-staging"
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
rocm_version:
description: ROCm version to install (e.g. "7.10.0a20251124")
type: string
tar_url:
description: "URL to TheRock tarball to build against (e.g. https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251124.tar.gz)"
type: string
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx101X-dgpu
- gfx103X-all
- gfx110X-all
- gfx1150
- gfx1151
- gfx1152
- gfx1153
- gfx120X-all
- gfx900
- gfx906
- gfx908
- gfx90a
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: "Staging subdirectory to push the wheels for test"
type: string
default: "v2-staging"
cloudfront_url:
description: CloudFront URL pointing to Python index
type: string
default: "https://rocm.devreleases.amd.com/v2"
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
type: string
default: "https://rocm.devreleases.amd.com/v2-staging"
rocm_version:
description: ROCm version to install (e.g. "7.10.0a20251124")
type: string
tar_url:
description: "URL to TheRock tarball to build (e.g. https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251124.tar.gz)"
type: string
ref:
description: "TheRock branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
default: ''
permissions:
id-token: write
contents: read
packages: read
run-name: Release portable Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.release_type }}, ${{ inputs.rocm_version }})
jobs:
release:
name: Release | ${{ inputs.amdgpu_family }} | py ${{ matrix.python_version }} | jax ${{ matrix.jax_ref }}
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0", "rocm-jaxlib-v0.9.1"]
include:
- jax_ref: "rocm-jaxlib-v0.8.0"
build_jaxlib: true
- jax_ref: "rocm-jaxlib-v0.8.2"
build_jaxlib: true
- jax_ref: "rocm-jaxlib-v0.9.0"
build_jaxlib: true
- jax_ref: "rocm-jaxlib-v0.9.1"
build_jaxlib: false
uses: ./.github/workflows/build_linux_jax_wheels.yml
with:
amdgpu_family: ${{ inputs.amdgpu_family }}
python_version: ${{ matrix.python_version }}
release_type: ${{ inputs.release_type }}
s3_subdir: ${{ inputs.s3_subdir }}
s3_staging_subdir: ${{ inputs.s3_staging_subdir }}
cloudfront_url: ${{ inputs.cloudfront_url }}
cloudfront_staging_url: ${{ inputs.cloudfront_staging_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
jax_ref: ${{ matrix.jax_ref }}
build_jaxlib: ${{ matrix.build_jaxlib }}
ref: ${{ inputs.ref }}