2525
2626#include < cassert>
2727#include < set>
28+ #include < optional>
2829
2930// Helper to check detail
3031
31- // / @return true When node has shape of '1 x .. x 1 x depth'
32- bool is_1D_with_dummy_dim (luci::CircleConst *node, uint32_t depth)
32+ #define CHECK_OR_FALSE (condition ) \
33+ if (not (condition)) \
34+ return false ;
35+
36+ // / @return true When node has shape with one dim other than `1` (like '1 x .. x 1 x depth' or '1
37+ // / x .. x depth' x 1)
38+ bool is_unsqueezed_1D (luci::CircleConst *node, uint32_t depth)
3339{
34- auto rank = node->rank ();
35- uint32_t axis ;
36- for (axis = 0 ; axis < rank - 1 ; ++axis)
40+ const auto rank = node->rank ();
41+ std::optional< uint32_t > depth_axis ;
42+ for (uint32_t axis = 0 ; axis < rank; ++axis)
3743 {
3844 if (node->dim (axis).value () != 1 )
39- return false ;
45+ {
46+ // only one axis can be other than 1
47+ if (depth_axis.has_value ())
48+ {
49+ return false ;
50+ }
51+ depth_axis = axis;
52+ }
53+ }
54+ if (!depth_axis.has_value ())
55+ {
56+ return false ;
57+ }
58+ return node->dim (depth_axis.value ()).value () == depth;
59+ }
60+
61+ // / @return true if the provided begin_reshape Reshape op adds `1` dimension
62+ // / and terminal_reshape Reshape op removes it (the result is neutral for further
63+ // / processing)
64+ bool is_unsqueeze_squeeze_pair (luci::CircleReshape *begin_reshape,
65+ luci::CircleReshape *terminal_reshape)
66+ {
67+ auto const begin_reshape_ifm = dynamic_cast <luci::CircleNode *>(begin_reshape->tensor ());
68+ CHECK_OR_FALSE (begin_reshape_ifm);
69+
70+ // check last axis
71+ CHECK_OR_FALSE ((begin_reshape_ifm->rank () + 1 ) == begin_reshape->rank ());
72+
73+ // check unchanged part of begin_shape
74+ for (uint32_t axis = 0 ; axis < begin_reshape_ifm->rank (); ++axis)
75+ {
76+ // skip dynamic cases
77+ CHECK_OR_FALSE (begin_reshape_ifm->dim (axis).known () && begin_reshape->dim (axis).known ());
78+ CHECK_OR_FALSE (begin_reshape_ifm->dim (axis).value () == begin_reshape->dim (axis).value ());
79+ }
80+ // check last axis
81+ CHECK_OR_FALSE (begin_reshape->dim (begin_reshape->rank () - 1 ) == 1 );
82+
83+ auto const terminal_reshape_ifm = dynamic_cast <luci::CircleNode *>(terminal_reshape->tensor ());
84+ CHECK_OR_FALSE (terminal_reshape_ifm);
85+
86+ CHECK_OR_FALSE (terminal_reshape_ifm->rank () == terminal_reshape->rank () + 1 );
87+
88+ // check last axis
89+ CHECK_OR_FALSE (terminal_reshape_ifm->dim (begin_reshape->rank () - 1 ) == 1 );
90+
91+ // check unchanged part of terminal_reshape
92+ for (uint32_t axis = 0 ; axis < terminal_reshape->rank (); ++axis)
93+ {
94+ // skip dynamic cases
95+ CHECK_OR_FALSE (terminal_reshape_ifm->dim (axis).known () && terminal_reshape->dim (axis).known ());
96+ CHECK_OR_FALSE (terminal_reshape_ifm->dim (axis).value () == terminal_reshape->dim (axis).value ());
4097 }
41- return node->dim (axis).value () == depth;
98+
99+ return true ;
42100}
43101
44102bool is_instance_mean_v1 (luci::CircleMean *mean)
@@ -370,6 +428,57 @@ namespace
370428 * |
371429 * V
372430 * [Out]
431+ *-------------------------------------------------------------------
432+ * Version_7
433+ * [In]
434+ * |
435+ * V
436+ * ifm
437+ * |
438+ * |
439+ * +-----------------------reshape_ifm -------------------------+
440+ * | | |
441+ * | (reduction indicies) | |
442+ * | | | |
443+ * V V | |
444+ * mean_of_ifm | |
445+ * | V |
446+ * +----------------------> sub_2 |
447+ * | | |
448+ * | V |
449+ * | square (reduction indicies) |
450+ * | | | |
451+ * | V | |
452+ * | mean_as_variance <----+ |
453+ * | | const_as_epsilon |
454+ * | | | |
455+ * | V | |
456+ * | add_as_variance <----------+ |
457+ * | | |
458+ * | | |
459+ * | V |
460+ * | rsqrt const_as_gamma |
461+ * | | | |
462+ * | | | |
463+ * neg_mean mul_gamma <------+ |
464+ * | | |
465+ * | | |
466+ * V | V
467+ * mul_as_scaled_mean <---------+-----------------------> mul_as_scaled_ifm
468+ * | |
469+ * | const_as_beta |
470+ * | | |
471+ * V V |
472+ * add_neg_mul |
473+ * | |
474+ * +---------------> add_as_terminal <------------------------+
475+ * |
476+ * |
477+ * V
478+ * reshape_as_terminal
479+ * |
480+ * V
481+ * [Out]
373482 */
374483class InstanceNormPattern final
375484{
@@ -383,6 +492,7 @@ class InstanceNormPattern final
383492 Version_4,
384493 Version_5,
385494 Version_6, // For only 3D I/O
495+ Version_7,
386496 };
387497
388498 InstanceNormPattern (luci::CircleAdd *candidate, PatternVersion pv)
@@ -399,6 +509,13 @@ class InstanceNormPattern final
399509 _pv = pv;
400510 }
401511
512+ InstanceNormPattern (luci::CircleReshape *candidate, PatternVersion pv)
513+ {
514+ assert (candidate);
515+ reshape_as_terminal = candidate;
516+ _pv = pv;
517+ }
518+
402519private:
403520 bool condition_common_1_5 (uint32_t ifm_channel_depth);
404521 bool condition_common_3_4 ();
@@ -424,6 +541,7 @@ class InstanceNormPattern final
424541 luci::CircleMean *mean_as_variance = nullptr ;
425542 luci::CircleConst *const_as_epsilon = nullptr ;
426543 luci::CircleAdd *add_as_variance = nullptr ;
544+ luci::CircleAdd *add_neg_mul = nullptr ;
427545 luci::CircleRsqrt *rsqrt = nullptr ;
428546 luci::CircleConst *const_as_gamma = nullptr ;
429547 luci::CircleMul *mul_gamma = nullptr ;
@@ -437,16 +555,15 @@ class InstanceNormPattern final
437555 luci::CirclePow *pow = nullptr ;
438556 luci::CircleSqrt *sqrt = nullptr ;
439557 luci::CircleDiv *div = nullptr ;
558+ luci::CircleConst *reshape_terminal_target_shape = nullptr ;
559+ luci::CircleReshape *reshape_as_terminal = nullptr ;
560+ luci::CircleNeg *neg_mean = nullptr ;
440561
441562private:
442563 bool _matched = false ;
443564 PatternVersion _pv;
444565};
445566
446- #define CHECK_OR_FALSE (condition ) \
447- if (not (condition)) \
448- return false ;
449-
450567bool InstanceNormPattern::condition_common_1_5 (uint32_t ifm_channel_depth)
451568{
452569 add_as_variance = dynamic_cast <luci::CircleAdd *>(rsqrt->x ());
@@ -472,7 +589,7 @@ bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
472589
473590 const_as_beta = dynamic_cast <luci::CircleConst *>(sub->x ());
474591 CHECK_OR_FALSE (const_as_beta);
475- CHECK_OR_FALSE (is_1D_with_dummy_dim (const_as_beta, ifm_channel_depth));
592+ CHECK_OR_FALSE (is_unsqueezed_1D (const_as_beta, ifm_channel_depth));
476593
477594 return true ;
478595}
@@ -534,7 +651,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
534651
535652 CHECK_OR_FALSE (luci::fill (&rsqrt, &const_as_gamma).with_commutative_args_of (mul_gamma));
536653
537- CHECK_OR_FALSE (is_1D_with_dummy_dim (const_as_gamma, ifm_channel_depth));
654+ CHECK_OR_FALSE (is_unsqueezed_1D (const_as_gamma, ifm_channel_depth));
538655
539656 CHECK_OR_FALSE (condition_common_1_5 (ifm_channel_depth));
540657
@@ -751,6 +868,83 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
751868 return true ;
752869}
753870
871+ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
872+ {
873+ add_as_terminal = dynamic_cast <luci::CircleAdd *>(reshape_as_terminal->tensor ());
874+ CHECK_OR_FALSE (add_as_terminal);
875+
876+ CHECK_OR_FALSE (
877+ luci::fill (&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of (add_as_terminal));
878+ CHECK_OR_FALSE (
879+ luci::fill (&reshape_of_ifm, &mul_gamma).with_commutative_args_of (mul_as_scaled_ifm));
880+
881+ mul_as_scaled_mean = dynamic_cast <luci::CircleMul *>(add_neg_mul->x ());
882+ CHECK_OR_FALSE (mul_as_scaled_mean);
883+
884+ neg_mean = dynamic_cast <luci::CircleNeg *>(mul_as_scaled_mean->x ());
885+ CHECK_OR_FALSE (neg_mean);
886+
887+ luci::CircleMul *mul_gamma_should_be = nullptr ;
888+ luci::CircleNeg *neg_should_be = nullptr ;
889+
890+ CHECK_OR_FALSE (
891+ luci::fill (&mul_gamma_should_be, &neg_should_be).with_commutative_args_of (mul_as_scaled_mean));
892+
893+ CHECK_OR_FALSE (mul_gamma == mul_gamma_should_be);
894+ CHECK_OR_FALSE (neg_mean == neg_should_be);
895+
896+ mean_of_ifm = dynamic_cast <luci::CircleMean *>(neg_mean->x ());
897+ CHECK_OR_FALSE (mean_of_ifm);
898+
899+ luci::CircleReshape *reshape_of_ifm_should_be = nullptr ;
900+ reshape_of_ifm_should_be = dynamic_cast <luci::CircleReshape *>(mean_of_ifm->input ());
901+ CHECK_OR_FALSE (reshape_of_ifm_should_be == reshape_of_ifm);
902+
903+ ifm = reshape_of_ifm->tensor ();
904+ auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
905+ CHECK_OR_FALSE (ifm_circle);
906+
907+ CHECK_OR_FALSE (ifm_circle->shape_status () == luci::ShapeStatus::VALID);
908+ CHECK_OR_FALSE (ifm_circle->rank () == 4 );
909+ CHECK_OR_FALSE (ifm_circle->dim (3 ).known ());
910+ uint32_t ifm_channel_depth = ifm_circle->dim (3 ).value ();
911+
912+ const_as_beta = dynamic_cast <luci::CircleConst *>(add_neg_mul->y ());
913+ CHECK_OR_FALSE (const_as_beta);
914+ CHECK_OR_FALSE (is_unsqueezed_1D (const_as_beta, ifm_channel_depth));
915+
916+ CHECK_OR_FALSE (luci::fill (&rsqrt, &const_as_gamma).with_commutative_args_of (mul_gamma));
917+ CHECK_OR_FALSE (is_unsqueezed_1D (const_as_gamma, ifm_channel_depth));
918+
919+ add_as_variance = dynamic_cast <luci::CircleAdd *>(rsqrt->x ());
920+ CHECK_OR_FALSE (add_as_variance);
921+
922+ CHECK_OR_FALSE (
923+ luci::fill (&mean_as_variance, &const_as_epsilon).with_commutative_args_of (add_as_variance));
924+ CHECK_OR_FALSE (mean_as_variance);
925+
926+ CHECK_OR_FALSE (const_as_epsilon->dtype () == loco::DataType::FLOAT32);
927+ // TODO Support regarding broadcast
928+ CHECK_OR_FALSE (const_as_epsilon->size <loco::DataType::FLOAT32>() == 1 );
929+
930+ square = dynamic_cast <luci::CircleSquare *>(mean_as_variance->input ());
931+ CHECK_OR_FALSE (square);
932+
933+ sub_2 = dynamic_cast <luci::CircleSub *>(square->x ());
934+ CHECK_OR_FALSE (sub_2);
935+
936+ auto mean_of_ifm_should_be = dynamic_cast <luci::CircleMean *>(sub_2->y ());
937+ CHECK_OR_FALSE (mean_of_ifm == mean_of_ifm_should_be);
938+
939+ auto reshape_of_ifm_should_be_2 = dynamic_cast <luci::CircleReshape *>(sub_2->x ());
940+ CHECK_OR_FALSE (reshape_of_ifm_should_be_2 == reshape_of_ifm);
941+
942+ CHECK_OR_FALSE (is_unsqueeze_squeeze_pair (reshape_of_ifm, reshape_as_terminal));
943+
944+ _matched = true ;
945+ return true ;
946+ }
947+
754948bool InstanceNormPattern::matched ()
755949{
756950 if (_matched)
@@ -772,6 +966,8 @@ bool InstanceNormPattern::matched()
772966 return match<PatternVersion::Version_5>();
773967 case PatternVersion::Version_6:
774968 return match<PatternVersion::Version_6>();
969+ case PatternVersion::Version_7:
970+ return match<PatternVersion::Version_7>();
775971
776972 default :
777973 break ;
@@ -1006,6 +1202,36 @@ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Ve
10061202 replace (_p.add_as_terminal ).with (instance_norm);
10071203}
10081204
1205+ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1206+ {
1207+ auto graph = _p.reshape_as_terminal ->graph ();
1208+
1209+ reshape_gamma_beta ();
1210+
1211+ auto instance_norm = create_inst_norm (graph);
1212+
1213+ // set origin
1214+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1215+ luci::get_origin (_p.reshape_of_ifm ),
1216+ luci::get_origin (_p.mean_of_ifm ),
1217+ luci::get_origin (_p.sub_2 ),
1218+ luci::get_origin (_p.square ),
1219+ luci::get_origin (_p.mean_as_variance ),
1220+ luci::get_origin (_p.add_as_variance ),
1221+ luci::get_origin (_p.rsqrt ),
1222+ luci::get_origin (_p.mul_gamma ),
1223+ luci::get_origin (_p.neg_mean ),
1224+ luci::get_origin (_p.mul_as_scaled_ifm ),
1225+ luci::get_origin (_p.mul_as_scaled_mean ),
1226+ luci::get_origin (_p.add_neg_mul ),
1227+ luci::get_origin (_p.add_as_terminal ),
1228+ luci::get_origin (_p.reshape_as_terminal )};
1229+
1230+ luci::add_origin (instance_norm, luci::composite_origin (origin_vec));
1231+
1232+ replace (_p.reshape_as_terminal ).with (instance_norm);
1233+ }
1234+
10091235void FuseInstanceNorm::apply ()
10101236{
10111237 assert (_p.matched ());
@@ -1030,6 +1256,9 @@ void FuseInstanceNorm::apply()
10301256 case InstanceNormPattern::PatternVersion::Version_6:
10311257 apply<InstanceNormPattern::PatternVersion::Version_6>();
10321258 break ;
1259+ case InstanceNormPattern::PatternVersion::Version_7:
1260+ apply<InstanceNormPattern::PatternVersion::Version_7>();
1261+ break ;
10331262
10341263 default :
10351264 break ;
@@ -1256,6 +1485,21 @@ bool fuse_instance_norm(luci::CircleDiv *div)
12561485 return false ;
12571486}
12581487
1488+ bool fuse_instance_norm (luci::CircleReshape *reshape)
1489+ {
1490+ InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1491+
1492+ InstanceNormPattern pattern (reshape, pv);
1493+ if (pattern.matched ())
1494+ {
1495+ FuseInstanceNorm fuse (pattern);
1496+ fuse.apply ();
1497+ return true ;
1498+ }
1499+
1500+ return false ;
1501+ }
1502+
12591503bool post_fusion (luci::CircleInstanceNorm *inst_norm)
12601504{
12611505 PostFusion postfusion (inst_norm);
@@ -1294,6 +1538,17 @@ bool FuseInstanceNormPass::run(loco::Graph *g)
12941538 changed = true ;
12951539 }
12961540
1541+ // Check Version_7(from Reshape) if other versions not found
1542+ for (auto node : loco::active_nodes (loco::output_nodes (g)))
1543+ {
1544+ auto reshape = dynamic_cast <luci::CircleReshape *>(node);
1545+ if (not reshape)
1546+ continue ;
1547+
1548+ if (fuse_instance_norm (reshape))
1549+ changed = true ;
1550+ }
1551+
12971552 // Post processing of FuseInstanceNorm
12981553 for (auto node : loco::active_nodes (loco::output_nodes (g)))
12991554 {
0 commit comments