Skip to content

Commit 0facc52

Browse files
author
Orbax Authors
committed
Support the XLA GPU compilation flags in Orbax
PiperOrigin-RevId: 873745325
1 parent c03189f commit 0facc52

File tree

3 files changed

+226
-7
lines changed

3 files changed

+226
-7
lines changed

export/orbax/export/modules/obm_module_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,36 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
384384
with self.subTest('test_weights_b_dtype'):
385385
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
386386

387+
def test_obm_module_gpu_xla_flags_integration_stable(self):
388+
param_shape = (2, 5)
389+
param_dtype = jnp.dtype(jnp.float32)
390+
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
391+
model_function_name = 'simple_add'
392+
393+
jax2obm_options = obm_configs.Jax2ObmOptions(
394+
checkpoint_path='checkpoint_path',
395+
native_serialization_platforms=('cuda',),
396+
xla_flags_per_platform={
397+
'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true']
398+
},
399+
)
400+
401+
orbax_model_module = obm_module.ObmModule(
402+
params=param_spec,
403+
apply_fn={model_function_name: simple_add},
404+
jax2obm_options=jax2obm_options,
405+
)
406+
407+
xla_compile_options_map = (
408+
orbax_model_module.xla_compile_options_per_platform
409+
)
410+
self.assertIsNotNone(xla_compile_options_map)
411+
build_options_cuda = xla_compile_options_map.map['cuda']
412+
self.assertIn(
413+
'xla_gpu_enable_latency_hiding_scheduler',
414+
build_options_cuda.env_option_overrides,
415+
)
416+
387417

388418
class GetSharedValueTest(parameterized.TestCase):
389419

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def generate_xla_compile_options(
165165
tpu_platform_name = manifest_pb2.Platform.Name(
166166
manifest_pb2.Platform.TPU
167167
).lower()
168+
cuda_platform_name = manifest_pb2.Platform.Name(
169+
manifest_pb2.Platform.CUDA
170+
).lower()
171+
168172
compile_options_map = manifest_pb2.CompileOptionsProtoMap()
169173
if native_serialization_platforms is None:
170174
# If no native serialization platforms are specified, we will set the
@@ -195,24 +199,97 @@ def generate_xla_compile_options(
195199
)
196200

197201
for platform in platforms:
198-
if platform.lower() == tpu_platform_name:
199-
if xla_flags_per_platform:
200-
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
202+
if xla_flags_per_platform:
203+
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
204+
if xla_flags_overrides:
201205
_validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags)
202-
else:
203-
xla_flags_overrides = None
206+
else:
207+
xla_flags_overrides = None
208+
209+
platform_lower = platform.lower()
210+
if platform_lower == tpu_platform_name:
204211
compile_environment = _generate_tpu_compilation_env(xla_flags_overrides)
212+
elif platform_lower == cuda_platform_name:
213+
# GPU Trick: Empty proto to bypass 'None' check and enable jax_mesh
214+
# serialization.
215+
compile_environment = xla_pb2.CompilationEnvironmentsProto()
205216
else:
217+
# CPU Path: Leave as None to preserve legacy portable execution behavior.
206218
compile_environment = None
207-
compile_options_map.map[platform.lower()].CopyFrom(
208-
_generate_compilation_options(compile_environment, jax_mesh)
219+
220+
compile_options = _generate_compilation_options(
221+
compile_environment, jax_mesh
209222
)
223+
224+
# Inject env_option_overrides natively for GPU using a dedicated helper.
225+
if platform_lower == cuda_platform_name and xla_flags_overrides:
226+
_apply_gpu_compilation_env_options(compile_options, xla_flags_overrides)
227+
228+
compile_options_map.map[platform_lower].CopyFrom(compile_options)
229+
210230
if not persist_xla_flags:
211231
for compile_options in compile_options_map.map.values():
212232
compile_options.executable_build_options.comp_envs.Clear()
213233
return compile_options_map
214234

215235

236+
def _apply_gpu_compilation_env_options(
237+
compile_options: compile_options_pb2.CompileOptionsProto,
238+
xla_flags_overrides: Sequence[str],
239+
) -> None:
240+
"""Applies XLA flag overrides generically for GPU platforms.
241+
242+
Args:
243+
compile_options: The compilation options proto to be modified.
244+
xla_flags_overrides: A sequence of XLA flags to apply as option overrides.
245+
"""
246+
overrides_map = _parse_env_option_overrides(xla_flags_overrides)
247+
for k, v in overrides_map.items():
248+
compile_options.env_option_overrides[k].CopyFrom(v)
249+
250+
251+
def _parse_env_option_overrides(
252+
xla_flags: Sequence[str],
253+
) -> dict[str, compile_options_pb2.OptionOverrideProto]:
254+
"""Parses a list of XLA flags into a dictionary of OptionOverrideProto."""
255+
overrides = {}
256+
for flag in xla_flags:
257+
if not flag.startswith('--'):
258+
raise ValueError(f"Flag {flag} must start with '--'")
259+
260+
try:
261+
# Use the C++ ValidateXlaGPUFlag logic to ensure consistent policy
262+
# enforcement across Python and C++ layers.
263+
# The C++ function expects the flag with the '--' prefix.
264+
tfrt_config.validate_xla_gpu_flag(flag, strict=True)
265+
except Exception as e:
266+
# pybind11_abseil appends the status code name to the exception string.
267+
# Remove it to match exactly what users would see from the C++ binaries.
268+
err_msg = str(e)
269+
if err_msg.endswith(' [INVALID_ARGUMENT]'):
270+
err_msg = err_msg.removesuffix(' [INVALID_ARGUMENT]')
271+
raise ValueError(err_msg) from e
272+
273+
key, value = flag[2:].split('=', 1)
274+
override_proto = compile_options_pb2.OptionOverrideProto()
275+
276+
# Infer type (True/False/Int/Float/String)
277+
if value.lower() == 'true':
278+
override_proto.bool_field = True
279+
elif value.lower() == 'false':
280+
override_proto.bool_field = False
281+
elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()):
282+
override_proto.int_field = int(value)
283+
else:
284+
try:
285+
override_proto.double_field = float(value)
286+
except ValueError:
287+
override_proto.string_field = value
288+
289+
overrides[key] = override_proto
290+
return overrides
291+
292+
216293
def _validate_xla_flags_setting(
217294
xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool
218295
) -> None:

