Skip to content

Commit 0c4db8f

Browse files
authored
Reduce expected fallback warning noise (#2921)
## Summary - suppress expected CPU-side pallas fallback warning (`requires TPU backend`) while still falling back to XLA - emit pallas fallback warnings once per unique message instead of repeating on every call - adapt inferred pallas `b_block_size`/`h_block_size` to TPU-lane-aligned divisors of the actual shape when available - reduce splash fallback warning spam by warning once per fallback category - make tracker no-op logging APIs (`jit_log`, `log_summary`, etc.) silently drop events when no global tracker is configured (with one info log) ## Validation - `cd lib/levanter && PYTHONPATH=tests:src:. uv run --package levanter --group test pytest tests/kernels/test_pallas_fused_cross_entropy_loss.py -q` - `cd lib/levanter && PYTHONPATH=tests:src:. uv run --package levanter --group test pytest tests/test_tracker.py tests/test_skip_step.py tests/test_weight_decay_mask.py -q` - `cd lib/levanter && PYTHONPATH=tests:src:. uv run --package levanter --group test pytest tests/grug/test_grugformer_fused_loss.py tests/grug/test_grugformer_model_loss.py -q` - `cd lib/levanter && PYTHONPATH=tests:src:. uv run --package levanter --group test pytest tests/test_attention.py -q`
1 parent e957390 commit 0c4db8f

File tree

5 files changed

+112
-27
lines changed

5 files changed

+112
-27
lines changed

lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"xla": linear_softmax_cross_entropy_loss_xla,
2828
}
2929
_DEFAULT_IMPLEMENTATION: tuple[Implementation, ...] = ("xla",)
30+
_PALLAS_FALLBACK_WARNINGS_EMITTED: set[str] = set()
3031

3132
try:
3233
from .pallas_tpu import PallasUnsupportedError, linear_softmax_cross_entropy_loss_pallas
@@ -37,6 +38,19 @@
3738
PallasUnsupportedError = NotImplementedError # type: ignore[assignment]
3839

3940

