Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions jax_rocm_plugin/build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ import subprocess
import sys


def get_rocm_jax_commit():
"""Get the rocm-jax git commit hash from the parent repo."""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=os.path.abspath("../"),
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except subprocess.CalledProcessError:
return ""


def dist_wheels(
rocm_version,
python_versions,
Expand Down Expand Up @@ -91,18 +106,19 @@ def dist_wheels(
cmd.append("-it")

# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
# NOTE: GIT_DIR and GIT_WORK_TREE are NOT set because they interfere with
# Bazel's git_repository rule. Instead, we pass ROCM_JAX_COMMIT explicitly.
rocm_jax_commit = get_rocm_jax_commit()
cmd.extend(
[
"--init",
"--rm",
"--shm-size",
"64G",
"-e",
"GIT_DIR=/repo/.git",
"-e",
"GIT_WORK_TREE=/repo",
"-e",
"ROCM_VERSION_EXTRA=" + rocm_version,
"-e",
"ROCM_JAX_COMMIT=" + rocm_jax_commit,
builder_image,
"bash",
"-c",
Expand Down
12 changes: 11 additions & 1 deletion jax_rocm_plugin/build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,17 @@ def get_jax_configure_bazel_options(bazel_command: list[str]):


def get_githash():
"""dummy docstring"""
"""Get the rocm-jax git commit hash.

First checks ROCM_JAX_COMMIT env var (set by ci_build when running in Docker),
then falls back to running git rev-parse HEAD.
"""
# Check env var first (set by ci_build for Docker builds).
env_hash = os.environ.get("ROCM_JAX_COMMIT", "")
if env_hash:
return env_hash

# Fall back to git command.
try:
return subprocess.run(
["git", "rev-parse", "HEAD"],
Expand Down
2 changes: 1 addition & 1 deletion jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

load("@jax//jaxlib:jax.bzl", "jax_wheel", "wheel_sources")
load("//third_party/jax:workspace.bzl", jax_commit = "COMMIT")
load("//third_party/jax:workspace.bzl", jax_commit = "JAX_COMMIT")
load("//third_party/xla:workspace.bzl", xla_commit = "XLA_COMMIT")

licenses(["notice"]) # Apache 2
Expand Down
2 changes: 1 addition & 1 deletion jax_rocm_plugin/pjrt/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

load("@jax//jaxlib:jax.bzl", "jax_wheel", "wheel_sources")
load("//third_party/jax:workspace.bzl", jax_commit = "COMMIT")
load("//third_party/jax:workspace.bzl", jax_commit = "JAX_COMMIT")
load("//third_party/xla:workspace.bzl", xla_commit = "XLA_COMMIT")

licenses(["notice"]) # Apache 2
Expand Down

This file was deleted.

This file was deleted.

92 changes: 0 additions & 92 deletions jax_rocm_plugin/third_party/jax/0003-hipblas-typedef-fix.patch

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
diff --git a/WORKSPACE b/WORKSPACE
index 1d2096fbe..9c5d8f0d5 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -100,6 +100,10 @@ load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

flatbuffers()

+load("//third_party/external_deps:workspace.bzl", "external_deps_repository")
+
+external_deps_repository(name = "rocm_external_test_deps")
+
load("//:test_shard_count.bzl", "test_shard_count_repository")

test_shard_count_repository(
diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl
index 304bab34f..34f760cd0 100644
--- a/jaxlib/jax.bzl
Expand Down
20 changes: 12 additions & 8 deletions jax_rocm_plugin/third_party/jax/workspace.bzl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
load("//third_party:repo.bzl", "amd_http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

COMMIT = "fbfa695aea59ed578b81d8fc72ab23bba5d2cfaa"
SHA = "b740b326b468ce7ef967fbfab0accfb19850fab9c43ab6a3a37112eff34223c2"
# To update JAX:
# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.8.2 branch)
# 2. Update JAX_COMMIT below

JAX_COMMIT = "fbfa695aea59ed578b81d8fc72ab23bba5d2cfaa"

def repo():
amd_http_archive(
git_repository(
name = "jax",
sha256 = SHA,
strip_prefix = "jax-{commit}".format(commit = COMMIT),
urls = ["https://github.com/ROCm/jax/archive/{commit}.tar.gz".format(commit = COMMIT)],
patch_file = [
remote = "https://github.com/ROCm/jax.git",
commit = JAX_COMMIT,
patch_tool = "patch",
patch_args = ["-p1"],
patches = [
"//third_party/jax:0005-Fix-HIP-availability-errors.patch",
"//third_party/jax:0006-Enable-testing-with-ROCm-plugin-wheels.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34641
"//third_party/jax:0007-Fix-legacy-create-init.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34770
Expand Down
26 changes: 10 additions & 16 deletions jax_rocm_plugin/third_party/xla/workspace.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
Expand All @@ -8,34 +7,29 @@
# limitations under the License.

# buildifier: disable=module-docstring
load("//third_party:repo.bzl", "amd_http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

# To update XLA to a new revision,
# a) update XLA_COMMIT to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
# To update XLA:
# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.8.2 branch)
# 2. Update XLA_COMMIT below

XLA_COMMIT = "24c5f10ae8fc24aefd20b43c501ade7f66fd0cfd"
XLA_SHA256 = "f00db8761e86bcb51b52e64bc983717181050c8752c040e33ecb1429d861c30b"

def repo():
amd_http_archive(
git_repository(
name = "xla",
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = ["https://github.com/ROCm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_file = [],
remote = "https://github.com/ROCm/xla.git",
commit = XLA_COMMIT,
)

# For development, one often wants to make changes to the TF repository as well
# For development, one often wants to make changes to the XLA repository as well
# as the JAX repository. You can override the pinned repository above with a
# local checkout by either:
# a) overriding the TF repository on the build.py command line by passing a flag
# a) overriding the XLA repository on the build.py command line by passing a flag
# like:
# python build/build.py build --local_xla_path=/path/to/xla
# or
# b) by commenting out the http_archive above and uncommenting the following:
# b) by commenting out the git_repository above and uncommenting the following:
# local_repository(
# name = "xla",
# path = "/path/to/xla",
Expand Down
Loading