Skip to content

Commit fb1fec4

Browse files
Merge pull request #11 from matlab-deep-learning/feature/pre-19b
Support R2019a
2 parents c5086ea + 87173dd commit fb1fec4

15 files changed

+328
-54
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
*.mltbx
2+
code/mtcnn/weights/dag*.mat

README.md

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Face Detection and Alignment MTCNN
22

3-
![matlab-deep-learning](https://circleci.com/gh/matlab-deep-learning/mtcnn-face-detection.svg?style=svg)
3+
![circleci](https://circleci.com/gh/matlab-deep-learning/mtcnn-face-detection.svg?style=svg)
44
[![codecov](https://codecov.io/gh/matlab-deep-learning/mtcnn-face-detection/branch/master/graph/badge.svg)](https://codecov.io/gh/matlab-deep-learning/mtcnn-face-detection)
55

66
## [__Download the toolbox here__](https://github.com/matlab-deep-learning/mtcnn-face-detection/releases/latest/download/MTCNN-Face-Detection.mltbx)
@@ -18,7 +18,7 @@ _Note: This code supports inference using a pretrained model. Training from scra
1818
## Installation
1919

2020
- Face Detection and Alignment MTCNN requires the following products:
21-
- MATLAB R2019b or later
21+
- MATLAB R2019a or later (_now works in R2019a and later!_)
2222
- Deep Learning Toolbox
2323
- Computer Vision Toolbox
2424
- Image Processing Toolbox
@@ -65,10 +65,6 @@ _Face detection from MTCNN in yellow, detections from the built in vision.Cascad
6565

6666
## Contribute
6767

68-
Please file any bug reports or feature requests as [GitHub issues](https://github.com/matlab-deep-learning/mtcnn-face-detection/issues). In particular comment on the following two issues if they interest you!
69-
70-
- [Support training MTCNN](https://github.com/matlab-deep-learning/mtcnn-face-detection/issues/1)
71-
- [Support MATLAB versions earlier than R2019b](https://github.com/matlab-deep-learning/mtcnn-face-detection/issues/2)
72-
68+
Please file any bug reports or feature requests as [GitHub issues](https://github.com/matlab-deep-learning/mtcnn-face-detection/issues). In particular if you'd be interested in training your own MTCNN network comment on the following issue: [Support training MTCNN](https://github.com/matlab-deep-learning/mtcnn-face-detection/issues/1)
7369

7470
_Copyright 2019 The MathWorks, Inc._
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
classdef DagNetworkStrategy < handle
2+
3+
properties (SetAccess=private)
4+
% Trained Dag networks
5+
Pnet
6+
Rnet
7+
Onet
8+
end
9+
10+
methods
11+
function obj = DagNetworkStrategy()
12+
end
13+
14+
function load(obj)
15+
% loadWeights Load the network weights from file.
16+
obj.Pnet = importdata(fullfile(mtcnnRoot(), "weights", "dagPNet.mat"));
17+
obj.Rnet = importdata(fullfile(mtcnnRoot(), "weights", "dagRNet.mat"));
18+
obj.Onet = importdata(fullfile(mtcnnRoot(), "weights", "dagONet.mat"));
19+
end
20+
21+
function pnet = getPNet(obj)
22+
pnet = obj.Pnet;
23+
end
24+
25+
function [probs, correction] = applyRNet(obj, im)
26+
output = obj.Rnet.predict(im);
27+
28+
probs = output(:,1:2);
29+
correction = output(:,3:end);
30+
end
31+
32+
function [probs, correction, landmarks] = applyONet(obj, im)
33+
output = obj.Onet.predict(im);
34+
35+
probs = output(:,1:2);
36+
correction = output(:,3:6);
37+
landmarks = output(:,7:end);
38+
end
39+
40+
end
41+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
classdef DlNetworkStrategy < handle
2+
3+
properties (SetAccess=private)
4+
UseGPU
5+
% Weights for the networks
6+
PnetWeights
7+
RnetWeights
8+
OnetWeights
9+
end
10+
11+
methods
12+
function obj = DlNetworkStrategy(useGpu)
13+
obj.UseGPU = useGpu;
14+
end
15+
16+
function load(obj)
17+
% loadWeights Load the network weights from file.
18+
obj.PnetWeights = load(fullfile(mtcnnRoot(), "weights", "pnet.mat"));
19+
obj.RnetWeights = load(fullfile(mtcnnRoot(), "weights", "rnet.mat"));
20+
obj.OnetWeights = load(fullfile(mtcnnRoot(), "weights", "onet.mat"));
21+
22+
if obj.UseGPU
23+
obj.PnetWeights = dlupdate(@gpuArray, obj.PnetWeights);
24+
obj.RnetWeights = dlupdate(@gpuArray, obj.RnetWeights);
25+
obj.OnetWeights = dlupdate(@gpuArray, obj.OnetWeights);
26+
end
27+
end
28+
29+
function pnet = getPNet(obj)
30+
pnet = obj.PnetWeights;
31+
end
32+
33+
function [probs, correction] = applyRNet(obj, im)
34+
im = dlarray(im, "SSCB");
35+
36+
[probs, correction] = mtcnn.rnet(im, obj.RnetWeights);
37+
38+
probs = extractdata(probs)';
39+
correction = extractdata(correction)';
40+
end
41+
42+
function [probs, correction, landmarks] = applyONet(obj, im)
43+
im = dlarray(im, "SSCB");
44+
45+
[probs, correction, landmarks] = mtcnn.onet(im, obj.OnetWeights);
46+
47+
probs = extractdata(probs)';
48+
correction = extractdata(correction)';
49+
landmarks = extractdata(landmarks)';
50+
end
51+
52+
end
53+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
classdef (Abstract) NetworkStrategy < handle
2+
methods
3+
load(obj)
4+
pnet = getPNet(obj)
5+
[probs, correction] = applyRNet(obj, im)
6+
[probs, correction, landmarks] = applyONet(obj, im)
7+
end
8+
end
+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
function net = convertToDagNet(stage)
2+
3+
warnId = "deep:functionToLayerGraph:Placeholder";
4+
warnState = warning('off', warnId);
5+
restoreWarn = onCleanup(@() warning(warnState));
6+
7+
switch stage
8+
case "p"
9+
inputSize = 12;
10+
nBlocks = 3;
11+
finalConnections = [sprintf("conv_%d", nBlocks), sprintf("prelu_%d", nBlocks)];
12+
catConnections = ["sm_1", "conv_5"];
13+
case "r"
14+
inputSize = 24;
15+
nBlocks = 4;
16+
finalConnections = [sprintf("prelu_%d", nBlocks-1), "fc_1";
17+
"fc_1", sprintf("prelu_%d", nBlocks)];
18+
catConnections = ["sm_1", "fc_3"];
19+
case "o"
20+
inputSize = 48;
21+
nBlocks = 5;
22+
finalConnections = ["fc_1", sprintf("prelu_%d", nBlocks)];
23+
catConnections = ["sm_1", "fc_3", "fc_4"];
24+
otherwise
25+
error("mtcnn:convertToDagNet:unknownStage", ...
26+
"Stage '%s' is not recognised", stage)
27+
end
28+
29+
matFilename = strcat(stage, "net.mat");
30+
weightsFile = load(fullfile(mtcnnRoot, "weights", matFilename));
31+
input = dlarray(zeros(inputSize, inputSize, 3, "single"), "SSCB");
32+
33+
switch stage
34+
case "p"
35+
netFunc = @(x) mtcnn.pnet(x, weightsFile);
36+
[a, b] = netFunc(input);
37+
output = cat(3, a, b);
38+
case "r"
39+
netFunc = @(x) mtcnn.rnet(x, weightsFile);
40+
[a, b] = netFunc(input);
41+
output = cat(1, a, b);
42+
case "o"
43+
netFunc = @(x) mtcnn.onet(x, weightsFile);
44+
[a, b, c] = netFunc(input);
45+
output = cat(1, a, b, c);
46+
end
47+
48+
lgraph = functionToLayerGraph(netFunc, input);
49+
placeholders = findPlaceholderLayers(lgraph);
50+
lgraph = removeLayers(lgraph, {placeholders.Name});
51+
52+
for iPrelu = 1:nBlocks
53+
name = sprintf("prelu_%d", iPrelu);
54+
weightName = sprintf("features_prelu%d_weight", iPrelu);
55+
if iPrelu ~= nBlocks
56+
weights = weightsFile.(weightName);
57+
else
58+
weights = reshape(weightsFile.(weightName), 1, 1, []);
59+
end
60+
prelu = mtcnn.util.preluLayer(weights, name);
61+
lgraph = replaceLayer(lgraph, sprintf("plus_%d", iPrelu), prelu, "ReconnectBy", "order");
62+
63+
if iPrelu ~= nBlocks
64+
lgraph = connectLayers(lgraph, sprintf("conv_%d", iPrelu), sprintf("prelu_%d", iPrelu));
65+
else
66+
% need to make different connections at the end of the
67+
% repeating blocks
68+
for iConnection = 1:size(finalConnections, 1)
69+
lgraph = connectLayers(lgraph, ...
70+
finalConnections(iConnection, 1), ...
71+
finalConnections(iConnection, 2));
72+
end
73+
74+
end
75+
end
76+
77+
lgraph = addLayers(lgraph, imageInputLayer([inputSize, inputSize, 3], ...
78+
"Name", "input", ...
79+
"Normalization", "none"));
80+
lgraph = connectLayers(lgraph, "input", "conv_1");
81+
82+
lgraph = addLayers(lgraph, concatenationLayer(3, numel(catConnections), "Name", "concat"));
83+
for iConnection = 1:numel(catConnections)
84+
lgraph = connectLayers(lgraph, ...
85+
catConnections(iConnection), ...
86+
sprintf("concat/in%d", iConnection));
87+
end
88+
lgraph = addLayers(lgraph, regressionLayer("Name", "output"));
89+
lgraph = connectLayers(lgraph, "concat", "output");
90+
91+
net = assembleNetwork(lgraph);
92+
result = net.predict(zeros(inputSize, inputSize, 3, "single"));
93+
94+
difference = extractdata(sum(output - result', "all"));
95+
96+
assert(difference < 1e-6, ...
97+
"mtcnn:convertToDagNet:outputMismatch", ...
98+
"Outputs of function and dag net do not match")
99+
end

code/mtcnn/+mtcnn/+util/preluLayer.m

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
classdef preluLayer < nnet.layer.Layer
2+
% Example custom PReLU layer.
3+
% Taken from "Define Custom Deep Learning Layer with Learnable
4+
% Parameters"
5+
6+
% Copyright 2020 The MathWorks, Inc.
7+
8+
properties (Learnable)
9+
% Scaling coefficient
10+
Alpha
11+
end
12+
13+
methods
14+
function layer = preluLayer(weights, name)
15+
% layer = preluLayer(numChannels, name) creates a PReLU layer
16+
% for 2-D image input with numChannels channels and specifies
17+
% the layer name.
18+
19+
layer.Name = name;
20+
layer.Alpha = weights;
21+
end
22+
23+
function Z = predict(layer, X)
24+
% Z = predict(layer, X) forwards the input data X through the
25+
% layer and outputs the result Z.
26+
Z = max(X,0) + layer.Alpha .* min(0,X);
27+
end
28+
29+
function [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~)
30+
dLdX = layer.Alpha .* dLdZ;
31+
dLdX(X>0) = dLdZ(X>0);
32+
dLdAlpha = min(0,X) .* dLdZ;
33+
dLdAlpha = sum(sum(dLdAlpha,1),2);
34+
35+
% Sum over all observations in mini-batch.
36+
dLdAlpha = sum(dLdAlpha,4);
37+
end
38+
end
39+
end

0 commit comments

Comments
 (0)