Skip to content

Commit 6ed7d8c

Browse files
authored
Merge pull request #81 from Novartis/67-torch.compile-tests
* torch.compile runs without error and evaluates cox and weibull losses to the same value * ran black test_torch_compile.py * used torch instead of numpy * removed numpy --------- Co-authored-by: corolth1 <[email protected]>
2 parents a505bfc + 3801106 commit 6ed7d8c

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

tests/test_torch_compile.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
3+
Tests for torch.compile
4+
5+
References:
6+
- https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
7+
- https://github.com/pytorch/pytorch/issues/122094
8+
9+
"""
10+
11+
# global modules
12+
import unittest
13+
14+
import torch
15+
16+
# Local modules
17+
from torchsurv.loss.cox import neg_partial_log_likelihood as cox
18+
from torchsurv.loss.weibull import neg_log_likelihood as weibull
19+
20+
# set seed for reproducibility
21+
torch.manual_seed(42)
22+
23+
N = 512
24+
25+
26+
class TestTorchCompile(unittest.TestCase):
27+
"""
28+
Tests using torch.compile with cox
29+
"""
30+
31+
def test_cox_equivalence(self):
32+
"""
33+
whether the compiled version of cox evaluates to the same value
34+
"""
35+
36+
# random data and parameters
37+
log_hz = torch.randn(N)
38+
event = torch.randint(low=0, high=2, size=(N,)).bool()
39+
time = torch.randint(low=1, high=100, size=(N,))
40+
41+
# compiled version of cox
42+
ccox = torch.compile(cox)
43+
44+
loss_cox = cox(log_hz, event, time)
45+
loss_ccox = ccox(log_hz, event, time)
46+
47+
self.assertTrue(torch.allclose(loss_cox, loss_ccox, rtol=1e-3, atol=1e-3))
48+
49+
def test_weibull_equivalence(self):
50+
"""
51+
whether the compiled version of weibull evaluates to the same value
52+
"""
53+
54+
# random data and parameters
55+
log_hz = torch.randn(N)
56+
event = torch.randint(low=0, high=2, size=(N,)).bool()
57+
time = torch.randint(low=1, high=100, size=(N,))
58+
59+
# compiled version of weibull
60+
cweibull = torch.compile(weibull)
61+
62+
loss_weibull = weibull(log_hz, event, time)
63+
loss_cweibull = cweibull(log_hz, event, time)
64+
65+
self.assertTrue(
66+
torch.allclose(loss_weibull, loss_cweibull, rtol=1e-3, atol=1e-3)
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
72+
unittest.main()

0 commit comments

Comments
 (0)