Skip to content

Commit 417d7a2

Browse files
Merge pull request jax-ml#25511 from jakevdp:fix-scan-err
PiperOrigin-RevId: 706780339
2 parents cd7109e + 74e9275 commit 417d7a2

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ def _get_states(attrs_tracked):
347347
vals.extend(leaves)
348348
return vals
349349

350+
def _capitalize(s):
351+
# s.capitalize() converts s[1:] to lowercase which we don't want.
352+
return s[0].capitalize() + s[1:]
353+
350354
def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
351355
try:
352356
sig = inspect.signature(body_fun)
@@ -380,7 +384,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
380384
# The trees may have different aux data but structures are the same.
381385
return
382386
if len(diffs) == 1:
383-
differences = f'{diffs[0]}.\n'.capitalize()
387+
differences = f'{_capitalize(diffs[0])}.\n'
384388
else:
385389
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
386390
+ f' * {diffs[-1]}.\n')
@@ -400,7 +404,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
400404
# The trees may have different aux data but structures are the same.
401405
return
402406
if len(diffs) == 1:
403-
differences = f'{diffs[0]}.\n'.capitalize()
407+
differences = f'{_capitalize(diffs[0])}.\n'
404408
else:
405409
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
406410
+ f' * {diffs[-1]}.\n')

tests/lax_control_flow_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,16 @@ def testScanBodyOutputError(self):
18971897
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
18981898
lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.))
18991899

1900+
def testScanMetadataError(self):
1901+
# Regression test for https://github.com/jax-ml/jax/issues/25507
1902+
def f(loop_i, x):
1903+
return {'T': jnp.array([0.5])}
1904+
1905+
init_val = {'t': jnp.array([1.0])}
1906+
msg = r".*with pytree metadata \('t',\).*with pytree metadata \('T',\)"
1907+
with self.assertRaisesRegex(TypeError, msg):
1908+
jax.lax.fori_loop(0, 1, f, init_val)
1909+
19001910
def testScanBodyCarryPytreeMismatchErrors(self):
19011911
with self.assertRaisesRegex(
19021912
TypeError,

0 commit comments

Comments
 (0)