Skip to content

Question on chunkwise Gated DeltaNet notation: should M include decay (Gamma)? #15

@WKQ9411

Description

@WKQ9411

Hi, thank you for the great work.

While reading the paper, I noticed a potential notation inconsistency in the final chunkwise formula for Gated DeltaNet:

$$ \mathbf{O}_{[t]} = \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t]}^\top + \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \left( \widetilde{\mathbf{U}_{[t]}} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t]}^\top \right) $$

Image

Earlier in the paper (e.g., the Mamba2 linear attention section), $\mathbf{M}$ is defined as a binary causal mask. For the decay setting, the chunkwise formulation introduces $\mathbf{\Gamma}$. In my derivation, I believe that the second term of the formula should include $\mathbf{\Gamma}$ :

$$ \begin{equation} \mathbf{O}_{[t]} = \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t]}^\top + \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{\Gamma} \right) \left( \widetilde{\mathbf{U}_{[t]}} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t]}^\top \right) \end{equation} $$

A detailed derivation is as follow:

for the $r$-th token of the chunk, we have:

$$ \begin{equation} \begin{aligned} \mathbf{S}^r_{[t]} &= \mathbf{S}_{[t]} \mathbf{F}_{[t]}^r + \mathbf{G}_{[t]}^r \\ &= \mathbf{S}_{[t]} \gamma_{[t]}^r \mathbf{P}_{[t]}^r + \mathbf{G}_{[t]}^r \\ & = \mathbf{S}_{[t]} \gamma_{[t]}^r \left( \mathbf{I} - \sum_{i=1}^{r} w^i_{[t]} k^{i\top}_{[t]} \right) + \sum_{i=1}^{r} \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} \tilde{u}^i_{[t]} k^{i\top}_{[t]} \\ &= \mathbf{S}_{[t]} \gamma_{[t]}^r + \left( \sum_{i=1}^{r} \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} \tilde{u}^i_{[t]} k^{i\top}_{[t]} - \mathbf{S}_{[t]} \sum_{i=1}^{r} \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} \gamma^i_{[t]} w^i_{[t]} k^{i\top}_{[t]} \right) \\ &= \mathbf{S}_{[t]} \gamma_{[t]}^r + \sum_{i=1}^{r} \left( \tilde{u}^i_{[t]} - \mathbf{S}_{[t]}\gamma^i_{[t]} w^i_{[t]} \right) \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} k^{i\top}_{[t]} \\ &= \mathbf{S}_{[t]} \gamma_{[t]}^r + \sum_{i=1}^{r} \left( \tilde{u}^i_{[t]} - \mathbf{S}_{[t]} \overleftarrow{w^i_{[t]}} \right) \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} k^{i\top}_{[t]} \end{aligned} \end{equation} $$

the output $o^r_{[t]}$ is:

$$ \begin{equation} \begin{aligned} o^r_{[t]} &= \mathbf{S}^r_{[t]} q^r_{[t]} \in \mathbb{R}^{d_v} \\ &= \mathbf{S}_{[t]} \gamma_{[t]}^r q^r_{[t]} + \sum_{i=1}^{r} \left( \tilde{u}^i_{[t]} - \mathbf{S}_{[t]} \overleftarrow{w^i_{[t]}} \right) \frac{\gamma^r_{[t]}}{\gamma^i_{[t]}} k^{i\top}_{[t]} q^r_{[t]} \end{aligned} \end{equation} $$

$q^r_{[t]}$ only focuses on $k^{i\top}_{[t]}$ and equivalent values before $r$. Therefore, by stacking all $r=1,2,\ldots,C$, the output of the chunk can be obtained as:

$$ \begin{equation} \mathbf{O}_{[t]} = \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t]}^\top + \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{\Gamma} \right) \left( \widetilde{\mathbf{U}_{[t]}} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t]}^\top \right) \in \mathbb{R}^{C \times d_v} \end{equation} $$

Thanks again, and looking forward to your clarification.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions