Skip to content

I'm trying to add "MultiHeadAttention" layer support. #15

@dennistang742

Description

@dennistang742

Hi Tokusumi,

Thanks for your great work and kindly sharing. I'm learning "pytest" package usage and the flops calculation through your assertion condition.
Now I want to add "MultiHeadAttention" layer support and writing a test case for "MultiHeadAttention" layer and submitting to your repository. But some serious problems occur to me:

  1. I study your function test_attention. In the part of assert condition[0], why the flops calculation of softmax is "5 * Tq * Tv"?
    [0]

    def test_attention():

    I attached the log by the way:
    Profile:
    node name | # float_ops
    TFProfRoot (--/6.90k flops)
    model_4/attention/MatMul (3.20k/3.20k flops)
    model_4/attention/MatMul_1 (3.20k/3.20k flops)
    model_4/attention/Softmax (500/500 flops)

  2. I write a toy test case for multiheadattention as follows:
    def test_multiheadattention():
    Tq = 10
    dim = 16
    q_shape = (Tq, dim)
    q = Input(q_shape)
    x = MultiHeadAttention(num_heads=1, key_dim=2)(q, q)
    model = Model(inputs=q, outputs=x)
    flops = get_flops(model, batch_size=1)
    print(f'{flops}')
    Profile:
    node name | # float_ops
    TFProfRoot (--/740 flops)
    model/multi_head_attention/softmax/Softmax (500/500 flops)
    model/multi_head_attention/attention_output/add (160/160 flops)
    model/multi_head_attention/Mul (20/20 flops)
    model/multi_head_attention/key/add (20/20 flops)
    model/multi_head_attention/query/add (20/20 flops)
    model/multi_head_attention/value/add (20/20 flops)

    what I supposed to be is:
    query input is \in M_{10, 16}, key input is \in M_{10, 16} and value input is \in M_{10, 16}. First the query and key are projected to M_{16, 2} by two M_{16, 2} respectively, which has total 10 * 16 * 2 * 2[convert MAC to flops] flops. The attention matrix require 10 * 2 * 10 * 2[convert MAC to flops] flops. Then the values are projection need 10 * 10 * 16 * 2[convert MAC to flops] flops. I did not count the softmax operation at this time.
    Do my thought have fundamental flaw or something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions