@@ -172,29 +172,13 @@ def create_random_index(
172172 if dtype is None :
173173 dtype = torch .int64
174174 num = min (shape [2 ], length )
175- # Batched random permutation: draw uniform random values of shape
176- # (batch, n_heads, length), argsort along the last dim to produce
177- # independent permutations per (b, h), then slice to `num`.
178- # This replaces a Python `for b, h: torch.randperm(length)` loop — on CUDA
179- # each randperm launches ~7 small kernels, so the loop produced
180- # batch * n_heads * 7 kernel launches; this path produces a handful.
181- if device .type == "cuda" :
182- rand_vals = torch .rand (
183- shape [0 ], shape [1 ], length , device = device , dtype = torch .float32
184- )
185- # argsort returns int64; cast to the requested dtype below
186- perms = rand_vals .argsort (dim = - 1 )
187- if num < length :
188- perms = perms [..., :num ]
189- result = perms .to (dtype = dtype )
190- else :
191- # CPU fallback: keep the original loop — randperm on CPU is a single
192- # call and this path isn't perf-critical anyway.
193- index_kwargs = dict (dtype = dtype , device = device )
194- result = torch .empty (shape [:- 1 ], ** index_kwargs )
195- for b in range (shape [0 ]):
196- for h in range (shape [1 ]):
197- result [b , h , :] = torch .randperm (length , ** index_kwargs )[:num ]
175+ # Keep the original loop — randperm on CPU is a single
176+ # call and this path isn't perf-critical anyway.
177+ index_kwargs = dict (dtype = dtype , device = device )
178+ result = torch .empty (shape [:- 1 ], ** index_kwargs )
179+ for b in range (shape [0 ]):
180+ for h in range (shape [1 ]):
181+ result [b , h , :] = torch .randperm (length , ** index_kwargs )[:num ]
198182 return expand_index (result , shape [- 1 ])
199183
200184
0 commit comments