Skip to content

[ROCm] Enable HLO module transform registration for GPU backends#38142

Open
mminutoli wants to merge 1 commit into
jax-ml:mainfrom
ROCm:xla-transform-gpu-support
Open

[ROCm] Enable HLO module transform registration for GPU backends#38142
mminutoli wants to merge 1 commit into
jax-ml:mainfrom
ROCm:xla-transform-gpu-support

Conversation

@mminutoli
Copy link
Copy Markdown
Contributor

@mminutoli mminutoli commented Jun 2, 2026

Summary

  • jax/_src/xla_transform.py: register_hlo_module_transformation(platforms=None) was silently skipping GPU plugin clients. The default platforms list only contained factory keys (e.g. "rocm") while PJRT plugin clients report client.platform as "gpu". Add "gpu" as an explicit entry so the backend initialization hook fires for all GPU plugin clients.
  • tests/BUILD: Enable xla_transform_test for the GPU backend.

Dependencies

Depends on the corresponding XLA change: openxla/xla#43630

register_hlo_module_transformation(platforms=None) was silently skipping
GPU plugin clients because the default platforms list only contained
factory keys (e.g. "rocm") while PJRT plugin clients report
client.platform as "gpu". Add "gpu" as an explicit entry so the backend
initialization hook fires for all GPU plugin clients.

Also enable the xla_transform_test for the GPU backend.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates register_hlo_module_transformation to include "gpu" in the default platforms list when platforms is None, and enables the GPU backend for the corresponding multiplatform test. However, a correctness issue was identified where explicitly specifying a concrete GPU platform (such as "rocm") causes the transform registration to be silently skipped at runtime because the client reports its platform as "gpu". It is recommended to canonicalize the client's platform in register_on_client to ensure correct matching.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread jax/_src/xla_transform.py
Comment on lines 76 to 80
if platforms is None:
platforms_list = ["cpu"] + list(xla_bridge._backend_factories.keys())
platforms_list = ["cpu", "gpu"] + list(xla_bridge._backend_factories.keys())
platforms_list = list(dict.fromkeys(platforms_list))
elif isinstance(platforms, str):
platforms_list = [platforms]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a correctness issue when a user explicitly specifies a concrete GPU platform (e.g., platforms="rocm" or platforms="cuda") instead of None.\n\nIf platforms="rocm", platforms_list becomes ["rocm"]. At runtime, the PJRT plugin client reports client.platform as "gpu".\n\nIn register_on_client (lines 90-91):\npython\ndef register_on_client(client):\n if client.platform in platforms_list:\n\nThis check evaluates "gpu" in ["rocm"], which is False. As a result, the transform is silently skipped and never registered for the GPU backend, even though the user explicitly requested "rocm".\n\n### Suggested Fix\n\nTo resolve this, we should canonicalize the client's platform in register_on_client using xla_bridge.canonicalize_platform so that it matches the concrete platform name:\n\npython\ndef register_on_client(client):\n try:\n concrete_platform = xla_bridge.canonicalize_platform(client.platform)\n except Exception:\n concrete_platform = client.platform\n if client.platform in platforms_list or concrete_platform in platforms_list:\n try:\n _xla.register_xla_transform_c_api(client, name, stage_int, callback)\n except RuntimeError:\n logger.debug(\n "Could not register XLA transform via C API for client platform %s",\n client.platform,\n )\n

@mminutoli mminutoli added the AMD GPU Issues pertaining to AMD GPUs (ROCM) label Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AMD GPU Issues pertaining to AMD GPUs (ROCM)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant