Skip to content

Commit bae09af

Browse files
authored
[luci/pass] Extend FuseInstanceNormPass with keras support (#16357)
This commit adds a new pattern of InstanceNorm fusing. It's based on the tf.keras.layers.GroupNormalization(groups=-1) layer from keras library in version 2.18.1. Signed-off-by: Mateusz Bencer <m.bencer@partner.samsung.com>
1 parent 0f04853 commit bae09af

File tree

1 file changed

+268
-13
lines changed

1 file changed

+268
-13
lines changed

compiler/luci/pass/src/FuseInstanceNormPass.cpp

Lines changed: 268 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,78 @@
2525

2626
#include <cassert>
2727
#include <set>
28+
#include <optional>
2829

2930
// Helper to check detail
3031

31-
/// @return true When node has shape of '1 x .. x 1 x depth'
32-
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
32+
#define CHECK_OR_FALSE(condition) \
33+
if (not(condition)) \
34+
return false;
35+
36+
/// @return true When node has shape with one dim other than `1` (like '1 x .. x 1 x depth' or '1
37+
/// x .. x depth' x 1)
38+
bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
3339
{
34-
auto rank = node->rank();
35-
uint32_t axis;
36-
for (axis = 0; axis < rank - 1; ++axis)
40+
const auto rank = node->rank();
41+
std::optional<uint32_t> depth_axis;
42+
for (uint32_t axis = 0; axis < rank; ++axis)
3743
{
3844
if (node->dim(axis).value() != 1)
39-
return false;
45+
{
46+
// only one axis can be other than 1
47+
if (depth_axis.has_value())
48+
{
49+
return false;
50+
}
51+
depth_axis = axis;
52+
}
53+
}
54+
if (!depth_axis.has_value())
55+
{
56+
return false;
57+
}
58+
return node->dim(depth_axis.value()).value() == depth;
59+
}
60+
61+
/// @return true if the provided begin_reshape Reshape op adds `1` dimension
62+
/// and terminal_reshape Reshape op removes it (the result is neutral for further
63+
/// processing)
64+
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape,
65+
luci::CircleReshape *terminal_reshape)
66+
{
67+
auto const begin_reshape_ifm = dynamic_cast<luci::CircleNode *>(begin_reshape->tensor());
68+
CHECK_OR_FALSE(begin_reshape_ifm);
69+
70+
// check last axis
71+
CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
72+
73+
// check unchanged part of begin_shape
74+
for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
75+
{
76+
// skip dynamic cases
77+
CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
78+
CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
79+
}
80+
// check last axis
81+
CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
82+
83+
auto const terminal_reshape_ifm = dynamic_cast<luci::CircleNode *>(terminal_reshape->tensor());
84+
CHECK_OR_FALSE(terminal_reshape_ifm);
85+
86+
CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
87+
88+
// check last axis
89+
CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
90+
91+
// check unchanged part of terminal_reshape
92+
for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
93+
{
94+
// skip dynamic cases
95+
CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
96+
CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
4097
}
41-
return node->dim(axis).value() == depth;
98+
99+
return true;
42100
}
43101

44102
bool is_instance_mean_v1(luci::CircleMean *mean)
@@ -370,6 +428,57 @@ namespace
370428
* |
371429
* V
372430
* [Out]
431+
*-------------------------------------------------------------------
432+
* Version_7
433+
* [In]
434+
* |
435+
* V
436+
* ifm
437+
* |
438+
* |
439+
* +-----------------------reshape_ifm -------------------------+
440+
* | | |
441+
* | (reduction indicies) | |
442+
* | | | |
443+
* V V | |
444+
* mean_of_ifm | |
445+
* | V |
446+
* +----------------------> sub_2 |
447+
* | | |
448+
* | V |
449+
* | square (reduction indicies) |
450+
* | | | |
451+
* | V | |
452+
* | mean_as_variance <----+ |
453+
* | | const_as_epsilon |
454+
* | | | |
455+
* | V | |
456+
* | add_as_variance <----------+ |
457+
* | | |
458+
* | | |
459+
* | V |
460+
* | rsqrt const_as_gamma |
461+
* | | | |
462+
* | | | |
463+
* neg_mean mul_gamma <------+ |
464+
* | | |
465+
* | | |
466+
* V | V
467+
* mul_as_scaled_mean <---------+-----------------------> mul_as_scaled_ifm
468+
* | |
469+
* | const_as_beta |
470+
* | | |
471+
* V V |
472+
* add_neg_mul |
473+
* | |
474+
* +---------------> add_as_terminal <------------------------+
475+
* |
476+
* |
477+
* V
478+
* reshape_as_terminal
479+
* |
480+
* V
481+
* [Out]
373482
*/
374483
class InstanceNormPattern final
375484
{
@@ -383,6 +492,7 @@ class InstanceNormPattern final
383492
Version_4,
384493
Version_5,
385494
Version_6, // For only 3D I/O
495+
Version_7,
386496
};
387497

