Skip to content

Commit 9572b91

Browse files
committed
checkc cache after populating pinned_tensors
1 parent 8ec5ba9 commit 9572b91

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/kvboost/streaming/awq_loader.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,17 +312,29 @@ def pin_layer(
312312
self,
313313
layer_idx: int,
314314
) -> dict[str, torch.Tensor]:
315+
"""Return a layer's streamed tensors as pinned-host tensors.
316+
317+
Cached: on the second call for the same layer we hand back the
318+
already-pinned tensors and skip the safetensors read entirely. The
319+
streaming scheduler calls this once per layer per token, so missing
320+
the cache pays full disk I/O on every decode step — that's a 3+
321+
second per-token regression on a 32B model.
322+
"""
315323

316324
assert self.index is not None
317325

318326
layer = self.index.layers[layer_idx]
327+
needed = [s for s in layer.tensors.values() if not s.is_resident]
319328

320-
grouped: dict[Path, list[TensorSpec]] = {}
321-
322-
for spec in layer.tensors.values():
323-
if spec.is_resident:
324-
continue
329+
# Cache hit: every needed tensor is already in _pinned_tensors.
330+
# Strict all-or-nothing — partial hits force a re-read so we don't
331+
# mix tensors from different load passes (defensive; in practice
332+
# we either pinned the whole layer or none of it).
333+
if needed and all(s.name in self._pinned_tensors for s in needed):
334+
return {s.name: self._pinned_tensors[s.name] for s in needed}
325335

336+
grouped: dict[Path, list[TensorSpec]] = {}
337+
for spec in needed:
326338
grouped.setdefault(spec.path, []).append(spec)
327339

328340
out: dict[str, torch.Tensor] = {}

0 commit comments

Comments
 (0)