Skip to content

Commit 8c48c95

Browse files
KfacJaxDevKfacJaxDev
authored andcommitted
Internal Change
PiperOrigin-RevId: 805859317
1 parent 72062e3 commit 8c48c95

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

kfac_jax/_src/utils/staging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def pmean_if_pmap_wrapper(
142142
func: Callable[..., TArrayTree],
143143
) -> Callable[..., TArrayTree]:
144144
"""Wraps a function to perform a pmean if `multi_device`."""
145-
if self.multi_device:
145+
if self.multi_device and not self.debug:
146146
return lambda *args, **kwargs: lax.pmean(
147147
func(*args, **kwargs), self.pmap_axis_name
148148
)

0 commit comments

Comments
 (0)