-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Hi, thanks for the interesting paper and code release. I am looking through the implementation trying to understand the code. In https://github.com/NVlabs/GatedDeltaNet/blob/main/lit_gpt/gated_delta_rule_ops/chunk.py#L662-L666, we have:
attn = -((k_beta @ k.transpose(-1, -2)) * L_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i].clone() + (attn[..., i, :i, None].clone() * attn[..., :i, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
attn = attn
k_cumsum = attn @ v
and yet in the immediately following lines we are seeing attn getting overwritten to another value that doesn't have the L_mask term:
attn = -((k_beta @ k.transpose(-1, -2))).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i].clone() + (attn[..., i, :i, None].clone() * attn[..., :i, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
attn = attn
w = k_cumdecay = attn @ k_beta
u = v = k_cumsum
w and u is taking the value calculated without the L_mask no? I haven't look at the kernel implementation yet, because I want to understand the ref version (which supposedly should be easier to understand). Is the chunk_gated_delta_rule_ref function correctly implementing what it should do? If it is, can you enlighten me on what is the purpose of the two blocks of code to compute attn?
Thank you!