@@ -1615,15 +1615,29 @@ def test_correct_class_is_exposed(self, backend_name):
16151615 self .assertIs (backend .RNGHandler , backend ._TFRNGHandler )
16161616 # pylint: enable=protected-access
16171617
1618- @parameterized .named_parameters (("tensorflow" , _TF ), ("jax" , _JAX ))
1619- def test_initialization_with_none_seed_is_noop (self , backend_name ):
1620- """Verifies that a None seed creates a handler that returns None."""
1618+ @parameterized .named_parameters (
1619+ dict (
1620+ testcase_name = "tensorflow" ,
1621+ backend_name = _TF ,
1622+ assert_fn_name = "assertIsNone" ,
1623+ ),
1624+ dict (
1625+ testcase_name = "jax" ,
1626+ backend_name = _JAX ,
1627+ assert_fn_name = "assertIsNotNone" ,
1628+ ),
1629+ )
1630+ def test_initialization_with_none_seed_is_noop (
1631+ self , backend_name , assert_fn_name
1632+ ):
1633+ """Verifies behavior when initialized with None."""
16211634 self ._set_backend_for_test (backend_name )
16221635 handler = backend .RNGHandler (None )
1636+ assertion = getattr (self , assert_fn_name )
16231637
16241638 self .assertIsNone (handler ._seed_input )
1625- self . assertIsNone (handler .get_next_seed ())
1626- self . assertIsNone (handler .get_kernel_seed ())
1639+ assertion (handler .get_next_seed ())
1640+ assertion (handler .get_kernel_seed ())
16271641
16281642 @parameterized .named_parameters (("tensorflow" , _TF ), ("jax" , _JAX ))
16291643 def test_initialization_with_integer_seed (self , backend_name ):
@@ -1770,17 +1784,29 @@ def test_get_next_seed_is_reproducible(self, backend_name):
17701784 else :
17711785 test_utils .assert_allequal (s1 , s2 )
17721786
1773- @parameterized .named_parameters (("tensorflow" , _TF ), ("jax" , _JAX ))
1774- def test_advance_handler_with_none_seed (self , backend_name ):
1775- """Tests that advancing a no-op handler produces another no-op handler."""
1787+ @parameterized .named_parameters (
1788+ dict (
1789+ testcase_name = "tensorflow" ,
1790+ backend_name = _TF ,
1791+ assert_fn_name = "assertIsNone" ,
1792+ ),
1793+ dict (
1794+ testcase_name = "jax" ,
1795+ backend_name = _JAX ,
1796+ assert_fn_name = "assertIsNotNone" ,
1797+ ),
1798+ )
1799+ def test_advance_handler_with_none_seed (self , backend_name , assert_fn_name ):
1800+ """Tests advancing a handler initialized with None."""
17761801 self ._set_backend_for_test (backend_name )
17771802 handler = backend .RNGHandler (None )
17781803 new_handler = handler .advance_handler ()
1804+ assertion = getattr (self , assert_fn_name )
17791805
17801806 self .assertIsNot (handler , new_handler )
1781- self . assertIsNone (new_handler ._seed_input )
1782- self . assertIsNone (handler .get_next_seed ())
1783- self . assertIsNone (new_handler .get_kernel_seed ())
1807+ assertion (new_handler ._seed_input )
1808+ assertion (handler .get_next_seed ())
1809+ assertion (new_handler .get_kernel_seed ())
17841810
17851811 @parameterized .named_parameters (("tensorflow" , _TF ), ("jax" , _JAX ))
17861812 def test_advance_handler_provides_independent_handlers (self , backend_name ):
0 commit comments