[ROCm] Enable HLO module transform registration for GPU backends#38142
[ROCm] Enable HLO module transform registration for GPU backends#38142mminutoli wants to merge 1 commit into
Conversation
register_hlo_module_transformation(platforms=None) was silently skipping GPU plugin clients because the default platforms list only contained factory keys (e.g. "rocm") while PJRT plugin clients report client.platform as "gpu". Add "gpu" as an explicit entry so the backend initialization hook fires for all GPU plugin clients. Also enable the xla_transform_test for the GPU backend.
There was a problem hiding this comment.
Code Review
This pull request updates register_hlo_module_transformation to include "gpu" in the default platforms list when platforms is None, and enables the GPU backend for the corresponding multiplatform test. However, a correctness issue was identified where explicitly specifying a concrete GPU platform (such as "rocm") causes the transform registration to be silently skipped at runtime because the client reports its platform as "gpu". It is recommended to canonicalize the client's platform in register_on_client to ensure correct matching.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if platforms is None: | ||
| platforms_list = ["cpu"] + list(xla_bridge._backend_factories.keys()) | ||
| platforms_list = ["cpu", "gpu"] + list(xla_bridge._backend_factories.keys()) | ||
| platforms_list = list(dict.fromkeys(platforms_list)) | ||
| elif isinstance(platforms, str): | ||
| platforms_list = [platforms] |
There was a problem hiding this comment.
There is a correctness issue when a user explicitly specifies a concrete GPU platform (e.g., platforms="rocm" or platforms="cuda") instead of None.\n\nIf platforms="rocm", platforms_list becomes ["rocm"]. At runtime, the PJRT plugin client reports client.platform as "gpu".\n\nIn register_on_client (lines 90-91):\npython\ndef register_on_client(client):\n if client.platform in platforms_list:\n\nThis check evaluates "gpu" in ["rocm"], which is False. As a result, the transform is silently skipped and never registered for the GPU backend, even though the user explicitly requested "rocm".\n\n### Suggested Fix\n\nTo resolve this, we should canonicalize the client's platform in register_on_client using xla_bridge.canonicalize_platform so that it matches the concrete platform name:\n\npython\ndef register_on_client(client):\n try:\n concrete_platform = xla_bridge.canonicalize_platform(client.platform)\n except Exception:\n concrete_platform = client.platform\n if client.platform in platforms_list or concrete_platform in platforms_list:\n try:\n _xla.register_xla_transform_c_api(client, name, stage_int, callback)\n except RuntimeError:\n logger.debug(\n "Could not register XLA transform via C API for client platform %s",\n client.platform,\n )\n
Summary
jax/_src/xla_transform.py:register_hlo_module_transformation(platforms=None)was silently skipping GPU plugin clients. The default platforms list only contained factory keys (e.g."rocm") while PJRT plugin clients reportclient.platformas"gpu". Add"gpu"as an explicit entry so the backend initialization hook fires for all GPU plugin clients.tests/BUILD: Enablexla_transform_testfor the GPU backend.Dependencies
Depends on the corresponding XLA change: openxla/xla#43630