Skip to content

Commit 03d788c

Browse files
committed
Testing global pooling layer
1 parent 95cbe55 commit 03d788c

File tree

4 files changed

+135
-1
lines changed

4 files changed

+135
-1
lines changed

code/nnv/engine/nn/layers/GlobalAveragePooling2DLayer.m

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,14 @@
3838
numOutputs = 1;
3939
inputNames = {'in1'};
4040
outputNames = {'out'};
41+
case 0
42+
name = 'global_average_pooling_2d';
43+
numInputs = 1;
44+
numOutputs = 1;
45+
inputNames = {'in1'};
46+
outputNames = {'out'};
4147
otherwise
42-
error('Invalid number of input arguments, should be 1 or 5');
48+
error('Invalid number of input arguments, should be 0, 1 or 5');
4349
end
4450

4551
if ~ischar(name)
500 Bytes
Binary file not shown.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
2+
3+
%% test 1: constructor
4+
L1 = GlobalAveragePooling2DLayer();
5+
6+
7+
%% test 2: inference
8+
L1 = GlobalAveragePooling2DLayer();
9+
x = load('one_image.mat');
10+
x = x.one_image;
11+
L1.evaluate(x);
12+
13+
14+
%% test 3: equivalence (inference)
15+
L1 = GlobalAveragePooling2DLayer();
16+
x = load('one_image.mat');
17+
x = x.one_image;
18+
y = L1.evaluate(x);
19+
20+
dlX = dlarray(x, 'SSBC');
21+
dlY = avgpool(dlX,'global');
22+
23+
assert(all(dlY == y, 'all'));
24+
25+
%% test 4: inference, higher dimension
26+
27+
miniBatchSize = 10;
28+
inputSize = [5 5];
29+
numChannels = 3;
30+
X = rand(inputSize(1),inputSize(2),numChannels,miniBatchSize);
31+
32+
L1 = GlobalAveragePooling2DLayer();
33+
Y = L1.evaluate(X);
34+
35+
dlX = dlarray(X,'SSCB');
36+
dlY = avgpool(dlX,'global');
37+
dlY = extractdata(dlY);
38+
39+
assert(all(dlY == Y, 'all'));
40+
41+
%% test 5: reachability
42+
43+
x = load('one_image.mat');
44+
X = x.one_image;
45+
46+
lb = X - 0.1;
47+
ub = X + 0.1;
48+
IS = ImageStar(lb,ub);
49+
50+
L1 = GlobalAveragePooling2DLayer();
51+
Y = L1.evaluate(X);
52+
Yset = L1.reach(IS,'approx-star');
53+
54+
[LB,UB] = Yset.estimateRanges;
55+
56+
assert(all(LB <= Y,'all'))
57+
assert(all(UB >= Y,'all'))
58+
59+
%% test 6: reach (sound)
60+
61+
N = 100; % random samples
62+
63+
x = load('one_image.mat');
64+
X = x.one_image;
65+
66+
lb = X - 0.1;
67+
ub = X + 0.1;
68+
IS = ImageStar(lb,ub);
69+
x_samples = IS.sample(N);
70+
71+
L1 = GlobalAveragePooling2DLayer();
72+
Yset = L1.reach(IS,'approx-star');
73+
[LB,UB] = Yset.estimateRanges;
74+
75+
for i=1:N
76+
xi = x_samples{i};
77+
Yi = L1.evaluate(xi);
78+
assert(all(LB <= Yi,'all'))
79+
assert(all(UB >= Yi,'all'))
80+
end
81+
82+
83+
%% test 7: reachability
84+
85+
miniBatchSize = 1;
86+
inputSize = [5 5];
87+
numChannels = 3;
88+
X = rand(inputSize(1),inputSize(2),numChannels,miniBatchSize);
89+
90+
lb = X - 0.1;
91+
ub = X + 0.1;
92+
IS = ImageStar(lb,ub);
93+
94+
L1 = GlobalAveragePooling2DLayer();
95+
Y = L1.evaluate(X);
96+
Yset = L1.reach(IS,'approx-star');
97+
98+
[LB,UB] = Yset.estimateRanges;
99+
100+
assert(all(LB <= Y,'all'))
101+
assert(all(UB >= Y,'all'))
102+
103+
%% test 8: reach (sound)
104+
105+
N = 200; % random samples
106+
107+
miniBatchSize = 1;
108+
inputSize = [5 5];
109+
numChannels = 3;
110+
X = rand(inputSize(1),inputSize(2),numChannels,miniBatchSize);
111+
112+
lb = X - 0.1;
113+
ub = X + 0.1;
114+
IS = ImageStar(lb,ub);
115+
116+
x_samples = IS.sample(N);
117+
118+
L1 = GlobalAveragePooling2DLayer();
119+
Yset = L1.reach(IS,'approx-star');
120+
[LB,UB] = Yset.estimateRanges;
121+
122+
for i=1:N
123+
xi = x_samples{i};
124+
Yi = L1.evaluate(xi);
125+
assert(all(LB <= Yi,'all'))
126+
assert(all(UB >= Yi,'all'))
127+
end
128+

code/nnv/tests/nn/layers/GlobalAveragePooling2DLayer/todo_test.m

Whitespace-only changes.

0 commit comments

Comments
 (0)