41+
def _warn_pallas_fallback_once(exc: Exception) -> None:
42+
message = str(exc)
43+
if "requires TPU backend" in message:
44+
return
45+
if message in _PALLAS_FALLBACK_WARNINGS_EMITTED:
46+
return
47+
_PALLAS_FALLBACK_WARNINGS_EMITTED.add(message)
48+
warnings.warn(
49+
f"Pallas fused cross-entropy unavailable, falling back to XLA: {message}",
50+
RuntimeWarning,
51+
)
52+
53+
4054
def _validate_inputs(x: jax.Array, labels: jax.Array, w: jax.Array) -> None:
4155
if x.ndim != 2:
4256
raise ValueError(f"x must be rank-2 [B, H], got shape {x.shape}.")
@@ -161,19 +175,13 @@ def fused_cross_entropy_loss_and_logsumexp_penalty(
161175
except PallasUnsupportedError as e:
162176
if explicit:
163177
raise
164-
warnings.warn(
165-
f"Pallas fused cross-entropy unavailable, falling back to XLA: {e}",
166-
RuntimeWarning,
167-
)
178+
_warn_pallas_fallback_once(e)
168179
errors.append(e)
169180
continue
170181
except NotImplementedError as e:
171182
if explicit:
172183
raise
173-
warnings.warn(
174-
f"Pallas fused cross-entropy unavailable, falling back to XLA: {e}",
175-
RuntimeWarning,
176-
)
184+
_warn_pallas_fallback_once(e)
177185
errors.append(e)
178186
continue
179187
else:
@@ -193,19 +201,13 @@ def fused_cross_entropy_loss_and_logsumexp_penalty(
193201
except PallasUnsupportedError as e:
194202
if explicit:
195203
raise
196-
warnings.warn(
197-
f"Pallas fused cross-entropy unavailable, falling back to XLA: {e}",
198-
RuntimeWarning,
199-
)
204+
_warn_pallas_fallback_once(e)
200205
errors.append(e)
201206
continue
202207
except NotImplementedError as e:
203208
if explicit:
204209
raise
205-
warnings.warn(
206-
f"Pallas fused cross-entropy unavailable, falling back to XLA: {e}",
207-
RuntimeWarning,
208-
)
210+
_warn_pallas_fallback_once(e)
209211
errors.append(e)
210212
continue
211213

lib/levanter/src/levanter/layers/attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class AttentionBackend(StrEnum):
6262

6363

6464
DEFAULT_SPLASH_BLOCK_SIZE = 512
65+
_SPLASH_FALLBACK_WARNINGS_EMITTED: set[str] = set()
66+
67+
68+
def _warn_splash_fallback_once(message: str) -> None:
69+
if message in _SPLASH_FALLBACK_WARNINGS_EMITTED:
70+
return
71+
_SPLASH_FALLBACK_WARNINGS_EMITTED.add(message)
72+
warnings.warn(message, stacklevel=3)
6573

6674

6775
def default_attention_type() -> AttentionBackend:
@@ -1185,13 +1193,13 @@ def _try_tpu_splash_attention(
11851193
if dropout != 0.0:
11861194
if force_flash:
11871195
raise NotImplementedError("Splash attention does not support dropout.")
1188-
warnings.warn("Splash attention does not support. Falling back to the reference implementation.")
1196+
_warn_splash_fallback_once("Splash attention does not support dropout. Falling back to the reference.")
11891197
return None
11901198

11911199
if bias is not None:
11921200
if force_flash:
11931201
raise NotImplementedError("Splash attention does not support bias.")
1194-
warnings.warn("Splash attention does not support bias. Falling back to the reference implementation.")
1202+
_warn_splash_fallback_once("Splash attention does not support bias. Falling back to the reference.")
11951203
return None
11961204

11971205
try:
@@ -1219,16 +1227,17 @@ def _try_tpu_splash_attention(
12191227
raise
12201228
if force_flash:
12211229
raise ImportError("Could not import splash attention. You need to update your JAX to at least 0.7.2.")
1222-
warnings.warn(
1230+
_warn_splash_fallback_once(
12231231
"Could not import splash attention. You need to update your JAX to at least 0.7.2. "
1224-
"Falling back to the reference implementation."
1232+
"Falling back to the reference implementation.",
12251233
)
12261234
return None
12271235
except NotImplementedError as e:
12281236
message = str(e)
12291237
if force_flash:
12301238
raise NotImplementedError(f"Could not use splash attention: {message}")
1231-
warnings.warn(f"Could not use splash attention: {message}. Falling back to the reference")
1239+
logger.info("Could not use splash attention. Falling back to the reference implementation: %s", message)
1240+
_warn_splash_fallback_once("Could not use splash attention. Falling back to the reference implementation.")
12321241
return None
12331242

12341243

lib/levanter/src/levanter/tracker/tracker_fns.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,19 @@
2929

3030
_should_use_callback = True
3131
_global_tracker: Optional["Tracker"] = None
32+
_has_logged_missing_tracker = False
3233

3334
LoggableValue: typing.TypeAlias = Scalar | jax.Array | str | dict | Histogram
3435

3536

37+
def _log_missing_tracker_once() -> None:
38+
global _has_logged_missing_tracker
39+
if _has_logged_missing_tracker:
40+
return
41+
_has_logged_missing_tracker = True
42+
logger.info("No global tracker set; tracker logs are being dropped.")
43+
44+
3645
def log(metrics: typing.Mapping[str, LoggableValue | Any], *, step: Optional[int], commit: Optional[bool] = None):
3746
"""
3847
Log metrics to the global tracker.
@@ -45,7 +54,8 @@ def log(metrics: typing.Mapping[str, LoggableValue | Any], *, step: Optional[int
4554
"""
4655
global _global_tracker
4756
if _global_tracker is None:
48-
raise RuntimeError("No global tracker set")
57+
_log_missing_tracker_once()
58+
return
4959

5060
if is_inside_jit():
5161
# we're inside a jit, so we need to log from the host
@@ -71,7 +81,7 @@ def log_metrics(
7181
def _do_jit_log(metrics, *, step=None):
7282
try:
7383
if _global_tracker is None:
74-
warnings.warn("No global tracker set")
84+
_log_missing_tracker_once()
7585
else:
7686
_global_tracker.log(metrics, step=step, commit=False)
7787
except Exception:
@@ -91,7 +101,7 @@ def jit_log(metrics, *, step=None):
91101
We strongly recommend using the first method, as it is much more performant.
92102
"""
93103
if _global_tracker is None:
94-
warnings.warn("No global tracker set")
104+
_log_missing_tracker_once()
95105
return
96106
if not _should_use_callback:
97107
# we're not using the callback, so we assume we're inside a defer_tracker_for_jit context manager
@@ -154,7 +164,7 @@ def log_summary(metrics: dict[str, Any]):
154164
"""
155165
global _global_tracker
156166
if _global_tracker is None:
157-
warnings.warn("No global tracker set")
167+
_log_missing_tracker_once()
158168
return
159169

160170
_global_tracker.log_summary(metrics)
@@ -169,7 +179,7 @@ def log_hyperparameters(hparams: dict[str, Any]):
169179
"""
170180
global _global_tracker
171181
if _global_tracker is None:
172-
warnings.warn("No global tracker set")
182+
_log_missing_tracker_once()
173183
return
174184

175185
_global_tracker.log_hyperparameters(hparams)
@@ -185,7 +195,7 @@ def log_configuration(hparams: Any, config_name: Optional[str] = None):
185195
"""
186196
global _global_tracker
187197
if _global_tracker is None:
188-
warnings.warn("No global tracker set")
198+
_log_missing_tracker_once()
189199
return
190200

191201
hparams_dict = hparams_to_dict(hparams)

lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,52 @@ def test_fused_cross_entropy_pallas_requires_tpu():
123123
)
124124

125125

126+
def test_infer_block_sizes_adapts_to_supported_divisors():
127+
block_sizes = infer_block_sizes(
128+
b=512,
129+
h=128,
130+
v=4096,
131+
dtype=jnp.float32,
132+
device_kind="TPU v5e",
133+
)
134+
135+
assert block_sizes.b_block_size == 512
136+
assert block_sizes.h_block_size == 128
137+
138+
139+
def test_infer_block_sizes_preserves_defaults_without_128_aligned_divisors():
140+
block_sizes = infer_block_sizes(
141+
b=96,
142+
h=64,
143+
v=4096,
144+
dtype=jnp.float32,
145+
device_kind="TPU v5e",
146+
)
147+
148+
assert block_sizes.b_block_size == 1024
149+
assert block_sizes.h_block_size == 512
150+
151+
152+
def test_default_implementation_on_cpu_skips_expected_tpu_warning():
153+
if jax.default_backend() == "tpu":
154+
pytest.skip("requires non-TPU backend")
155+
156+
x = jnp.zeros((32, 64), dtype=jnp.float32)
157+
w = jnp.zeros((64, 128), dtype=jnp.float32)
158+
y = jnp.zeros((32,), dtype=jnp.int32)
159+
160+
with warnings.catch_warnings(record=True) as caught:
161+
warnings.simplefilter("always")
162+
fused_api.fused_cross_entropy_loss_and_logsumexp_penalty(
163+
x,
164+
y,
165+
w,
166+
reduction=None,
167+
)
168+
169+
assert not any("requires TPU backend" in str(warning.message) for warning in caught)
170+
171+
126172
def test_fused_cross_entropy_default_matches_reference():
127173
backend = jax.default_backend()
128174
if backend == "tpu":

lib/levanter/tests/test_tracker.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass.
55
import dataclasses
6+
import warnings
67
from typing import Tuple
78

89
import pytest
@@ -82,3 +83,20 @@ def test_get_tracker_by_name(monkeypatch):
8283

8384
with pytest.raises(KeyError):
8485
levanter.tracker.get_tracker("foo")
86+
87+
88+
def test_tracker_logging_without_global_tracker_emits_no_warning(monkeypatch):
89+
import levanter.tracker.tracker_fns as tracker_fns
90+
91+
monkeypatch.setattr(tracker_fns, "_global_tracker", None)
92+
monkeypatch.setattr(tracker_fns, "_has_logged_missing_tracker", False)
93+
94+
with warnings.catch_warnings(record=True) as caught:
95+
warnings.simplefilter("always")
96+
tracker_fns.log({"metric": 1.0}, step=0)
97+
tracker_fns.jit_log({"metric": 1.0}, step=0)
98+
tracker_fns.log_summary({"metric": 1.0})
99+
tracker_fns.log_hyperparameters({"metric": 1.0})
100+
tracker_fns.log_configuration({"metric": 1.0})
101+
102+
assert not caught

0 commit comments

Comments
 (0)