-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest_cnn.py
31 lines (27 loc) · 1.02 KB
/
test_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from cnn_model import ConvNN
class TestCNN:
# def test_conv_relu_output_size(self):
# net = CNN()
# assert net(torch.randn(1, 1, 28, 28)).size() == (1, 32, 22, 22), \
# "expect output size of (1, 32, 22, 22)"
# def test_maxpool2d_output_size(self):
# net = CNN()
# assert net(torch.randn(1, 1, 28, 28)).size() == (1, 32, 11, 11), \
# "expect output size of (1, 32, 11, 11)"
# def test_linear_output_size(self):
# net = CNN()
# assert net(torch.randn(1, 1, 28, 28)).size() == (1, 3872), \
# "expect output size of (1, 3872)"
def test_output_size_28x28(self):
net = ConvNN()
assert net(torch.randn(1, 1, 28, 28)).size() == (
1,
10,
), "expect output size of (1,10)"
def test_output_size_32x32(self):
net = ConvNN(img_rows=32, img_cols=32)
assert net(torch.randn(1, 1, 32, 32)).size() == (
1,
10,
), "expect output size of (1, 10)"