Skip to content

Commit d88f93f

Browse files
committed
Fixing forward functions across files
1 parent 1b7dd25 commit d88f93f

File tree

4 files changed

+56
-28
lines changed

4 files changed

+56
-28
lines changed

code/nnv/engine/nn/Prob_reach/ProbReach_ImageStar.m

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,41 @@
7171

7272
end
7373

74-
function out = forward(obj, x)
75-
74+
function out = forward(obj, x, inputFormat)
75+
7676
model_source = class(obj.model);
77-
77+
7878
switch model_source
79-
79+
8080
case 'SeriesNetwork'
8181
out = obj.model.predict(x);
82-
82+
8383
case 'DAGNetwork'
8484
out = obj.model.predict(x);
85-
85+
8686
case 'dlnetwork'
87-
dlX = dlarray(x);
87+
% dlX = dlarray(x,inputFormat);
88+
if strcmp(inputFormat, "default")
89+
if isa(obj.model.Layers(1, 1), 'nnet.cnn.layer.ImageInputLayer')
90+
dlX = dlarray(x, "SSCB");
91+
elseif isa(obj.model.Layers(1, 1), 'nnet.cnn.layer.FeatureInputLayer') || isa(model.Layers(1, 1), 'nnet.onnx.layer.FeatureInputLayer')
92+
dlX = dlarray(x, "CB");
93+
else
94+
disp(obj.model.Layers(1,1));
95+
error("Unknown input format");
96+
end
97+
else
98+
if contains(inputFormat, "U")
99+
dlX = dlarray(x, inputFormat+"U");
100+
else
101+
dlX = dlarray(x, inputFormat);
102+
end
103+
end
88104
out = obj.model.predict(dlX);
89-
105+
90106
case 'NN'
91107
out = obj.model.evaluate(x);
92-
108+
93109
otherwise
94110
error("Unknown model source: " + model_source + ". We only cover NN, SeriesNetwork, dlnetwork and DAGNetwork.");
95111
end
@@ -106,6 +122,7 @@
106122
N = obj.params.Nt;
107123
N_dir = obj.params.N_dir;
108124
trn_batch = obj.params.trn_batch;
125+
inputFormat = obj.params.inputFormat;
109126

110127
N_perturbed = size(obj.indices , 1);
111128

@@ -124,7 +141,7 @@
124141
end
125142
Inp = single(obj.LB + d_at);
126143
X(:,i) = Rand;
127-
Y(:,:,:,i) = obj.forward(Inp);
144+
Y(:,:,:,i) = obj.forward(Inp,inputFormat);
128145
end
129146
%%%%%%%%%%%%%
130147
n1 = numel(Y(:,:,:,1));
@@ -340,7 +357,7 @@
340357
end
341358
Inp = obj.LB + d_at;
342359
X_test_nc(:,i) = Rand;
343-
Y_test_nc(:,:,:,i) = obj.forward(Inp);
360+
Y_test_nc(:,:,:,i) = obj.forward(Inp,inputFormat);
344361
end
345362
%%%%%%%%%%%%%
346363

@@ -455,7 +472,8 @@
455472
disp(' The Image Star is large for your memory and should be presented in sparse format.')
456473
disp('Unfortunately matlab does not support sparse representation for (N>2)D arrays.')
457474
disp('Thus we provide the vectorized format of ImageStar() that is a Star() via sparse 2D arrays. ')
458-
p = input('Do you want to continue? Yes <-- 1 / No <-- 0 ');
475+
% p = input('Do you want to continue? Yes <-- 1 / No <-- 0 ');
476+
p = 1;
459477

460478
if p==1
461479

code/nnv/engine/utils/Prob_reach.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
params.threshold_normal = threshold_normal;
108108
params.guarantee = coverage;
109109
params.py_dir = py_dir;
110+
params.inputFormat = inputFormat;
110111

111112

112113
obj = ProbReach_ImageStar(Net,LB, UB,indices, SizeOut, train_mode, params);

code/nnv/examples/Submission/VNN_COMP2025/run_vnncomp_instance.m

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
% Load networks
1515

16-
[net, nnvnet, needReshape, reachOptionsList, inputSize,inputFormat] = load_vnncomp_network(category, onnx, vnnlib);
16+
[net, nnvnet, needReshape, reachOptionsList, inputSize, inputFormat] = load_vnncomp_network(category, onnx, vnnlib);
1717

1818
if isempty(inputSize)
1919
inputSize = net.Layers(1, 1).InputSize;
@@ -71,7 +71,7 @@
7171

7272
vT = tic;
7373

