Skip to content

Commit 13b0937

Browse files
averyhNVclaude
andcommitted
fmt: trailing newlines on sglang trace fixtures
Pre-commit's end-of-file-fixer adds a trailing newline these 4 JSONs were missing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 69ebf27 commit 13b0937

4 files changed

Lines changed: 4 additions & 4 deletions

File tree

tests/trace/fi_trace_out_sglang/fused_add_rmsnorm_h3072.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@
5656
}
5757
},
5858
"reference": "@torch.no_grad()\ndef _fused_add_rmsnorm_reference(hidden_states, residual, weight):\n \"\"\"Fused Add + RMSNorm. Epsilon is fixed at 1e-6.\"\"\"\n EPS = 1e-6\n x = hidden_states.to(torch.float32) + residual.to(torch.float32)\n inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + EPS)\n y = (x * inv_rms) * weight.to(torch.float32)\n return y.to(hidden_states.dtype)\n"
59-
}
59+
}

tests/trace/fi_trace_out_sglang/gqa_paged_decode_h24_kv128_d128_ps8.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,4 @@
113113
}
114114
},
115115
"reference": "@torch.no_grad()\ndef _gqa_paged_decode_reference(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):\n batch_size, num_qo_heads, head_dim = q.shape\n _, page_size, num_kv_heads, _ = k_cache.shape\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=q.device\n )\n lse = torch.full(\n (batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=q.device\n )\n\n gqa_ratio = num_qo_heads // num_kv_heads\n k_cache_f32 = k_cache.to(torch.float32)\n v_cache_f32 = v_cache.to(torch.float32)\n\n for b in range(batch_size):\n page_start = int(kv_indptr[b].item())\n page_end = int(kv_indptr[b + 1].item())\n if page_start >= page_end:\n output[b].zero_()\n continue\n # kv_indices are page IDs. Gather pages first, then flatten the\n # [num_selected_pages, page_size] axis into a single token axis.\n page_ids = kv_indices[page_start:page_end].to(torch.long)\n k_b = k_cache_f32[page_ids].reshape(-1, num_kv_heads, head_dim)\n v_b = v_cache_f32[page_ids].reshape(-1, num_kv_heads, head_dim)\n q_b = q[b].to(torch.float32) # [num_qo_heads, head_dim]\n for h in range(num_qo_heads):\n kv_h = h // gqa_ratio\n logits = torch.matmul(q_b[h], k_b[:, kv_h].T) * sm_scale\n lse[b, h] = torch.logsumexp(logits, dim=-1) / math.log(2.0)\n attn = torch.softmax(logits, dim=-1)\n output[b, h] = torch.matmul(attn, v_b[:, kv_h]).to(torch.bfloat16)\n\n return output, lse\n"
116-
}
116+
}

tests/trace/fi_trace_out_sglang/gqa_paged_prefill_h24_kv128_d128_ps8.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,4 @@
121121
}
122122
},
123123
"reference": "@torch.no_grad()\ndef _gqa_paged_prefill_reference(\n q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale\n):\n total_q, num_qo_heads, head_dim = q.shape\n num_pages, page_size, num_kv_heads, _ = k_cache.shape\n len_indptr = qo_indptr.shape[0]\n\n output = torch.zeros(\n (total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=q.device\n )\n lse = torch.full(\n (total_q, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=q.device\n )\n\n gqa_ratio = num_qo_heads // num_kv_heads\n q_f32 = q.to(torch.float32)\n k_cache_f32 = k_cache.to(torch.float32)\n v_cache_f32 = v_cache.to(torch.float32)\n\n for b in range(len_indptr - 1):\n q_start = int(qo_indptr[b].item())\n q_end = int(qo_indptr[b + 1].item())\n kv_start = int(kv_indptr[b].item())\n kv_end = int(kv_indptr[b + 1].item())\n if q_start >= q_end or kv_start >= kv_end:\n continue\n # kv_indices are page IDs. Gather pages and flatten to a token axis.\n page_ids = kv_indices[kv_start:kv_end].to(torch.long)\n k_b = k_cache_f32[page_ids].reshape(-1, num_kv_heads, head_dim)\n v_b = v_cache_f32[page_ids].reshape(-1, num_kv_heads, head_dim)\n num_kv_tokens = k_b.shape[0]\n q_b = q_f32[q_start:q_end]\n delta = num_kv_tokens - q_b.shape[0]\n for q_idx in range(q_b.shape[0]):\n max_kv = min(q_idx + 1 + delta, num_kv_tokens)\n if max_kv <= 0:\n continue\n global_q = q_start + q_idx\n for h in range(num_qo_heads):\n kv_h = h // gqa_ratio\n logits = torch.matmul(q_b[q_idx, h], k_b[:max_kv, kv_h].T) * sm_scale\n lse[global_q, h] = torch.logsumexp(logits, dim=-1) / math.log(2.0)\n attn = torch.softmax(logits, dim=-1)\n output[global_q, h] = torch.matmul(attn, v_b[:max_kv, kv_h]).to(\n torch.bfloat16\n )\n\n return output, lse\n"
124-
}
124+
}

tests/trace/fi_trace_out_sglang/rmsnorm_h3072.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@
4040
}
4141
},
4242
"reference": "@torch.no_grad()\ndef _rmsnorm_reference(hidden_states, weight):\n \"\"\"Root Mean Square Normalization. Epsilon is fixed at 1e-6.\"\"\"\n EPS = 1e-6\n x = hidden_states.to(torch.float32)\n inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + EPS)\n y = (x * inv_rms) * weight.to(torch.float32)\n return y.to(hidden_states.dtype)\n"
43-
}
43+
}

0 commit comments

Comments
 (0)