388498
InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
@@ -399,6 +509,13 @@ class InstanceNormPattern final
399509
_pv = pv;
400510
}
401511

512+
InstanceNormPattern(luci::CircleReshape *candidate, PatternVersion pv)
513+
{
514+
assert(candidate);
515+
reshape_as_terminal = candidate;
516+
_pv = pv;
517+
}
518+
402519
private:
403520
bool condition_common_1_5(uint32_t ifm_channel_depth);
404521
bool condition_common_3_4();
@@ -424,6 +541,7 @@ class InstanceNormPattern final
424541
luci::CircleMean *mean_as_variance = nullptr;
425542
luci::CircleConst *const_as_epsilon = nullptr;
426543
luci::CircleAdd *add_as_variance = nullptr;
544+
luci::CircleAdd *add_neg_mul = nullptr;
427545
luci::CircleRsqrt *rsqrt = nullptr;
428546
luci::CircleConst *const_as_gamma = nullptr;
429547
luci::CircleMul *mul_gamma = nullptr;
@@ -437,16 +555,15 @@ class InstanceNormPattern final
437555
luci::CirclePow *pow = nullptr;
438556
luci::CircleSqrt *sqrt = nullptr;
439557
luci::CircleDiv *div = nullptr;
558+
luci::CircleConst *reshape_terminal_target_shape = nullptr;
559+
luci::CircleReshape *reshape_as_terminal = nullptr;
560+
luci::CircleNeg *neg_mean = nullptr;
440561

441562
private:
442563
bool _matched = false;
443564
PatternVersion _pv;
444565
};
445566

446-
#define CHECK_OR_FALSE(condition) \
447-
if (not(condition)) \
448-
return false;
449-
450567
bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
451568
{
452569
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
@@ -472,7 +589,7 @@ bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
472589

473590
const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
474591
CHECK_OR_FALSE(const_as_beta);
475-
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
592+
CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
476593

477594
return true;
478595
}
@@ -534,7 +651,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
534651

535652
CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
536653

537-
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
654+
CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
538655

539656
CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
540657

@@ -751,6 +868,83 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
751868
return true;
752869
}
753870

