Skip to content

Commit bb85ceb

Browse files
authored
Add/custom registerstatistics (#9)
* Add custom tf.python.framework.ops.registerstatistics * Support global max pooling: add registerstatistics for Max op * support batch normalization: add registerstatistics for FusedBatchNormV3 op
1 parent cf6d828 commit bb85ceb

File tree

5 files changed

+94
-59
lines changed

5 files changed

+94
-59
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Support `tf.keras.layers` as follows,
5858
| Pooling | AveragePooling[1D/2D] |
5959
| | GlobalAveragePooling[1D/2D/3D]|
6060
| | MaxPooling[1D/2D] |
61+
| | GlobalMaxPool[1D/2D/3D] |
62+
| Normalization | BatchNormalization |
6163
| Activation | Softmax |
6264
| Attention | Attention |
6365
| | AdditiveAttention |
@@ -72,10 +74,8 @@ Not support `tf.keras.layers` as follows. They are calculated as zero or smaller
7274
| Conv | Conv3DTranspose |
7375
| Pooling | AveragePooling3D |
7476
| | MaxPooling3D |
75-
| | GlobalMaxPool[1D/2D/3D] |
7677
| | UpSampling[1D/2D/3D] |
77-
| Normalization | BatchNormalization |
78-
| | LayerNormalization |
78+
| Normalization | LayerNormalization |
7979
| RNN | SimpleRNN |
8080
| | LSTM |
8181
| | GRU |

keras_flops/flops_calculation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from tensorflow.keras import Sequential, Model
88

9+
import keras_flops.flops_registory
10+
911

1012
def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None) -> int:
1113
"""
@@ -35,5 +37,6 @@ def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None)
3537
flops = tf.compat.v1.profiler.profile(
3638
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
3739
)
40+
# print(frozen_func.graph.get_operations())
3841
# TODO: show each FLOPS
3942
return flops.total_float_ops

keras_flops/flops_registory.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
from tensorflow.python.framework import ops
3+
from tensorflow.python.framework import graph_util
4+
from tensorflow.python.profiler.internal.flops_registry import _reduction_op_flops
5+
6+
7+
@ops.RegisterStatistics("FusedBatchNormV3", "flops")
8+
def _flops_fused_batch_norm_v3(graph, node):
9+
"""inference is only supportted"""
10+
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
11+
in_shape.assert_is_fully_defined()
12+
mean_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[3])
13+
mean_shape.assert_is_fully_defined()
14+
variance_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[4])
15+
variance_shape.assert_is_fully_defined()
16+
17+
if node.attr["is_training"].b is True:
18+
raise ValueError("Only supports inference mode")
19+
20+
num_flops = (
21+
in_shape.num_elements()
22+
+ 4 * variance_shape.num_elements()
23+
+ mean_shape.num_elements()
24+
)
25+
return ops.OpStats("flops", num_flops)
26+
27+
28+
@ops.RegisterStatistics("Max", "flops")
29+
def _flops_max(graph, node):
30+
"""inference is supportted"""
31+
# reduction - comparison, no finalization
32+
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
33+

tests/test_flops.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
GlobalAveragePooling3D,
1616
MaxPooling1D,
1717
MaxPooling2D,
18+
GlobalMaxPooling1D,
19+
GlobalMaxPooling2D,
20+
GlobalMaxPooling3D,
1821
BatchNormalization,
1922
AdditiveAttention,
2023
Attention,
@@ -235,6 +238,29 @@ def test_maxpooling1d2d3d():
235238
assert flops == in_w * in_h * kernel
236239

237240

241+
def test_global_maxpooling1d2d3d():
242+
"""
243+
reduct rest (Ndim) of target axis.
244+
compare Ndim - 1 ops.
245+
"""
246+
in_w = 32
247+
in_h = 32
248+
in_z = 32
249+
kernel = 3
250+
251+
model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel)))
252+
flops = get_flops(model, batch_size=1)
253+
assert flops == (in_w - 1) * kernel
254+
255+
model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel)))
256+
flops = get_flops(model, batch_size=1)
257+
assert flops == (in_w * in_h - 1) * kernel
258+
259+
model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel)))
260+
flops = get_flops(model, batch_size=1)
261+
assert flops == (in_w * in_h * in_z - 1) * kernel
262+
263+
238264
def test_softmax():
239265
kernel = 8
240266
model = Sequential(Activation("softmax", input_shape=(kernel,)))
@@ -293,10 +319,7 @@ def test_batchnormalization():
293319
2. (1 ops * |var|) inv *= gamma (scale)
294320
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
295321
, where |var| = |mean| = channel size in default
296-
Thus, 5 * channel size + input element size.
297-
298-
NOTE: support only fused=False
299-
Use gen_nn_ops.fused_batch_norm_v3 but this is not registered yet and calculated as zero.
322+
Thus, tot FLOPs = 5 * channel size + input element size.
300323
"""
301324
in_w = 32
302325
in_h = 32
@@ -310,7 +333,21 @@ def test_batchnormalization():
310333
)
311334
)
312335
flops = get_flops(model, batch_size=1)
313-
assert flops == 5 * in_ch + in_w * in_ch, "fused is False"
336+
assert (
337+
flops == 5 * in_ch + in_w * in_ch
338+
), "fused is False. see nn_impl.batch_normalization"
339+
340+
model = Sequential(
341+
BatchNormalization(
342+
beta_initializer="ones",
343+
gamma_initializer="ones",
344+
input_shape=(in_w, in_h, in_ch),
345+
)
346+
)
347+
flops = get_flops(model, batch_size=1)
348+
assert (
349+
flops == 5 * in_ch + in_w * in_h * in_ch
350+
), "fused is True, see gen_nn.fused_batch_norm_v3"
314351

