Skip to content

Commit 0a17498

Browse files
update tests for jax 0.5.3
1 parent eeefdc8 commit 0a17498

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

axlearn/audio/decoder_asr_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def test_forward_summary(self):
569569
summaries = output_collections.summaries
570570
# 6 out of 8 examples are valid, therefore the average example weight is 0.75
571571
self._check_summary(summaries, "loss/example_weight", WeightedScalar(0.75, 8))
572-
self._check_summary(summaries, "loss/ctc_loss", WeightedScalar(6972.1353, 6))
572+
self._check_summary(summaries, "loss/ctc_loss", WeightedScalar(6972.135, 6))
573573
self._check_summary(summaries, "loss/invalid_seq_percent", 0.25)
574574
total_ctc_loss = summaries["loss/ctc_loss"].weight * summaries["loss/ctc_loss"].mean
575575
num_valid_frames = jnp.sum(safe_not(paddings) * per_example_weight[:, None])

axlearn/audio/frontend_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def test_small_input(self, input_dtype):
236236
output_with_large_mel_floor,
237237
),
238238
output_with_correct_mel_floor,
239+
rtol=6e-4,
239240
)
240241

241242
@set_threefry_partitionable(False) # TODO(Luzy): update for threefry_partitionable True
@@ -279,7 +280,7 @@ def test_fft(self):
279280
ref_outputs = self._jit_forward(ref_layer, inputs, paddings)
280281
test_outputs = self._jit_forward(layer, inputs, paddings)
281282

282-
self.assertAllClose(ref_outputs["outputs"], test_outputs["outputs"])
283+
self.assertAllClose(ref_outputs["outputs"], test_outputs["outputs"], rtol=5e-3)
283284
self.assertAllClose(ref_outputs["paddings"], test_outputs["paddings"])
284285

285286
@parameterized.product(

axlearn/audio/frontend_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def test_fft(self):
417417
ref_ffts = jax.jit(jnp.fft.fft, static_argnames="n")(inputs, n=fft_size)
418418
test_ffts = fft_fn(inputs)
419419

420-
assert_allclose(ref_ffts, test_ffts)
420+
assert_allclose(ref_ffts, test_ffts, rtol=1e-3)
421421
# Run the following on gpu.
422422
jax.debug.inspect_array_sharding(test_ffts, callback=print)
423423
jax.debug.inspect_array_sharding(ref_ffts, callback=print)

axlearn/common/base_layer_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
FanAxes,
4242
WeightInitializer,
4343
)
44-
from axlearn.common.test_utils import TestCase, assert_allclose, set_threefry_partitionable
44+
from axlearn.common.test_utils import TestCase, assert_allclose
4545
from axlearn.common.utils import safe_not
4646

4747

@@ -375,7 +375,6 @@ def test_apply_parameter_noise_recursively(self, param_noise_cfg):
375375
self.assertNestedAllClose(jnp.zeros_like(orig_value), noisy_value)
376376

377377
@parameterized.parameters(False, True)
378-
@set_threefry_partitionable(True) # TODO(mhopkins): remove during jax 0.5.0+ upgrade
379378
def test_tensor_stats(self, inline_child_summaries: bool):
380379
test_layer: TestLayer = (
381380
TestLayer.default_config()

axlearn/common/quantizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_forward(
348348
),
349349
expected_values[batch_size][normalize_codebook]["q_vecs"],
350350
atol=1e-6,
351-
rtol=1e-6,
351+
rtol=2e-6,
352352
)
353353
assert_allclose(
354354
np.sum(q_outputs.ids * safe_not(paddings)[:, :, None]),

0 commit comments

Comments
 (0)