@@ -44,7 +44,7 @@ def _l2_normalize(x: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
4444 Args:
4545 x: input to normalize
4646 eps: epsilon for numerical stability
47-
47+
4848 Returns:
4949 normalized x
5050 """
@@ -67,13 +67,16 @@ def _chunk_gated_delta_rule(
6767 By chunking here, we can effectively transform the purely sequential
6868 RNN recurrence into a block-parallel operation. It processes tokens in chunks
6969 and then only passes the recurrent state sequentially between chunks.
70-
70+
7171 One detail worth pointing out is that the continuous decay mask (`g`) is
72- cumulative, so
73- Applying the triangular mask *before* exponentiation is key here to prevent
74- NaNs
75- when dealing with large sequence lengths.
76-
72+ cumulative, so the upper triangle of the pairwise differences (`g[i] - g[j]`
73+ for `i < j`) can overflow. Thus, we apply the the triangular mask
74+ NaNs exponentiation to prevent NaNs when dealing with longer seq lens.
75+
76+ Args:
77+ query: (B, H, T, d_k) — already L2-normed
78+ key: (B, H, T, d_k) — already L2-normed
79+
7780 Args:
7881 query: (B, H, T, d_k) — already L2-normed
7982 key: (B, H, T, d_k) — already L2-normed
@@ -82,7 +85,7 @@ def _chunk_gated_delta_rule(
8285 beta: (B, H, T) — input gate (after sigmoid)
8386 chunk_size: chunk processing size
8487 initial_state: (B, H, d_k, d_v) or None
85-
88+
8689 Returns:
8790 output: (B, H, T, d_v)
8891 final_state: (B, H, d_k, d_v) or None
@@ -191,7 +194,7 @@ def _recurrent_gated_delta_rule_step(
191194 g: (B, H, T)
192195 beta: (B, H, T)
193196 state: (B, H, d_k, d_v)
194-
197+
195198 Returns:
196199 output: (B, H, T, d_v)
197200 new_state: (B, H, d_k, d_v)
@@ -233,7 +236,7 @@ def ragged_conv1d(
233236 state_indices: Tensor of shape `(max_reqs,)` mapping request index to state
234237 index.
235238 kernel_size: The size of the convolution kernel.
236-
239+
237240 Returns:
238241 A tuple containing:
239242 - output: The output tensor of shape `(num_tokens, dim)`.
@@ -475,7 +478,7 @@ def run_jax_gdn_attention_local(
475478 d_k: Dimension of key.
476479 d_v: Dimension of value.
477480 kernel_size: Convolution kernel size.
478-
481+
479482 Returns:
480483 A tuple containing the new states and the output.
481484 - A tuple of (new_conv_state, new_recurrent_state).
0 commit comments