model/orbax/experimental/model/core/python/compile_options_util_test.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,118 @@ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self):
328328
persist_xla_flags=False,
329329
)
330330

331+
def test_generate_xla_compile_options_env_overrides(self):
332+
compile_options_map = compile_options_util.generate_xla_compile_options(
333+
native_serialization_platforms=['cuda'],
334+
xla_flags_per_platform={
335+
'cuda': [
336+
'--xla_gpu_enable_latency_hiding_scheduler=true',
337+
'--xla_gpu_autotune_level=0',
338+
]
339+
},
340+
persist_xla_flags=True,
341+
)
342+
self.assertIn('cuda', compile_options_map.map)
343+
compile_options = compile_options_map.map['cuda']
344+
345+
overrides = compile_options.env_option_overrides
346+
self.assertIn('xla_gpu_enable_latency_hiding_scheduler', overrides)
347+
self.assertTrue(
348+
overrides['xla_gpu_enable_latency_hiding_scheduler'].bool_field
349+
)
350+
351+
self.assertIn('xla_gpu_autotune_level', overrides)
352+
self.assertEqual(overrides['xla_gpu_autotune_level'].int_field, 0)
353+
354+
def test_generate_xla_compile_options_gpu_flags_experimental_rejection(self):
355+
with self.assertRaisesRegex(
356+
ValueError,
357+
r'XLA GPU compilation flag --xla_gpu_experimental_flag=true is not'
358+
r' supported. Please check field description at'
359+
r' CompilationConfig::xla_gpu_flags',
360+
):
361+
compile_options_util.generate_xla_compile_options(
362+
native_serialization_platforms=['cuda'],
363+
xla_flags_per_platform={'cuda': ['--xla_gpu_experimental_flag=true']},
364+
persist_xla_flags=True,
365+
)
366+
367+
@parameterized.named_parameters(
368+
dict(
369+
testcase_name='bool_true',
370+
flag='--xla_gpu_enable_latency_hiding_scheduler=true',
371+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
372+
expected_field='bool_field',
373+
expected_value=True,
374+
),
375+
dict(
376+
testcase_name='bool_false',
377+
flag='--xla_gpu_enable_latency_hiding_scheduler=false',
378+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
379+
expected_field='bool_field',
380+
expected_value=False,
381+
),
382+
dict(
383+
testcase_name='bool_uppercase_true',
384+
flag='--xla_gpu_enable_latency_hiding_scheduler=TRUE',
385+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
386+
expected_field='bool_field',
387+
expected_value=True,
388+
),
389+
dict(
390+
testcase_name='int_positive',
391+
flag='--xla_gpu_autotune_level=4',
392+
expected_key='xla_gpu_autotune_level',
393+
expected_field='int_field',
394+
expected_value=4,
395+
),
396+
dict(
397+
testcase_name='int_negative',
398+
flag='--xla_gpu_nccl_termination_timeout_seconds=-1',
399+
expected_key='xla_gpu_nccl_termination_timeout_seconds',
400+
expected_field='int_field',
401+
expected_value=-1,
402+
),
403+
dict(
404+
testcase_name='float_positive',
405+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=1.5',
406+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
407+
expected_field='double_field',
408+
expected_value=1.5,
409+
),
410+
dict(
411+
testcase_name='float_negative',
412+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=-0.5',
413+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
414+
expected_field='double_field',
415+
expected_value=-0.5,
416+
),
417+
dict(
418+
testcase_name='string_value',
419+
flag='--xla_gpu_cuda_data_dir=/usr/local/cuda',
420+
expected_key='xla_gpu_cuda_data_dir',
421+
expected_field='string_field',
422+
expected_value='/usr/local/cuda',
423+
),
424+
)
425+
def test_generate_xla_compile_options_gpu_flags_type_inference(
426+
self, mock_validate, flag, expected_key, expected_field, expected_value
427+
):
428+
del mock_validate # Unused, just patching for bypass
429+
compile_options_map = compile_options_util.generate_xla_compile_options(
430+
native_serialization_platforms=['cuda'],
431+
xla_flags_per_platform={'cuda': [flag]},
432+
persist_xla_flags=True,
433+
)
434+
self.assertIsNotNone(compile_options_map.map)
435+
build_options_cuda = compile_options_map.map['cuda']
436+
self.assertIn(expected_key, build_options_cuda.env_option_overrides)
437+
override_proto = build_options_cuda.env_option_overrides[expected_key]
438+
with self.subTest('test_oneof_field'):
439+
self.assertEqual(override_proto.WhichOneof('value'), expected_field)
440+
with self.subTest('test_value'):
441+
self.assertEqual(getattr(override_proto, expected_field), expected_value)
442+
331443
@parameterized.named_parameters(
332444
dict(
333445
testcase_name='1d_mesh',

0 commit comments

Comments
 (0)