|
1 | 1 | % test_GNN.m - Unit tests for GNN wrapper class |
2 | 2 | % |
3 | | -% Tests the GNN class functionality including: |
4 | | -% - Constructor variants |
5 | | -% - Forward pass (evaluate) |
6 | | -% - Reachability analysis |
7 | | -% - Graph structure management (setGraph) |
8 | | -% |
9 | | -% Author: Anne Tumlin |
10 | | -% Date: 01/13/2026 |
11 | | - |
12 | | -%% 1) Empty constructor test |
13 | | -gnn_empty = GNN(); |
14 | | -assert(gnn_empty.numLayers == 0, 'Empty GNN should have 0 layers'); |
15 | | -assert(isempty(gnn_empty.Layers), 'Empty GNN should have empty Layers'); |
| 3 | +% Tests: constructor, evaluate, reach + soundness, setGraph, precision |
16 | 4 |
|
17 | | -%% 2) Constructor test - layers only |
| 5 | +% Shared setup (before any %% sections) |
18 | 6 | W1 = rand(4, 8); b1 = rand(8, 1); |
19 | 7 | W2 = rand(8, 4); b2 = rand(4, 1); |
20 | 8 | L1 = GCNLayer('gcn1', W1, b1); |
21 | 9 | L2 = GCNLayer('gcn2', W2, b2); |
22 | 10 |
|
23 | | -gnn_layers = GNN({L1, L2}); |
24 | | -assert(gnn_layers.numLayers == 2, 'Should have 2 layers'); |
25 | | -assert(isempty(gnn_layers.A_norm), 'A_norm should be empty'); |
26 | | - |
27 | | -%% 3) Constructor test - GCN-only network (layers + A_norm) |
28 | 11 | numNodes = 5; |
29 | 12 | A_norm = rand(numNodes, numNodes); |
| 13 | +X = rand(numNodes, 4); |
| 14 | + |
| 15 | +NF = rand(numNodes, 4); |
| 16 | +LB = -0.1 * ones(numNodes, 4); |
| 17 | +UB = 0.1 * ones(numNodes, 4); |
| 18 | +GS_in = GraphStar(NF, LB, UB); |
| 19 | + |
30 | 20 | gnn = GNN({L1, L2}, A_norm); |
31 | 21 |
|
| 22 | +%% 1) Constructor test |
32 | 23 | assert(gnn.numLayers == 2, 'Should have 2 layers'); |
33 | 24 | assert(isequal(gnn.A_norm, A_norm), 'A_norm should match'); |
34 | 25 | assert(gnn.InputSize == 4, 'InputSize should be 4'); |
35 | 26 | assert(gnn.OutputSize == 4, 'OutputSize should be 4'); |
36 | 27 |
|
37 | | -%% 4) Evaluate test - GCN-only |
38 | | -X = rand(numNodes, 4); % 5 nodes, 4 features |
| 28 | +%% 2) Evaluate test |
39 | 29 | Y = gnn.evaluate(X); |
40 | | - |
41 | 30 | assert(size(Y, 1) == numNodes, 'Output should have same number of nodes'); |
42 | 31 | assert(size(Y, 2) == 4, 'Output should have 4 features'); |
43 | 32 |
|
44 | | -%% 5) Verify evaluate matches manual layer-by-layer computation |
| 33 | +% Verify matches manual layer-by-layer computation |
45 | 34 | Y_manual = L1.evaluate(X, A_norm); |
46 | 35 | Y_manual = L2.evaluate(Y_manual, A_norm); |
47 | 36 | assert(max(abs(Y - Y_manual), [], 'all') < 1e-10, 'GNN.evaluate should match manual computation'); |
48 | 37 |
|
49 | | -%% 6) Constructor test - Full GNN (layers + A_norm + adj_list + E) |
50 | | -W_node = rand(8, 8); b_node = rand(8, 1); |
51 | | -W_edge = rand(3, 8); b_edge = rand(8, 1); |
52 | | -L_gine = GINELayer('gine', W_node, b_node, W_edge, b_edge); |
53 | | - |
54 | | -adj_list = [1 2; 2 3; 3 4; 4 5; 5 1]; % 5 edges forming a cycle |
55 | | -E = rand(5, 3); % 5 edges, 3 edge features |
56 | | - |
57 | | -gnn_full = GNN({L1, L_gine, L2}, A_norm, adj_list, E); |
58 | | - |
59 | | -assert(gnn_full.numLayers == 3, 'Should have 3 layers'); |
60 | | -assert(isequal(gnn_full.adj_list, adj_list), 'adj_list should match'); |
61 | | -assert(isequal(gnn_full.E, E), 'E should match'); |
62 | | - |
63 | | -%% 7) Constructor test - Full GNN with name |
64 | | -gnn_named = GNN({L1, L2}, A_norm, adj_list, E, [], 'my_gnn'); |
65 | | -assert(strcmp(gnn_named.Name, 'my_gnn'), 'Name should match'); |
66 | | - |
67 | | -%% 8) Evaluate test - Mixed GCN + GINE network |
68 | | -Y_mixed = gnn_full.evaluate(X); |
69 | | - |
70 | | -assert(size(Y_mixed, 1) == numNodes, 'Output should have same number of nodes'); |
71 | | -assert(size(Y_mixed, 2) == 4, 'Output should have 4 features'); |
72 | | - |
73 | | -%% 9) Reachability test - GCN-only network |
74 | | -NF = rand(numNodes, 4); |
75 | | -LB = -0.1 * ones(numNodes, 4); |
76 | | -UB = 0.1 * ones(numNodes, 4); |
77 | | -GS_in = GraphStar(NF, LB, UB); |
78 | | - |
| 38 | +%% 3) Reach and soundness test |
79 | 39 | reachOpts = struct('reachMethod', 'approx-star'); |
80 | 40 | GS_out = gnn.reach(GS_in, reachOpts); |
81 | 41 |
|
82 | 42 | assert(isa(GS_out, 'GraphStar'), 'Output should be GraphStar'); |
83 | 43 | assert(GS_out.numNodes == numNodes, 'Output should have same number of nodes'); |
84 | 44 | assert(GS_out.numFeatures == 4, 'Output should have 4 features'); |
85 | 45 |
|
86 | | -%% 10) Verify center matches evaluate for GCN-only |
| 46 | +% Soundness: center should match evaluate |
87 | 47 | center_in = GS_in.V(:, :, 1); |
88 | 48 | center_out = GS_out.V(:, :, 1); |
89 | 49 | expected = gnn.evaluate(center_in); |
90 | | - |
91 | 50 | assert(max(abs(center_out - expected), [], 'all') < 1e-10, ... |
92 | 51 | 'Center of output GraphStar should match evaluate()'); |
93 | 52 |
|
94 | | -%% 11) Verify reachSet and reachTime are populated |
| 53 | +% Containment: center output should be within bounds |
| 54 | +[lb_out, ub_out] = GS_out.getRanges(); |
| 55 | +Y_center = gnn.evaluate(GS_in.V(:,:,1)); |
| 56 | +tol = 1e-6; |
| 57 | +assert(all(Y_center(:) >= lb_out(:) - tol), 'Center output should be >= lower bound'); |
| 58 | +assert(all(Y_center(:) <= ub_out(:) + tol), 'Center output should be <= upper bound'); |
| 59 | + |
| 60 | +% Verify reachSet and reachTime populated |
95 | 61 | assert(length(gnn.reachSet) == gnn.numLayers, 'reachSet should have entry per layer'); |
96 | 62 | assert(length(gnn.reachTime) == gnn.numLayers, 'reachTime should have entry per layer'); |
97 | 63 | assert(all(gnn.reachTime > 0), 'reachTime entries should be positive'); |
98 | 64 |
|
99 | | -%% 12) Test setGraph - update A_norm only |
| 65 | +%% 4) setGraph test |
| 66 | +Y_original = gnn.evaluate(X); % Store original output |
100 | 67 | A_norm_new = rand(numNodes, numNodes); |
101 | 68 | gnn.setGraph(A_norm_new); |
102 | | - |
103 | 69 | assert(isequal(gnn.A_norm, A_norm_new), 'A_norm should be updated'); |
104 | 70 |
|
105 | | -%% 13) Test setGraph - weight reuse produces different output |
106 | 71 | Y_new = gnn.evaluate(X); |
107 | | -assert(~isequal(Y, Y_new), 'Different graph should produce different output'); |
108 | | - |
109 | | -%% 14) Test setGraph - full update |
110 | | -adj_list_new = [1 3; 2 4; 3 5; 4 1; 5 2]; |
111 | | -E_new = rand(5, 3); |
112 | | -gnn_full.setGraph(A_norm_new, adj_list_new, E_new); |
| 72 | +assert(~isequal(Y_original, Y_new), 'Different graph should produce different output'); |
113 | 73 |
|
114 | | -assert(isequal(gnn_full.A_norm, A_norm_new), 'A_norm should be updated'); |
115 | | -assert(isequal(gnn_full.adj_list, adj_list_new), 'adj_list should be updated'); |
116 | | -assert(isequal(gnn_full.E, E_new), 'E should be updated'); |
117 | | - |
118 | | -%% 15) Test precision change |
| 74 | +%% 5) Precision change test |
119 | 75 | gnn_prec = GNN({L1, L2}, A_norm); |
120 | 76 | gnn_prec.changeParamsPrecision('single'); |
121 | | - |
122 | 77 | assert(isa(gnn_prec.Layers{1}.Weights, 'single'), 'Weights should be single precision'); |
123 | | -assert(isa(gnn_prec.Layers{2}.Weights, 'single'), 'Weights should be single precision'); |
124 | 78 |
|
125 | 79 | gnn_prec.changeParamsPrecision('double'); |
126 | 80 | assert(isa(gnn_prec.Layers{1}.Weights, 'double'), 'Weights should be double precision'); |
127 | 81 |
|
128 | | -%% 16) Test getInfo |
129 | | -info = gnn.getInfo(); |
130 | | - |
131 | | -assert(info.numLayers == 2, 'numLayers should be 2'); |
132 | | -assert(info.hasAdjacency == true, 'hasAdjacency should be true'); |
133 | | -assert(strcmp(info.layerTypes{1}, 'GCNLayer'), 'First layer should be GCNLayer'); |
134 | | - |
135 | | -%% 17) Test with default reachOptions |
136 | | -gnn.setGraph(A_norm); % Reset to original A_norm |
137 | | -GS_out_default = gnn.reach(GS_in); |
138 | | - |
139 | | -assert(isa(GS_out_default, 'GraphStar'), 'Output should be GraphStar with default options'); |
140 | | - |
141 | | -%% 18) Test output bounds contain samples |
142 | | -num_samples = 10; |
143 | | -[lb_out, ub_out] = GS_out.getRanges(); |
144 | | - |
145 | | -for s = 1:num_samples |
146 | | - % Generate random sample from input GraphStar |
147 | | - alpha = rand(GS_in.numPred, 1) .* (GS_in.pred_ub - GS_in.pred_lb) + GS_in.pred_lb; |
148 | | - X_sample = GS_in.V(:, :, 1); |
149 | | - for k = 1:GS_in.numPred |
150 | | - X_sample = X_sample + alpha(k) * GS_in.V(:, :, k+1); |
151 | | - end |
152 | | - |
153 | | - % Evaluate at sample |
154 | | - Y_sample = gnn.evaluate(X_sample); |
155 | | - |
156 | | - % Check sample is within bounds |
157 | | - tol = 1e-6; |
158 | | - assert(all(Y_sample(:) >= lb_out(:) - tol), 'Sample should be above lower bound'); |
159 | | - assert(all(Y_sample(:) <= ub_out(:) + tol), 'Sample should be below upper bound'); |
160 | | -end |
161 | | - |
162 | 82 | disp('All GNN tests passed!'); |
0 commit comments