315352

316353
def test_additive_attention():

tests/test_not_support.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
Conv3DTranspose,
99
AveragePooling3D,
1010
MaxPooling3D,
11-
GlobalMaxPooling1D,
12-
GlobalMaxPooling2D,
13-
GlobalMaxPooling3D,
1411
UpSampling1D,
1512
UpSampling2D,
1613
UpSampling3D,
17-
BatchNormalization,
1814
LayerNormalization,
1915
)
2016
from keras_flops import get_flops
@@ -135,26 +131,6 @@ def test_maxpooling1d2d3d():
135131
assert flops == in_w * in_h * in_z * kernel
136132

137133

138-
@pytest.mark.xfail
139-
def test_global_maxpooling1d2d3d():
140-
in_w = 32
141-
in_h = 32
142-
in_z = 32
143-
kernel = 32
144-
145-
model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel)))
146-
flops = get_flops(model, batch_size=1)
147-
assert flops == in_w * kernel
148-
149-
model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel)))
150-
flops = get_flops(model, batch_size=1)
151-
assert flops == in_w * in_h * kernel
152-
153-
model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel)))
154-
flops = get_flops(model, batch_size=1)
155-
assert flops == in_w * in_h * in_z * kernel
156-
157-
158134
@pytest.mark.xfail
159135
def test_upsampling1d2d3d():
160136
in_w = 32
@@ -182,28 +158,6 @@ def test_upsampling1d2d3d():
182158
assert flops == in_w * in_h * in_z * kernel
183159

184160

185-
@pytest.mark.xfail
186-
def test_batchnormalization():
187-
"""
188-
batch normalization in tf uses gen_nn_ops.fused_batch_norm_v3 if input shape are 4D
189-
"""
190-
in_w = 32
191-
in_h = 32
192-
in_ch = 3
193-
194-
model = Sequential(
195-
BatchNormalization(
196-
beta_initializer="ones",
197-
gamma_initializer="ones",
198-
input_shape=(in_w, in_h, in_ch),
199-
)
200-
)
201-
flops = get_flops(model, batch_size=1)
202-
assert (
203-
flops == 5 * in_ch + in_w * in_h * in_ch
204-
), "fused is True, fused_batch_norm_v3 is not supportted"
205-
206-
207161
@pytest.mark.xfail
208162
def test_layernormalization():
209163
"""
@@ -213,11 +167,12 @@ def test_layernormalization():
213167
2. (1 ops * |var|) inv *= gamma (scale)
214168
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
215169
, where |var| = |mean| = 1 in default
216-
Thus, 5 channel size + input element size.
170+
Thus, 5 + input element size.
217171
218-
Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation,
219-
but gen_nn_ops.fused_batch_norm_v3 is not registered yet, so can not evaluate corrent FLOPs.
172+
Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation.
173+
gen_nn_ops.fused_batch_norm_v3 support only 4D, so reshape data as 4D and input them.
220174
squeezed_shape (ndim ops), scale (|x| ops) and shift (not float ops) is calculated.
175+
NOTE: is_training = True, if make trainable attributes of tf.keras.Model instanse False. So, statistics will be incorrect.
221176
"""
222177
in_w = 32
223178
in_h = 32
@@ -244,6 +199,13 @@ def test_layernormalization():
244199
)
245200
)
246201
flops = get_flops(model, batch_size=1)
247-
assert flops == len(input_shape) + 1 + in_w * in_h * in_ch, "fused is True"
202+
assert (
203+
flops
204+
== len(input_shape)
205+
+ 1
206+
+ 5
207+
+ in_w * in_h * in_ch
208+
+ 5 * in_ch
209+
+ in_w * in_h * in_ch
210+
), "fused is True. check gen_nn_ops.fused_batch_norm_v3"
248211

249-
assert flops == len(input_shape) + 1 + 5 * in_ch + in_w * in_h * in_ch

0 commit comments

Comments
 (0)