Skip to content

Commit c72c5d5

Browse files
committed
used torch instead of numpy
1 parent ab64bc3 commit c72c5d5

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tests/test_torch_compile.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"""
1010

1111
# global modules
12-
import json
1312
import unittest
1413

1514
import numpy as np
@@ -22,6 +21,8 @@
2221
# set seed for reproducibility
2322
torch.manual_seed(42)
2423

24+
N = 512
25+
2526

2627
class TestTorchCompile(unittest.TestCase):
2728
"""
@@ -34,7 +35,6 @@ def test_cox_equivalence(self):
3435
"""
3536

3637
# random data and parameters
37-
N = 32
3838
log_hz = torch.randn(N)
3939
event = torch.randint(low=0, high=2, size=(N,)).bool()
4040
time = torch.randint(low=1, high=100, size=(N,))
@@ -45,17 +45,14 @@ def test_cox_equivalence(self):
4545
loss_cox = cox(log_hz, event, time)
4646
loss_ccox = ccox(log_hz, event, time)
4747

48-
self.assertTrue(
49-
np.isclose(loss_cox.numpy(), loss_ccox.numpy(), rtol=1e-3, atol=1e-3)
50-
)
48+
self.assertTrue(torch.allclose(loss_cox, loss_ccox, rtol=1e-3, atol=1e-3))
5149

5250
def test_weibull_equivalence(self):
5351
"""
5452
whether the compiled version of weibull evaluates to the same value
5553
"""
5654

5755
# random data and parameters
58-
N = 32
5956
log_hz = torch.randn(N)
6057
event = torch.randint(low=0, high=2, size=(N,)).bool()
6158
time = torch.randint(low=1, high=100, size=(N,))
@@ -67,9 +64,7 @@ def test_weibull_equivalence(self):
6764
loss_cweibull = cweibull(log_hz, event, time)
6865

6966
self.assertTrue(
70-
np.isclose(
71-
loss_weibull.numpy(), loss_cweibull.numpy(), rtol=1e-3, atol=1e-3
72-
)
67+
torch.allclose(loss_weibull, loss_cweibull, rtol=1e-3, atol=1e-3)
7368
)
7469

7570

0 commit comments

Comments
 (0)