Skip to content

Commit a442673

Browse files
committed
[luci/pass] Extend FuseInstanceNormPass on keras version
This commit adds a new pattern of InstanceNorm fusing. It's based on the tf.keras.layers.GroupNormalization(groups=-1) layer from keras library. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer m.bencer@partner.samsung.com
1 parent bbd17bd commit a442673

4 files changed

Lines changed: 779 additions & 10 deletions

File tree

compiler/circle2circle-dredd-recipe-test/test.lst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Add(Net_InstanceNorm_005 PASS fuse_instnorm)
6464
Add(Net_InstanceNorm_006 PASS fuse_instnorm)
6565
Add(Net_InstanceNorm_007 PASS fuse_instnorm)
6666
Add(Net_InstanceNorm_008 PASS fuse_instnorm)
67+
Add(Net_InstanceNorm_009 PASS fuse_instnorm)
6768
Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6)
6869
Add(Net_Mul_Add_000 PASS remove_unnecessary_add)
6970
Add(Net_Mul_Add_001 PASS remove_unnecessary_add)

compiler/luci/pass/src/FuseInstanceNormPass.cpp

Lines changed: 264 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,79 @@
2525

2626
#include <cassert>
2727
#include <set>
28+
#include <optional>
2829

2930
// Helper to check detail
3031

31-
/// @return true When node has shape of '1 x .. x 1 x depth'
32+
#define CHECK_OR_FALSE(condition) \
33+
if (not(condition)) \
34+
return false;
35+
36+
/// @return true When node has shape with one dim other than `1` (like '1 x .. x 1 x depth' or '1 x .. x depth' x 1)
3237
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
3338
{
34-
auto rank = node->rank();
35-
uint32_t axis;
36-
for (axis = 0; axis < rank - 1; ++axis)
39+
const auto rank = node->rank();
40+
std::optional<uint32_t> depth_axis;
41+
for (uint32_t axis = 0; axis < rank; ++axis)
3742
{
3843
if (node->dim(axis).value() != 1)
39-
return false;
44+
{
45+
// only one axis can be other than 1
46+
if (depth_axis.has_value())
47+
{
48+
return false;
49+
}
50+
depth_axis = axis;
51+
}
52+
}
53+
if (!depth_axis.has_value())
54+
{
55+
return false;
4056
}
41-
return node->dim(axis).value() == depth;
57+
return node->dim(depth_axis.value()).value() == depth;
58+
}
59+
60+
/// @return true if the provided begin_reshape Reshape op adds `1` dimension
61+
/// and terminal_reshape Reshape op removes it (the result is neutral for further processing)
62+
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
63+
{
64+
auto const begin_ifm = dynamic_cast<luci::CircleNode *>(begin_reshape->tensor());
65+
CHECK_OR_FALSE(begin_ifm);
66+
auto const begin_ofm = loco::must_cast<luci::CircleNode *>(begin_reshape);
67+
CHECK_OR_FALSE(begin_ofm);
68+
69+
// check last axis
70+
CHECK_OR_FALSE((begin_ifm->rank() + 1) == begin_ofm->rank());
71+
72+
// check unchanged part of begin_shape
73+
for(uint32_t axis=0;axis<begin_ifm->rank();++axis)
74+
{
75+
// skip dynamic cases
76+
CHECK_OR_FALSE(begin_ifm->dim(axis).known() && begin_ofm->dim(axis).known());
77+
CHECK_OR_FALSE(begin_ifm->dim(axis).value() == begin_ofm->dim(axis).value());
78+
}
79+
// check last axis
80+
CHECK_OR_FALSE(begin_ofm->dim(begin_ofm->rank()-1) == 1);
81+
82+
auto const terminal_ifm = dynamic_cast<luci::CircleNode *>(terminal_reshape->tensor());
83+
CHECK_OR_FALSE(terminal_ifm);
84+
auto const terminal_ofm = loco::must_cast<luci::CircleNode *>(terminal_reshape);
85+
CHECK_OR_FALSE(terminal_ofm);
86+
87+
CHECK_OR_FALSE(terminal_ifm->rank() == terminal_ofm->rank() + 1);
88+
89+
// check last axis
90+
CHECK_OR_FALSE(terminal_ifm->dim(begin_ofm->rank()-1) == 1);
91+
92+
// check unchanged part of terminal_reshape
93+
for(uint32_t axis=0;axis<terminal_ofm->rank();++axis)
94+
{
95+
// skip dynamic cases
96+
CHECK_OR_FALSE(terminal_ifm->dim(axis).known() && terminal_ofm->dim(axis).known());
97+
CHECK_OR_FALSE(terminal_ifm->dim(axis).value() == terminal_ofm->dim(axis).value());
98+
}
99+
100+
return true;
42101
}
43102

