2525
2626#include < cassert>
2727#include < set>
28+ #include < optional>
2829
2930// Helper to check detail
3031
31- // / @return true When node has shape of '1 x .. x 1 x depth'
32+ #define CHECK_OR_FALSE (condition ) \
33+ if (not (condition)) \
34+ return false ;
35+
36+ // / @return true When node has shape with one dim other than `1` (like '1 x .. x 1 x depth' or '1 x .. x depth' x 1)
3237bool is_1D_with_dummy_dim (luci::CircleConst *node, uint32_t depth)
3338{
34- auto rank = node->rank ();
35- uint32_t axis ;
36- for (axis = 0 ; axis < rank - 1 ; ++axis)
39+ const auto rank = node->rank ();
40+ std::optional< uint32_t > depth_axis ;
41+ for (uint32_t axis = 0 ; axis < rank; ++axis)
3742 {
3843 if (node->dim (axis).value () != 1 )
39- return false ;
44+ {
45+ // only one axis can be other than 1
46+ if (depth_axis.has_value ())
47+ {
48+ return false ;
49+ }
50+ depth_axis = axis;
51+ }
52+ }
53+ if (!depth_axis.has_value ())
54+ {
55+ return false ;
4056 }
41- return node->dim (axis).value () == depth;
57+ return node->dim (depth_axis.value ()).value () == depth;
58+ }
59+
60+ // / @return true if the provided begin_reshape Reshape op adds `1` dimension
61+ // / and terminal_reshape Reshape op removes it (the result is neutral for further processing)
62+ bool is_unsqueeze_squeeze_pair (luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
63+ {
64+ auto const begin_ifm = dynamic_cast <luci::CircleNode *>(begin_reshape->tensor ());
65+ CHECK_OR_FALSE (begin_ifm);
66+ auto const begin_ofm = loco::must_cast<luci::CircleNode *>(begin_reshape);
67+ CHECK_OR_FALSE (begin_ofm);
68+
69+ // check last axis
70+ CHECK_OR_FALSE ((begin_ifm->rank () + 1 ) == begin_ofm->rank ());
71+
72+ // check unchanged part of begin_shape
73+ for (uint32_t axis=0 ;axis<begin_ifm->rank ();++axis)
74+ {
75+ // skip dynamic cases
76+ CHECK_OR_FALSE (begin_ifm->dim (axis).known () && begin_ofm->dim (axis).known ());
77+ CHECK_OR_FALSE (begin_ifm->dim (axis).value () == begin_ofm->dim (axis).value ());
78+ }
79+ // check last axis
80+ CHECK_OR_FALSE (begin_ofm->dim (begin_ofm->rank ()-1 ) == 1 );
81+
82+ auto const terminal_ifm = dynamic_cast <luci::CircleNode *>(terminal_reshape->tensor ());
83+ CHECK_OR_FALSE (terminal_ifm);
84+ auto const terminal_ofm = loco::must_cast<luci::CircleNode *>(terminal_reshape);
85+ CHECK_OR_FALSE (terminal_ofm);
86+
87+ CHECK_OR_FALSE (terminal_ifm->rank () == terminal_ofm->rank () + 1 );
88+
89+ // check last axis
90+ CHECK_OR_FALSE (terminal_ifm->dim (begin_ofm->rank ()-1 ) == 1 );
91+
92+ // check unchanged part of terminal_reshape
93+ for (uint32_t axis=0 ;axis<terminal_ofm->rank ();++axis)
94+ {
95+ // skip dynamic cases
96+ CHECK_OR_FALSE (terminal_ifm->dim (axis).known () && terminal_ofm->dim (axis).known ());
97+ CHECK_OR_FALSE (terminal_ifm->dim (axis).value () == terminal_ofm->dim (axis).value ());
98+ }
99+
100+ return true ;
42101}
43102
44103bool is_instance_mean_v1 (luci::CircleMean *mean)
@@ -370,6 +429,57 @@ namespace
370429 * |
371430 * V
372431 * [Out]
432+ *-------------------------------------------------------------------
433+ * Version_7
434+ * [In]
435+ * |
436+ * V
437+ * ifm
438+ * |
439+ * |
440+ * +-----------------------reshape_ifm -------------------------+
441+ * | | |
442+ * | (reduction indicies) | |
443+ * | | | |
444+ * V V | |
445+ * mean_of_ifm | |
446+ * | V |
447+ * +----------------------> sub_2 |
448+ * | | |
449+ * | V |
450+ * | square (reduction indicies) |
451+ * | | | |
452+ * | V | |
453+ * | mean_as_variance <----+ |
454+ * | | const_as_epsilon |
455+ * | | | |
456+ * | V | |
457+ * | add_as_variance <----------+ |
458+ * | | |
459+ * | | |
460+ * | V |
461+ * | rsqrt const_as_gamma |
462+ * | | | |
463+ * | | | |
464+ * neg_mean mul_gamma <------+ |
465+ * | | |
466+ * | | |
467+ * V | V
468+ * mul_as_scaled_mean <---------+-----------------------> mul_as_scaled_ifm
469+ * | |
470+ * | const_as_beta |
471+ * | | |
472+ * V V |
473+ * add_neg_mul |
474+ * | |
475+ * +---------------> add_as_terminal <------------------------+
476+ * |
477+ * |
478+ * V
479+ * reshape_as_terminal
480+ * |
481+ * V
482+ * [Out]
373483 */
374484class InstanceNormPattern final
375485{
@@ -383,6 +493,7 @@ class InstanceNormPattern final
383493 Version_4,
384494 Version_5,
385495 Version_6, // For only 3D I/O
496+ Version_7,
386497 };
387498
388499 InstanceNormPattern (luci::CircleAdd *candidate, PatternVersion pv)
@@ -399,6 +510,13 @@ class InstanceNormPattern final
399510 _pv = pv;
400511 }
401512
513+ InstanceNormPattern (luci::CircleReshape *candidate, PatternVersion pv)
514+ {
515+ assert (candidate);
516+ reshape_as_terminal = candidate;
517+ _pv = pv;
518+ }
519+
402520private:
403521 bool condition_common_1_5 (uint32_t ifm_channel_depth);
404522 bool condition_common_3_4 ();
@@ -424,6 +542,7 @@ class InstanceNormPattern final
424542 luci::CircleMean *mean_as_variance = nullptr ;
425543 luci::CircleConst *const_as_epsilon = nullptr ;
426544 luci::CircleAdd *add_as_variance = nullptr ;
545+ luci::CircleAdd *add_neg_mul = nullptr ;
427546 luci::CircleRsqrt *rsqrt = nullptr ;
428547 luci::CircleConst *const_as_gamma = nullptr ;
429548 luci::CircleMul *mul_gamma = nullptr ;
@@ -437,16 +556,15 @@ class InstanceNormPattern final
437556 luci::CirclePow *pow = nullptr ;
438557 luci::CircleSqrt *sqrt = nullptr ;
439558 luci::CircleDiv *div = nullptr ;
559+ luci::CircleConst *reshape_terminal_target_shape = nullptr ;
560+ luci::CircleReshape *reshape_as_terminal = nullptr ;
561+ luci::CircleNeg *neg_mean = nullptr ;
440562
441563private:
442564 bool _matched = false ;
443565 PatternVersion _pv;
444566};
445567
446- #define CHECK_OR_FALSE (condition ) \
447- if (not (condition)) \
448- return false ;
449-
450568bool InstanceNormPattern::condition_common_1_5 (uint32_t ifm_channel_depth)
451569{
452570 add_as_variance = dynamic_cast <luci::CircleAdd *>(rsqrt->x ());
@@ -751,6 +869,81 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
751869 return true ;
752870}
753871
872+ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
873+ {
874+ add_as_terminal = dynamic_cast <luci::CircleAdd *>(reshape_as_terminal->tensor ());
875+ CHECK_OR_FALSE (add_as_terminal);
876+
877+ CHECK_OR_FALSE (luci::fill (&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of (add_as_terminal));
878+ CHECK_OR_FALSE (luci::fill (&reshape_of_ifm, &mul_gamma).with_commutative_args_of (mul_as_scaled_ifm));
879+
880+ mul_as_scaled_mean = dynamic_cast <luci::CircleMul *>(add_neg_mul->x ());
881+ CHECK_OR_FALSE (mul_as_scaled_mean);
882+
883+ neg_mean = dynamic_cast <luci::CircleNeg *>(mul_as_scaled_mean->x ());
884+ CHECK_OR_FALSE (neg_mean);
885+
886+ luci::CircleMul *mul_gamma_should_be = nullptr ;
887+ luci::CircleNeg *neg_should_be = nullptr ;
888+
889+ CHECK_OR_FALSE (luci::fill (&mul_gamma_should_be, &neg_should_be)
890+ .with_commutative_args_of (mul_as_scaled_mean));
891+
892+
893+ CHECK_OR_FALSE (mul_gamma == mul_gamma_should_be);
894+ CHECK_OR_FALSE (neg_mean == neg_should_be);
895+
896+ mean_of_ifm = dynamic_cast <luci::CircleMean *>(neg_mean->x ());
897+ CHECK_OR_FALSE (mean_of_ifm);
898+
899+ luci::CircleReshape *reshape_of_ifm_should_be = nullptr ;
900+ reshape_of_ifm_should_be = dynamic_cast <luci::CircleReshape *>(mean_of_ifm->input ());
901+ CHECK_OR_FALSE (reshape_of_ifm_should_be == reshape_of_ifm);
902+
903+ ifm = reshape_of_ifm->tensor ();
904+ auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
905+ CHECK_OR_FALSE (ifm_circle);
906+
907+ CHECK_OR_FALSE (ifm_circle->shape_status () == luci::ShapeStatus::VALID);
908+ CHECK_OR_FALSE (ifm_circle->rank () == 4 );
909+ CHECK_OR_FALSE (ifm_circle->dim (3 ).known ());
910+ uint32_t ifm_channel_depth = ifm_circle->dim (3 ).value ();
911+
912+ const_as_beta = dynamic_cast <luci::CircleConst *>(add_neg_mul->y ());
913+ CHECK_OR_FALSE (const_as_beta);
914+ CHECK_OR_FALSE (is_1D_with_dummy_dim (const_as_beta, ifm_channel_depth));
915+
916+ CHECK_OR_FALSE (luci::fill (&rsqrt, &const_as_gamma).with_commutative_args_of (mul_gamma));
917+ CHECK_OR_FALSE (is_1D_with_dummy_dim (const_as_gamma, ifm_channel_depth));
918+
919+ add_as_variance = dynamic_cast <luci::CircleAdd *>(rsqrt->x ());
920+ CHECK_OR_FALSE (add_as_variance);
921+
922+ CHECK_OR_FALSE (luci::fill (&mean_as_variance, &const_as_epsilon).with_commutative_args_of (add_as_variance));
923+ CHECK_OR_FALSE (mean_as_variance);
924+
925+ CHECK_OR_FALSE (const_as_epsilon->dtype () == loco::DataType::FLOAT32);
926+ // TODO Support regarding broadcast
927+ CHECK_OR_FALSE (const_as_epsilon->size <loco::DataType::FLOAT32>() == 1 );
928+
929+ square = dynamic_cast <luci::CircleSquare *>(mean_as_variance->input ());
930+ CHECK_OR_FALSE (square);
931+
932+ sub_2 = dynamic_cast <luci::CircleSub *>(square->x ());
933+ CHECK_OR_FALSE (sub_2);
934+
935+ auto mean_of_ifm_should_be = dynamic_cast <luci::CircleMean *>(sub_2->y ());
936+ CHECK_OR_FALSE (mean_of_ifm == mean_of_ifm_should_be);
937+
938+ auto reshape_of_ifm_should_be_2 = dynamic_cast <luci::CircleReshape *>(sub_2->x ());
939+ CHECK_OR_FALSE (reshape_of_ifm_should_be_2 == reshape_of_ifm);
940+
941+ CHECK_OR_FALSE (is_unsqueeze_squeeze_pair (reshape_of_ifm, reshape_as_terminal));
942+
943+ _matched = true ;
944+ return true ;
945+ }
946+
754947bool InstanceNormPattern::matched ()
755948{
756949 if (_matched)
@@ -772,6 +965,8 @@ bool InstanceNormPattern::matched()
772965 return match<PatternVersion::Version_5>();
773966 case PatternVersion::Version_6:
774967 return match<PatternVersion::Version_6>();
968+ case PatternVersion::Version_7:
969+ return match<PatternVersion::Version_7>();
775970
776971 default :
777972 break ;
@@ -1006,6 +1201,36 @@ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Ve
10061201 replace (_p.add_as_terminal ).with (instance_norm);
10071202}
10081203
1204+ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1205+ {
1206+ auto graph = _p.reshape_as_terminal ->graph ();
1207+
1208+ reshape_gamma_beta ();
1209+
1210+ auto instance_norm = create_inst_norm (graph);
1211+
1212+ // set origin
1213+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1214+ luci::get_origin (_p.reshape_of_ifm ),
1215+ luci::get_origin (_p.mean_of_ifm ),
1216+ luci::get_origin (_p.sub_2 ),
1217+ luci::get_origin (_p.square ),
1218+ luci::get_origin (_p.mean_as_variance ),
1219+ luci::get_origin (_p.add_as_variance ),
1220+ luci::get_origin (_p.rsqrt ),
1221+ luci::get_origin (_p.mul_gamma ),
1222+ luci::get_origin (_p.neg_mean ),
1223+ luci::get_origin (_p.mul_as_scaled_ifm ),
1224+ luci::get_origin (_p.mul_as_scaled_mean ),
1225+ luci::get_origin (_p.add_neg_mul ),
1226+ luci::get_origin (_p.add_as_terminal ),
1227+ luci::get_origin (_p.reshape_as_terminal )};
1228+
1229+ luci::add_origin (instance_norm, luci::composite_origin (origin_vec));
1230+
1231+ replace (_p.reshape_as_terminal ).with (instance_norm);
1232+ }
1233+
10091234void FuseInstanceNorm::apply ()
10101235{
10111236 assert (_p.matched ());
@@ -1030,6 +1255,9 @@ void FuseInstanceNorm::apply()
10301255 case InstanceNormPattern::PatternVersion::Version_6:
10311256 apply<InstanceNormPattern::PatternVersion::Version_6>();
10321257 break ;
1258+ case InstanceNormPattern::PatternVersion::Version_7:
1259+ apply<InstanceNormPattern::PatternVersion::Version_7>();
1260+ break ;
10331261
10341262 default :
10351263 break ;
@@ -1256,6 +1484,21 @@ bool fuse_instance_norm(luci::CircleDiv *div)
12561484 return false ;
12571485}
12581486
1487+ bool fuse_instance_norm (luci::CircleReshape *reshape)
1488+ {
1489+ InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1490+
1491+ InstanceNormPattern pattern (reshape, pv);
1492+ if (pattern.matched ())
1493+ {
1494+ FuseInstanceNorm fuse (pattern);
1495+ fuse.apply ();
1496+ return true ;
1497+ }
1498+
1499+ return false ;
1500+ }
1501+
12591502bool post_fusion (luci::CircleInstanceNorm *inst_norm)
12601503{
12611504 PostFusion postfusion (inst_norm);
@@ -1294,6 +1537,17 @@ bool FuseInstanceNormPass::run(loco::Graph *g)
12941537 changed = true ;
12951538 }
12961539
1540+ // Check Version_7(from Reshape) if other versions not found
1541+ for (auto node : loco::active_nodes (loco::output_nodes (g)))
1542+ {
1543+ auto reshape = dynamic_cast <luci::CircleReshape *>(node);
1544+ if (not reshape)
1545+ continue ;
1546+
1547+ if (fuse_instance_norm (reshape))
1548+ changed = true ;
1549+ }
1550+
12971551 // Post processing of FuseInstanceNorm
12981552 for (auto node : loco::active_nodes (loco::output_nodes (g)))
12991553 {
0 commit comments