99""" 
1010
1111# global modules 
12- import  json 
1312import  unittest 
1413
1514import  numpy  as  np 
2221# set seed for reproducibility 
2322torch .manual_seed (42 )
2423
24+ N  =  512 
25+ 
2526
2627class  TestTorchCompile (unittest .TestCase ):
2728    """ 
@@ -34,7 +35,6 @@ def test_cox_equivalence(self):
3435        """ 
3536
3637        # random data and parameters 
37-         N  =  32 
3838        log_hz  =  torch .randn (N )
3939        event  =  torch .randint (low = 0 , high = 2 , size = (N ,)).bool ()
4040        time  =  torch .randint (low = 1 , high = 100 , size = (N ,))
@@ -45,17 +45,14 @@ def test_cox_equivalence(self):
4545        loss_cox  =  cox (log_hz , event , time )
4646        loss_ccox  =  ccox (log_hz , event , time )
4747
48-         self .assertTrue (
49-             np .isclose (loss_cox .numpy (), loss_ccox .numpy (), rtol = 1e-3 , atol = 1e-3 )
50-         )
48+         self .assertTrue (torch .allclose (loss_cox , loss_ccox , rtol = 1e-3 , atol = 1e-3 ))
5149
5250    def  test_weibull_equivalence (self ):
5351        """ 
5452        whether the compiled version of weibull evaluates to the same value 
5553        """ 
5654
5755        # random data and parameters 
58-         N  =  32 
5956        log_hz  =  torch .randn (N )
6057        event  =  torch .randint (low = 0 , high = 2 , size = (N ,)).bool ()
6158        time  =  torch .randint (low = 1 , high = 100 , size = (N ,))
@@ -67,9 +64,7 @@ def test_weibull_equivalence(self):
6764        loss_cweibull  =  cweibull (log_hz , event , time )
6865
6966        self .assertTrue (
70-             np .isclose (
71-                 loss_weibull .numpy (), loss_cweibull .numpy (), rtol = 1e-3 , atol = 1e-3 
72-             )
67+             torch .allclose (loss_weibull , loss_cweibull , rtol = 1e-3 , atol = 1e-3 )
7368        )
7469
7570
0 commit comments