74-
quickRun = true;
74+
quickRun = false;
7575
%
7676
% if quickRun
7777
% tTime = toc(t);
@@ -100,7 +100,7 @@
100100
if ~strcmp(reachOptions.reachMethod, "cp-star")
101101
ySet = nnvnet.reach(IS, reachOptions);
102102
else
103-
ySet = Prob_reach(net, IS, []);
103+
ySet = Prob_reach(net, IS, reachOptions);
104104
end
105105

106106
% Verify property
@@ -143,7 +143,7 @@
143143
if ~strcmp(reachOptions.reachMethod, "cp-star")
144144
ySet = nnvnet.reach(IS, reachOptions);
145145
else
146-
ySet = Prob_reach(net, IS, []);
146+
ySet = Prob_reach(net, IS, reachOptions);
147147
end
148148

149149
% Verify property
@@ -203,7 +203,7 @@
203203
if ~strcmp(reachOptions.reachMethod, "cp-star")
204204
ySet = nnvnet.reach(IS, reachOptions);
205205
else
206-
ySet = Prob_reach(net, IS, []);
206+
ySet = Prob_reach(net, IS, reachOptions);
207207
end
208208

209209
% Add verification status
@@ -320,6 +320,7 @@
320320
function [net,nnvnet,needReshape,reachOptionsList,inputSize,inputFormat] = load_vnncomp_network(category, onnx, vnnlib)
321321
% load participating vnncomp 2025 benchmark NNs
322322
% Not yet supported:
323+
% - cctsdb (some errrors when forward propagating)
323324
% - lsnc_relu
324325
% - ml4acopf
325326
% - traffic_signs_recognition
@@ -350,15 +351,17 @@
350351
end
351352

352353
elseif contains(category, "cctsdb_yolo")
353-
net = importNetworkFromONNX(onnx);
354-
nnvnet = "";
355-
inputSize = [12296, 1];
356-
inputFormat = "UU";
357-
X = dlarray(rand(12296, 1), inputFormat);
358-
net = initialize(net, X);
359-
reachOptions = struct;
360-
reachOptions.reachMethod = 'cp-star';
361-
reachOptionsList{1} = reachOptions;
354+
% net = importNetworkFromONNX(onnx);
355+
% nnvnet = "";
356+
% inputSize = [12296, 1];
357+
% inputFormat = "UU";
358+
% X = dlarray(rand(12296, 1), inputFormat);
359+
% net = initialize(net, X);
360+
% reachOptions = struct;
361+
% reachOptions.reachMethod = 'cp-star';
362+
% reachOptions.inputFormat = inputFormat;
363+
% reachOptionsList{1} = reachOptions;
364+
error("Working on supporting this one");
362365

363366
elseif contains(category, "cersyve")
364367
net = importNetworkFromONNX(onnx, "InputDataFormats", "BC");
@@ -520,6 +523,7 @@
520523
% net = initialize(net, X);
521524
% nnvnet = "";
522525
% reachOptions = struct;
526+
% reachOptions.inputFormat = inputFormat;
523527
% reachOptions.reachMethod = 'cp-star';
524528
% reachOptionsList{1} = reachOptions;
525529
error("Not supported");
@@ -562,6 +566,7 @@
562566
net = initialize(net, X);
563567
nnvnet = "";
564568
reachOptions = struct;
569+
reachOptions.inputFormat = inputFormat;
565570
reachOptions.reachMethod = 'cp-star'; % default parameters
566571
reachOptionsList{1} = reachOptions;
567572
end
@@ -586,6 +591,7 @@
586591
reachOptions = struct;
587592
reachOptions.reachMethod = 'relax-star-area';
588593
reachOptions.relaxFactor = 1;
594+
reachOptions.inputFormat = inputFormat;
589595
reachOptionsList{1} = reachOptions;
590596
% reachOptions.reachMethod = 'relax-star-area';
591597
% reachOptions.relaxFactor = 0.5;
@@ -660,6 +666,7 @@
660666
% % needReshape = 1; % 1 is wrong
661667
% reachOptions = struct;
662668
% % inputFormat = "BSSC";
669+
% reachOptions.inputFormat = inputFormat;
663670
% reachOptions.reachMethod = 'cp-star';
664671
% reachOptionsList{1} = reachOptions;
665672
error("IR and opset not yet supported in MATLAB")

code/nnv/examples/Submission/VNN_COMP2025/test_instances.m

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@
6666
end
6767
end
6868

69-
% Errors on both? None of them finished...
70-
% Let's test this again
69+
% There are errors in some of the forward prediction (gather to multlayer)
70+
% There seems to be some assertion that fails sometimes when randomly
71+
% executing
72+
% Let's test this again in the future
7173

7274
%% cersyve
7375

0 commit comments

Comments
 (0)