Skip to content

Commit 6b79ac3

Browse files
committed
Fix my docstring
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent f9f631a commit 6b79ac3

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

tpu_inference/layers/vllm/ops/gdn_attention.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _l2_normalize(x: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
4444
Args:
4545
x: input to normalize
4646
eps: epsilon for numerical stability
47-
47+
4848
Returns:
4949
normalized x
5050
"""
@@ -67,13 +67,16 @@ def _chunk_gated_delta_rule(
6767
By chunking here, we can effectively transform the purely sequential
6868
RNN recurrence into a block-parallel operation. It processes tokens in chunks
6969
and then only passes the recurrent state sequentially between chunks.
70-
70+
7171
One detail worth pointing out is that the continuous decay mask (`g`) is
72-
cumulative, so
73-
Applying the triangular mask *before* exponentiation is key here to prevent
74-
NaNs
75-
when dealing with large sequence lengths.
76-
72+
cumulative, so the upper triangle of the pairwise differences (`g[i] - g[j]`
73+
for `i < j`) can overflow. Thus, we apply the the triangular mask
74+
NaNs exponentiation to prevent NaNs when dealing with longer seq lens.
75+
76+
Args:
77+
query: (B, H, T, d_k) — already L2-normed
78+
key: (B, H, T, d_k) — already L2-normed
79+
7780
Args:
7881
query: (B, H, T, d_k) — already L2-normed
7982
key: (B, H, T, d_k) — already L2-normed
@@ -82,7 +85,7 @@ def _chunk_gated_delta_rule(
8285
beta: (B, H, T) — input gate (after sigmoid)
8386
chunk_size: chunk processing size
8487
initial_state: (B, H, d_k, d_v) or None
85-
88+
8689
Returns:
8790
output: (B, H, T, d_v)
8891
final_state: (B, H, d_k, d_v) or None
@@ -191,7 +194,7 @@ def _recurrent_gated_delta_rule_step(
191194
g: (B, H, T)
192195
beta: (B, H, T)
193196
state: (B, H, d_k, d_v)
194-
197+
195198
Returns:
196199
output: (B, H, T, d_v)
197200
new_state: (B, H, d_k, d_v)
@@ -233,7 +236,7 @@ def ragged_conv1d(
233236
state_indices: Tensor of shape `(max_reqs,)` mapping request index to state
234237
index.
235238
kernel_size: The size of the convolution kernel.
236-
239+
237240
Returns:
238241
A tuple containing:
239242
- output: The output tensor of shape `(num_tokens, dim)`.
@@ -475,7 +478,7 @@ def run_jax_gdn_attention_local(
475478
d_k: Dimension of key.
476479
d_v: Dimension of value.
477480
kernel_size: Convolution kernel size.
478-
481+
479482
Returns:
480483
A tuple containing the new states and the output.
481484
- A tuple of (new_conv_state, new_recurrent_state).

0 commit comments

Comments
 (0)