871+
template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
872+
{
873+
add_as_terminal = dynamic_cast<luci::CircleAdd *>(reshape_as_terminal->tensor());
874+
CHECK_OR_FALSE(add_as_terminal);
875+
876+
CHECK_OR_FALSE(
877+
luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
878+
CHECK_OR_FALSE(
879+
luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
880+
881+
mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(add_neg_mul->x());
882+
CHECK_OR_FALSE(mul_as_scaled_mean);
883+
884+
neg_mean = dynamic_cast<luci::CircleNeg *>(mul_as_scaled_mean->x());
885+
CHECK_OR_FALSE(neg_mean);
886+
887+
luci::CircleMul *mul_gamma_should_be = nullptr;
888+
luci::CircleNeg *neg_should_be = nullptr;
889+
890+
CHECK_OR_FALSE(
891+
luci::fill(&mul_gamma_should_be, &neg_should_be).with_commutative_args_of(mul_as_scaled_mean));
892+
893+
CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
894+
CHECK_OR_FALSE(neg_mean == neg_should_be);
895+
896+
mean_of_ifm = dynamic_cast<luci::CircleMean *>(neg_mean->x());
897+
CHECK_OR_FALSE(mean_of_ifm);
898+
899+
luci::CircleReshape *reshape_of_ifm_should_be = nullptr;
900+
reshape_of_ifm_should_be = dynamic_cast<luci::CircleReshape *>(mean_of_ifm->input());
901+
CHECK_OR_FALSE(reshape_of_ifm_should_be == reshape_of_ifm);
902+
903+
ifm = reshape_of_ifm->tensor();
904+
auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
905+
CHECK_OR_FALSE(ifm_circle);
906+
907+
CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
908+
CHECK_OR_FALSE(ifm_circle->rank() == 4);
909+
CHECK_OR_FALSE(ifm_circle->dim(3).known());
910+
uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
911+
912+
const_as_beta = dynamic_cast<luci::CircleConst *>(add_neg_mul->y());
913+
CHECK_OR_FALSE(const_as_beta);
914+
CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
915+
916+
CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
917+
CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
918+
919+
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
920+
CHECK_OR_FALSE(add_as_variance);
921+
922+
CHECK_OR_FALSE(
923+
luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
924+
CHECK_OR_FALSE(mean_as_variance);
925+
926+
CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
927+
// TODO Support regarding broadcast
928+
CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
929+
930+
square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
931+
CHECK_OR_FALSE(square);
932+
933+
sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
934+
CHECK_OR_FALSE(sub_2);
935+
936+
auto mean_of_ifm_should_be = dynamic_cast<luci::CircleMean *>(sub_2->y());
937+
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
938+
939+
auto reshape_of_ifm_should_be_2 = dynamic_cast<luci::CircleReshape *>(sub_2->x());
940+
CHECK_OR_FALSE(reshape_of_ifm_should_be_2 == reshape_of_ifm);
941+
942+
CHECK_OR_FALSE(is_unsqueeze_squeeze_pair(reshape_of_ifm, reshape_as_terminal));
943+
944+
_matched = true;
945+
return true;
946+
}
947+
754948
bool InstanceNormPattern::matched()
755949
{
756950
if (_matched)
@@ -772,6 +966,8 @@ bool InstanceNormPattern::matched()
772966
return match<PatternVersion::Version_5>();
773967
case PatternVersion::Version_6:
774968
return match<PatternVersion::Version_6>();
969+
case PatternVersion::Version_7:
970+
return match<PatternVersion::Version_7>();
775971

776972
default:
777973
break;
@@ -1006,6 +1202,36 @@ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Ve
10061202
replace(_p.add_as_terminal).with(instance_norm);
10071203
}
10081204

1205+
template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1206+
{
1207+
auto graph = _p.reshape_as_terminal->graph();
1208+
1209+
reshape_gamma_beta();
1210+
1211+
auto instance_norm = create_inst_norm(graph);
1212+
1213+
// set origin
1214+
std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1215+
luci::get_origin(_p.reshape_of_ifm),
1216+
luci::get_origin(_p.mean_of_ifm),
1217+
luci::get_origin(_p.sub_2),
1218+
luci::get_origin(_p.square),
1219+
luci::get_origin(_p.mean_as_variance),
1220+
luci::get_origin(_p.add_as_variance),
1221+
luci::get_origin(_p.rsqrt),
1222+
luci::get_origin(_p.mul_gamma),
1223+
luci::get_origin(_p.neg_mean),
1224+
luci::get_origin(_p.mul_as_scaled_ifm),
1225+
luci::get_origin(_p.mul_as_scaled_mean),
1226+
luci::get_origin(_p.add_neg_mul),
1227+
luci::get_origin(_p.add_as_terminal),
1228+
luci::get_origin(_p.reshape_as_terminal)};
1229+
1230+
luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1231+
1232+
replace(_p.reshape_as_terminal).with(instance_norm);
1233+
}
1234+
10091235
void FuseInstanceNorm::apply()
10101236
{
10111237
assert(_p.matched());
@@ -1030,6 +1256,9 @@ void FuseInstanceNorm::apply()
10301256
case InstanceNormPattern::PatternVersion::Version_6:
10311257
apply<InstanceNormPattern::PatternVersion::Version_6>();
10321258
break;
1259+
case InstanceNormPattern::PatternVersion::Version_7:
1260+
apply<InstanceNormPattern::PatternVersion::Version_7>();
1261+
break;
10331262

10341263
default:
10351264
break;
@@ -1256,6 +1485,21 @@ bool fuse_instance_norm(luci::CircleDiv *div)
12561485
return false;
12571486
}
12581487

1488+
bool fuse_instance_norm(luci::CircleReshape *reshape)
1489+
{
1490+
InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1491+
1492+
InstanceNormPattern pattern(reshape, pv);
1493+
if (pattern.matched())
1494+
{
1495+
FuseInstanceNorm fuse(pattern);
1496+
fuse.apply();
1497+
return true;
1498+
}
1499+
1500+
return false;
1501+
}
1502+
12591503
bool post_fusion(luci::CircleInstanceNorm *inst_norm)
12601504
{
12611505
PostFusion postfusion(inst_norm);
@@ -1294,6 +1538,17 @@ bool FuseInstanceNormPass::run(loco::Graph *g)
12941538
changed = true;
12951539
}
12961540

1541+
// Check Version_7(from Reshape) if other versions not found
1542+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
1543+
{
1544+
auto reshape = dynamic_cast<luci::CircleReshape *>(node);
1545+
if (not reshape)
1546+
continue;
1547+
1548+
if (fuse_instance_norm(reshape))
1549+
changed = true;
1550+
}
1551+
12971552
// Post processing of FuseInstanceNorm
12981553
for (auto node : loco::active_nodes(loco::output_nodes(g)))
12991554
{

0 commit comments

Comments
 (0)