Skip to content

Question about chunk_gated_delta_rule_ref function #14

@wyc1997

Description

@wyc1997

Hi, thanks for the interesting paper and code release. I am looking through the implementation trying to understand the code. In https://github.com/NVlabs/GatedDeltaNet/blob/main/lit_gpt/gated_delta_rule_ops/chunk.py#L662-L666, we have:

    attn = -((k_beta @ k.transpose(-1, -2)) * L_mask).masked_fill(mask, 0)
    for i in range(1, chunk_size):
      attn[..., i, :i] = attn[..., i, :i].clone() + (attn[..., i, :i, None].clone() * attn[..., :i, :i].clone()).sum(-2)
    attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
    attn = attn
    k_cumsum = attn @ v

and yet in the immediately following lines we are seeing attn getting overwritten to another value that doesn't have the L_mask term:

    attn = -((k_beta @ k.transpose(-1, -2))).masked_fill(mask, 0)
    for i in range(1, chunk_size):
      attn[..., i, :i] = attn[..., i, :i].clone() + (attn[..., i, :i, None].clone() * attn[..., :i, :i].clone()).sum(-2)
    attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
    attn = attn
    w = k_cumdecay = attn @ k_beta
    u = v = k_cumsum

w and u is taking the value calculated without the L_mask no? I haven't look at the kernel implementation yet, because I want to understand the ref version (which supposedly should be easier to understand). Is the chunk_gated_delta_rule_ref function correctly implementing what it should do? If it is, can you enlighten me on what is the purpose of the two blocks of code to compute attn?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions