Skip to content

Commit db92271

Browse files
committed
Revert create_random_index change for now (fails unit tests)
1 parent ff63ba5 commit db92271

1 file changed

Lines changed: 7 additions & 23 deletions

File tree

keys_values/kvcache/gradient/annotation.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -172,29 +172,13 @@ def create_random_index(
172172
if dtype is None:
173173
dtype = torch.int64
174174
num = min(shape[2], length)
175-
# Batched random permutation: draw uniform random values of shape
176-
# (batch, n_heads, length), argsort along the last dim to produce
177-
# independent permutations per (b, h), then slice to `num`.
178-
# This replaces a Python `for b, h: torch.randperm(length)` loop — on CUDA
179-
# each randperm launches ~7 small kernels, so the loop produced
180-
# batch * n_heads * 7 kernel launches; this path produces a handful.
181-
if device.type == "cuda":
182-
rand_vals = torch.rand(
183-
shape[0], shape[1], length, device=device, dtype=torch.float32
184-
)
185-
# argsort returns int64; cast to the requested dtype below
186-
perms = rand_vals.argsort(dim=-1)
187-
if num < length:
188-
perms = perms[..., :num]
189-
result = perms.to(dtype=dtype)
190-
else:
191-
# CPU fallback: keep the original loop — randperm on CPU is a single
192-
# call and this path isn't perf-critical anyway.
193-
index_kwargs = dict(dtype=dtype, device=device)
194-
result = torch.empty(shape[:-1], **index_kwargs)
195-
for b in range(shape[0]):
196-
for h in range(shape[1]):
197-
result[b, h, :] = torch.randperm(length, **index_kwargs)[:num]
175+
# Keep the original loop — randperm on CPU is a single
176+
# call and this path isn't perf-critical anyway.
177+
index_kwargs = dict(dtype=dtype, device=device)
178+
result = torch.empty(shape[:-1], **index_kwargs)
179+
for b in range(shape[0]):
180+
for h in range(shape[1]):
181+
result[b, h, :] = torch.randperm(length, **index_kwargs)[:num]
198182
return expand_index(result, shape[-1])
199183

200184

0 commit comments

Comments
 (0)