Skip to content

Commit 6553f3f

Browse files
authored
[levanter] Share Pallas autotune helpers and restore compile offload (#4130)
Move shard-aware autotune benchmarking out of fused cross-entropy into a shared Pallas helper and restore compile offload for shard-mapped autotune sweeps. This keeps fused CE behavior intact while making the benchmark path reusable from other kernels. Fixes #4129
1 parent 040c1f2 commit 6553f3f

4 files changed

Lines changed: 437 additions & 183 deletions

File tree

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from collections.abc import Callable, Sequence
5+
from concurrent.futures import ThreadPoolExecutor
6+
import time
7+
from typing import Any, cast
8+
9+
import jax
10+
from jax import core as jax_core
11+
from jax._src import mesh as mesh_lib
12+
from jax.sharding import NamedSharding
13+
14+
15+
_AUTOTUNE_THREAD_POOL = ThreadPoolExecutor(max_workers=1, thread_name_prefix="pallas_autotune")
16+
17+
18+
def sharding_of(value: jax.Array):
19+
"""Return array sharding metadata when available."""
20+
sharding = None
21+
try:
22+
sharding = value.sharding # type: ignore[attr-defined]
23+
except Exception:
24+
sharding = None
25+
if sharding is not None:
26+
return sharding
27+
28+
aval = getattr(value, "aval", None)
29+
if aval is None:
30+
return None
31+
return getattr(aval, "sharding", None)
32+
33+
34+
def named_sharding_of(value: jax.Array) -> NamedSharding | None:
35+
"""Return a NamedSharding for the value when one is attached."""
36+
sharding = sharding_of(value)
37+
if isinstance(sharding, NamedSharding):
38+
return sharding
39+
return None
40+
41+
42+
def hlo_sharding_of(value: jax.Array):
43+
"""Return XLA HLO sharding metadata when it can be derived."""
44+
sharding = sharding_of(value)
45+
if sharding is None:
46+
return None
47+
to_hlo = getattr(sharding, "_to_xla_hlo_sharding", None)
48+
if to_hlo is None:
49+
return None
50+
try:
51+
return to_hlo(value.ndim)
52+
except Exception:
53+
return None
54+
55+
56+
def value_uses_manual_sharding(value: jax.Array) -> bool:
57+
"""Detect shard_map-local tracer values that carry manual sharding."""
58+
hlo_sharding = hlo_sharding_of(value)
59+
return hlo_sharding is not None and hlo_sharding.is_manual()
60+
61+
62+
def shape_dtype_struct_for_benchmark(value: jax.Array) -> jax.ShapeDtypeStruct:
63+
"""Build a lowering spec while preserving compatible global sharding."""
64+
sharding = sharding_of(value)
65+
if sharding is None or value_uses_manual_sharding(value):
66+
return jax.ShapeDtypeStruct(value.shape, value.dtype)
67+
return jax.ShapeDtypeStruct(value.shape, value.dtype, sharding=sharding)
68+
69+
70+
def contains_tracer(*values: jax.Array) -> bool:
71+
"""Whether any lowering input is already a tracer."""
72+
return any(isinstance(value, jax_core.Tracer) for value in values)
73+
74+
75+
def benchmark_lowering_args(*values: jax.Array) -> tuple[jax.Array | jax.ShapeDtypeStruct, ...]:
76+
"""Choose tracer-aware lowering inputs for autotune benchmarks."""
77+
if contains_tracer(*values):
78+
return values
79+
return tuple(shape_dtype_struct_for_benchmark(value) for value in values)
80+
81+
82+
def should_offload_compile(*values: jax.Array) -> bool:
83+
"""Whether benchmark lowering should run on the shared autotune thread."""
84+
return (
85+
contains_tracer(*values)
86+
or any(value_uses_manual_sharding(value) for value in values)
87+
or jax_core.unsafe_am_i_under_a_jit_DO_NOT_USE()
88+
or not mesh_lib.thread_resources.env.physical_mesh.empty
89+
)
90+
91+
92+
def compile_benchmark_fn_current_thread(
93+
benchmark_fn: Callable[..., jax.Array],
94+
lowering_args: tuple[jax.Array | jax.ShapeDtypeStruct, ...],
95+
) -> float:
96+
"""Compile a benchmark function on the current thread and return compile time."""
97+
jitted = jax.jit(benchmark_fn)
98+
start = time.perf_counter()
99+
lowered = jitted.lower(*lowering_args)
100+
lowered.compile()
101+
return time.perf_counter() - start
102+
103+
104+
def compile_benchmark_fn(
105+
*,
106+
benchmark_fn: Callable[..., jax.Array],
107+
lowering_args: tuple[jax.Array | jax.ShapeDtypeStruct, ...],
108+
args: Sequence[jax.Array],
109+
) -> float:
110+
"""Compile a benchmark function, offloading when JAX thread-local state is unsafe."""
111+
if not should_offload_compile(*args):
112+
return compile_benchmark_fn_current_thread(benchmark_fn, lowering_args)
113+
return _AUTOTUNE_THREAD_POOL.submit(
114+
compile_benchmark_fn_current_thread,
115+
benchmark_fn,
116+
lowering_args,
117+
).result()
118+
119+
120+
def maybe_wrap_in_shard_map(
121+
fn: Callable[..., jax.Array],
122+
*,
123+
args: Sequence[jax.Array],
124+
out_specs: Any,
125+
check_vma: bool = False,
126+
) -> Callable[..., jax.Array]:
127+
"""Wrap a benchmark function in shard_map when inputs are globally NamedSharded."""
128+
if not args or any(value_uses_manual_sharding(value) for value in args):
129+
return fn
130+
131+
shardings = tuple(named_sharding_of(value) for value in args)
132+
if any(sharding is None for sharding in shardings):
133+
return fn
134+
135+
named_shardings = cast(tuple[NamedSharding, ...], shardings)
136+
mesh = named_shardings[0].mesh
137+
if any(sharding.mesh != mesh for sharding in named_shardings[1:]):
138+
return fn
139+
140+
return jax.shard_map(
141+
fn,
142+
mesh=mesh,
143+
in_specs=tuple(sharding.spec for sharding in named_shardings),
144+
out_specs=out_specs,
145+
check_vma=check_vma,
146+
)
147+
148+
149+
__all__ = [
150+
"benchmark_lowering_args",
151+
"compile_benchmark_fn",
152+
"compile_benchmark_fn_current_thread",
153+
"contains_tracer",
154+
"hlo_sharding_of",
155+
"maybe_wrap_in_shard_map",
156+
"named_sharding_of",
157+
"shape_dtype_struct_for_benchmark",
158+
"sharding_of",
159+
"should_offload_compile",
160+
"value_uses_manual_sharding",
161+
]

lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py

Lines changed: 11 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
import warnings
1111

1212
import jax
13-
from jax import core as jax_core
1413
import jax.numpy as jnp
15-
from jax.sharding import NamedSharding
1614
from jaxtyping import Array, Float, Int
1715

18-
from levanter.kernels.pallas import autotune_cache_utils
16+
from levanter.kernels.pallas import autotune_cache_utils, autotune_utils
1917

2018
from .config import BlockSizes
2119
from .tuned_block_sizes import (
@@ -114,80 +112,6 @@ def _is_tpu_vmem_compile_error(exc: Exception) -> bool:
114112
return "resource_exhausted" in message and "vmem" in message
115113

116114

117-
def _sharding_of(value: jax.Array):
118-
sharding = None
119-
try:
120-
sharding = value.sharding # type: ignore[attr-defined]
121-
except Exception:
122-
sharding = None
123-
if sharding is not None:
124-
return sharding
125-
126-
aval = getattr(value, "aval", None)
127-
if aval is None:
128-
return None
129-
return getattr(aval, "sharding", None)
130-
131-
132-
def _named_sharding_of(value: jax.Array) -> NamedSharding | None:
133-
sharding = _sharding_of(value)
134-
if isinstance(sharding, NamedSharding):
135-
return sharding
136-
return None
137-
138-
139-
def _hlo_sharding_of(value: jax.Array):
140-
sharding = _sharding_of(value)
141-
if sharding is None:
142-
return None
143-
to_hlo = getattr(sharding, "_to_xla_hlo_sharding", None)
144-
if to_hlo is None:
145-
return None
146-
try:
147-
return to_hlo(value.ndim)
148-
except Exception:
149-
return None
150-
151-
152-
def _value_uses_manual_sharding(value: jax.Array) -> bool:
153-
hlo_sharding = _hlo_sharding_of(value)
154-
return hlo_sharding is not None and hlo_sharding.is_manual()
155-
156-
157-
def _shape_dtype_struct_for_benchmark(value: jax.Array) -> jax.ShapeDtypeStruct:
158-
sharding = _sharding_of(value)
159-
if sharding is None or _value_uses_manual_sharding(value):
160-
return jax.ShapeDtypeStruct(value.shape, value.dtype)
161-
return jax.ShapeDtypeStruct(value.shape, value.dtype, sharding=sharding)
162-
163-
164-
def _maybe_wrap_loss_in_shard_map_for_benchmark(
165-
fn: Callable[[jax.Array, jax.Array, jax.Array], jax.Array],
166-
*,
167-
x: jax.Array,
168-
labels: jax.Array,
169-
w: jax.Array,
170-
) -> Callable[[jax.Array, jax.Array, jax.Array], jax.Array]:
171-
if _value_uses_manual_sharding(x) or _value_uses_manual_sharding(labels) or _value_uses_manual_sharding(w):
172-
return fn
173-
174-
x_sharding = _named_sharding_of(x)
175-
labels_sharding = _named_sharding_of(labels)
176-
w_sharding = _named_sharding_of(w)
177-
if x_sharding is None or labels_sharding is None or w_sharding is None:
178-
return fn
179-
if x_sharding.mesh != labels_sharding.mesh or x_sharding.mesh != w_sharding.mesh:
180-
return fn
181-
182-
return jax.shard_map(
183-
fn,
184-
mesh=x_sharding.mesh,
185-
in_specs=(x_sharding.spec, labels_sharding.spec, w_sharding.spec),
186-
out_specs=labels_sharding.spec,
187-
check_vma=False,
188-
)
189-
190-
191115
def _warn_vmem_compile_fallback_once(exc: Exception, *, impl_name: str) -> None:
192116
message = str(exc)
193117
key = f"{impl_name}|{message}"
@@ -403,10 +327,6 @@ def _candidate_block_sizes(
403327
return deduped
404328

405329

406-
def _is_tracer(x: jax.Array) -> bool:
407-
return isinstance(x, jax_core.Tracer)
408-
409-
410330
def _benchmark_block_sizes_candidate(
411331
*,
412332
fn: ArrayImpl,
@@ -431,38 +351,28 @@ def _loss_only(x_value: jax.Array, labels_value: jax.Array, w_value: jax.Array)
431351
out = fn(x_value, labels_value, w_value, **kwargs)
432352
return out[0]
433353

434-
benchmark_fn = _maybe_wrap_loss_in_shard_map_for_benchmark(
354+
benchmark_fn = autotune_utils.maybe_wrap_in_shard_map(
435355
_loss_only,
436-
x=x,
437-
labels=labels,
438-
w=w,
356+
args=(x, labels, w),
357+
out_specs=autotune_utils.named_sharding_of(labels).spec if autotune_utils.named_sharding_of(labels) else None,
439358
)
440-
jitted = jax.jit(benchmark_fn)
441-
442-
use_tracer_lowering = _is_tracer(x) or _is_tracer(labels) or _is_tracer(w)
443-
lowering_args = (
444-
(x, labels, w)
445-
if use_tracer_lowering
446-
else (
447-
_shape_dtype_struct_for_benchmark(x),
448-
_shape_dtype_struct_for_benchmark(labels),
449-
_shape_dtype_struct_for_benchmark(w),
450-
)
359+
lowering_args = autotune_utils.benchmark_lowering_args(x, labels, w)
360+
compile_time = autotune_utils.compile_benchmark_fn(
361+
benchmark_fn=benchmark_fn,
362+
lowering_args=lowering_args,
363+
args=(x, labels, w),
451364
)
452-
start = time.perf_counter()
453-
lowered = jitted.lower(*lowering_args)
454-
lowered.compile()
455-
compile_time = time.perf_counter() - start
456365
if compile_time <= _AUTOTUNE_COMPILE_HIT_THRESHOLD_S:
457366
logger.info(
458367
"Fused CE autotune candidate %s likely hit JAX compilation cache (compile %.3fs).",
459368
candidate,
460369
compile_time,
461370
)
462371

463-
if _is_tracer(x) or _is_tracer(labels) or _is_tracer(w):
372+
if autotune_utils.contains_tracer(x, labels, w):
464373
return compile_time
465374

375+
jitted = jax.jit(benchmark_fn)
466376
start = time.perf_counter()
467377
out = jitted(x, labels, w)
468378
jax.block_until_ready(out)

0 commit comments

Comments
 (0)