Skip to content

Commit 7eced04

Browse files
lrdxgmtensorflower-gardener
authored andcommitted
Add an optimization pattern to convert a fully_connected op with a weight with last dimension of 1 to a broadcasting mul op.
PiperOrigin-RevId: 755893068
1 parent ccb2a3d commit 7eced04

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

tensorflow/compiler/mlir/lite/tests/optimize.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4800,3 +4800,31 @@ func.func @AddComputedZeroNegative(%arg0: tensor<1x512xf32>, %arg1: tensor<512x5
48004800
// CHECK: %0 = tfl.sub %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<512x512xf32>
48014801
// CHECK: %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<1x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
48024802
}
4803+
4804+
// CHECK-LABEL: @DegerateFC
4805+
func.func @DegerateFC(%input: tensor<5x3x1xf32>) -> tensor<5x3x2xf32> {
4806+
%weights = arith.constant dense<[[1.0], [2.0]]> : tensor<2x1xf32>
4807+
%bias = "tfl.no_value"() {value} : () -> none
4808+
%0 = "tfl.fully_connected"(%input, %weights, %bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1xf32>, none) -> tensor<5x3x2xf32>
4809+
func.return %0: tensor<5x3x2xf32>
4810+
4811+
// CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<5x3x1xf32>, tensor<2xf32>) -> tensor<5x3x2xf32>
4812+
}
4813+
4814+
// CHECK-LABEL: @DegerateFCNegative
4815+
func.func @DegerateFCNegative(%input_ok: tensor<5x3x1xf32>, %input_too_many_dims: tensor<11x7x5x3x1xf32>, %input_last_dim_not_1: tensor<5x3x2xf32>) -> (tensor<11x7x5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>) {
4816+
%weights_ok = arith.constant dense<[[1.0], [2.0]]> : tensor<2x1xf32>
4817+
%weights_last_dim_not_1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
4818+
%weights_quantized = "tfl.pseudo_qconst"() <{qtype = tensor<2x1x!quant.uniform<i8:f32:0, {1.0}>>, value = dense<42> : tensor<2x1xi8>}> : () -> tensor<2x1x!quant.uniform<i8:f32:0, {1.0}>>
4819+
4820+
%bias_ok = "tfl.no_value"() {value} : () -> none
4821+
%bias_notnull = arith.constant dense<[1.0, 2.0]>: tensor<2xf32>
4822+
4823+
%1 = "tfl.fully_connected"(%input_too_many_dims, %weights_ok, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<11x7x5x3x1xf32>, tensor<2x1xf32>, none) -> tensor<11x7x5x3x2xf32>
4824+
%2 = "tfl.fully_connected"(%input_last_dim_not_1, %weights_last_dim_not_1, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x2xf32>, tensor<2x2xf32>, none) -> tensor<5x3x2xf32>
4825+
%3 = "tfl.fully_connected"(%input_ok, %weights_quantized, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1x!quant.uniform<i8:f32:0, {1.0}>>, none) -> tensor<5x3x2xf32>
4826+
%4 = "tfl.fully_connected"(%input_ok, %weights_ok, %bias_notnull) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1xf32>, tensor<2xf32>) -> tensor<5x3x2xf32>
4827+
func.return %1, %2, %3, %4 : tensor<11x7x5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>
4828+
4829+
// CHECK-NOT: tfl.mul
4830+
}

tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,3 +2191,27 @@ def AddComputedZeroLHS : Pat<
21912191
TFL_AF_None),
21922192
(replaceWithValue $input),
21932193
[(HasSameType $input, $output)]>;
2194+
2195+
// Replace matmul where inputs & weights have a last dimension of 1 with an
2196+
// elementwise multiplication that broadcasts, i.e. replace:
2197+
// [a, b, 1] x [n, 1] => [a, b, n]
2198+
// with:
2199+
// [a, b, 1] * [n] => [a, b, n]
2200+
def DegenerateFCtoMul : Pat<
2201+
(TFL_FullyConnectedOp
2202+
$input,
2203+
(Arith_ConstantOp:$filter $filterVal),
2204+
$bias,
2205+
$fused_activation_function,
2206+
TFL_FCWO_Default,
2207+
ConstBoolAttrTrue,
2208+
$asymmetric_quantize_inputs),
2209+
(TFL_MulOp
2210+
$input,
2211+
(Arith_ConstantOp (FlattenTo1D $filterVal)),
2212+
$fused_activation_function),
2213+
[(HasRankAtMost<4> $input),
2214+
(HasRank<2> $filter),
2215+
(IsLastDimensionEqualOne $input),
2216+
(SameElementType $input, $filter),
2217+
(IsNoneType $bias)]>;

0 commit comments

Comments
 (0)