44103
bool is_instance_mean_v1(luci::CircleMean *mean)
@@ -370,6 +429,57 @@ namespace
370429
* |
371430
* V
372431
* [Out]
432+
*-------------------------------------------------------------------
433+
* Version_7
434+
* [In]
435+
* |
436+
* V
437+
* ifm
438+
* |
439+
* |
440+
* +-----------------------reshape_ifm -------------------------+
441+
* | | |
442+
* | (reduction indicies) | |
443+
* | | | |
444+
* V V | |
445+
* mean_of_ifm | |
446+
* | V |
447+
* +----------------------> sub_2 |
448+
* | | |
449+
* | V |
450+
* | square (reduction indicies) |
451+
* | | | |
452+
* | V | |
453+
* | mean_as_variance <----+ |
454+
* | | const_as_epsilon |
455+
* | | | |
456+
* | V | |
457+
* | add_as_variance <----------+ |
458+
* | | |
459+
* | | |
460+
* | V |
461+
* | rsqrt const_as_gamma |
462+
* | | | |
463+
* | | | |
464+
* neg_mean mul_gamma <------+ |
465+
* | | |
466+
* | | |
467+
* V | V
468+
* mul_as_scaled_mean <---------+-----------------------> mul_as_scaled_ifm
469+
* | |
470+
* | const_as_beta |
471+
* | | |
472+
* V V |
473+
* add_neg_mul |
474+
* | |
475+
* +---------------> add_as_terminal <------------------------+
476+
* |
477+
* |
478+
* V
479+
* reshape_as_terminal
480+
* |
481+
* V
482+
* [Out]
373483
*/
374484
class InstanceNormPattern final
375485
{
@@ -383,6 +493,7 @@ class InstanceNormPattern final
383493
Version_4,
384494
Version_5,
385495
Version_6, // For only 3D I/O
496+
Version_7,
386497
};
387498

388499
InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
@@ -399,6 +510,13 @@ class InstanceNormPattern final
399510
_pv = pv;
400511
}
401512

513+
InstanceNormPattern(luci::CircleReshape *candidate, PatternVersion pv)
514+
{
515+
assert(candidate);
516+
reshape_as_terminal = candidate;
517+
_pv = pv;
518+
}
519+
402520
private:
403521
bool condition_common_1_5(uint32_t ifm_channel_depth);
404522
bool condition_common_3_4();
@@ -424,6 +542,7 @@ class InstanceNormPattern final
424542
luci::CircleMean *mean_as_variance = nullptr;
425543
luci::CircleConst *const_as_epsilon = nullptr;
426544
luci::CircleAdd *add_as_variance = nullptr;
545+
luci::CircleAdd *add_neg_mul = nullptr;
427546
luci::CircleRsqrt *rsqrt = nullptr;
428547
luci::CircleConst *const_as_gamma = nullptr;
429548
luci::CircleMul *mul_gamma = nullptr;
@@ -437,16 +556,15 @@ class InstanceNormPattern final
437556
luci::CirclePow *pow = nullptr;
438557
luci::CircleSqrt *sqrt = nullptr;
439558
luci::CircleDiv *div = nullptr;
559+
luci::CircleConst *reshape_terminal_target_shape = nullptr;
560+
luci::CircleReshape *reshape_as_terminal = nullptr;
561+
luci::CircleNeg *neg_mean = nullptr;
440562

441563
private:
442564
bool _matched = false;
443565
PatternVersion _pv;
444566
};
445567

446-
#define CHECK_OR_FALSE(condition) \
447-
if (not(condition)) \
448-
return false;
449-
450568
bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
451569
{
452570
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
@@ -751,6 +869,81 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
751869
return true;
752870
}
753871

