-
-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hi Tokusumi,
Thanks for your great work and kindly sharing. I'm learning "pytest" package usage and the flops calculation through your assertion condition.
Now I want to add "MultiHeadAttention" layer support and writing a test case for "MultiHeadAttention" layer and submitting to your repository. But some serious problems occur to me:
-
I study your function test_attention. In the part of assert condition[0], why the flops calculation of softmax is "5 * Tq * Tv"?
[0]keras-flops/tests/test_flops.py
Line 385 in 5c130bf
def test_attention():
I attached the log by the way:
Profile:
node name | # float_ops
TFProfRoot (--/6.90k flops)
model_4/attention/MatMul (3.20k/3.20k flops)
model_4/attention/MatMul_1 (3.20k/3.20k flops)
model_4/attention/Softmax (500/500 flops) -
I write a toy test case for multiheadattention as follows:
def test_multiheadattention():
Tq = 10
dim = 16
q_shape = (Tq, dim)
q = Input(q_shape)
x = MultiHeadAttention(num_heads=1, key_dim=2)(q, q)
model = Model(inputs=q, outputs=x)
flops = get_flops(model, batch_size=1)
print(f'{flops}')
Profile:
node name | # float_ops
TFProfRoot (--/740 flops)
model/multi_head_attention/softmax/Softmax (500/500 flops)
model/multi_head_attention/attention_output/add (160/160 flops)
model/multi_head_attention/Mul (20/20 flops)
model/multi_head_attention/key/add (20/20 flops)
model/multi_head_attention/query/add (20/20 flops)
model/multi_head_attention/value/add (20/20 flops)
what I supposed to be is:
query input is \in M_{10, 16}, key input is \in M_{10, 16} and value input is \in M_{10, 16}. First the query and key are projected to M_{16, 2} by two M_{16, 2} respectively, which has total 10 * 16 * 2 * 2[convert MAC to flops] flops. The attention matrix require 10 * 2 * 10 * 2[convert MAC to flops] flops. Then the values are projection need 10 * 10 * 16 * 2[convert MAC to flops] flops. I did not count the softmax operation at this time.
Do my thought have fundamental flaw or something?