872+
template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
873+
{
874+
add_as_terminal = dynamic_cast<luci::CircleAdd *>(reshape_as_terminal->tensor());
875+
CHECK_OR_FALSE(add_as_terminal);
876+
877+
CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
878+
CHECK_OR_FALSE(luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
879+
880+
mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(add_neg_mul->x());
881+
CHECK_OR_FALSE(mul_as_scaled_mean);
882+
883+
neg_mean = dynamic_cast<luci::CircleNeg *>(mul_as_scaled_mean->x());
884+
CHECK_OR_FALSE(neg_mean);
885+
886+
luci::CircleMul *mul_gamma_should_be = nullptr;
887+
luci::CircleNeg *neg_should_be = nullptr;
888+
889+
CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &neg_should_be)
890+
.with_commutative_args_of(mul_as_scaled_mean));
891+
892+
893+
CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
894+
CHECK_OR_FALSE(neg_mean == neg_should_be);
895+
896+
mean_of_ifm = dynamic_cast<luci::CircleMean *>(neg_mean->x());
897+
CHECK_OR_FALSE(mean_of_ifm);
898+
899+
luci::CircleReshape *reshape_of_ifm_should_be = nullptr;
900+
reshape_of_ifm_should_be = dynamic_cast<luci::CircleReshape *>(mean_of_ifm->input());
901+
CHECK_OR_FALSE(reshape_of_ifm_should_be == reshape_of_ifm);
902+
903+
ifm = reshape_of_ifm->tensor();
904+
auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
905+
CHECK_OR_FALSE(ifm_circle);
906+
907+
CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
908+
CHECK_OR_FALSE(ifm_circle->rank() == 4);
909+
CHECK_OR_FALSE(ifm_circle->dim(3).known());
910+
uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
911+
912+
const_as_beta = dynamic_cast<luci::CircleConst *>(add_neg_mul->y());
913+
CHECK_OR_FALSE(const_as_beta);
914+
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
915+
916+
CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
917+
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
918+
919+
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
920+
CHECK_OR_FALSE(add_as_variance);
921+
922+
CHECK_OR_FALSE(luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
923+
CHECK_OR_FALSE(mean_as_variance);
924+
925+
CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
926+
// TODO Support regarding broadcast
927+
CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
928+
929+
square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
930+
CHECK_OR_FALSE(square);
931+
932+
sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
933+
CHECK_OR_FALSE(sub_2);
934+
935+
auto mean_of_ifm_should_be = dynamic_cast<luci::CircleMean *>(sub_2->y());
936+
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
937+
938+
auto reshape_of_ifm_should_be_2 = dynamic_cast<luci::CircleReshape *>(sub_2->x());
939+
CHECK_OR_FALSE(reshape_of_ifm_should_be_2 == reshape_of_ifm);
940+
941+
CHECK_OR_FALSE(is_unsqueeze_squeeze_pair(reshape_of_ifm, reshape_as_terminal));
942+
943+
_matched = true;
944+
return true;
945+
}
946+
754947
bool InstanceNormPattern::matched()
755948
{
756949
if (_matched)
@@ -772,6 +965,8 @@ bool InstanceNormPattern::matched()
772965
return match<PatternVersion::Version_5>();
773966
case PatternVersion::Version_6:
774967
return match<PatternVersion::Version_6>();
968+
case PatternVersion::Version_7:
969+
return match<PatternVersion::Version_7>();
775970

776971
default:
777972
break;
@@ -1006,6 +1201,36 @@ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Ve
10061201
replace(_p.add_as_terminal).with(instance_norm);
10071202
}
10081203

1204+
template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1205+
{
1206+
auto graph = _p.reshape_as_terminal->graph();
1207+
1208+
reshape_gamma_beta();
1209+
1210+
auto instance_norm = create_inst_norm(graph);
1211+
1212+
// set origin
1213+
std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1214+
luci::get_origin(_p.reshape_of_ifm),
1215+
luci::get_origin(_p.mean_of_ifm),
1216+
luci::get_origin(_p.sub_2),
1217+
luci::get_origin(_p.square),
1218+
luci::get_origin(_p.mean_as_variance),
1219+
luci::get_origin(_p.add_as_variance),
1220+
luci::get_origin(_p.rsqrt),
1221+
luci::get_origin(_p.mul_gamma),
1222+
luci::get_origin(_p.neg_mean),
1223+
luci::get_origin(_p.mul_as_scaled_ifm),
1224+
luci::get_origin(_p.mul_as_scaled_mean),
1225+
luci::get_origin(_p.add_neg_mul),
1226+
luci::get_origin(_p.add_as_terminal),
1227+
luci::get_origin(_p.reshape_as_terminal)};
1228+
1229+
luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1230+
1231+
replace(_p.reshape_as_terminal).with(instance_norm);
1232+
}
1233+
10091234
void FuseInstanceNorm::apply()
10101235
{
10111236
assert(_p.matched());
@@ -1030,6 +1255,9 @@ void FuseInstanceNorm::apply()
10301255
case InstanceNormPattern::PatternVersion::Version_6:
10311256
apply<InstanceNormPattern::PatternVersion::Version_6>();
10321257
break;
1258+
case InstanceNormPattern::PatternVersion::Version_7:
1259+
apply<InstanceNormPattern::PatternVersion::Version_7>();
1260+
break;
10331261

10341262
default:
10351263
break;
@@ -1256,6 +1484,21 @@ bool fuse_instance_norm(luci::CircleDiv *div)
12561484
return false;
12571485
}
12581486

1487+
bool fuse_instance_norm(luci::CircleReshape *reshape)
1488+
{
1489+
InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1490+
1491+
InstanceNormPattern pattern(reshape, pv);
1492+
if (pattern.matched())
1493+
{
1494+
FuseInstanceNorm fuse(pattern);
1495+
fuse.apply();
1496+
return true;
1497+
}
1498+
1499+
return false;
1500+
}
1501+
12591502
bool post_fusion(luci::CircleInstanceNorm *inst_norm)
12601503
{
12611504
PostFusion postfusion(inst_norm);
@@ -1294,6 +1537,17 @@ bool FuseInstanceNormPass::run(loco::Graph *g)
12941537
changed = true;
12951538
}
12961539

1540+
// Check Version_7(from Reshape) if other versions not found
1541+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
1542+
{
1543+
auto reshape = dynamic_cast<luci::CircleReshape *>(node);
1544+
if (not reshape)
1545+
continue;
1546+
1547+
if (fuse_instance_norm(reshape))
1548+
changed = true;
1549+
}
1550+
12971551
// Post processing of FuseInstanceNorm
12981552
for (auto node : loco::active_nodes(loco::output_nodes(g)))
12991553
{

0 commit comments

Comments
 (0)